monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::Bind;
56use hyperactor::Handler;
57use hyperactor::RemoteSpawn;
58use hyperactor::Unbind;
59use hyperactor::actor::ActorHandle;
60use hyperactor::context;
61use hyperactor::reference;
62use hyperactor_config::Flattrs;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys_cuda::nccl::ReduceOp;
93use torch_sys_cuda::nccl::UniqueId;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104    device_mesh_ref: Ref,
105    dims: SortedVec<String>,
106    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111        Self {
112            device_mesh_ref,
113            dims,
114            comms: HashMap::new(),
115        }
116    }
117}
118
119#[derive(Debug)]
120enum Recording {
121    // In the process of receiving DefineRecording messages for this
122    // recording.
123    PartialRecording {
124        // The index of the last DefineRecording message received.
125        last_index: usize,
126        // The list of commands seen so far for this recording.
127        commands: Vec<WorkerMessage>,
128    },
129
130    // The recording is ready to be run.
131    CompleteRecording {
132        // The list of streams on which this recording is defined.
133        streams: HashSet<StreamRef>,
134    },
135}
136
137/// A PyTorch runtime instance, operating on a single accelerator device,
138/// controlled via hyperactor messaging.
139///
140/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
141/// real work.
142///
143/// See [`WorkerMessage`] for what it can do!
144#[derive(Debug)]
145#[hyperactor::export(
146    spawn = true,
147    handlers = [
148        WorkerMessage {cast = true},
149        AssignRankMessage {cast = true},
150    ],
151)]
152pub struct WorkerActor {
153    device: Option<CudaDevice>,
154    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155    /// Maps streams to the device mesh and a map of dim names to the concrete
156    /// communicator actor that represents the dimension for that stream.
157    device_meshes: HashMap<
158        Ref,
159        (
160            DeviceMesh,
161            // A map for comms for this mesh for a given pair of stream and dims.
162            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163        ),
164    >,
165    world_size: usize,
166    rank: usize,
167    borrows: HashMap<u64, Borrow>,
168    comm: Option<ActorHandle<NcclCommActor>>,
169    controller_actor: reference::ActorRef<ControllerActor>,
170    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
171    /// subsequent `CallFunction` calls, as this is where the actual allocation
172    /// will happen.
173    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174    /// The comm actor for each pair of streams that need to send/recv tensors.
175    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176    recordings: HashMap<Ref, Recording>,
177    defining_recording: Option<Ref>,
178    respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
183        self.streams
184            .get(&stream)
185            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
186    }
187
188    async fn maybe_add_stream_to_recording(
189        &mut self,
190        cx: &impl context::Actor,
191        stream: StreamRef,
192    ) -> Result<()> {
193        // If we're defining a recording, add the stream to the list of streams that
194        // this recording uses, and call define_recording on the stream.
195        if let Some(defining_recording) = self.defining_recording {
196            let recording = self.recordings.get_mut(&defining_recording).unwrap();
197            let fut = match recording {
198                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
199                Recording::CompleteRecording { streams } => {
200                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
201                        Ok(self
202                            .try_get_stream(stream)?
203                            .define_recording(cx, defining_recording))
204                    })
205                }
206            }
207            .transpose()?;
208            match fut {
209                Some(fut) => fut.await,
210                None => Ok(()),
211            }
212        } else {
213            Ok(())
214        }
215    }
216}
217
218impl Actor for WorkerActor {}
219
220#[async_trait]
221impl RemoteSpawn for WorkerActor {
222    type Params = WorkerParams;
223
224    async fn new(
225        WorkerParams {
226            world_size,
227            rank,
228            device_index,
229            controller_actor,
230        }: Self::Params,
231        _environment: Flattrs,
232    ) -> Result<Self> {
233        Python::attach(|py| {
234            py.import("monarch.safe_torch").unwrap();
235        });
236        Ok(Self {
237            device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
238            streams: HashMap::new(),
239            device_meshes: HashMap::new(),
240            world_size,
241            rank,
242            borrows: HashMap::new(),
243            comm: None,
244            controller_actor,
245            remote_process_groups: HashMap::new(),
246            send_recv_comms: HashMap::new(),
247            recordings: HashMap::new(),
248            defining_recording: None,
249            respond_with_python_message: false,
250        })
251    }
252
253    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
254}
255
256#[async_trait]
257impl Handler<AssignRankMessage> for WorkerActor {
258    async fn handle(
259        &mut self,
260        cx: &hyperactor::Context<Self>,
261        _: AssignRankMessage,
262    ) -> anyhow::Result<()> {
263        let point = cx.cast_point();
264        self.rank = point.rank();
265        self.respond_with_python_message = true;
266        Python::attach(|py| {
267            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
268            let p: PyPoint = point.into();
269            mesh_controller
270                .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
271                .unwrap();
272        });
273        Ok(())
274    }
275}
276
277/// Worker messages. These define the observable behavior of the worker, so the
278/// documentations here
279#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
280pub enum AssignRankMessage {
281    AssignRank(),
282}
283wirevalue::register_type!(AssignRankMessage);
284
285#[async_trait]
286impl Handler<WorkerMessage> for WorkerActor {
287    async fn handle(
288        &mut self,
289        cx: &hyperactor::Context<Self>,
290        message: WorkerMessage,
291    ) -> anyhow::Result<()> {
292        <Self as WorkerMessageHandler>::handle(self, cx, message).await
293    }
294}
295
296#[async_trait]
297impl WorkerMessageHandler for WorkerActor {
298    async fn backend_network_init(
299        &mut self,
300        cx: &hyperactor::Context<Self>,
301        unique_id: UniqueId,
302    ) -> Result<()> {
303        let device = self
304            .device
305            .expect("tried to init backend network on a non-CUDA worker");
306        let comm = NcclCommActor::new(CommParams::New {
307            device,
308            unique_id,
309            world_size: self.world_size.try_into().unwrap(),
310            rank: self.rank.try_into().unwrap(),
311        })
312        .await?
313        .spawn(cx)?;
314
315        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
316        let cell = TensorCell::new(tensor);
317
318        comm.all_reduce(
319            cx,
320            cell,
321            ReduceOp::Sum,
322            torch_sys_cuda::cuda::Stream::get_current_stream(),
323        )
324        .await?;
325
326        // TODO: this blocks forward progress of the the actor loop while we
327        // wait for the streams to catch up. Once we have a way of spawning
328        // tasks that the actor system can monitor in a non-blocking way, we
329        // should remove this.
330
331        // We need to be careful to initialize the streams in a consistent order
332        // across all workers to avoid NCCL deadlocks. Use the refs to provide
333        // this order, as a stream's ref will be the same across all workers.
334        let sorted_streams = self
335            .streams
336            .iter()
337            .sorted_by_key(|(k, _)| *k)
338            .map(|(_, v)| v.as_ref());
339
340        let mut splits = Vec::new();
341        for _ in 0..sorted_streams.len() {
342            // Do the split in this event loop, to provide a deterministic
343            // order.
344            splits.push(comm.split_all(cx).await?);
345        }
346        let _: Vec<()> = try_join_all(
347            sorted_streams
348                .into_iter()
349                .zip(splits.into_iter())
350                .map(|(stream, split)| stream.init_comm(cx, split)),
351        )
352        .await?;
353
354        self.comm = Some(comm);
355
356        Ok(())
357    }
358
359    async fn backend_network_point_to_point_init(
360        &mut self,
361        cx: &hyperactor::Context<Self>,
362        from_stream: StreamRef,
363        to_stream: StreamRef,
364    ) -> Result<()> {
365        if !self.streams.contains_key(&from_stream) {
366            bail!("invalid from_stream id: {:#?}", from_stream);
367        }
368        if !self.streams.contains_key(&to_stream) {
369            bail!("invalid to_stream id: {:#?}", to_stream);
370        }
371        let global_comm = self
372            .comm
373            .as_ref()
374            .context("tried to call Reduce before BackendNetworkInit")?;
375        let comm = global_comm.split_all(cx).await?;
376        self.send_recv_comms
377            .insert((from_stream, to_stream), Arc::new(comm));
378        Ok(())
379    }
380
381    async fn call_function(
382        &mut self,
383        cx: &hyperactor::Context<Self>,
384        params: CallFunctionParams,
385    ) -> Result<()> {
386        let stream = self.try_get_stream(params.stream)?.clone();
387        self.maybe_add_stream_to_recording(cx, params.stream)
388            .await?;
389
390        let device_meshes = self
391            .device_meshes
392            .iter()
393            .map(|(k, v)| (k.clone(), v.0.clone()))
394            .collect();
395
396        let mut remote_process_groups = HashMap::new();
397        for remote_process_group_ref in &params.remote_process_groups {
398            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
399                let dims_vec = state.dims.iter().cloned().collect();
400                let (device_mesh, _) = self
401                    .device_meshes
402                    .get(&state.device_mesh_ref)
403                    .ok_or_else(|| {
404                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
405                    })?
406                    .clone();
407                let comm = state.comms
408                    .get(&params.stream)
409                    .ok_or_else(|| {
410                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
411                    })?
412                    .clone();
413                remote_process_groups.insert(
414                    remote_process_group_ref.clone(),
415                    (device_mesh, dims_vec, comm),
416                );
417            }
418        }
419
420        stream
421            .call_function(cx, params, device_meshes, remote_process_groups)
422            .await?;
423
424        Ok(())
425    }
426
427    async fn command_group(
428        &mut self,
429        cx: &hyperactor::Context<Self>,
430        params: Vec<WorkerMessage>,
431    ) -> Result<()> {
432        for msg in params {
433            WorkerMessageHandler::handle(self, cx, msg).await?;
434        }
435        Ok(())
436    }
437
438    async fn create_stream(
439        &mut self,
440        cx: &hyperactor::Context<Self>,
441        result: StreamRef,
442        creation_mode: StreamCreationMode,
443    ) -> Result<()> {
444        let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
445            world_size: self.world_size,
446            rank: self.rank,
447            creation_mode,
448            id: result,
449            device: self.device,
450            controller_actor: self.controller_actor.clone(),
451            respond_with_python_message: self.respond_with_python_message,
452        })
453        .spawn(cx)?;
454        self.streams.insert(result, Arc::new(handle));
455        Ok(())
456    }
457
458    async fn create_device_mesh(
459        &mut self,
460        _cx: &hyperactor::Context<Self>,
461        result: Ref,
462        names: Vec<String>,
463        ranks: Slice,
464    ) -> Result<()> {
465        self.device_meshes.insert(
466            result,
467            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
468        );
469        Ok(())
470    }
471
472    async fn create_remote_process_group(
473        &mut self,
474        _cx: &hyperactor::Context<Self>,
475        result: Ref,
476        device_mesh: Ref,
477        dims: Vec<String>,
478    ) -> Result<()> {
479        self.device_meshes
480            .get(&device_mesh)
481            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
482        match self.remote_process_groups.entry(result) {
483            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
484                device_mesh,
485                SortedVec::from_unsorted(dims),
486            )),
487            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
488        };
489        Ok(())
490    }
491
492    async fn borrow_create(
493        &mut self,
494        cx: &hyperactor::Context<Self>,
495        result: Ref,
496        borrow_id: u64,
497        tensor_ref: Ref,
498        from_stream: StreamRef,
499        to_stream: StreamRef,
500    ) -> Result<()> {
501        self.maybe_add_stream_to_recording(cx, from_stream).await?;
502        self.maybe_add_stream_to_recording(cx, to_stream).await?;
503        let from_stream = self.try_get_stream(from_stream)?.clone();
504        let to_stream = self.try_get_stream(to_stream)?.clone();
505
506        let borrow =
507            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
508        self.borrows.insert(borrow_id, borrow);
509        Ok(())
510    }
511
512    async fn borrow_first_use(
513        &mut self,
514        cx: &hyperactor::Context<Self>,
515        borrow: u64,
516    ) -> Result<()> {
517        let borrow = self
518            .borrows
519            .get_mut(&borrow)
520            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
521
522        borrow.first_use(cx).await?;
523        Ok(())
524    }
525
526    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
527        let borrow = self
528            .borrows
529            .get_mut(&borrow)
530            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
531
532        borrow.last_use(cx).await?;
533        Ok(())
534    }
535
536    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
537        let borrow = self
538            .borrows
539            .get_mut(&borrow_id)
540            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
541
542        borrow.drop(cx).await?;
543        self.borrows.remove(&borrow_id);
544        Ok(())
545    }
546
547    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
548        // Fan the delete message to all streams.
549        // Check for errors.
550        // TODO: this blocks forward progress of the the actor loop while we
551        // wait for the streams to catch up. Once we have a way of spawning
552        // tasks that the actor system can monitor in a non-blocking way, we
553        // should remove this.
554        let _: Vec<()> = try_join_all(
555            self.streams
556                .values()
557                .map(|s| s.delete_refs(cx, refs.clone())),
558        )
559        .await?;
560        Ok(())
561    }
562
563    async fn request_status(
564        &mut self,
565        cx: &hyperactor::Context<Self>,
566        seq: Seq,
567        controller: bool,
568    ) -> Result<()> {
569        // TODO: this blocks forward progress of the the actor loop while we
570        // wait for the streams to catch up. Once we have a way of spawning
571        // tasks that the actor system can monitor in a non-blocking way, we
572        // should remove this.
573        let _: Vec<()> = try_join_all(
574            self.streams
575                .values()
576                .map(|stream| stream.request_status(cx)),
577        )
578        .await?;
579
580        ControllerMessageClient::status(
581            &self.controller_actor,
582            cx,
583            seq.next(),
584            cx.self_id().clone(),
585            controller,
586        )
587        .await?;
588        Ok(())
589    }
590
591    async fn reduce(
592        &mut self,
593        cx: &hyperactor::Context<Self>,
594        result: Ref,
595        local_tensor: Ref,
596        factory: Factory,
597        source_mesh: Ref,
598        stream_ref: StreamRef,
599        dims: Vec<String>,
600        reduction: Reduction,
601        scatter: bool,
602        in_place: bool,
603        out: Option<Ref>,
604    ) -> Result<()> {
605        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
606
607        // Sort for stable indexing.
608        let dims = SortedVec::from_unsorted(dims);
609        let stream = self.try_get_stream(stream_ref)?.clone();
610
611        let (_, comm_map) = self
612            .device_meshes
613            .get_mut(&source_mesh)
614            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
615
616        let (size, comm) = comm_map
617            .get(&(stream_ref, dims.clone()))
618            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
619            .clone();
620
621        stream
622            .reduce(
623                cx,
624                comm,
625                size.try_into()?,
626                result,
627                local_tensor,
628                factory,
629                reduction,
630                scatter,
631                in_place,
632                out,
633            )
634            .await?;
635
636        Ok(())
637    }
638
639    async fn send_tensor(
640        &mut self,
641        cx: &hyperactor::Context<Self>,
642        result: Ref,
643        from_ranks: Slice,
644        to_ranks: Slice,
645        tensor: Ref,
646        factory: Factory,
647        from_stream: StreamRef,
648        to_stream: StreamRef,
649    ) -> Result<()> {
650        let comm = self
651            .send_recv_comms
652            .get(&(from_stream, to_stream))
653            .ok_or_else(|| {
654                anyhow::anyhow!(
655                    "could not find stream to stream comm for: {:#?}",
656                    (from_stream, to_stream)
657                )
658            })?
659            .clone();
660
661        let to_rank = from_ranks
662            .index(self.rank)
663            .map(|index| to_ranks.get(index).ok())
664            .ok()
665            .flatten();
666        let from_rank = to_ranks
667            .index(self.rank)
668            .map(|index| from_ranks.get(index).ok())
669            .ok()
670            .flatten();
671
672        let (stream, stream_ref) = if to_rank.is_none() {
673            (self.try_get_stream(to_stream)?.clone(), to_stream)
674        } else if from_rank.is_none() || from_stream == to_stream {
675            (self.try_get_stream(from_stream)?.clone(), from_stream)
676        } else {
677            unimplemented!(
678                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
679                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
680                Then the send stream would do the nccl op, and then sync with sending stream again."
681            );
682        };
683
684        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
685
686        stream
687            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
688            .await?;
689
690        Ok(())
691    }
692
693    async fn exit(
694        &mut self,
695        cx: &hyperactor::Context<Self>,
696        error: Option<(Option<reference::ActorId>, String)>,
697    ) -> Result<()> {
698        for (_, stream) in self.streams.drain() {
699            stream.drain_and_stop("tensor worker exit cleanup")?;
700            Arc::into_inner(stream)
701                .expect("there should be no owners of this stream handle except the worker stream table")
702                .await;
703        }
704
705        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
706            .ok()
707            .and_then(|val| val.parse::<i32>().ok())
708            .unwrap_or(1);
709        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
710            .ok()
711            .and_then(|val| val.parse::<i32>().ok())
712            .unwrap_or(1);
713
714        // Exit the worker process if there is an error.
715        let exit_code = match error {
716            Some((Some(actor_id), reason)) => {
717                tracing::error!(
718                    "stopping the worker, actor {} failed with error: {}",
719                    actor_id,
720                    reason
721                );
722                if *cx.self_id() == actor_id {
723                    self_error_exit_code
724                } else {
725                    peer_error_exit_code
726                }
727            }
728            Some((None, reason)) => {
729                tracing::error!("stopping the worker, reason: {}", reason);
730                1
731            }
732            None => 0,
733        };
734
735        if exit_code != 0 {
736            tracing::info!("stopping the worker process, exit code: {}", exit_code);
737            std::process::exit(exit_code);
738        }
739        cx.stop("tensor worker exit")?;
740        Ok(())
741    }
742
743    async fn send_value(
744        &mut self,
745        cx: &hyperactor::Context<Self>,
746        seq: Seq,
747        destination: Option<Ref>,
748        mutates: Vec<Ref>,
749        function: Option<ResolvableFunction>,
750        args_kwargs: ArgsKwargs,
751        stream: StreamRef,
752    ) -> Result<()> {
753        // Resolve the stream.
754        let stream = self.try_get_stream(stream)?;
755
756        let device_meshes = if function.is_none() {
757            HashMap::new()
758        } else {
759            self.device_meshes
760                .iter()
761                .map(|(k, v)| (k.clone(), v.0.clone()))
762                .collect()
763        };
764
765        if destination.is_some() {
766            panic!("send_value with pipe destination is no longer implemented")
767        }
768
769        // Resolve the value on the stream, then send the value back to the controller.
770        stream
771            .send_value(
772                cx,
773                seq,
774                cx.self_id().clone(),
775                mutates,
776                function,
777                args_kwargs,
778                device_meshes,
779            )
780            .await
781    }
782
783    async fn send_result_of_actor_call(
784        &mut self,
785        cx: &hyperactor::Context<Self>,
786        params: ActorCallParams,
787    ) -> Result<()> {
788        let stream = self.try_get_stream(params.stream)?;
789        stream.send_result_of_actor_call(cx, params).await?;
790        Ok(())
791    }
792    async fn call_actor_method(
793        &mut self,
794        cx: &hyperactor::Context<Self>,
795        params: ActorMethodParams,
796    ) -> Result<()> {
797        let stream = self.try_get_stream(params.call.stream)?;
798        stream.call_actor_method(cx, params).await?;
799        Ok(())
800    }
801    async fn split_comm(
802        &mut self,
803        cx: &hyperactor::Context<Self>,
804        dims: Vec<String>,
805        device_mesh: Ref,
806        stream_ref: StreamRef,
807    ) -> Result<()> {
808        let global_comm = self
809            .comm
810            .as_ref()
811            .context("tried to call SplitComm before BackendNetworkInit")?;
812        match self.device_meshes.get_mut(&device_mesh) {
813            Some((device_mesh, comm_map)) => {
814                // This rank is in the group to be split off. Split a new
815                // communicator for it off from the global communicator.
816                let stream = self
817                    .streams
818                    .get(&stream_ref)
819                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
820
821                let dims = SortedVec::from_unsorted(dims);
822
823                anyhow::ensure!(
824                    !comm_map.contains_key(&(stream_ref, dims.clone())),
825                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
826                );
827                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
828                let size = ranks_for_group.len();
829                let split_comm = global_comm
830                    .split_from(
831                        cx,
832                        ranks_for_group
833                            .into_iter()
834                            .map(|v| v.clone().try_into())
835                            .collect::<Result<Vec<_>, _>>()?,
836                    )
837                    .await?
838                    .context("split comm should include self rank")?;
839                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
840            }
841            None => {
842                // This rank is not in the group to be split off. We still need to
843                // participate in the commSplit call, however.
844                global_comm.split_from(cx, vec![]).await?;
845            }
846        }
847        Ok(())
848    }
849
850    async fn split_comm_for_process_group(
851        &mut self,
852        cx: &hyperactor::Context<Self>,
853        remote_process_group_ref: Ref,
854        stream_ref: StreamRef,
855    ) -> Result<()> {
856        ensure!(
857            self.streams.contains_key(&stream_ref),
858            "invalid stream id: {:#?}",
859            stream_ref
860        );
861        let global_comm = self
862            .comm
863            .as_ref()
864            .context("tried to call SplitComm before BackendNetworkInit")?;
865        let state = self
866            .remote_process_groups
867            .get_mut(&remote_process_group_ref)
868            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
869        match self.device_meshes.get_mut(&state.device_mesh_ref) {
870            Some((device_mesh, _)) => {
871                // This rank is in the group to be split off. Split a new
872                // communicator for it off from the global communicator.
873                let entry = match state.comms.entry(stream_ref) {
874                    Entry::Vacant(entry) => entry,
875                    Entry::Occupied(_) => bail!(
876                        "comm already exists for remote process group {:#?} on stream {:#?}",
877                        remote_process_group_ref,
878                        stream_ref,
879                    ),
880                };
881                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
882                let split_comm = global_comm
883                    .split_from(
884                        cx,
885                        ranks_for_group
886                            .into_iter()
887                            .map(|v| v.clone().try_into())
888                            .collect::<Result<Vec<_>, _>>()?,
889                    )
890                    .await?
891                    .context("split comm should include self rank")?;
892                entry.insert(Arc::new(split_comm));
893            }
894            None => {
895                // This rank is not in the group to be split off. We still need to
896                // participate in the commSplit call, however.
897                global_comm.split_from(cx, vec![]).await?;
898            }
899        }
900        Ok(())
901    }
902
903    async fn pipe_recv(
904        &mut self,
905        _cx: &hyperactor::Context<Self>,
906        _seq: Seq,
907        _results: Vec<Option<Ref>>,
908        _pipe: Ref,
909        _stream: StreamRef,
910    ) -> Result<()> {
911        panic!("pipe_recv is no longer implemented")
912    }
913
914    async fn set_ref_unit_tests_only(
915        &mut self,
916        cx: &hyperactor::Context<Self>,
917        reference: Ref,
918        value: WireValue,
919        stream: StreamRef,
920    ) -> Result<()> {
921        let stream = self.try_get_stream(stream)?;
922
923        stream.set_ref_unit_tests_only(cx, reference, value).await
924    }
925
926    async fn get_ref_unit_tests_only(
927        &mut self,
928        cx: &hyperactor::Context<Self>,
929        ref_id: Ref,
930        stream: StreamRef,
931    ) -> Result<Option<Result<WireValue, String>>> {
932        let stream = self.try_get_stream(stream)?;
933        Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
934    }
935
936    async fn define_recording(
937        &mut self,
938        cx: &hyperactor::Context<Self>,
939        result: Ref,
940        _nresults: usize,
941        _nformals: usize,
942        commands: Vec<WorkerMessage>,
943        ntotal_messages: usize,
944        index: usize,
945    ) -> Result<()> {
946        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
947            bail!("already defining a different recording");
948        }
949        self.defining_recording = Some(result);
950
951        match self.recordings.entry(result) {
952            Entry::Vacant(entry) => {
953                ensure!(
954                    index == 0,
955                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
956                    index
957                );
958                entry.insert(Recording::PartialRecording {
959                    last_index: 0,
960                    commands,
961                });
962            }
963            Entry::Occupied(mut entry) => match entry.get_mut() {
964                Recording::CompleteRecording { .. } => {
965                    bail!("got DefineRecording message for already complete recording")
966                }
967                Recording::PartialRecording {
968                    last_index,
969                    commands: existing_commands,
970                } => {
971                    ensure!(
972                        index == *last_index + 1,
973                        "Got DefineRecording message with index = {:?}, but \
974                            last seen index for recording is {:?}",
975                        index,
976                        last_index
977                    );
978                    *last_index = index;
979                    existing_commands.extend(commands);
980                }
981            },
982        };
983
984        if index < ntotal_messages - 1 {
985            return Ok(());
986        }
987        let commands = match self.recordings.remove(&result).unwrap() {
988            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
989            Recording::PartialRecording { commands, .. } => {
990                self.recordings.insert(
991                    result,
992                    Recording::CompleteRecording {
993                        streams: HashSet::new(),
994                    },
995                );
996                commands
997            }
998        };
999
1000        for command in commands {
1001            WorkerMessageHandler::handle(self, cx, command).await?;
1002        }
1003
1004        match self.recordings.get(&result).unwrap() {
1005            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1006            Recording::CompleteRecording { streams, .. } => {
1007                for stream in streams {
1008                    self.try_get_stream(*stream)?
1009                        .finalize_recording(cx, result)
1010                        .await?;
1011                }
1012            }
1013        }
1014
1015        self.defining_recording = None;
1016        Ok(())
1017    }
1018
1019    async fn recording_formal(
1020        &mut self,
1021        cx: &hyperactor::Context<Self>,
1022        result: Ref,
1023        argument_index: usize,
1024        stream: StreamRef,
1025    ) -> Result<()> {
1026        ensure!(self.defining_recording.is_some());
1027        self.maybe_add_stream_to_recording(cx, stream).await?;
1028        self.try_get_stream(stream)?
1029            .recording_formal(cx, result, argument_index)
1030            .await
1031    }
1032
1033    async fn recording_result(
1034        &mut self,
1035        cx: &hyperactor::Context<Self>,
1036        result: Ref,
1037        output_index: usize,
1038        stream: StreamRef,
1039    ) -> Result<()> {
1040        ensure!(self.defining_recording.is_some());
1041        self.maybe_add_stream_to_recording(cx, stream).await?;
1042        self.try_get_stream(stream)?
1043            .recording_result(cx, result, output_index)
1044            .await
1045    }
1046
1047    async fn call_recording(
1048        &mut self,
1049        cx: &hyperactor::Context<Self>,
1050        seq: Seq,
1051        recording: Ref,
1052        results: Vec<Ref>,
1053        actuals: Vec<Ref>,
1054    ) -> Result<()> {
1055        ensure!(self.defining_recording.is_none());
1056        let recording_ref = recording;
1057        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1058            "could not find recording: {:#?}",
1059            recording
1060        ))?;
1061        match recording {
1062            Recording::PartialRecording { .. } => {
1063                bail!("cannot call recording because it is incomplete")
1064            }
1065            Recording::CompleteRecording { streams } => try_join_all(
1066                streams
1067                    .iter()
1068                    .map(|stream| self.try_get_stream(*stream))
1069                    .collect::<Result<Vec<_>>>()?
1070                    .into_iter()
1071                    .map(|stream| {
1072                        stream.call_recording(
1073                            cx,
1074                            seq,
1075                            recording_ref,
1076                            results.clone(),
1077                            actuals.clone(),
1078                        )
1079                    }),
1080            )
1081            .await
1082            .map(|_| ()),
1083        }
1084    }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use std::assert_matches::assert_matches;
1090
1091    use anyhow::Result;
1092    use hyperactor::RemoteSpawn;
1093    use hyperactor::channel::ChannelAddr;
1094    use hyperactor::proc::Proc;
1095    use monarch_messages::controller::ControllerMessage;
1096    use monarch_messages::controller::WorkerError;
1097    use monarch_messages::worker::WorkerMessageClient;
1098    use monarch_types::PickledPyObject;
1099    use pyo3::Python;
1100    use pyo3::prelude::*;
1101    use pyo3::types::PyList;
1102    use pyo3::types::PyString;
1103    use rand::Rng;
1104    use rand::distr::Alphanumeric;
1105    use timed_test::async_timed_test;
1106
1107    use super::*;
1108    use crate::test_util::test_setup;
1109
1110    #[async_timed_test(timeout_secs = 60)]
1111    async fn basic_worker() -> Result<()> {
1112        test_setup()?;
1113
1114        let proc = Proc::local();
1115        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1116
1117        let worker_handle = proc
1118            .spawn(
1119                "worker",
1120                WorkerActor::new(
1121                    WorkerParams {
1122                        world_size: 1,
1123                        rank: 0,
1124                        device_index: None,
1125                        controller_actor: controller_ref,
1126                    },
1127                    Flattrs::default(),
1128                )
1129                .await
1130                .unwrap(),
1131            )
1132            .unwrap();
1133        worker_handle
1134            .command_group(
1135                &client,
1136                vec![
1137                    WorkerMessage::CreateStream {
1138                        id: 1.into(),
1139                        stream_creation: StreamCreationMode::UseDefaultStream,
1140                    },
1141                    WorkerMessage::CallFunction(CallFunctionParams {
1142                        seq: 0.into(),
1143                        results: vec![Some(0.into())],
1144                        mutates: vec![],
1145                        function: "torch.ops.aten.ones.default".into(),
1146                        args_kwargs: ArgsKwargs::from_wire_values(
1147                            vec![WireValue::IntList(vec![2, 3])],
1148                            HashMap::new(),
1149                        )
1150                        .unwrap(),
1151                        stream: 1.into(),
1152                        remote_process_groups: vec![],
1153                    }),
1154                    WorkerMessage::CallFunction(CallFunctionParams {
1155                        seq: 2.into(),
1156                        results: vec![Some(Ref { id: 2 })],
1157                        mutates: vec![0.into()],
1158                        function: "torch.ops.aten.sub_.Scalar".into(),
1159                        args_kwargs: ArgsKwargs::from_wire_values(
1160                            vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1161                            HashMap::new(),
1162                        )
1163                        .unwrap(),
1164                        stream: 1.into(),
1165                        remote_process_groups: vec![],
1166                    }),
1167                    WorkerMessage::CallFunction(CallFunctionParams {
1168                        seq: 3.into(),
1169                        results: vec![Some(Ref { id: 3 })],
1170                        mutates: vec![],
1171                        function: "torch.ops.aten.zeros.default".into(),
1172                        args_kwargs: ArgsKwargs::from_wire_values(
1173                            vec![WireValue::IntList(vec![2, 3])],
1174                            HashMap::new(),
1175                        )
1176                        .unwrap(),
1177                        stream: 1.into(),
1178                        remote_process_groups: vec![],
1179                    }),
1180                    WorkerMessage::CallFunction(CallFunctionParams {
1181                        seq: 4.into(),
1182                        results: vec![Some(Ref { id: 4 })],
1183                        mutates: vec![],
1184                        function: "torch.ops.aten.allclose.default".into(),
1185                        args_kwargs: ArgsKwargs::from_wire_values(
1186                            vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1187                            HashMap::new(),
1188                        )
1189                        .unwrap(),
1190                        stream: 1.into(),
1191                        remote_process_groups: vec![],
1192                    }),
1193                ],
1194            )
1195            .await
1196            .unwrap();
1197
1198        let result: bool = worker_handle
1199            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1200            .await
1201            .unwrap()
1202            .unwrap()
1203            .unwrap()
1204            .try_into()
1205            .unwrap();
1206        worker_handle.drain_and_stop("test").unwrap();
1207        worker_handle.await;
1208        let error_responses = controller_rx.drain();
1209        assert!(
1210            error_responses.is_empty(),
1211            "Expected no error responses, got: {:#?}",
1212            error_responses
1213        );
1214        assert!(result);
1215
1216        Ok(())
1217    }
1218
1219    #[async_timed_test(timeout_secs = 60)]
1220    async fn error_sends_response() -> Result<()> {
1221        test_setup()?;
1222
1223        let proc = Proc::local();
1224        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1225
1226        let worker_handle = proc
1227            .spawn(
1228                "worker",
1229                WorkerActor::new(
1230                    WorkerParams {
1231                        world_size: 1,
1232                        rank: 0,
1233                        device_index: None,
1234                        controller_actor: controller_ref,
1235                    },
1236                    Flattrs::default(),
1237                )
1238                .await
1239                .unwrap(),
1240            )
1241            .unwrap();
1242        worker_handle
1243            .command_group(
1244                &client,
1245                vec![
1246                    WorkerMessage::CreateStream {
1247                        id: 1.into(),
1248                        stream_creation: StreamCreationMode::UseDefaultStream,
1249                    },
1250                    WorkerMessage::CallFunction(CallFunctionParams {
1251                        seq: 0.into(),
1252                        results: vec![Some(0.into())],
1253                        mutates: vec![],
1254                        function: "torch.ops.aten.rand.default".into(),
1255                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1256                        stream: 1.into(),
1257                        remote_process_groups: vec![],
1258                    }),
1259                    WorkerMessage::Exit { error: None },
1260                ],
1261            )
1262            .await
1263            .unwrap();
1264
1265        worker_handle.drain_and_stop("test").unwrap();
1266        worker_handle.await;
1267        let response_message = controller_rx.recv().await.unwrap();
1268        match response_message {
1269            ControllerMessage::RemoteFunctionFailed {
1270                seq,
1271                error: WorkerError { backtrace: msg, .. },
1272            } => {
1273                assert_eq!(seq, 0.into());
1274                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1275            }
1276            _ => panic!("unexpected response {:#?}", response_message),
1277        }
1278
1279        Ok(())
1280    }
1281
1282    #[async_timed_test(timeout_secs = 60)]
1283    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1284        test_setup()?;
1285
1286        let proc = Proc::local();
1287        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1288
1289        let worker_handle = proc
1290            .spawn(
1291                "worker",
1292                WorkerActor::new(
1293                    WorkerParams {
1294                        world_size: 1,
1295                        rank: 0,
1296                        device_index: None,
1297                        controller_actor: controller_ref,
1298                    },
1299                    Flattrs::default(),
1300                )
1301                .await
1302                .unwrap(),
1303            )
1304            .unwrap();
1305        worker_handle
1306            .command_group(
1307                &client,
1308                vec![
1309                    WorkerMessage::CreateStream {
1310                        id: 1.into(),
1311                        stream_creation: StreamCreationMode::UseDefaultStream,
1312                    },
1313                    WorkerMessage::SetRefUnitTestsOnly {
1314                        reference: 0.into(),
1315                        value: WireValue::Int(1),
1316                        stream: 1.into(),
1317                    },
1318                    WorkerMessage::CallFunction(CallFunctionParams {
1319                        seq: 0.into(),
1320                        results: vec![Some(Ref { id: 2 })],
1321                        mutates: vec![0.into()],
1322                        function: "i.dont.exist".into(),
1323                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1324                        stream: 1.into(),
1325                        remote_process_groups: vec![],
1326                    }),
1327                ],
1328            )
1329            .await
1330            .unwrap();
1331
1332        let result = worker_handle
1333            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1334            .await?;
1335
1336        // Stop/drain worker before asserts to avoid hangs.
1337        worker_handle.drain_and_stop("test").unwrap();
1338        worker_handle.await;
1339
1340        let mutated_ref = result
1341            .context("no such ref")?
1342            .err()
1343            .context("expected error")?;
1344        assert!(mutated_ref.contains("failed to resolve function"));
1345
1346        let responses = controller_rx.drain();
1347        assert_eq!(
1348            responses.len(),
1349            1,
1350            "Expected one response, got: {:#?}",
1351            responses
1352        );
1353        Ok(())
1354    }
1355
1356    #[async_timed_test(timeout_secs = 60)]
1357    async fn accessing_errored_dependency() -> Result<()> {
1358        test_setup()?;
1359
1360        let proc = Proc::local();
1361        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1362
1363        let worker_handle = proc
1364            .spawn(
1365                "worker",
1366                WorkerActor::new(
1367                    WorkerParams {
1368                        world_size: 1,
1369                        rank: 0,
1370                        device_index: None,
1371                        controller_actor: controller_ref,
1372                    },
1373                    Flattrs::default(),
1374                )
1375                .await
1376                .unwrap(),
1377            )
1378            .unwrap();
1379        worker_handle
1380            .command_group(
1381                &client,
1382                vec![
1383                    WorkerMessage::CreateStream {
1384                        id: 1.into(),
1385                        stream_creation: StreamCreationMode::UseDefaultStream,
1386                    },
1387                    WorkerMessage::CallFunction(CallFunctionParams {
1388                        seq: 0.into(),
1389                        results: vec![Some(0.into())],
1390                        mutates: vec![],
1391                        function: "i.dont.exist".into(),
1392                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1393                        stream: 1.into(),
1394                        remote_process_groups: vec![],
1395                    }),
1396                    WorkerMessage::CallFunction(CallFunctionParams {
1397                        seq: 1.into(),
1398                        results: vec![Some(1.into())],
1399                        mutates: vec![],
1400                        function: "torch.ops.aten.sub_.Scalar".into(),
1401                        args_kwargs: ArgsKwargs::from_wire_values(
1402                            vec![WireValue::Ref(0.into())],
1403                            HashMap::new(),
1404                        )
1405                        .unwrap(),
1406                        stream: 1.into(),
1407                        remote_process_groups: vec![],
1408                    }),
1409                    WorkerMessage::Exit { error: None },
1410                ],
1411            )
1412            .await
1413            .unwrap();
1414
1415        worker_handle.drain_and_stop("test").unwrap();
1416        worker_handle.await;
1417
1418        let responses = controller_rx.drain();
1419        assert_eq!(
1420            responses.len(),
1421            1,
1422            "Expected one response, got: {:#?}",
1423            responses
1424        );
1425
1426        match &responses[0] {
1427            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1428                assert_eq!(seq, &0.into())
1429            }
1430            _ => panic!("unexpected response {:#?}", responses[0]),
1431        };
1432        Ok(())
1433    }
1434
1435    #[async_timed_test(timeout_secs = 60)]
1436    async fn py_remote_function_calls() -> Result<()> {
1437        test_setup()?;
1438
1439        let proc = Proc::local();
1440        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1441
1442        let worker_handle = proc
1443            .spawn(
1444                "worker",
1445                WorkerActor::new(
1446                    WorkerParams {
1447                        world_size: 1,
1448                        rank: 0,
1449                        device_index: None,
1450                        controller_actor: controller_ref,
1451                    },
1452                    Flattrs::default(),
1453                )
1454                .await
1455                .unwrap(),
1456            )
1457            .unwrap();
1458        let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1459            Python::attach(|py| {
1460                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1461                    .into_any()
1462                    .try_into()?;
1463                let sort_list: PickledPyObject =
1464                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1465                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1466                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1467                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1468                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1469                let device: PickledPyObject = py
1470                    .import("torch")?
1471                    .getattr("device")?
1472                    .call1(("cuda:1",))?
1473                    .try_into()?;
1474                let memory_format: PickledPyObject = py
1475                    .import("torch")?
1476                    .getattr("contiguous_format")?
1477                    .try_into()?;
1478                PyResult::Ok((
1479                    split_arg,
1480                    sort_list,
1481                    dim,
1482                    layout,
1483                    none,
1484                    scalar,
1485                    device,
1486                    memory_format,
1487                ))
1488            })?;
1489
1490        worker_handle
1491            .command_group(
1492                &client,
1493                vec![
1494                    WorkerMessage::CreateStream {
1495                        id: 1.into(),
1496                        stream_creation: StreamCreationMode::UseDefaultStream,
1497                    },
1498                    WorkerMessage::CallFunction(CallFunctionParams {
1499                        seq: 0.into(),
1500                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1501                        mutates: vec![],
1502                        function: "os.path.split".into(),
1503                        args_kwargs: ArgsKwargs::from_wire_values(
1504                            vec![split_arg.into()],
1505                            HashMap::new(),
1506                        )
1507                        .unwrap(),
1508                        stream: 1.into(),
1509                        remote_process_groups: vec![],
1510                    }),
1511                    WorkerMessage::CallFunction(CallFunctionParams {
1512                        seq: 2.into(),
1513                        results: vec![Some(4.into()), None, None, None, None],
1514                        mutates: vec![],
1515                        function: "builtins.sorted".into(),
1516                        args_kwargs: ArgsKwargs::from_wire_values(
1517                            vec![sort_list.into()],
1518                            HashMap::new(),
1519                        )
1520                        .unwrap(),
1521                        stream: 1.into(),
1522                        remote_process_groups: vec![],
1523                    }),
1524                    WorkerMessage::CreateDeviceMesh {
1525                        result: 5.into(),
1526                        names: vec!["x".into()],
1527                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1528                    },
1529                    WorkerMessage::CallFunction(CallFunctionParams {
1530                        seq: 2.into(),
1531                        results: vec![Some(6.into())],
1532                        mutates: vec![],
1533                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1534                        args_kwargs: ArgsKwargs::from_wire_values(
1535                            vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1536                            HashMap::new(),
1537                        )
1538                        .unwrap(),
1539                        stream: 1.into(),
1540                        remote_process_groups: vec![],
1541                    }),
1542                    WorkerMessage::CallFunction(CallFunctionParams {
1543                        seq: 4.into(),
1544                        results: vec![Some(7.into())],
1545                        mutates: vec![],
1546                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1547                            .into(),
1548                        args_kwargs: ArgsKwargs::from_wire_values(
1549                            vec![scalar.into()],
1550                            HashMap::new(),
1551                        )
1552                        .unwrap(),
1553                        stream: 1.into(),
1554                        remote_process_groups: vec![],
1555                    }),
1556                    WorkerMessage::CallFunction(CallFunctionParams {
1557                        seq: 5.into(),
1558                        results: vec![Some(8.into())],
1559                        mutates: vec![],
1560                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1561                        args_kwargs: ArgsKwargs::from_wire_values(
1562                            vec![layout.into()],
1563                            HashMap::new(),
1564                        )
1565                        .unwrap(),
1566                        stream: 1.into(),
1567                        remote_process_groups: vec![],
1568                    }),
1569                    WorkerMessage::CallFunction(CallFunctionParams {
1570                        seq: 6.into(),
1571                        results: vec![Some(9.into())],
1572                        mutates: vec![],
1573                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1574                        args_kwargs: ArgsKwargs::from_wire_values(
1575                            vec![none.into()],
1576                            HashMap::new(),
1577                        )
1578                        .unwrap(),
1579                        stream: 1.into(),
1580                        remote_process_groups: vec![],
1581                    }),
1582                    // Verify that a function that returns `None` matches up with an
1583                    // empty result list.
1584                    WorkerMessage::CallFunction(CallFunctionParams {
1585                        seq: 7.into(),
1586                        results: vec![None],
1587                        mutates: vec![],
1588                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1589                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1590                        stream: 1.into(),
1591                        remote_process_groups: vec![],
1592                    }),
1593                    WorkerMessage::CallFunction(CallFunctionParams {
1594                        seq: 8.into(),
1595                        results: vec![Some(10.into())],
1596                        mutates: vec![],
1597                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1598                        args_kwargs: ArgsKwargs::from_wire_values(
1599                            vec![device.into()],
1600                            HashMap::new(),
1601                        )
1602                        .unwrap(),
1603                        stream: 1.into(),
1604                        remote_process_groups: vec![],
1605                    }),
1606                    WorkerMessage::CallFunction(CallFunctionParams {
1607                        seq: 9.into(),
1608                        results: vec![Some(11.into())],
1609                        mutates: vec![],
1610                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1611                            .into(),
1612                        args_kwargs: ArgsKwargs::from_wire_values(
1613                            vec![memory_format.into()],
1614                            HashMap::new(),
1615                        )
1616                        .unwrap(),
1617                        stream: 1.into(),
1618                        remote_process_groups: vec![],
1619                    }),
1620                    // Test that list of tests can be passes correctly
1621                    WorkerMessage::CallFunction(CallFunctionParams {
1622                        seq: 10.into(),
1623                        results: vec![Some(12.into())],
1624                        mutates: vec![],
1625                        function: "torch.ops.aten.ones.default".into(),
1626                        args_kwargs: ArgsKwargs::from_wire_values(
1627                            vec![WireValue::IntList(vec![2, 3])],
1628                            HashMap::new(),
1629                        )
1630                        .unwrap(),
1631                        stream: 1.into(),
1632                        remote_process_groups: vec![],
1633                    }),
1634                    WorkerMessage::CallFunction(CallFunctionParams {
1635                        seq: 11.into(),
1636                        results: vec![Some(13.into())],
1637                        mutates: vec![],
1638                        function: "torch.ops.aten.stack.default".into(),
1639                        args_kwargs: ArgsKwargs::from_wire_values(
1640                            vec![WireValue::RefList(vec![12.into(), 12.into()])],
1641                            HashMap::new(),
1642                        )
1643                        .unwrap(),
1644                        stream: 1.into(),
1645                        remote_process_groups: vec![],
1646                    }),
1647                ],
1648            )
1649            .await
1650            .unwrap();
1651
1652        let result1: String = worker_handle
1653            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1654            .await
1655            .unwrap()
1656            .unwrap()
1657            .unwrap()
1658            .try_into()
1659            .unwrap();
1660        let result2: String = worker_handle
1661            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1662            .await
1663            .unwrap()
1664            .unwrap()
1665            .unwrap()
1666            .try_into()
1667            .unwrap();
1668        let result3: i64 = worker_handle
1669            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1670            .await
1671            .unwrap()
1672            .unwrap()
1673            .unwrap()
1674            .try_into()
1675            .unwrap();
1676        let result4: i64 = worker_handle
1677            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1678            .await
1679            .unwrap()
1680            .unwrap()
1681            .unwrap()
1682            .try_into()
1683            .unwrap();
1684        worker_handle
1685            .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1686            .await
1687            .unwrap()
1688            .unwrap()
1689            .unwrap();
1690
1691        worker_handle
1692            .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1693            .await
1694            .unwrap()
1695            .unwrap()
1696            .unwrap();
1697
1698        assert_matches!(
1699            worker_handle
1700                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1701                .await
1702                .unwrap()
1703                .unwrap()
1704                .unwrap(),
1705            WireValue::None(()),
1706        );
1707        worker_handle
1708            .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1709            .await
1710            .unwrap()
1711            .unwrap()
1712            .unwrap();
1713        worker_handle
1714            .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1715            .await
1716            .unwrap()
1717            .unwrap()
1718            .unwrap();
1719
1720        worker_handle.drain_and_stop("test").unwrap();
1721        worker_handle.await;
1722        let error_responses = controller_rx.drain();
1723        assert!(
1724            error_responses.is_empty(),
1725            "Expected no error responses, got: {:#?}",
1726            error_responses
1727        );
1728
1729        assert_eq!(result1, "/fbs/fbc/foo");
1730        assert_eq!(result2, "bar");
1731        assert_eq!(result3, 1);
1732        assert_eq!(result4, 0);
1733
1734        Ok(())
1735    }
1736
1737    #[async_timed_test(timeout_secs = 60)]
1738    async fn delete_refs() -> Result<()> {
1739        test_setup()?;
1740
1741        let proc = Proc::local();
1742        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1743
1744        let worker_handle = proc
1745            .spawn(
1746                "worker",
1747                WorkerActor::new(
1748                    WorkerParams {
1749                        world_size: 1,
1750                        rank: 0,
1751                        device_index: None,
1752                        controller_actor: controller_ref,
1753                    },
1754                    Flattrs::default(),
1755                )
1756                .await
1757                .unwrap(),
1758            )
1759            .unwrap();
1760        worker_handle
1761            .command_group(
1762                &client,
1763                vec![
1764                    WorkerMessage::CreateStream {
1765                        id: 0.into(),
1766                        stream_creation: StreamCreationMode::CreateNewStream,
1767                    },
1768                    WorkerMessage::CreateStream {
1769                        id: 1.into(),
1770                        stream_creation: StreamCreationMode::CreateNewStream,
1771                    },
1772                    WorkerMessage::SetRefUnitTestsOnly {
1773                        reference: Ref { id: 2 },
1774                        value: WireValue::Bool(false),
1775                        stream: 0.into(),
1776                    },
1777                    WorkerMessage::SetRefUnitTestsOnly {
1778                        reference: Ref { id: 3 },
1779                        value: WireValue::Bool(true),
1780                        stream: 0.into(),
1781                    },
1782                    WorkerMessage::SetRefUnitTestsOnly {
1783                        reference: Ref { id: 4 },
1784                        value: WireValue::Int(0),
1785                        stream: 1.into(),
1786                    },
1787                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1788                ],
1789            )
1790            .await
1791            .unwrap();
1792
1793        let result: bool = worker_handle
1794            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1795            .await
1796            .unwrap()
1797            .unwrap()
1798            .unwrap()
1799            .try_into()
1800            .unwrap();
1801        let fail_result = worker_handle
1802            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1803            .await
1804            .unwrap();
1805
1806        worker_handle.drain_and_stop("test").unwrap();
1807        worker_handle.await;
1808
1809        assert!(result, "should be able to get a non-deleted ref");
1810        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1811
1812        Ok(())
1813    }
1814
1815    #[async_timed_test(timeout_secs = 60)]
1816    async fn request_status() -> Result<()> {
1817        test_setup()?;
1818
1819        let proc = Proc::local();
1820        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1821
1822        let worker_handle = proc
1823            .spawn(
1824                "worker",
1825                WorkerActor::new(
1826                    WorkerParams {
1827                        world_size: 1,
1828                        rank: 0,
1829                        device_index: None,
1830                        controller_actor: controller_ref,
1831                    },
1832                    Flattrs::default(),
1833                )
1834                .await
1835                .unwrap(),
1836            )
1837            .unwrap();
1838        worker_handle
1839            .command_group(
1840                &client,
1841                vec![
1842                    WorkerMessage::CreateStream {
1843                        id: 0.into(),
1844                        stream_creation: StreamCreationMode::CreateNewStream,
1845                    },
1846                    WorkerMessage::CreateStream {
1847                        id: 1.into(),
1848                        stream_creation: StreamCreationMode::CreateNewStream,
1849                    },
1850                ],
1851            )
1852            .await
1853            .unwrap();
1854
1855        for i in 0..100 {
1856            // call alternating functions on this stream.
1857            worker_handle
1858                .call_function(
1859                    &client,
1860                    CallFunctionParams {
1861                        seq: i.into(),
1862                        results: vec![Some(Ref { id: i + 2 })],
1863                        mutates: vec![],
1864                        function: "torch.ops.aten.ones.default".into(),
1865                        args_kwargs: ArgsKwargs::from_wire_values(
1866                            vec![WireValue::IntList(vec![2, 3])],
1867                            HashMap::new(),
1868                        )
1869                        .unwrap(),
1870                        stream: (i % 2).into(),
1871                        remote_process_groups: vec![],
1872                    },
1873                )
1874                .await
1875                .unwrap();
1876        }
1877
1878        worker_handle
1879            .request_status(&client, 100.into(), false)
1880            .await
1881            .unwrap();
1882
1883        worker_handle.drain_and_stop("test").unwrap();
1884        worker_handle.await;
1885
1886        let mut responses = controller_rx.drain();
1887        assert_eq!(
1888            responses.len(),
1889            1,
1890            "Expected one response, got: {:#?}",
1891            responses
1892        );
1893
1894        let response = responses.pop().unwrap();
1895        match response {
1896            ControllerMessage::Status { seq, .. } => {
1897                assert_eq!(seq, 101.into())
1898            }
1899            _ => panic!("unexpected response {:#?}", response),
1900        };
1901
1902        Ok(())
1903    }
1904
1905    #[async_timed_test(timeout_secs = 60)]
1906    async fn backend_network_init() {
1907        test_setup().unwrap();
1908        let proc = Proc::local();
1909        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1910
1911        let worker_handle1 = proc
1912            .spawn(
1913                "worker0",
1914                WorkerActor::new(
1915                    WorkerParams {
1916                        world_size: 2,
1917                        rank: 0,
1918                        device_index: Some(0),
1919                        controller_actor: controller_ref.clone(),
1920                    },
1921                    Flattrs::default(),
1922                )
1923                .await
1924                .unwrap(),
1925            )
1926            .unwrap();
1927        let worker_handle2 = proc
1928            .spawn(
1929                "worker1",
1930                WorkerActor::new(
1931                    WorkerParams {
1932                        world_size: 2,
1933                        rank: 1,
1934                        device_index: Some(1),
1935                        controller_actor: controller_ref,
1936                    },
1937                    Flattrs::default(),
1938                )
1939                .await
1940                .unwrap(),
1941            )
1942            .unwrap();
1943
1944        let unique_id = UniqueId::new().unwrap();
1945        worker_handle1
1946            .backend_network_init(&client, unique_id.clone())
1947            .await
1948            .unwrap();
1949        worker_handle2
1950            .backend_network_init(&client, unique_id)
1951            .await
1952            .unwrap();
1953
1954        worker_handle1.drain_and_stop("test").unwrap();
1955        worker_handle1.await;
1956        worker_handle2.drain_and_stop("test").unwrap();
1957        worker_handle2.await;
1958    }
1959
1960    #[allow(dead_code)]
1961    fn get_random_channel_addr() -> ChannelAddr {
1962        let random_string = rand::rng()
1963            .sample_iter(&Alphanumeric)
1964            .take(24)
1965            .map(char::from)
1966            .collect::<String>();
1967        format!("unix!@{random_string}").parse().unwrap()
1968    }
1969}