monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30pub mod bootstrap;
31mod borrow;
32mod comm;
33pub mod device_mesh;
34pub mod pipe;
35pub mod py_pipe;
36pub mod stream;
37pub mod test_util;
38
39use std::collections::HashMap;
40use std::collections::HashSet;
41use std::collections::hash_map::Entry;
42use std::sync::Arc;
43
44use anyhow::Context;
45use anyhow::Result;
46use anyhow::anyhow;
47use anyhow::bail;
48use anyhow::ensure;
49use async_trait::async_trait;
50use borrow::Borrow;
51use comm::CommMessageClient;
52use comm::CommParams;
53use comm::NcclCommActor;
54use derive_more::TryInto;
55use device_mesh::DeviceMesh;
56use futures::future::try_join_all;
57use hyperactor::Actor;
58use hyperactor::ActorRef;
59use hyperactor::Bind;
60use hyperactor::Handler;
61use hyperactor::Named;
62use hyperactor::Unbind;
63use hyperactor::actor::ActorHandle;
64use hyperactor::cap;
65use hyperactor::reference::ActorId;
66use hyperactor_mesh::comm::multicast::CastInfo;
67use itertools::Itertools;
68use monarch_hyperactor::shape::PyPoint;
69use monarch_hyperactor::shape::PyShape;
70use monarch_messages::controller::ControllerActor;
71use monarch_messages::controller::ControllerMessageClient;
72use monarch_messages::controller::Seq;
73use monarch_messages::wire_value::WireValue;
74use monarch_messages::worker::ActorCallParams;
75use monarch_messages::worker::ActorMethodParams;
76use monarch_messages::worker::CallFunctionError;
77use monarch_messages::worker::CallFunctionParams;
78use monarch_messages::worker::Factory;
79use monarch_messages::worker::Reduction;
80use monarch_messages::worker::Ref;
81use monarch_messages::worker::ResolvableFunction;
82use monarch_messages::worker::StreamCreationMode;
83use monarch_messages::worker::StreamRef;
84use monarch_messages::worker::WorkerMessage;
85use monarch_messages::worker::WorkerMessageHandler;
86use monarch_messages::worker::WorkerParams;
87use monarch_types::PyTree;
88use ndslice::Slice;
89use pipe::PipeActor;
90use pipe::PipeParams;
91use pyo3::Py;
92use pyo3::Python;
93use pyo3::types::PyAnyMethods;
94use serde::Deserialize;
95use serde::Serialize;
96use sorted_vec::SortedVec;
97use stream::StreamActor;
98use stream::StreamMessageClient;
99use stream::StreamParams;
100use torch_sys::CudaDevice;
101use torch_sys::DeviceIndex;
102use torch_sys::Layout;
103use torch_sys::RValue;
104use torch_sys::ScalarType;
105use torch_sys::TensorCell;
106use torch_sys::factory_zeros;
107use torch_sys_cuda::nccl::NcclConfig;
108use torch_sys_cuda::nccl::ReduceOp;
109use torch_sys_cuda::nccl::UniqueId;
110
111#[derive(Debug)]
112struct RemoteProcessGroupState {
113    device_mesh_ref: Ref,
114    dims: SortedVec<String>,
115    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
116}
117
118impl RemoteProcessGroupState {
119    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
120        Self {
121            device_mesh_ref,
122            dims,
123            comms: HashMap::new(),
124        }
125    }
126}
127
128#[derive(Debug)]
129enum Recording {
130    // In the process of receiving DefineRecording messages for this
131    // recording.
132    PartialRecording {
133        // The index of the last DefineRecording message received.
134        last_index: usize,
135        // The list of commands seen so far for this recording.
136        commands: Vec<WorkerMessage>,
137    },
138
139    // The recording is ready to be run.
140    CompleteRecording {
141        // The list of streams on which this recording is defined.
142        streams: HashSet<StreamRef>,
143    },
144}
145
146/// A PyTorch runtime instance, operating on a single accelerator device,
147/// controlled via hyperactor messaging.
148///
149/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
150/// real work.
151///
152/// See [`WorkerMessage`] for what it can do!
153#[derive(Debug)]
154#[hyperactor::export(
155    spawn = true,
156    handlers = [
157        WorkerMessage {cast = true},
158        AssignRankMessage {cast = true},
159    ],
160)]
161pub struct WorkerActor {
162    device: Option<CudaDevice>,
163    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
164    /// Maps streams to the device mesh and a map of dim names to the concrete
165    /// communicator actor that represents the dimension for that stream.
166    device_meshes: HashMap<
167        Ref,
168        (
169            DeviceMesh,
170            // A map for comms for this mesh for a given pair of stream and dims.
171            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
172        ),
173    >,
174    world_size: usize,
175    rank: usize,
176    borrows: HashMap<u64, Borrow>,
177    comm: Option<ActorHandle<NcclCommActor>>,
178    controller_actor: ActorRef<ControllerActor>,
179    /// Pipes created for the worker.
180    pipes: HashMap<Ref, ActorHandle<PipeActor>>,
181    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
182    /// subsequent `CallFunction` calls, as this is where the actual allocation
183    /// will happen.
184    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
185    /// The comm actor for each pair of streams that need to send/recv tensors.
186    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
187    recordings: HashMap<Ref, Recording>,
188    defining_recording: Option<Ref>,
189    respond_with_python_message: bool,
190}
191
192impl WorkerActor {
193    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
194        self.streams
195            .get(&stream)
196            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
197    }
198
199    async fn maybe_add_stream_to_recording(
200        &mut self,
201        caps: &impl cap::CanSend,
202        stream: StreamRef,
203    ) -> Result<()> {
204        // If we're defining a recording, add the stream to the list of streams that
205        // this recording uses, and call define_recording on the stream.
206        if let Some(defining_recording) = self.defining_recording {
207            let recording = self.recordings.get_mut(&defining_recording).unwrap();
208            let fut = match recording {
209                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
210                Recording::CompleteRecording { streams } => {
211                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
212                        Ok(self
213                            .try_get_stream(stream)?
214                            .define_recording(caps, defining_recording))
215                    })
216                }
217            }
218            .transpose()?;
219            match fut {
220                Some(fut) => fut.await,
221                None => Ok(()),
222            }
223        } else {
224            Ok(())
225        }
226    }
227}
228
229#[async_trait]
230impl Actor for WorkerActor {
231    type Params = WorkerParams;
232
233    async fn new(
234        WorkerParams {
235            world_size,
236            rank,
237            device_index,
238            controller_actor,
239        }: Self::Params,
240    ) -> Result<Self> {
241        Ok(Self {
242            device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
243            streams: HashMap::new(),
244            device_meshes: HashMap::new(),
245            world_size,
246            rank,
247            borrows: HashMap::new(),
248            comm: None,
249            controller_actor,
250            pipes: HashMap::new(),
251            remote_process_groups: HashMap::new(),
252            send_recv_comms: HashMap::new(),
253            recordings: HashMap::new(),
254            defining_recording: None,
255            respond_with_python_message: false,
256        })
257    }
258
259    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
260}
261
262#[async_trait]
263impl Handler<AssignRankMessage> for WorkerActor {
264    async fn handle(
265        &mut self,
266        cx: &hyperactor::Context<Self>,
267        _: AssignRankMessage,
268    ) -> anyhow::Result<()> {
269        let (rank, shape) = cx.cast_info();
270        self.rank = rank;
271        self.respond_with_python_message = true;
272        Python::with_gil(|py| {
273            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
274            let shape: PyShape = shape.into();
275            let shape: Py<PyShape> = Py::new(py, shape).unwrap();
276            let p: PyPoint = PyPoint::new(rank, shape);
277            mesh_controller
278                .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
279                .unwrap();
280        });
281        Ok(())
282    }
283}
284
285/// Worker messages. These define the observable behavior of the worker, so the
286/// documentations here
287#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
288pub enum AssignRankMessage {
289    AssignRank(),
290}
291
292#[async_trait]
293impl Handler<WorkerMessage> for WorkerActor {
294    async fn handle(
295        &mut self,
296        cx: &hyperactor::Context<Self>,
297        message: WorkerMessage,
298    ) -> anyhow::Result<()> {
299        <Self as WorkerMessageHandler>::handle(self, cx, message).await
300    }
301}
302
303#[async_trait]
304impl WorkerMessageHandler for WorkerActor {
305    async fn backend_network_init(
306        &mut self,
307        cx: &hyperactor::Context<Self>,
308        unique_id: UniqueId,
309    ) -> Result<()> {
310        let device = self
311            .device
312            .expect("tried to init backend network on a non-CUDA worker");
313        let comm = NcclCommActor::spawn(
314            cx,
315            CommParams::New {
316                device,
317                unique_id,
318                world_size: self.world_size.try_into().unwrap(),
319                rank: self.rank.try_into().unwrap(),
320            },
321        )
322        .await?;
323
324        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
325        let cell = TensorCell::new(tensor);
326
327        comm.all_reduce(
328            cx,
329            cell,
330            ReduceOp::Sum,
331            torch_sys_cuda::cuda::Stream::get_current_stream(),
332        )
333        .await?;
334
335        // TODO: this blocks forward progress of the the actor loop while we
336        // wait for the streams to catch up. Once we have a way of spawning
337        // tasks that the actor system can monitor in a non-blocking way, we
338        // should remove this.
339
340        // We need to be careful to initialize the streams in a consistent order
341        // across all workers to avoid NCCL deadlocks. Use the refs to provide
342        // this order, as a stream's ref will be the same across all workers.
343        let sorted_streams = self
344            .streams
345            .iter()
346            .sorted_by_key(|(k, _)| *k)
347            .map(|(_, v)| v.as_ref());
348
349        let mut splits = Vec::new();
350        for _ in 0..sorted_streams.len() {
351            // Do the split in this event loop, to provide a deterministic
352            // order.
353            splits.push(comm.split_all(cx, None).await?);
354        }
355        let _: Vec<()> = try_join_all(
356            sorted_streams
357                .into_iter()
358                .zip(splits.into_iter())
359                .map(|(stream, split)| stream.init_comm(cx, split)),
360        )
361        .await?;
362
363        self.comm = Some(comm);
364
365        Ok(())
366    }
367
368    async fn backend_network_point_to_point_init(
369        &mut self,
370        cx: &hyperactor::Context<Self>,
371        from_stream: StreamRef,
372        to_stream: StreamRef,
373    ) -> Result<()> {
374        if !self.streams.contains_key(&from_stream) {
375            bail!("invalid from_stream id: {:#?}", from_stream);
376        }
377        if !self.streams.contains_key(&to_stream) {
378            bail!("invalid to_stream id: {:#?}", to_stream);
379        }
380        let global_comm = self
381            .comm
382            .as_ref()
383            .context("tried to call Reduce before BackendNetworkInit")?;
384        let comm = global_comm.split_all(cx, None).await?;
385        self.send_recv_comms
386            .insert((from_stream, to_stream), Arc::new(comm));
387        Ok(())
388    }
389
390    async fn call_function(
391        &mut self,
392        cx: &hyperactor::Context<Self>,
393        params: CallFunctionParams,
394    ) -> Result<()> {
395        let stream = self.try_get_stream(params.stream)?.clone();
396        self.maybe_add_stream_to_recording(cx, params.stream)
397            .await?;
398
399        let device_meshes = if params.function.as_torch_op().is_some() {
400            HashMap::new()
401        } else {
402            self.device_meshes
403                .iter()
404                .map(|(k, v)| (k.clone(), v.0.clone()))
405                .collect()
406        };
407
408        let mut remote_process_groups = HashMap::new();
409        for remote_process_group_ref in &params.remote_process_groups {
410            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
411                let dims_vec = state.dims.iter().cloned().collect();
412                let (device_mesh, _) = self
413                    .device_meshes
414                    .get(&state.device_mesh_ref)
415                    .ok_or_else(|| {
416                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
417                    })?
418                    .clone();
419                let comm = state.comms
420                    .get(&params.stream)
421                    .ok_or_else(|| {
422                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
423                    })?
424                    .clone();
425                remote_process_groups.insert(
426                    remote_process_group_ref.clone(),
427                    (device_mesh, dims_vec, comm),
428                );
429            }
430        }
431
432        stream
433            .call_function(cx, params, device_meshes, remote_process_groups)
434            .await?;
435
436        Ok(())
437    }
438
439    async fn command_group(
440        &mut self,
441        cx: &hyperactor::Context<Self>,
442        params: Vec<WorkerMessage>,
443    ) -> Result<()> {
444        for msg in params {
445            WorkerMessageHandler::handle(self, cx, msg).await?;
446        }
447        Ok(())
448    }
449
450    async fn create_stream(
451        &mut self,
452        cx: &hyperactor::Context<Self>,
453        result: StreamRef,
454        creation_mode: StreamCreationMode,
455    ) -> Result<()> {
456        let handle: ActorHandle<StreamActor> = StreamActor::spawn(
457            cx,
458            StreamParams {
459                world_size: self.world_size,
460                rank: self.rank,
461                creation_mode,
462                id: result,
463                device: self.device,
464                controller_actor: self.controller_actor.clone(),
465                respond_with_python_message: self.respond_with_python_message,
466            },
467        )
468        .await?;
469        self.streams.insert(result, Arc::new(handle));
470        Ok(())
471    }
472
473    async fn create_device_mesh(
474        &mut self,
475        _cx: &hyperactor::Context<Self>,
476        result: Ref,
477        names: Vec<String>,
478        ranks: Slice,
479    ) -> Result<()> {
480        self.device_meshes.insert(
481            result,
482            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
483        );
484        Ok(())
485    }
486
487    async fn create_remote_process_group(
488        &mut self,
489        _cx: &hyperactor::Context<Self>,
490        result: Ref,
491        device_mesh: Ref,
492        dims: Vec<String>,
493    ) -> Result<()> {
494        self.device_meshes
495            .get(&device_mesh)
496            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
497        match self.remote_process_groups.entry(result) {
498            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
499                device_mesh,
500                SortedVec::from_unsorted(dims),
501            )),
502            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
503        };
504        Ok(())
505    }
506
507    async fn borrow_create(
508        &mut self,
509        cx: &hyperactor::Context<Self>,
510        result: Ref,
511        borrow_id: u64,
512        tensor_ref: Ref,
513        from_stream: StreamRef,
514        to_stream: StreamRef,
515    ) -> Result<()> {
516        self.maybe_add_stream_to_recording(cx, from_stream).await?;
517        self.maybe_add_stream_to_recording(cx, to_stream).await?;
518        let from_stream = self.try_get_stream(from_stream)?.clone();
519        let to_stream = self.try_get_stream(to_stream)?.clone();
520
521        let borrow =
522            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
523        self.borrows.insert(borrow_id, borrow);
524        Ok(())
525    }
526
527    async fn borrow_first_use(
528        &mut self,
529        cx: &hyperactor::Context<Self>,
530        borrow: u64,
531    ) -> Result<()> {
532        let borrow = self
533            .borrows
534            .get_mut(&borrow)
535            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
536
537        borrow.first_use(cx).await?;
538        Ok(())
539    }
540
541    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
542        let borrow = self
543            .borrows
544            .get_mut(&borrow)
545            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
546
547        borrow.last_use(cx).await?;
548        Ok(())
549    }
550
551    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
552        let borrow = self
553            .borrows
554            .get_mut(&borrow_id)
555            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
556
557        borrow.drop(cx).await?;
558        self.borrows.remove(&borrow_id);
559        Ok(())
560    }
561
562    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
563        // Fan the delete message to all streams.
564        // Check for errors.
565        // TODO: this blocks forward progress of the the actor loop while we
566        // wait for the streams to catch up. Once we have a way of spawning
567        // tasks that the actor system can monitor in a non-blocking way, we
568        // should remove this.
569        let _: Vec<()> = try_join_all(
570            self.streams
571                .values()
572                .map(|s| s.delete_refs(cx, refs.clone())),
573        )
574        .await?;
575        Ok(())
576    }
577
578    async fn request_status(
579        &mut self,
580        cx: &hyperactor::Context<Self>,
581        seq: Seq,
582        controller: bool,
583    ) -> Result<()> {
584        // TODO: this blocks forward progress of the the actor loop while we
585        // wait for the streams to catch up. Once we have a way of spawning
586        // tasks that the actor system can monitor in a non-blocking way, we
587        // should remove this.
588        let _: Vec<()> = try_join_all(
589            self.streams
590                .values()
591                .map(|stream| stream.request_status(cx)),
592        )
593        .await?;
594
595        ControllerMessageClient::status(
596            &self.controller_actor,
597            cx,
598            seq.next(),
599            cx.self_id().clone(),
600            controller,
601        )
602        .await?;
603        Ok(())
604    }
605
606    async fn reduce(
607        &mut self,
608        cx: &hyperactor::Context<Self>,
609        result: Ref,
610        local_tensor: Ref,
611        factory: Factory,
612        source_mesh: Ref,
613        stream_ref: StreamRef,
614        dims: Vec<String>,
615        reduction: Reduction,
616        scatter: bool,
617        in_place: bool,
618        out: Option<Ref>,
619    ) -> Result<()> {
620        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
621
622        // Sort for stable indexing.
623        let dims = SortedVec::from_unsorted(dims);
624        let stream = self.try_get_stream(stream_ref)?.clone();
625
626        let (_, comm_map) = self
627            .device_meshes
628            .get_mut(&source_mesh)
629            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
630
631        let (size, comm) = comm_map
632            .get(&(stream_ref, dims.clone()))
633            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
634            .clone();
635
636        stream
637            .reduce(
638                cx,
639                comm,
640                size.try_into()?,
641                result,
642                local_tensor,
643                factory,
644                reduction,
645                scatter,
646                in_place,
647                out,
648            )
649            .await?;
650
651        Ok(())
652    }
653
654    async fn create_pipe(
655        &mut self,
656        cx: &hyperactor::Context<Self>,
657        result: Ref,
658        // TODO(agallagher): This is used in the python impl to name the socket
659        // path to use for comms, but we don't currently use a named socket.
660        _key: String,
661        function: ResolvableFunction,
662        max_messages: i64,
663        device_mesh: Ref,
664        args: Vec<WireValue>,
665        kwargs: HashMap<String, WireValue>,
666    ) -> Result<()> {
667        println!("CREATE PIPE1 {}", result);
668        let args: Vec<PyTree<RValue>> = args
669            .into_iter()
670            .map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
671            .collect();
672        let kwargs: HashMap<_, PyTree<RValue>> = kwargs
673            .into_iter()
674            .map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
675            .collect();
676        let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
677            CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
678        })?;
679        println!("CREATE PIPE2 {}", result);
680        // TODO(agallagher): Fix error prop. (When pipe is read from the pipes dict if it had an error it should cause a dependent error in send_value not an actor error as it does now)
681        let pipe = PipeActor::spawn(
682            cx,
683            PipeParams {
684                function,
685                max_messages,
686                ranks: device_mesh.0.ranks(),
687                sizes: device_mesh.0.sizes(),
688                args,
689                kwargs,
690            },
691        )
692        .await?;
693        println!("AFTER CREATE PIPE {}", result);
694
695        self.pipes.insert(result, pipe);
696        Ok(())
697    }
698
699    async fn send_tensor(
700        &mut self,
701        cx: &hyperactor::Context<Self>,
702        result: Ref,
703        from_ranks: Slice,
704        to_ranks: Slice,
705        tensor: Ref,
706        factory: Factory,
707        from_stream: StreamRef,
708        to_stream: StreamRef,
709    ) -> Result<()> {
710        let comm = self
711            .send_recv_comms
712            .get(&(from_stream, to_stream))
713            .ok_or_else(|| {
714                anyhow::anyhow!(
715                    "could not find stream to stream comm for: {:#?}",
716                    (from_stream, to_stream)
717                )
718            })?
719            .clone();
720
721        let to_rank = from_ranks
722            .index(self.rank)
723            .map(|index| to_ranks.get(index).ok())
724            .ok()
725            .flatten();
726        let from_rank = to_ranks
727            .index(self.rank)
728            .map(|index| from_ranks.get(index).ok())
729            .ok()
730            .flatten();
731
732        let (stream, stream_ref) = if to_rank.is_none() {
733            (self.try_get_stream(to_stream)?.clone(), to_stream)
734        } else if from_rank.is_none() || from_stream == to_stream {
735            (self.try_get_stream(from_stream)?.clone(), from_stream)
736        } else {
737            unimplemented!(
738                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
739                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
740                Then the send stream would do the nccl op, and then sync with sending stream again."
741            );
742        };
743
744        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
745
746        stream
747            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
748            .await?;
749
750        Ok(())
751    }
752
753    async fn exit(
754        &mut self,
755        cx: &hyperactor::Context<Self>,
756        error: Option<(Option<ActorId>, String)>,
757    ) -> Result<()> {
758        for (_, stream) in self.streams.drain() {
759            stream.drain_and_stop()?;
760            Arc::into_inner(stream)
761                .expect("there should be no owners of this stream handle except the worker stream table")
762                .await;
763        }
764
765        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
766            .ok()
767            .and_then(|val| val.parse::<i32>().ok())
768            .unwrap_or(1);
769        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
770            .ok()
771            .and_then(|val| val.parse::<i32>().ok())
772            .unwrap_or(1);
773
774        // Exit the worker process if there is an error.
775        let exit_code = match error {
776            Some((Some(actor_id), reason)) => {
777                tracing::error!(
778                    "stopping the worker, actor {} failed with error: {}",
779                    actor_id,
780                    reason
781                );
782                if *cx.self_id() == actor_id {
783                    self_error_exit_code
784                } else {
785                    peer_error_exit_code
786                }
787            }
788            Some((None, reason)) => {
789                tracing::error!("stopping the worker, reason: {}", reason);
790                1
791            }
792            None => 0,
793        };
794
795        if exit_code != 0 {
796            tracing::info!("stopping the worker process, exit code: {}", exit_code);
797            std::process::exit(exit_code);
798        }
799        cx.stop()?;
800        Ok(())
801    }
802
803    async fn send_value(
804        &mut self,
805        cx: &hyperactor::Context<Self>,
806        seq: Seq,
807        destination: Option<Ref>,
808        mutates: Vec<Ref>,
809        function: Option<ResolvableFunction>,
810        args: Vec<WireValue>,
811        kwargs: HashMap<String, WireValue>,
812        stream: StreamRef,
813    ) -> Result<()> {
814        // Resolve the stream.
815        let stream = self.try_get_stream(stream)?;
816
817        let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
818            HashMap::new()
819        } else {
820            self.device_meshes
821                .iter()
822                .map(|(k, v)| (k.clone(), v.0.clone()))
823                .collect()
824        };
825
826        let pipe = if let Some(destination) = destination {
827            let pipe = self
828                .pipes
829                .get(&destination)
830                .ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
831                .port();
832            Some(pipe)
833        } else {
834            None
835        };
836        // Resolve the value on the stream, then send the value to the pipe if provided,
837        // or back to the controller if not.
838        stream
839            .send_value(
840                cx,
841                seq,
842                cx.self_id().clone(),
843                mutates,
844                function,
845                args,
846                kwargs,
847                device_meshes,
848                pipe,
849            )
850            .await
851    }
852
853    async fn send_result_of_actor_call(
854        &mut self,
855        cx: &hyperactor::Context<Self>,
856        params: ActorCallParams,
857    ) -> Result<()> {
858        let stream = self.try_get_stream(params.stream)?;
859        stream
860            .send_result_of_actor_call(cx, cx.self_id().clone(), params)
861            .await?;
862        Ok(())
863    }
864    async fn call_actor_method(
865        &mut self,
866        cx: &hyperactor::Context<Self>,
867        params: ActorMethodParams,
868    ) -> Result<()> {
869        let stream = self.try_get_stream(params.call.stream)?;
870        stream.call_actor_method(cx, params).await?;
871        Ok(())
872    }
873    async fn split_comm(
874        &mut self,
875        cx: &hyperactor::Context<Self>,
876        dims: Vec<String>,
877        device_mesh: Ref,
878        stream_ref: StreamRef,
879        config: Option<NcclConfig>,
880    ) -> Result<()> {
881        let global_comm = self
882            .comm
883            .as_ref()
884            .context("tried to call SplitComm before BackendNetworkInit")?;
885        match self.device_meshes.get_mut(&device_mesh) {
886            Some((device_mesh, comm_map)) => {
887                // This rank is in the group to be split off. Split a new
888                // communicator for it off from the global communicator.
889                let stream = self
890                    .streams
891                    .get(&stream_ref)
892                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
893
894                let dims = SortedVec::from_unsorted(dims);
895
896                anyhow::ensure!(
897                    !comm_map.contains_key(&(stream_ref, dims.clone())),
898                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
899                );
900                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
901                let size = ranks_for_group.len();
902                let split_comm = global_comm
903                    .split_from(
904                        cx,
905                        ranks_for_group
906                            .into_iter()
907                            .map(|v| v.clone().try_into())
908                            .collect::<Result<Vec<_>, _>>()?,
909                        config,
910                    )
911                    .await?
912                    .context("split comm should include self rank")?;
913                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
914            }
915            None => {
916                // This rank is not in the group to be split off. We still need to
917                // participate in the commSplit call, however.
918                global_comm.split_from(cx, vec![], config).await?;
919            }
920        }
921        Ok(())
922    }
923
924    async fn split_comm_for_process_group(
925        &mut self,
926        cx: &hyperactor::Context<Self>,
927        remote_process_group_ref: Ref,
928        stream_ref: StreamRef,
929        config: Option<NcclConfig>,
930    ) -> Result<()> {
931        ensure!(
932            self.streams.contains_key(&stream_ref),
933            "invalid stream id: {:#?}",
934            stream_ref
935        );
936        let global_comm = self
937            .comm
938            .as_ref()
939            .context("tried to call SplitComm before BackendNetworkInit")?;
940        let state = self
941            .remote_process_groups
942            .get_mut(&remote_process_group_ref)
943            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
944        match self.device_meshes.get_mut(&state.device_mesh_ref) {
945            Some((device_mesh, _)) => {
946                // This rank is in the group to be split off. Split a new
947                // communicator for it off from the global communicator.
948                let entry = match state.comms.entry(stream_ref) {
949                    Entry::Vacant(entry) => entry,
950                    Entry::Occupied(_) => bail!(
951                        "comm already exists for remote process group {:#?} on stream {:#?}",
952                        remote_process_group_ref,
953                        stream_ref,
954                    ),
955                };
956                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
957                let split_comm = global_comm
958                    .split_from(
959                        cx,
960                        ranks_for_group
961                            .into_iter()
962                            .map(|v| v.clone().try_into())
963                            .collect::<Result<Vec<_>, _>>()?,
964                        config,
965                    )
966                    .await?
967                    .context("split comm should include self rank")?;
968                entry.insert(Arc::new(split_comm));
969            }
970            None => {
971                // This rank is not in the group to be split off. We still need to
972                // participate in the commSplit call, however.
973                global_comm.split_from(cx, vec![], config).await?;
974            }
975        }
976        Ok(())
977    }
978
979    async fn pipe_recv(
980        &mut self,
981        cx: &hyperactor::Context<Self>,
982        seq: Seq,
983        results: Vec<Option<Ref>>,
984        pipe: Ref,
985        stream: StreamRef,
986    ) -> Result<()> {
987        self.maybe_add_stream_to_recording(cx, stream).await?;
988
989        // Get a port for the pipe
990        let pipe = self
991            .pipes
992            .get(&pipe)
993            .ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
994        let pipe = pipe.port();
995        // Resolve the stream.
996        let stream = self.try_get_stream(stream)?;
997        // Push result into the stream.
998        stream.set_value(cx, seq, results, pipe).await
999    }
1000
1001    async fn set_ref_unit_tests_only(
1002        &mut self,
1003        cx: &hyperactor::Context<Self>,
1004        reference: Ref,
1005        value: WireValue,
1006        stream: StreamRef,
1007    ) -> Result<()> {
1008        let stream = self.try_get_stream(stream)?;
1009
1010        stream.set_ref_unit_tests_only(cx, reference, value).await
1011    }
1012
1013    async fn get_ref_unit_tests_only(
1014        &mut self,
1015        cx: &hyperactor::Context<Self>,
1016        ref_id: Ref,
1017        stream: StreamRef,
1018    ) -> Result<Option<Result<WireValue, String>>> {
1019        let stream = self.try_get_stream(stream)?;
1020        Ok(stream
1021            .get_ref_unit_tests_only(cx, ref_id.clone())
1022            .await?
1023            .map(|o| Ok(o?)))
1024    }
1025
1026    async fn define_recording(
1027        &mut self,
1028        cx: &hyperactor::Context<Self>,
1029        result: Ref,
1030        _nresults: usize,
1031        _nformals: usize,
1032        commands: Vec<WorkerMessage>,
1033        ntotal_messages: usize,
1034        index: usize,
1035    ) -> Result<()> {
1036        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
1037            bail!("already defining a different recording");
1038        }
1039        self.defining_recording = Some(result);
1040
1041        match self.recordings.entry(result) {
1042            Entry::Vacant(entry) => {
1043                ensure!(
1044                    index == 0,
1045                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
1046                    index
1047                );
1048                entry.insert(Recording::PartialRecording {
1049                    last_index: 0,
1050                    commands,
1051                });
1052            }
1053            Entry::Occupied(mut entry) => match entry.get_mut() {
1054                Recording::CompleteRecording { .. } => {
1055                    bail!("got DefineRecording message for already complete recording")
1056                }
1057                Recording::PartialRecording {
1058                    last_index,
1059                    commands: existing_commands,
1060                } => {
1061                    ensure!(
1062                        index == *last_index + 1,
1063                        "Got DefineRecording message with index = {:?}, but \
1064                            last seen index for recording is {:?}",
1065                        index,
1066                        last_index
1067                    );
1068                    *last_index = index;
1069                    existing_commands.extend(commands.into_iter());
1070                }
1071            },
1072        };
1073
1074        if index < ntotal_messages - 1 {
1075            return Ok(());
1076        }
1077        let commands = match self.recordings.remove(&result).unwrap() {
1078            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1079            Recording::PartialRecording { commands, .. } => {
1080                self.recordings.insert(
1081                    result,
1082                    Recording::CompleteRecording {
1083                        streams: HashSet::new(),
1084                    },
1085                );
1086                commands
1087            }
1088        };
1089
1090        for command in commands {
1091            WorkerMessageHandler::handle(self, cx, command).await?;
1092        }
1093
1094        match self.recordings.get(&result).unwrap() {
1095            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1096            Recording::CompleteRecording { streams, .. } => {
1097                for stream in streams {
1098                    self.try_get_stream(*stream)?
1099                        .finalize_recording(cx, result)
1100                        .await?;
1101                }
1102            }
1103        }
1104
1105        self.defining_recording = None;
1106        Ok(())
1107    }
1108
1109    async fn recording_formal(
1110        &mut self,
1111        cx: &hyperactor::Context<Self>,
1112        result: Ref,
1113        argument_index: usize,
1114        stream: StreamRef,
1115    ) -> Result<()> {
1116        ensure!(self.defining_recording.is_some());
1117        self.maybe_add_stream_to_recording(cx, stream).await?;
1118        self.try_get_stream(stream)?
1119            .recording_formal(cx, result, argument_index)
1120            .await
1121    }
1122
1123    async fn recording_result(
1124        &mut self,
1125        cx: &hyperactor::Context<Self>,
1126        result: Ref,
1127        output_index: usize,
1128        stream: StreamRef,
1129    ) -> Result<()> {
1130        ensure!(self.defining_recording.is_some());
1131        self.maybe_add_stream_to_recording(cx, stream).await?;
1132        self.try_get_stream(stream)?
1133            .recording_result(cx, result, output_index)
1134            .await
1135    }
1136
1137    async fn call_recording(
1138        &mut self,
1139        cx: &hyperactor::Context<Self>,
1140        seq: Seq,
1141        recording: Ref,
1142        results: Vec<Ref>,
1143        actuals: Vec<Ref>,
1144    ) -> Result<()> {
1145        ensure!(self.defining_recording.is_none());
1146        let recording_ref = recording;
1147        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1148            "could not find recording: {:#?}",
1149            recording
1150        ))?;
1151        match recording {
1152            Recording::PartialRecording { .. } => {
1153                bail!("cannot call recording because it is incomplete")
1154            }
1155            Recording::CompleteRecording { streams } => try_join_all(
1156                streams
1157                    .iter()
1158                    .map(|stream| self.try_get_stream(*stream))
1159                    .collect::<Result<Vec<_>>>()?
1160                    .into_iter()
1161                    .map(|stream| {
1162                        stream.call_recording(
1163                            cx,
1164                            seq,
1165                            recording_ref,
1166                            results.clone(),
1167                            actuals.clone(),
1168                        )
1169                    }),
1170            )
1171            .await
1172            .map(|_| ()),
1173        }
1174    }
1175}
1176
1177#[cfg(test)]
1178mod tests {
1179    use std::assert_matches::assert_matches;
1180    use std::process::Stdio;
1181
1182    use anyhow::Result;
1183    use hyperactor::Mailbox;
1184    use hyperactor::Named;
1185    use hyperactor::WorldId;
1186    use hyperactor::actor::ActorStatus;
1187    use hyperactor::channel::ChannelAddr;
1188    use hyperactor::id;
1189    use hyperactor::mailbox::open_port;
1190    use hyperactor::proc::Proc;
1191    use hyperactor_multiprocess::System;
1192    use hyperactor_multiprocess::proc_actor::Environment;
1193    use hyperactor_multiprocess::proc_actor::ProcActor;
1194    use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1195    use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1196    use hyperactor_multiprocess::system_actor::Shape;
1197    use hyperactor_multiprocess::system_actor::SystemMessageClient;
1198    use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1199    use hyperactor_multiprocess::system_actor::WorldStatus;
1200    use monarch_messages::controller::ControllerMessage;
1201    use monarch_messages::controller::WorkerError;
1202    use monarch_messages::worker::WorkerMessageClient;
1203    use monarch_types::PickledPyObject;
1204    use monarch_types::PyTree;
1205    use pyo3::IntoPyObjectExt;
1206    use pyo3::Python;
1207    use pyo3::prelude::*;
1208    use pyo3::types::PyList;
1209    use pyo3::types::PyString;
1210    use rand::Rng;
1211    use rand::distributions::Alphanumeric;
1212    use timed_test::async_timed_test;
1213    use tokio::io::BufReader;
1214    use tokio::process::Command;
1215    use tokio_retry::Retry;
1216    use tokio_retry::strategy::FixedInterval;
1217    use torch_sys::Device;
1218    use torch_sys::DeviceIndex;
1219    use torch_sys::MemoryFormat;
1220
1221    use super::*;
1222    use crate::test_util::test_setup;
1223
1224    #[async_timed_test(timeout_secs = 60)]
1225    async fn basic_worker() -> Result<()> {
1226        test_setup()?;
1227
1228        let proc = Proc::local();
1229        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1230
1231        let worker_handle = proc
1232            .spawn::<WorkerActor>(
1233                "worker",
1234                WorkerParams {
1235                    world_size: 1,
1236                    rank: 0,
1237                    device_index: None,
1238                    controller_actor: controller_ref,
1239                },
1240            )
1241            .await
1242            .unwrap();
1243        worker_handle
1244            .command_group(
1245                &client,
1246                vec![
1247                    WorkerMessage::CreateStream {
1248                        id: 1.into(),
1249                        stream_creation: StreamCreationMode::UseDefaultStream,
1250                    },
1251                    WorkerMessage::CallFunction(CallFunctionParams {
1252                        seq: 0.into(),
1253                        results: vec![Some(0.into())],
1254                        mutates: vec![],
1255                        function: "torch.ops.aten.ones.default".into(),
1256                        args: vec![WireValue::IntList(vec![2, 3])],
1257                        kwargs: HashMap::new(),
1258                        stream: 1.into(),
1259                        remote_process_groups: vec![],
1260                    }),
1261                    WorkerMessage::CallFunction(CallFunctionParams {
1262                        seq: 2.into(),
1263                        results: vec![Some(Ref { id: 2 })],
1264                        mutates: vec![0.into()],
1265                        function: "torch.ops.aten.sub_.Scalar".into(),
1266                        args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1267                        kwargs: HashMap::new(),
1268                        stream: 1.into(),
1269                        remote_process_groups: vec![],
1270                    }),
1271                    WorkerMessage::CallFunction(CallFunctionParams {
1272                        seq: 3.into(),
1273                        results: vec![Some(Ref { id: 3 })],
1274                        mutates: vec![],
1275                        function: "torch.ops.aten.zeros.default".into(),
1276                        args: vec![WireValue::IntList(vec![2, 3])],
1277                        kwargs: HashMap::new(),
1278                        stream: 1.into(),
1279                        remote_process_groups: vec![],
1280                    }),
1281                    WorkerMessage::CallFunction(CallFunctionParams {
1282                        seq: 4.into(),
1283                        results: vec![Some(Ref { id: 4 })],
1284                        mutates: vec![],
1285                        function: "torch.ops.aten.allclose.default".into(),
1286                        args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1287                        kwargs: HashMap::new(),
1288                        stream: 1.into(),
1289                        remote_process_groups: vec![],
1290                    }),
1291                ],
1292            )
1293            .await
1294            .unwrap();
1295
1296        let result: bool = worker_handle
1297            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1298            .await
1299            .unwrap()
1300            .unwrap()
1301            .unwrap()
1302            .try_into()
1303            .unwrap();
1304        worker_handle.drain_and_stop().unwrap();
1305        worker_handle.await;
1306        let error_responses = controller_rx.drain();
1307        assert!(
1308            error_responses.is_empty(),
1309            "Expected no error responses, got: {:#?}",
1310            error_responses
1311        );
1312        assert!(result);
1313
1314        Ok(())
1315    }
1316
1317    #[async_timed_test(timeout_secs = 60)]
1318    async fn error_sends_response() -> Result<()> {
1319        test_setup()?;
1320
1321        let proc = Proc::local();
1322        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1323
1324        let worker_handle = proc
1325            .spawn::<WorkerActor>(
1326                "worker",
1327                WorkerParams {
1328                    world_size: 1,
1329                    rank: 0,
1330                    device_index: None,
1331                    controller_actor: controller_ref,
1332                },
1333            )
1334            .await
1335            .unwrap();
1336        worker_handle
1337            .command_group(
1338                &client,
1339                vec![
1340                    WorkerMessage::CreateStream {
1341                        id: 1.into(),
1342                        stream_creation: StreamCreationMode::UseDefaultStream,
1343                    },
1344                    WorkerMessage::CallFunction(CallFunctionParams {
1345                        seq: 0.into(),
1346                        results: vec![Some(0.into())],
1347                        mutates: vec![],
1348                        function: "torch.ops.aten.rand.default".into(),
1349                        args: vec![],
1350                        kwargs: HashMap::new(),
1351                        stream: 1.into(),
1352                        remote_process_groups: vec![],
1353                    }),
1354                    WorkerMessage::Exit { error: None },
1355                ],
1356            )
1357            .await
1358            .unwrap();
1359
1360        worker_handle.drain_and_stop().unwrap();
1361        worker_handle.await;
1362        let response_message = controller_rx.recv().await.unwrap();
1363        match response_message {
1364            ControllerMessage::RemoteFunctionFailed {
1365                seq,
1366                error: WorkerError { backtrace: msg, .. },
1367            } => {
1368                assert_eq!(seq, 0.into());
1369                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1370            }
1371            _ => panic!("unexpected response {:#?}", response_message),
1372        }
1373
1374        Ok(())
1375    }
1376
1377    #[async_timed_test(timeout_secs = 60)]
1378    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1379        test_setup()?;
1380
1381        let proc = Proc::local();
1382        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1383
1384        let worker_handle = proc
1385            .spawn::<WorkerActor>(
1386                "worker",
1387                WorkerParams {
1388                    world_size: 1,
1389                    rank: 0,
1390                    device_index: None,
1391                    controller_actor: controller_ref,
1392                },
1393            )
1394            .await
1395            .unwrap();
1396        worker_handle
1397            .command_group(
1398                &client,
1399                vec![
1400                    WorkerMessage::CreateStream {
1401                        id: 1.into(),
1402                        stream_creation: StreamCreationMode::UseDefaultStream,
1403                    },
1404                    WorkerMessage::SetRefUnitTestsOnly {
1405                        reference: 0.into(),
1406                        value: WireValue::Int(1),
1407                        stream: 1.into(),
1408                    },
1409                    WorkerMessage::CallFunction(CallFunctionParams {
1410                        seq: 0.into(),
1411                        results: vec![Some(Ref { id: 2 })],
1412                        mutates: vec![0.into()],
1413                        function: "i.dont.exist".into(),
1414                        args: vec![],
1415                        kwargs: HashMap::new(),
1416                        stream: 1.into(),
1417                        remote_process_groups: vec![],
1418                    }),
1419                ],
1420            )
1421            .await
1422            .unwrap();
1423
1424        let result = worker_handle
1425            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1426            .await?;
1427
1428        // Stop/drain worker before asserts to avoid hangs.
1429        worker_handle.drain_and_stop().unwrap();
1430        worker_handle.await;
1431
1432        let mutated_ref = result
1433            .context("no such ref")?
1434            .err()
1435            .context("expected error")?;
1436        assert!(mutated_ref.contains("failed to resolve function"));
1437
1438        let responses = controller_rx.drain();
1439        assert_eq!(
1440            responses.len(),
1441            1,
1442            "Expected one response, got: {:#?}",
1443            responses
1444        );
1445        Ok(())
1446    }
1447
1448    #[async_timed_test(timeout_secs = 60)]
1449    async fn accessing_errored_dependency() -> Result<()> {
1450        test_setup()?;
1451
1452        let proc = Proc::local();
1453        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1454
1455        let worker_handle = proc
1456            .spawn::<WorkerActor>(
1457                "worker",
1458                WorkerParams {
1459                    world_size: 1,
1460                    rank: 0,
1461                    device_index: None,
1462                    controller_actor: controller_ref,
1463                },
1464            )
1465            .await
1466            .unwrap();
1467        worker_handle
1468            .command_group(
1469                &client,
1470                vec![
1471                    WorkerMessage::CreateStream {
1472                        id: 1.into(),
1473                        stream_creation: StreamCreationMode::UseDefaultStream,
1474                    },
1475                    WorkerMessage::CallFunction(CallFunctionParams {
1476                        seq: 0.into(),
1477                        results: vec![Some(0.into())],
1478                        mutates: vec![],
1479                        function: "i.dont.exist".into(),
1480                        args: vec![],
1481                        kwargs: HashMap::new(),
1482                        stream: 1.into(),
1483                        remote_process_groups: vec![],
1484                    }),
1485                    WorkerMessage::CallFunction(CallFunctionParams {
1486                        seq: 1.into(),
1487                        results: vec![Some(1.into())],
1488                        mutates: vec![],
1489                        function: "torch.ops.aten.sub_.Scalar".into(),
1490                        args: vec![WireValue::Ref(0.into())],
1491                        kwargs: HashMap::new(),
1492                        stream: 1.into(),
1493                        remote_process_groups: vec![],
1494                    }),
1495                    WorkerMessage::Exit { error: None },
1496                ],
1497            )
1498            .await
1499            .unwrap();
1500
1501        worker_handle.drain_and_stop().unwrap();
1502        worker_handle.await;
1503
1504        let responses = controller_rx.drain();
1505        assert_eq!(
1506            responses.len(),
1507            1,
1508            "Expected one response, got: {:#?}",
1509            responses
1510        );
1511
1512        match &responses[0] {
1513            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1514                assert_eq!(seq, &0.into())
1515            }
1516            _ => panic!("unexpected response {:#?}", responses[0]),
1517        };
1518        Ok(())
1519    }
1520
1521    #[async_timed_test(timeout_secs = 60)]
1522    async fn py_remote_function_calls() -> Result<()> {
1523        test_setup()?;
1524
1525        let proc = Proc::local();
1526        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1527
1528        let worker_handle = proc
1529            .spawn::<WorkerActor>(
1530                "worker",
1531                WorkerParams {
1532                    world_size: 1,
1533                    rank: 0,
1534                    device_index: None,
1535                    controller_actor: controller_ref,
1536                },
1537            )
1538            .await
1539            .unwrap();
1540        let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1541            Python::with_gil(|py| {
1542                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1543                    .into_any()
1544                    .try_into()?;
1545                let sort_list: PickledPyObject =
1546                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1547                let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1548                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1549                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1550                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1551                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1552                let device: PickledPyObject = py
1553                    .import("torch")?
1554                    .getattr("device")?
1555                    .call1(("cuda:1",))?
1556                    .try_into()?;
1557                let memory_format: PickledPyObject = py
1558                    .import("torch")?
1559                    .getattr("contiguous_format")?
1560                    .try_into()?;
1561                PyResult::Ok((
1562                    split_arg,
1563                    sort_list,
1564                    mesh_ref,
1565                    dim,
1566                    layout,
1567                    none,
1568                    scalar,
1569                    device,
1570                    memory_format,
1571                ))
1572            })?;
1573
1574        worker_handle
1575            .command_group(
1576                &client,
1577                vec![
1578                    WorkerMessage::CreateStream {
1579                        id: 1.into(),
1580                        stream_creation: StreamCreationMode::UseDefaultStream,
1581                    },
1582                    WorkerMessage::CallFunction(CallFunctionParams {
1583                        seq: 0.into(),
1584                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1585                        mutates: vec![],
1586                        function: "os.path.split".into(),
1587                        args: vec![split_arg.into()],
1588                        kwargs: HashMap::new(),
1589                        stream: 1.into(),
1590                        remote_process_groups: vec![],
1591                    }),
1592                    WorkerMessage::CallFunction(CallFunctionParams {
1593                        seq: 2.into(),
1594                        results: vec![Some(4.into()), None, None, None, None],
1595                        mutates: vec![],
1596                        function: "builtins.sorted".into(),
1597                        args: vec![sort_list.into()],
1598                        kwargs: HashMap::new(),
1599                        stream: 1.into(),
1600                        remote_process_groups: vec![],
1601                    }),
1602                    WorkerMessage::CreateDeviceMesh {
1603                        result: 5.into(),
1604                        names: vec!["x".into()],
1605                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1606                    },
1607                    WorkerMessage::CallFunction(CallFunctionParams {
1608                        seq: 2.into(),
1609                        results: vec![Some(6.into())],
1610                        mutates: vec![],
1611                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1612                        args: vec![mesh_ref.into(), dim.into()],
1613                        kwargs: HashMap::new(),
1614                        stream: 1.into(),
1615                        remote_process_groups: vec![],
1616                    }),
1617                    WorkerMessage::CallFunction(CallFunctionParams {
1618                        seq: 4.into(),
1619                        results: vec![Some(7.into())],
1620                        mutates: vec![],
1621                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1622                            .into(),
1623                        args: vec![scalar.into()],
1624                        kwargs: HashMap::new(),
1625                        stream: 1.into(),
1626                        remote_process_groups: vec![],
1627                    }),
1628                    WorkerMessage::CallFunction(CallFunctionParams {
1629                        seq: 5.into(),
1630                        results: vec![Some(8.into())],
1631                        mutates: vec![],
1632                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1633                        args: vec![layout.into()],
1634                        kwargs: HashMap::new(),
1635                        stream: 1.into(),
1636                        remote_process_groups: vec![],
1637                    }),
1638                    WorkerMessage::CallFunction(CallFunctionParams {
1639                        seq: 6.into(),
1640                        results: vec![Some(9.into())],
1641                        mutates: vec![],
1642                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1643                        args: vec![none.into()],
1644                        kwargs: HashMap::new(),
1645                        stream: 1.into(),
1646                        remote_process_groups: vec![],
1647                    }),
1648                    // Verify that a function that returns `None` matches up with an
1649                    // empty result list.
1650                    WorkerMessage::CallFunction(CallFunctionParams {
1651                        seq: 7.into(),
1652                        results: vec![None],
1653                        mutates: vec![],
1654                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1655                        args: vec![],
1656                        kwargs: HashMap::new(),
1657                        stream: 1.into(),
1658                        remote_process_groups: vec![],
1659                    }),
1660                    WorkerMessage::CallFunction(CallFunctionParams {
1661                        seq: 8.into(),
1662                        results: vec![Some(10.into())],
1663                        mutates: vec![],
1664                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1665                        args: vec![device.into()],
1666                        kwargs: HashMap::new(),
1667                        stream: 1.into(),
1668                        remote_process_groups: vec![],
1669                    }),
1670                    WorkerMessage::CallFunction(CallFunctionParams {
1671                        seq: 9.into(),
1672                        results: vec![Some(11.into())],
1673                        mutates: vec![],
1674                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1675                            .into(),
1676                        args: vec![memory_format.into()],
1677                        kwargs: HashMap::new(),
1678                        stream: 1.into(),
1679                        remote_process_groups: vec![],
1680                    }),
1681                    // Test that list of tests can be passes correctly
1682                    WorkerMessage::CallFunction(CallFunctionParams {
1683                        seq: 10.into(),
1684                        results: vec![Some(12.into())],
1685                        mutates: vec![],
1686                        function: "torch.ops.aten.ones.default".into(),
1687                        args: vec![WireValue::IntList(vec![2, 3])],
1688                        kwargs: HashMap::new(),
1689                        stream: 1.into(),
1690                        remote_process_groups: vec![],
1691                    }),
1692                    WorkerMessage::CallFunction(CallFunctionParams {
1693                        seq: 11.into(),
1694                        results: vec![Some(13.into())],
1695                        mutates: vec![],
1696                        function: "torch.ops.aten.stack.default".into(),
1697                        args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1698                        kwargs: HashMap::new(),
1699                        stream: 1.into(),
1700                        remote_process_groups: vec![],
1701                    }),
1702                ],
1703            )
1704            .await
1705            .unwrap();
1706
1707        let result1: String = worker_handle
1708            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1709            .await
1710            .unwrap()
1711            .unwrap()
1712            .unwrap()
1713            .try_into()
1714            .unwrap();
1715        let result2: String = worker_handle
1716            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1717            .await
1718            .unwrap()
1719            .unwrap()
1720            .unwrap()
1721            .try_into()
1722            .unwrap();
1723        let result3: i64 = worker_handle
1724            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1725            .await
1726            .unwrap()
1727            .unwrap()
1728            .unwrap()
1729            .try_into()
1730            .unwrap();
1731        let result4: i64 = worker_handle
1732            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1733            .await
1734            .unwrap()
1735            .unwrap()
1736            .unwrap()
1737            .try_into()
1738            .unwrap();
1739        assert_eq!(
1740            ScalarType::Float,
1741            worker_handle
1742                .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1743                .await
1744                .unwrap()
1745                .unwrap()
1746                .unwrap()
1747                .try_into()
1748                .unwrap()
1749        );
1750        assert_eq!(
1751            Layout::Strided,
1752            worker_handle
1753                .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1754                .await
1755                .unwrap()
1756                .unwrap()
1757                .unwrap()
1758                .try_into()
1759                .unwrap()
1760        );
1761        assert_matches!(
1762            worker_handle
1763                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1764                .await
1765                .unwrap()
1766                .unwrap()
1767                .unwrap(),
1768            WireValue::None(()),
1769        );
1770        let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1771        assert_eq!(
1772            device,
1773            worker_handle
1774                .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1775                .await
1776                .unwrap()
1777                .unwrap()
1778                .unwrap()
1779                .try_into()
1780                .unwrap()
1781        );
1782        assert_matches!(
1783            worker_handle
1784                .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1785                .await
1786                .unwrap()
1787                .unwrap()
1788                .unwrap(),
1789            WireValue::MemoryFormat(MemoryFormat::Contiguous),
1790        );
1791
1792        worker_handle.drain_and_stop().unwrap();
1793        worker_handle.await;
1794        let error_responses = controller_rx.drain();
1795        assert!(
1796            error_responses.is_empty(),
1797            "Expected no error responses, got: {:#?}",
1798            error_responses
1799        );
1800
1801        assert_eq!(result1, "/fbs/fbc/foo");
1802        assert_eq!(result2, "bar");
1803        assert_eq!(result3, 1);
1804        assert_eq!(result4, 0);
1805
1806        Ok(())
1807    }
1808
1809    #[async_timed_test(timeout_secs = 60)]
1810    async fn delete_refs() -> Result<()> {
1811        test_setup()?;
1812
1813        let proc = Proc::local();
1814        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1815
1816        let worker_handle = proc
1817            .spawn::<WorkerActor>(
1818                "worker",
1819                WorkerParams {
1820                    world_size: 1,
1821                    rank: 0,
1822                    device_index: None,
1823                    controller_actor: controller_ref,
1824                },
1825            )
1826            .await
1827            .unwrap();
1828        worker_handle
1829            .command_group(
1830                &client,
1831                vec![
1832                    WorkerMessage::CreateStream {
1833                        id: 0.into(),
1834                        stream_creation: StreamCreationMode::CreateNewStream,
1835                    },
1836                    WorkerMessage::CreateStream {
1837                        id: 1.into(),
1838                        stream_creation: StreamCreationMode::CreateNewStream,
1839                    },
1840                    WorkerMessage::SetRefUnitTestsOnly {
1841                        reference: Ref { id: 2 },
1842                        value: WireValue::Bool(false),
1843                        stream: 0.into(),
1844                    },
1845                    WorkerMessage::SetRefUnitTestsOnly {
1846                        reference: Ref { id: 3 },
1847                        value: WireValue::Bool(true),
1848                        stream: 0.into(),
1849                    },
1850                    WorkerMessage::SetRefUnitTestsOnly {
1851                        reference: Ref { id: 4 },
1852                        value: WireValue::Int(0),
1853                        stream: 1.into(),
1854                    },
1855                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1856                ],
1857            )
1858            .await
1859            .unwrap();
1860
1861        let result: bool = worker_handle
1862            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1863            .await
1864            .unwrap()
1865            .unwrap()
1866            .unwrap()
1867            .try_into()
1868            .unwrap();
1869        let fail_result = worker_handle
1870            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1871            .await
1872            .unwrap();
1873
1874        worker_handle.drain_and_stop().unwrap();
1875        worker_handle.await;
1876
1877        assert!(result, "should be able to get a non-deleted ref");
1878        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1879
1880        Ok(())
1881    }
1882
1883    #[async_timed_test(timeout_secs = 60)]
1884    async fn request_status() -> Result<()> {
1885        test_setup()?;
1886
1887        let proc = Proc::local();
1888        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1889
1890        let worker_handle = proc
1891            .spawn::<WorkerActor>(
1892                "worker",
1893                WorkerParams {
1894                    world_size: 1,
1895                    rank: 0,
1896                    device_index: None,
1897                    controller_actor: controller_ref,
1898                },
1899            )
1900            .await
1901            .unwrap();
1902        worker_handle
1903            .command_group(
1904                &client,
1905                vec![
1906                    WorkerMessage::CreateStream {
1907                        id: 0.into(),
1908                        stream_creation: StreamCreationMode::CreateNewStream,
1909                    },
1910                    WorkerMessage::CreateStream {
1911                        id: 1.into(),
1912                        stream_creation: StreamCreationMode::CreateNewStream,
1913                    },
1914                ],
1915            )
1916            .await
1917            .unwrap();
1918
1919        for i in 0..100 {
1920            // call alternating functions on this stream.
1921            worker_handle
1922                .call_function(
1923                    &client,
1924                    CallFunctionParams {
1925                        seq: i.into(),
1926                        results: vec![Some(Ref { id: i + 2 })],
1927                        mutates: vec![],
1928                        function: "torch.ops.aten.ones.default".into(),
1929                        args: vec![WireValue::IntList(vec![2, 3])],
1930                        kwargs: HashMap::new(),
1931                        stream: (i % 2).into(),
1932                        remote_process_groups: vec![],
1933                    },
1934                )
1935                .await
1936                .unwrap();
1937        }
1938
1939        worker_handle
1940            .request_status(&client, 100.into(), false)
1941            .await
1942            .unwrap();
1943
1944        worker_handle.drain_and_stop().unwrap();
1945        worker_handle.await;
1946
1947        let mut responses = controller_rx.drain();
1948        assert_eq!(
1949            responses.len(),
1950            1,
1951            "Expected one response, got: {:#?}",
1952            responses
1953        );
1954
1955        let response = responses.pop().unwrap();
1956        match response {
1957            ControllerMessage::Status { seq, .. } => {
1958                assert_eq!(seq, 101.into())
1959            }
1960            _ => panic!("unexpected response {:#?}", response),
1961        };
1962
1963        Ok(())
1964    }
1965
1966    #[async_timed_test(timeout_secs = 60)]
1967    async fn backend_network_init() {
1968        let proc = Proc::local();
1969        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1970
1971        let worker_handle1 = proc
1972            .spawn::<WorkerActor>(
1973                "worker0",
1974                WorkerParams {
1975                    world_size: 2,
1976                    rank: 0,
1977                    device_index: Some(0),
1978                    controller_actor: controller_ref.clone(),
1979                },
1980            )
1981            .await
1982            .unwrap();
1983        let worker_handle2 = proc
1984            .spawn::<WorkerActor>(
1985                "worker1",
1986                WorkerParams {
1987                    world_size: 2,
1988                    rank: 1,
1989                    device_index: Some(1),
1990                    controller_actor: controller_ref,
1991                },
1992            )
1993            .await
1994            .unwrap();
1995
1996        let unique_id = UniqueId::new().unwrap();
1997        worker_handle1
1998            .backend_network_init(&client, unique_id.clone())
1999            .await
2000            .unwrap();
2001        worker_handle2
2002            .backend_network_init(&client, unique_id)
2003            .await
2004            .unwrap();
2005
2006        worker_handle1.drain_and_stop().unwrap();
2007        worker_handle1.await;
2008        worker_handle2.drain_and_stop().unwrap();
2009        worker_handle2.await;
2010    }
2011
2012    #[async_timed_test(timeout_secs = 60)]
2013    async fn send_value() -> Result<()> {
2014        test_setup()?;
2015
2016        let proc = Proc::local();
2017        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2018
2019        let worker_handle = proc
2020            .spawn::<WorkerActor>(
2021                "worker",
2022                WorkerParams {
2023                    world_size: 1,
2024                    rank: 0,
2025                    device_index: None,
2026                    controller_actor: controller_ref,
2027                },
2028            )
2029            .await
2030            .unwrap();
2031        worker_handle
2032            .command_group(
2033                &client,
2034                vec![
2035                    WorkerMessage::CreateStream {
2036                        id: 1.into(),
2037                        stream_creation: StreamCreationMode::UseDefaultStream,
2038                    },
2039                    WorkerMessage::CallFunction(CallFunctionParams {
2040                        seq: 0.into(),
2041                        results: vec![Some(0.into())],
2042                        mutates: vec![],
2043                        function: "torch.ops.aten.ones.default".into(),
2044                        args: vec![WireValue::IntList(vec![2, 3])],
2045                        kwargs: HashMap::new(),
2046                        stream: 1.into(),
2047                        remote_process_groups: vec![],
2048                    }),
2049                    WorkerMessage::SendValue {
2050                        seq: 1.into(),
2051                        destination: None,
2052                        mutates: vec![],
2053                        function: None,
2054                        args: vec![WireValue::Ref(0.into())],
2055                        kwargs: HashMap::new(),
2056                        stream: 1.into(),
2057                    },
2058                    WorkerMessage::SendValue {
2059                        seq: 2.into(),
2060                        destination: None,
2061                        mutates: vec![],
2062                        function: Some("torch.ops.aten.var_mean.default".into()),
2063                        args: vec![WireValue::Ref(0.into())],
2064                        kwargs: HashMap::new(),
2065                        stream: 1.into(),
2066                    },
2067                    WorkerMessage::Exit { error: None },
2068                ],
2069            )
2070            .await
2071            .unwrap();
2072
2073        worker_handle.drain_and_stop()?;
2074        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2075
2076        let mut responses = controller_rx.drain();
2077        assert_eq!(
2078            responses.len(),
2079            3,
2080            "Expected one response, got: {:#?}",
2081            responses
2082        );
2083
2084        match responses.pop().unwrap() {
2085            ControllerMessage::FetchResult { seq, value } => {
2086                assert_eq!(seq, 2.into());
2087                let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2088                assert_eq!(value.leaves().len(), 2);
2089            }
2090            resp => panic!("unexpected response {:#?}", resp),
2091        };
2092        match responses.pop().unwrap() {
2093            ControllerMessage::FetchResult { seq, .. } => {
2094                assert_eq!(seq, 1.into())
2095            }
2096            resp => panic!("unexpected response {:#?}", resp),
2097        };
2098        Ok(())
2099    }
2100
2101    #[async_timed_test(timeout_secs = 60)]
2102    async fn send_value_err_result() -> Result<()> {
2103        test_setup()?;
2104
2105        let proc = Proc::local();
2106        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2107
2108        let worker_handle = proc
2109            .spawn::<WorkerActor>(
2110                "worker",
2111                WorkerParams {
2112                    world_size: 1,
2113                    rank: 0,
2114                    device_index: None,
2115                    controller_actor: controller_ref,
2116                },
2117            )
2118            .await
2119            .unwrap();
2120
2121        let ref_arg: PickledPyObject =
2122            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2123
2124        worker_handle
2125            .command_group(
2126                &client,
2127                vec![
2128                    WorkerMessage::CreateStream {
2129                        id: 1.into(),
2130                        stream_creation: StreamCreationMode::UseDefaultStream,
2131                    },
2132                    WorkerMessage::SetRefUnitTestsOnly {
2133                        reference: Ref { id: 2 },
2134                        value: WireValue::Bool(false),
2135                        stream: 1.into(),
2136                    },
2137                    WorkerMessage::SendValue {
2138                        seq: 1.into(),
2139                        destination: None,
2140                        mutates: vec![Ref { id: 2 }],
2141                        function: Some("non.existent.function".into()),
2142                        args: vec![],
2143                        kwargs: HashMap::new(),
2144                        stream: 1.into(),
2145                    },
2146                    WorkerMessage::SendValue {
2147                        seq: 2.into(),
2148                        destination: None,
2149                        mutates: vec![],
2150                        function: None,
2151                        args: vec![ref_arg.into()],
2152                        kwargs: HashMap::new(),
2153                        stream: 1.into(),
2154                    },
2155                    WorkerMessage::Exit { error: None },
2156                ],
2157            )
2158            .await
2159            .unwrap();
2160
2161        worker_handle.drain_and_stop()?;
2162        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2163
2164        let mut responses = controller_rx.drain();
2165        assert_eq!(
2166            responses.len(),
2167            3,
2168            "Expected one response, got: {:#?}",
2169            responses
2170        );
2171        match responses.pop() {
2172            Some(ControllerMessage::FetchResult { seq, value }) => {
2173                assert_eq!(seq, 2.into());
2174                assert!(value.is_err());
2175                assert!(
2176                    value
2177                        .unwrap_err()
2178                        .backtrace
2179                        .contains("failed to resolve function")
2180                );
2181            }
2182            _ => panic!("unexpected response {:#?}", responses),
2183        }
2184        match responses.pop() {
2185            Some(ControllerMessage::FetchResult { seq, value }) => {
2186                assert_eq!(seq, 1.into());
2187                assert!(value.is_err());
2188                assert!(
2189                    value
2190                        .unwrap_err()
2191                        .backtrace
2192                        .contains("failed to resolve function")
2193                );
2194            }
2195            _ => panic!("unexpected response {:#?}", responses),
2196        }
2197        Ok(())
2198    }
2199
2200    #[async_timed_test(timeout_secs = 60)]
2201    async fn pipe_send_recv() -> Result<()> {
2202        test_setup()?;
2203
2204        let proc = Proc::local();
2205        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2206
2207        let handle = proc
2208            .spawn::<WorkerActor>(
2209                "worker",
2210                WorkerParams {
2211                    world_size: 1,
2212                    rank: 0,
2213                    device_index: None,
2214                    controller_actor: controller_ref,
2215                },
2216            )
2217            .await
2218            .unwrap();
2219        let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
2220            PickledPyObject,
2221            PickledPyObject,
2222            PickledPyObject,
2223        ) = Python::with_gil(|py| {
2224            PyResult::Ok((
2225                PyList::new(py, [2, 3])?.into_any().try_into()?,
2226                Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
2227                Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
2228            ))
2229        })?;
2230
2231        handle
2232            .command_group(
2233                &client,
2234                vec![
2235                    WorkerMessage::CreateStream {
2236                        id: 0.into(),
2237                        stream_creation: StreamCreationMode::UseDefaultStream,
2238                    },
2239                    WorkerMessage::CreateDeviceMesh {
2240                        result: 1.into(),
2241                        names: vec!["x".into()],
2242                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2243                    },
2244                    // Create a tensor value which we'll send through the pipe.
2245                    WorkerMessage::CallFunction(CallFunctionParams {
2246                        seq: 0.into(),
2247                        results: vec![Some(2.into())],
2248                        mutates: vec![],
2249                        function: "torch.ops.aten.ones.default".into(),
2250                        args: vec![WireValue::IntList(vec![2, 3])],
2251                        kwargs: HashMap::new(),
2252                        stream: 0.into(),
2253                        remote_process_groups: vec![],
2254                    }),
2255                    WorkerMessage::CreatePipe {
2256                        result: 3.into(),
2257                        key: "unused".into(),
2258                        function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
2259                        max_messages: 1,
2260                        mesh: 1.into(),
2261                        args: vec![],
2262                        kwargs: HashMap::new(),
2263                    },
2264                    WorkerMessage::SendValue {
2265                        seq: 1.into(),
2266                        destination: Some(3.into()),
2267                        mutates: vec![],
2268                        function: Some(
2269                            "monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
2270                        ),
2271                        args: vec![resolve_value_arg.into()],
2272                        kwargs: HashMap::new(),
2273                        stream: 0.into(),
2274                    },
2275                    WorkerMessage::PipeRecv {
2276                        seq: 2.into(),
2277                        results: vec![Some(4.into())],
2278                        pipe: 3.into(),
2279                        stream: 0.into(),
2280                    },
2281                    WorkerMessage::CallFunction(CallFunctionParams {
2282                        seq: 0.into(),
2283                        results: vec![Some(5.into())],
2284                        mutates: vec![],
2285                        function: "torch.equal".into(),
2286                        args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
2287                        kwargs: HashMap::new(),
2288                        stream: 0.into(),
2289                        remote_process_groups: vec![],
2290                    }),
2291                ],
2292            )
2293            .await
2294            .unwrap();
2295
2296        let matches: bool = handle
2297            .get_ref_unit_tests_only(&client, 5.into(), 0.into())
2298            .await
2299            .unwrap()
2300            .unwrap()
2301            .unwrap()
2302            .try_into()
2303            .unwrap();
2304        assert!(matches);
2305
2306        handle.drain_and_stop()?;
2307        assert_matches!(handle.await, ActorStatus::Stopped);
2308
2309        let responses = controller_rx.drain();
2310        assert_eq!(
2311            responses.len(),
2312            0,
2313            "Expected one response, got: {:#?}",
2314            responses
2315        );
2316
2317        Ok(())
2318    }
2319
2320    fn get_random_channel_addr() -> ChannelAddr {
2321        let random_string = rand::thread_rng()
2322            .sample_iter(&Alphanumeric)
2323            .take(24)
2324            .map(char::from)
2325            .collect::<String>();
2326        format!("unix!@{random_string}").parse().unwrap()
2327    }
2328
2329    async fn ensure_world_ready(client: Mailbox, world: WorldId) -> Result<()> {
2330        tracing::info!("checking whether world {world} is ready");
2331        let retry_strategy = FixedInterval::from_millis(1000).take(100);
2332        Retry::spawn(retry_strategy, async || {
2333            let snapshot = SYSTEM_ACTOR_REF
2334                .snapshot(&client, SystemSnapshotFilter::default())
2335                .await?;
2336            let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2337            tracing::info!("world status: {:?}", world_snapshot.status);
2338            match world_snapshot.status {
2339                WorldStatus::Live => Ok(()),
2340                _ => Err(anyhow!("world is not live")),
2341            }
2342        })
2343        .await?;
2344        Ok(())
2345    }
2346
2347    #[async_timed_test(timeout_secs = 60)]
2348    async fn remote_process_group() -> Result<()> {
2349        test_setup()?;
2350
2351        // Spin up a system to manage the test setup.
2352        let timeout: Duration = Duration::from_secs(10);
2353        let system_addr = get_random_channel_addr();
2354        let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2355
2356        // Create a fake controller for the workers to talk to.
2357        let client = System::new(system_addr.clone()).attach().await?;
2358        let (handle, mut controller_rx) = client.open_port::<ControllerMessage>();
2359        handle.bind_to(ControllerMessage::port());
2360        let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.actor_id().clone());
2361
2362        // Create the worker world
2363        let world_size = 2;
2364        SYSTEM_ACTOR_REF
2365            .upsert_world(
2366                &client,
2367                id!(world),
2368                Shape::Definite(vec![world_size]),
2369                4,
2370                Environment::Local,
2371                HashMap::new(),
2372            )
2373            .await?;
2374
2375        // Bootstrap a proc for each worker
2376        let mut worker_process_handles = vec![];
2377        let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2378        for rank in 0..world_size {
2379            let world_id = "world".to_string();
2380            let proc_id = format!("{world_id}[{rank}]");
2381            worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2382
2383            let mut handle =
2384                Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2385                    anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2386                })?)
2387                .arg("worker")
2388                .arg(format!("--bootstrap-addr={system_addr}"))
2389                .arg(format!("--world-id={world_id}"))
2390                .arg(format!("--proc-id={proc_id}"))
2391                .stdout(Stdio::piped())
2392                .stdin(Stdio::piped())
2393                .kill_on_drop(true)
2394                .spawn()?;
2395
2396            let out = handle.stdout.take().unwrap();
2397            tokio::spawn(async move {
2398                let mut reader = BufReader::new(out);
2399                tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2400                    .await
2401                    .unwrap();
2402            });
2403            worker_process_handles.push(handle);
2404        }
2405
2406        // Wait for procs to initialize
2407        ensure_world_ready(client.clone(), id!(world)).await?;
2408
2409        // Spawn workers on each proc
2410        let (spawned_port, mut spawned_receiver) = open_port(&client);
2411        for (rank, worker_proc) in worker_procs.iter().enumerate() {
2412            let params = WorkerParams {
2413                world_size,
2414                rank,
2415                device_index: Some(rank.try_into().unwrap()),
2416                controller_actor: controller_ref.clone(),
2417            };
2418            worker_proc
2419                .spawn(
2420                    &client,
2421                    "monarch_tensor_worker::WorkerActor".to_owned(),
2422                    "worker".to_owned(),
2423                    bincode::serialize(&params)?,
2424                    spawned_port.bind(),
2425                )
2426                .await?;
2427        }
2428        let mut spawned = vec![];
2429        while spawned.len() < world_size {
2430            spawned.push(spawned_receiver.recv().await?);
2431        }
2432        tracing::info!("spawned {} worker actors", world_size);
2433        let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2434            .map(|rank| format!("world[{rank}].worker[0]"))
2435            .map(|name| ActorRef::attest(name.parse().unwrap()))
2436            .collect();
2437
2438        let remote_proc_grp_ref: PickledPyObject =
2439            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2440
2441        let unique_id = UniqueId::new()?;
2442        let messages = vec![
2443            WorkerMessage::CreateStream {
2444                id: 0.into(),
2445                stream_creation: StreamCreationMode::UseDefaultStream,
2446            },
2447            WorkerMessage::BackendNetworkInit(unique_id.clone()),
2448            WorkerMessage::CreateDeviceMesh {
2449                result: 1.into(),
2450                names: vec!["x".into()],
2451                ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2452            },
2453            WorkerMessage::CreateRemoteProcessGroup {
2454                result: 2.into(),
2455                device_mesh: 1.into(),
2456                dims: vec!["x".into()],
2457            },
2458            WorkerMessage::SplitCommForProcessGroup {
2459                remote_process_group: 2.into(),
2460                stream: 0.into(),
2461                config: None,
2462            },
2463            WorkerMessage::CallFunction(CallFunctionParams {
2464                seq: 0.into(),
2465                results: vec![Some(3.into())],
2466                mutates: vec![],
2467                function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2468                    .into(),
2469                args: vec![remote_proc_grp_ref.into()],
2470                kwargs: HashMap::new(),
2471                stream: 0.into(),
2472                remote_process_groups: vec![2.into()],
2473            }),
2474        ];
2475
2476        workers[0].command_group(&client, messages.clone()).await?;
2477        workers[1].command_group(&client, messages).await?;
2478
2479        let _ = workers[0]
2480            .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2481            .await?
2482            .unwrap()
2483            .unwrap();
2484
2485        let error_responses = controller_rx.drain();
2486        assert!(
2487            error_responses.is_empty(),
2488            "Expected no error responses, got: {:#?}",
2489            error_responses
2490        );
2491
2492        Ok(())
2493    }
2494}