monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30pub mod bootstrap;
31mod borrow;
32mod comm;
33pub mod device_mesh;
34pub mod pipe;
35pub mod py_pipe;
36pub mod stream;
37pub mod test_util;
38
39use std::collections::HashMap;
40use std::collections::HashSet;
41use std::collections::hash_map::Entry;
42use std::sync::Arc;
43
44use anyhow::Context;
45use anyhow::Result;
46use anyhow::anyhow;
47use anyhow::bail;
48use anyhow::ensure;
49use async_trait::async_trait;
50use borrow::Borrow;
51use comm::CommMessageClient;
52use comm::CommParams;
53use comm::NcclCommActor;
54use derive_more::TryInto;
55use device_mesh::DeviceMesh;
56use futures::future::try_join_all;
57use hyperactor::Actor;
58use hyperactor::ActorRef;
59use hyperactor::Bind;
60use hyperactor::Handler;
61use hyperactor::Named;
62use hyperactor::Unbind;
63use hyperactor::actor::ActorHandle;
64use hyperactor::context;
65use hyperactor::reference::ActorId;
66use hyperactor_mesh::comm::multicast::CastInfo;
67use itertools::Itertools;
68use monarch_hyperactor::shape::PyPoint;
69use monarch_messages::controller::ControllerActor;
70use monarch_messages::controller::ControllerMessageClient;
71use monarch_messages::controller::Seq;
72use monarch_messages::wire_value::WireValue;
73use monarch_messages::worker::ActorCallParams;
74use monarch_messages::worker::ActorMethodParams;
75use monarch_messages::worker::CallFunctionError;
76use monarch_messages::worker::CallFunctionParams;
77use monarch_messages::worker::Factory;
78use monarch_messages::worker::Reduction;
79use monarch_messages::worker::Ref;
80use monarch_messages::worker::ResolvableFunction;
81use monarch_messages::worker::StreamCreationMode;
82use monarch_messages::worker::StreamRef;
83use monarch_messages::worker::WorkerMessage;
84use monarch_messages::worker::WorkerMessageHandler;
85use monarch_messages::worker::WorkerParams;
86use monarch_types::PyTree;
87use ndslice::Slice;
88use pipe::PipeActor;
89use pipe::PipeParams;
90use pyo3::Python;
91use pyo3::types::PyAnyMethods;
92use serde::Deserialize;
93use serde::Serialize;
94use sorted_vec::SortedVec;
95use stream::StreamActor;
96use stream::StreamMessageClient;
97use stream::StreamParams;
98use torch_sys::CudaDevice;
99use torch_sys::DeviceIndex;
100use torch_sys::Layout;
101use torch_sys::RValue;
102use torch_sys::ScalarType;
103use torch_sys::TensorCell;
104use torch_sys::factory_zeros;
105use torch_sys_cuda::nccl::NcclConfig;
106use torch_sys_cuda::nccl::ReduceOp;
107use torch_sys_cuda::nccl::UniqueId;
108
109#[derive(Debug)]
110struct RemoteProcessGroupState {
111    device_mesh_ref: Ref,
112    dims: SortedVec<String>,
113    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
114}
115
116impl RemoteProcessGroupState {
117    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
118        Self {
119            device_mesh_ref,
120            dims,
121            comms: HashMap::new(),
122        }
123    }
124}
125
126#[derive(Debug)]
127enum Recording {
128    // In the process of receiving DefineRecording messages for this
129    // recording.
130    PartialRecording {
131        // The index of the last DefineRecording message received.
132        last_index: usize,
133        // The list of commands seen so far for this recording.
134        commands: Vec<WorkerMessage>,
135    },
136
137    // The recording is ready to be run.
138    CompleteRecording {
139        // The list of streams on which this recording is defined.
140        streams: HashSet<StreamRef>,
141    },
142}
143
144/// A PyTorch runtime instance, operating on a single accelerator device,
145/// controlled via hyperactor messaging.
146///
147/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
148/// real work.
149///
150/// See [`WorkerMessage`] for what it can do!
151#[derive(Debug)]
152#[hyperactor::export(
153    spawn = true,
154    handlers = [
155        WorkerMessage {cast = true},
156        AssignRankMessage {cast = true},
157    ],
158)]
159pub struct WorkerActor {
160    device: Option<CudaDevice>,
161    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
162    /// Maps streams to the device mesh and a map of dim names to the concrete
163    /// communicator actor that represents the dimension for that stream.
164    device_meshes: HashMap<
165        Ref,
166        (
167            DeviceMesh,
168            // A map for comms for this mesh for a given pair of stream and dims.
169            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
170        ),
171    >,
172    world_size: usize,
173    rank: usize,
174    borrows: HashMap<u64, Borrow>,
175    comm: Option<ActorHandle<NcclCommActor>>,
176    controller_actor: ActorRef<ControllerActor>,
177    /// Pipes created for the worker.
178    pipes: HashMap<Ref, ActorHandle<PipeActor>>,
179    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
180    /// subsequent `CallFunction` calls, as this is where the actual allocation
181    /// will happen.
182    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
183    /// The comm actor for each pair of streams that need to send/recv tensors.
184    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
185    recordings: HashMap<Ref, Recording>,
186    defining_recording: Option<Ref>,
187    respond_with_python_message: bool,
188}
189
190impl WorkerActor {
191    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
192        self.streams
193            .get(&stream)
194            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
195    }
196
197    async fn maybe_add_stream_to_recording(
198        &mut self,
199        cx: &impl context::Actor,
200        stream: StreamRef,
201    ) -> Result<()> {
202        // If we're defining a recording, add the stream to the list of streams that
203        // this recording uses, and call define_recording on the stream.
204        if let Some(defining_recording) = self.defining_recording {
205            let recording = self.recordings.get_mut(&defining_recording).unwrap();
206            let fut = match recording {
207                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
208                Recording::CompleteRecording { streams } => {
209                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
210                        Ok(self
211                            .try_get_stream(stream)?
212                            .define_recording(cx, defining_recording))
213                    })
214                }
215            }
216            .transpose()?;
217            match fut {
218                Some(fut) => fut.await,
219                None => Ok(()),
220            }
221        } else {
222            Ok(())
223        }
224    }
225}
226
227#[async_trait]
228impl Actor for WorkerActor {
229    type Params = WorkerParams;
230
231    async fn new(
232        WorkerParams {
233            world_size,
234            rank,
235            device_index,
236            controller_actor,
237        }: Self::Params,
238    ) -> Result<Self> {
239        Ok(Self {
240            device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
241            streams: HashMap::new(),
242            device_meshes: HashMap::new(),
243            world_size,
244            rank,
245            borrows: HashMap::new(),
246            comm: None,
247            controller_actor,
248            pipes: HashMap::new(),
249            remote_process_groups: HashMap::new(),
250            send_recv_comms: HashMap::new(),
251            recordings: HashMap::new(),
252            defining_recording: None,
253            respond_with_python_message: false,
254        })
255    }
256
257    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
258}
259
260#[async_trait]
261impl Handler<AssignRankMessage> for WorkerActor {
262    async fn handle(
263        &mut self,
264        cx: &hyperactor::Context<Self>,
265        _: AssignRankMessage,
266    ) -> anyhow::Result<()> {
267        let point = cx.cast_point();
268        self.rank = point.rank();
269        self.respond_with_python_message = true;
270        Python::with_gil(|py| {
271            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
272            let p: PyPoint = point.into();
273            mesh_controller
274                .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
275                .unwrap();
276        });
277        Ok(())
278    }
279}
280
281/// Worker messages. These define the observable behavior of the worker, so the
282/// documentations here
283#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
284pub enum AssignRankMessage {
285    AssignRank(),
286}
287
288#[async_trait]
289impl Handler<WorkerMessage> for WorkerActor {
290    async fn handle(
291        &mut self,
292        cx: &hyperactor::Context<Self>,
293        message: WorkerMessage,
294    ) -> anyhow::Result<()> {
295        <Self as WorkerMessageHandler>::handle(self, cx, message).await
296    }
297}
298
299#[async_trait]
300impl WorkerMessageHandler for WorkerActor {
301    async fn backend_network_init(
302        &mut self,
303        cx: &hyperactor::Context<Self>,
304        unique_id: UniqueId,
305    ) -> Result<()> {
306        let device = self
307            .device
308            .expect("tried to init backend network on a non-CUDA worker");
309        let comm = NcclCommActor::spawn(
310            cx,
311            CommParams::New {
312                device,
313                unique_id,
314                world_size: self.world_size.try_into().unwrap(),
315                rank: self.rank.try_into().unwrap(),
316            },
317        )
318        .await?;
319
320        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
321        let cell = TensorCell::new(tensor);
322
323        comm.all_reduce(
324            cx,
325            cell,
326            ReduceOp::Sum,
327            torch_sys_cuda::cuda::Stream::get_current_stream(),
328        )
329        .await?;
330
331        // TODO: this blocks forward progress of the the actor loop while we
332        // wait for the streams to catch up. Once we have a way of spawning
333        // tasks that the actor system can monitor in a non-blocking way, we
334        // should remove this.
335
336        // We need to be careful to initialize the streams in a consistent order
337        // across all workers to avoid NCCL deadlocks. Use the refs to provide
338        // this order, as a stream's ref will be the same across all workers.
339        let sorted_streams = self
340            .streams
341            .iter()
342            .sorted_by_key(|(k, _)| *k)
343            .map(|(_, v)| v.as_ref());
344
345        let mut splits = Vec::new();
346        for _ in 0..sorted_streams.len() {
347            // Do the split in this event loop, to provide a deterministic
348            // order.
349            splits.push(comm.split_all(cx, None).await?);
350        }
351        let _: Vec<()> = try_join_all(
352            sorted_streams
353                .into_iter()
354                .zip(splits.into_iter())
355                .map(|(stream, split)| stream.init_comm(cx, split)),
356        )
357        .await?;
358
359        self.comm = Some(comm);
360
361        Ok(())
362    }
363
364    async fn backend_network_point_to_point_init(
365        &mut self,
366        cx: &hyperactor::Context<Self>,
367        from_stream: StreamRef,
368        to_stream: StreamRef,
369    ) -> Result<()> {
370        if !self.streams.contains_key(&from_stream) {
371            bail!("invalid from_stream id: {:#?}", from_stream);
372        }
373        if !self.streams.contains_key(&to_stream) {
374            bail!("invalid to_stream id: {:#?}", to_stream);
375        }
376        let global_comm = self
377            .comm
378            .as_ref()
379            .context("tried to call Reduce before BackendNetworkInit")?;
380        let comm = global_comm.split_all(cx, None).await?;
381        self.send_recv_comms
382            .insert((from_stream, to_stream), Arc::new(comm));
383        Ok(())
384    }
385
386    async fn call_function(
387        &mut self,
388        cx: &hyperactor::Context<Self>,
389        params: CallFunctionParams,
390    ) -> Result<()> {
391        let stream = self.try_get_stream(params.stream)?.clone();
392        self.maybe_add_stream_to_recording(cx, params.stream)
393            .await?;
394
395        let device_meshes = if params.function.as_torch_op().is_some() {
396            HashMap::new()
397        } else {
398            self.device_meshes
399                .iter()
400                .map(|(k, v)| (k.clone(), v.0.clone()))
401                .collect()
402        };
403
404        let mut remote_process_groups = HashMap::new();
405        for remote_process_group_ref in &params.remote_process_groups {
406            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
407                let dims_vec = state.dims.iter().cloned().collect();
408                let (device_mesh, _) = self
409                    .device_meshes
410                    .get(&state.device_mesh_ref)
411                    .ok_or_else(|| {
412                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
413                    })?
414                    .clone();
415                let comm = state.comms
416                    .get(&params.stream)
417                    .ok_or_else(|| {
418                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
419                    })?
420                    .clone();
421                remote_process_groups.insert(
422                    remote_process_group_ref.clone(),
423                    (device_mesh, dims_vec, comm),
424                );
425            }
426        }
427
428        stream
429            .call_function(cx, params, device_meshes, remote_process_groups)
430            .await?;
431
432        Ok(())
433    }
434
435    async fn command_group(
436        &mut self,
437        cx: &hyperactor::Context<Self>,
438        params: Vec<WorkerMessage>,
439    ) -> Result<()> {
440        for msg in params {
441            WorkerMessageHandler::handle(self, cx, msg).await?;
442        }
443        Ok(())
444    }
445
446    async fn create_stream(
447        &mut self,
448        cx: &hyperactor::Context<Self>,
449        result: StreamRef,
450        creation_mode: StreamCreationMode,
451    ) -> Result<()> {
452        let handle: ActorHandle<StreamActor> = StreamActor::spawn(
453            cx,
454            StreamParams {
455                world_size: self.world_size,
456                rank: self.rank,
457                creation_mode,
458                id: result,
459                device: self.device,
460                controller_actor: self.controller_actor.clone(),
461                respond_with_python_message: self.respond_with_python_message,
462            },
463        )
464        .await?;
465        self.streams.insert(result, Arc::new(handle));
466        Ok(())
467    }
468
469    async fn create_device_mesh(
470        &mut self,
471        _cx: &hyperactor::Context<Self>,
472        result: Ref,
473        names: Vec<String>,
474        ranks: Slice,
475    ) -> Result<()> {
476        self.device_meshes.insert(
477            result,
478            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
479        );
480        Ok(())
481    }
482
483    async fn create_remote_process_group(
484        &mut self,
485        _cx: &hyperactor::Context<Self>,
486        result: Ref,
487        device_mesh: Ref,
488        dims: Vec<String>,
489    ) -> Result<()> {
490        self.device_meshes
491            .get(&device_mesh)
492            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
493        match self.remote_process_groups.entry(result) {
494            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
495                device_mesh,
496                SortedVec::from_unsorted(dims),
497            )),
498            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
499        };
500        Ok(())
501    }
502
503    async fn borrow_create(
504        &mut self,
505        cx: &hyperactor::Context<Self>,
506        result: Ref,
507        borrow_id: u64,
508        tensor_ref: Ref,
509        from_stream: StreamRef,
510        to_stream: StreamRef,
511    ) -> Result<()> {
512        self.maybe_add_stream_to_recording(cx, from_stream).await?;
513        self.maybe_add_stream_to_recording(cx, to_stream).await?;
514        let from_stream = self.try_get_stream(from_stream)?.clone();
515        let to_stream = self.try_get_stream(to_stream)?.clone();
516
517        let borrow =
518            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
519        self.borrows.insert(borrow_id, borrow);
520        Ok(())
521    }
522
523    async fn borrow_first_use(
524        &mut self,
525        cx: &hyperactor::Context<Self>,
526        borrow: u64,
527    ) -> Result<()> {
528        let borrow = self
529            .borrows
530            .get_mut(&borrow)
531            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
532
533        borrow.first_use(cx).await?;
534        Ok(())
535    }
536
537    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
538        let borrow = self
539            .borrows
540            .get_mut(&borrow)
541            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
542
543        borrow.last_use(cx).await?;
544        Ok(())
545    }
546
547    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
548        let borrow = self
549            .borrows
550            .get_mut(&borrow_id)
551            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
552
553        borrow.drop(cx).await?;
554        self.borrows.remove(&borrow_id);
555        Ok(())
556    }
557
558    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
559        // Fan the delete message to all streams.
560        // Check for errors.
561        // TODO: this blocks forward progress of the the actor loop while we
562        // wait for the streams to catch up. Once we have a way of spawning
563        // tasks that the actor system can monitor in a non-blocking way, we
564        // should remove this.
565        let _: Vec<()> = try_join_all(
566            self.streams
567                .values()
568                .map(|s| s.delete_refs(cx, refs.clone())),
569        )
570        .await?;
571        Ok(())
572    }
573
574    async fn request_status(
575        &mut self,
576        cx: &hyperactor::Context<Self>,
577        seq: Seq,
578        controller: bool,
579    ) -> Result<()> {
580        // TODO: this blocks forward progress of the the actor loop while we
581        // wait for the streams to catch up. Once we have a way of spawning
582        // tasks that the actor system can monitor in a non-blocking way, we
583        // should remove this.
584        let _: Vec<()> = try_join_all(
585            self.streams
586                .values()
587                .map(|stream| stream.request_status(cx)),
588        )
589        .await?;
590
591        ControllerMessageClient::status(
592            &self.controller_actor,
593            cx,
594            seq.next(),
595            cx.self_id().clone(),
596            controller,
597        )
598        .await?;
599        Ok(())
600    }
601
602    async fn reduce(
603        &mut self,
604        cx: &hyperactor::Context<Self>,
605        result: Ref,
606        local_tensor: Ref,
607        factory: Factory,
608        source_mesh: Ref,
609        stream_ref: StreamRef,
610        dims: Vec<String>,
611        reduction: Reduction,
612        scatter: bool,
613        in_place: bool,
614        out: Option<Ref>,
615    ) -> Result<()> {
616        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
617
618        // Sort for stable indexing.
619        let dims = SortedVec::from_unsorted(dims);
620        let stream = self.try_get_stream(stream_ref)?.clone();
621
622        let (_, comm_map) = self
623            .device_meshes
624            .get_mut(&source_mesh)
625            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
626
627        let (size, comm) = comm_map
628            .get(&(stream_ref, dims.clone()))
629            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
630            .clone();
631
632        stream
633            .reduce(
634                cx,
635                comm,
636                size.try_into()?,
637                result,
638                local_tensor,
639                factory,
640                reduction,
641                scatter,
642                in_place,
643                out,
644            )
645            .await?;
646
647        Ok(())
648    }
649
650    async fn create_pipe(
651        &mut self,
652        cx: &hyperactor::Context<Self>,
653        result: Ref,
654        // TODO(agallagher): This is used in the python impl to name the socket
655        // path to use for comms, but we don't currently use a named socket.
656        _key: String,
657        function: ResolvableFunction,
658        max_messages: i64,
659        device_mesh: Ref,
660        args: Vec<WireValue>,
661        kwargs: HashMap<String, WireValue>,
662    ) -> Result<()> {
663        println!("CREATE PIPE1 {}", result);
664        let args: Vec<PyTree<RValue>> = args
665            .into_iter()
666            .map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
667            .collect();
668        let kwargs: HashMap<_, PyTree<RValue>> = kwargs
669            .into_iter()
670            .map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
671            .collect();
672        let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
673            CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
674        })?;
675        println!("CREATE PIPE2 {}", result);
676        // TODO(agallagher): Fix error prop. (When pipe is read from the pipes dict if it had an error it should cause a dependent error in send_value not an actor error as it does now)
677        let pipe = PipeActor::spawn(
678            cx,
679            PipeParams {
680                function,
681                max_messages,
682                ranks: device_mesh.0.ranks(),
683                sizes: device_mesh.0.sizes(),
684                args,
685                kwargs,
686            },
687        )
688        .await?;
689        println!("AFTER CREATE PIPE {}", result);
690
691        self.pipes.insert(result, pipe);
692        Ok(())
693    }
694
695    async fn send_tensor(
696        &mut self,
697        cx: &hyperactor::Context<Self>,
698        result: Ref,
699        from_ranks: Slice,
700        to_ranks: Slice,
701        tensor: Ref,
702        factory: Factory,
703        from_stream: StreamRef,
704        to_stream: StreamRef,
705    ) -> Result<()> {
706        let comm = self
707            .send_recv_comms
708            .get(&(from_stream, to_stream))
709            .ok_or_else(|| {
710                anyhow::anyhow!(
711                    "could not find stream to stream comm for: {:#?}",
712                    (from_stream, to_stream)
713                )
714            })?
715            .clone();
716
717        let to_rank = from_ranks
718            .index(self.rank)
719            .map(|index| to_ranks.get(index).ok())
720            .ok()
721            .flatten();
722        let from_rank = to_ranks
723            .index(self.rank)
724            .map(|index| from_ranks.get(index).ok())
725            .ok()
726            .flatten();
727
728        let (stream, stream_ref) = if to_rank.is_none() {
729            (self.try_get_stream(to_stream)?.clone(), to_stream)
730        } else if from_rank.is_none() || from_stream == to_stream {
731            (self.try_get_stream(from_stream)?.clone(), from_stream)
732        } else {
733            unimplemented!(
734                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
735                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
736                Then the send stream would do the nccl op, and then sync with sending stream again."
737            );
738        };
739
740        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
741
742        stream
743            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
744            .await?;
745
746        Ok(())
747    }
748
749    async fn exit(
750        &mut self,
751        cx: &hyperactor::Context<Self>,
752        error: Option<(Option<ActorId>, String)>,
753    ) -> Result<()> {
754        for (_, stream) in self.streams.drain() {
755            stream.drain_and_stop()?;
756            Arc::into_inner(stream)
757                .expect("there should be no owners of this stream handle except the worker stream table")
758                .await;
759        }
760
761        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
762            .ok()
763            .and_then(|val| val.parse::<i32>().ok())
764            .unwrap_or(1);
765        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
766            .ok()
767            .and_then(|val| val.parse::<i32>().ok())
768            .unwrap_or(1);
769
770        // Exit the worker process if there is an error.
771        let exit_code = match error {
772            Some((Some(actor_id), reason)) => {
773                tracing::error!(
774                    "stopping the worker, actor {} failed with error: {}",
775                    actor_id,
776                    reason
777                );
778                if *cx.self_id() == actor_id {
779                    self_error_exit_code
780                } else {
781                    peer_error_exit_code
782                }
783            }
784            Some((None, reason)) => {
785                tracing::error!("stopping the worker, reason: {}", reason);
786                1
787            }
788            None => 0,
789        };
790
791        if exit_code != 0 {
792            tracing::info!("stopping the worker process, exit code: {}", exit_code);
793            std::process::exit(exit_code);
794        }
795        cx.stop()?;
796        Ok(())
797    }
798
799    async fn send_value(
800        &mut self,
801        cx: &hyperactor::Context<Self>,
802        seq: Seq,
803        destination: Option<Ref>,
804        mutates: Vec<Ref>,
805        function: Option<ResolvableFunction>,
806        args: Vec<WireValue>,
807        kwargs: HashMap<String, WireValue>,
808        stream: StreamRef,
809    ) -> Result<()> {
810        // Resolve the stream.
811        let stream = self.try_get_stream(stream)?;
812
813        let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
814            HashMap::new()
815        } else {
816            self.device_meshes
817                .iter()
818                .map(|(k, v)| (k.clone(), v.0.clone()))
819                .collect()
820        };
821
822        let pipe = if let Some(destination) = destination {
823            let pipe = self
824                .pipes
825                .get(&destination)
826                .ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
827                .port();
828            Some(pipe)
829        } else {
830            None
831        };
832        // Resolve the value on the stream, then send the value to the pipe if provided,
833        // or back to the controller if not.
834        stream
835            .send_value(
836                cx,
837                seq,
838                cx.self_id().clone(),
839                mutates,
840                function,
841                args,
842                kwargs,
843                device_meshes,
844                pipe,
845            )
846            .await
847    }
848
849    async fn send_result_of_actor_call(
850        &mut self,
851        cx: &hyperactor::Context<Self>,
852        params: ActorCallParams,
853    ) -> Result<()> {
854        let stream = self.try_get_stream(params.stream)?;
855        stream
856            .send_result_of_actor_call(cx, cx.self_id().clone(), params)
857            .await?;
858        Ok(())
859    }
860    async fn call_actor_method(
861        &mut self,
862        cx: &hyperactor::Context<Self>,
863        params: ActorMethodParams,
864    ) -> Result<()> {
865        let stream = self.try_get_stream(params.call.stream)?;
866        stream.call_actor_method(cx, params).await?;
867        Ok(())
868    }
869    async fn split_comm(
870        &mut self,
871        cx: &hyperactor::Context<Self>,
872        dims: Vec<String>,
873        device_mesh: Ref,
874        stream_ref: StreamRef,
875        config: Option<NcclConfig>,
876    ) -> Result<()> {
877        let global_comm = self
878            .comm
879            .as_ref()
880            .context("tried to call SplitComm before BackendNetworkInit")?;
881        match self.device_meshes.get_mut(&device_mesh) {
882            Some((device_mesh, comm_map)) => {
883                // This rank is in the group to be split off. Split a new
884                // communicator for it off from the global communicator.
885                let stream = self
886                    .streams
887                    .get(&stream_ref)
888                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
889
890                let dims = SortedVec::from_unsorted(dims);
891
892                anyhow::ensure!(
893                    !comm_map.contains_key(&(stream_ref, dims.clone())),
894                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
895                );
896                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
897                let size = ranks_for_group.len();
898                let split_comm = global_comm
899                    .split_from(
900                        cx,
901                        ranks_for_group
902                            .into_iter()
903                            .map(|v| v.clone().try_into())
904                            .collect::<Result<Vec<_>, _>>()?,
905                        config,
906                    )
907                    .await?
908                    .context("split comm should include self rank")?;
909                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
910            }
911            None => {
912                // This rank is not in the group to be split off. We still need to
913                // participate in the commSplit call, however.
914                global_comm.split_from(cx, vec![], config).await?;
915            }
916        }
917        Ok(())
918    }
919
920    async fn split_comm_for_process_group(
921        &mut self,
922        cx: &hyperactor::Context<Self>,
923        remote_process_group_ref: Ref,
924        stream_ref: StreamRef,
925        config: Option<NcclConfig>,
926    ) -> Result<()> {
927        ensure!(
928            self.streams.contains_key(&stream_ref),
929            "invalid stream id: {:#?}",
930            stream_ref
931        );
932        let global_comm = self
933            .comm
934            .as_ref()
935            .context("tried to call SplitComm before BackendNetworkInit")?;
936        let state = self
937            .remote_process_groups
938            .get_mut(&remote_process_group_ref)
939            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
940        match self.device_meshes.get_mut(&state.device_mesh_ref) {
941            Some((device_mesh, _)) => {
942                // This rank is in the group to be split off. Split a new
943                // communicator for it off from the global communicator.
944                let entry = match state.comms.entry(stream_ref) {
945                    Entry::Vacant(entry) => entry,
946                    Entry::Occupied(_) => bail!(
947                        "comm already exists for remote process group {:#?} on stream {:#?}",
948                        remote_process_group_ref,
949                        stream_ref,
950                    ),
951                };
952                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
953                let split_comm = global_comm
954                    .split_from(
955                        cx,
956                        ranks_for_group
957                            .into_iter()
958                            .map(|v| v.clone().try_into())
959                            .collect::<Result<Vec<_>, _>>()?,
960                        config,
961                    )
962                    .await?
963                    .context("split comm should include self rank")?;
964                entry.insert(Arc::new(split_comm));
965            }
966            None => {
967                // This rank is not in the group to be split off. We still need to
968                // participate in the commSplit call, however.
969                global_comm.split_from(cx, vec![], config).await?;
970            }
971        }
972        Ok(())
973    }
974
975    async fn pipe_recv(
976        &mut self,
977        cx: &hyperactor::Context<Self>,
978        seq: Seq,
979        results: Vec<Option<Ref>>,
980        pipe: Ref,
981        stream: StreamRef,
982    ) -> Result<()> {
983        self.maybe_add_stream_to_recording(cx, stream).await?;
984
985        // Get a port for the pipe
986        let pipe = self
987            .pipes
988            .get(&pipe)
989            .ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
990        let pipe = pipe.port();
991        // Resolve the stream.
992        let stream = self.try_get_stream(stream)?;
993        // Push result into the stream.
994        stream.set_value(cx, seq, results, pipe).await
995    }
996
997    async fn set_ref_unit_tests_only(
998        &mut self,
999        cx: &hyperactor::Context<Self>,
1000        reference: Ref,
1001        value: WireValue,
1002        stream: StreamRef,
1003    ) -> Result<()> {
1004        let stream = self.try_get_stream(stream)?;
1005
1006        stream.set_ref_unit_tests_only(cx, reference, value).await
1007    }
1008
1009    async fn get_ref_unit_tests_only(
1010        &mut self,
1011        cx: &hyperactor::Context<Self>,
1012        ref_id: Ref,
1013        stream: StreamRef,
1014    ) -> Result<Option<Result<WireValue, String>>> {
1015        let stream = self.try_get_stream(stream)?;
1016        Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
1017    }
1018
1019    async fn define_recording(
1020        &mut self,
1021        cx: &hyperactor::Context<Self>,
1022        result: Ref,
1023        _nresults: usize,
1024        _nformals: usize,
1025        commands: Vec<WorkerMessage>,
1026        ntotal_messages: usize,
1027        index: usize,
1028    ) -> Result<()> {
1029        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
1030            bail!("already defining a different recording");
1031        }
1032        self.defining_recording = Some(result);
1033
1034        match self.recordings.entry(result) {
1035            Entry::Vacant(entry) => {
1036                ensure!(
1037                    index == 0,
1038                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
1039                    index
1040                );
1041                entry.insert(Recording::PartialRecording {
1042                    last_index: 0,
1043                    commands,
1044                });
1045            }
1046            Entry::Occupied(mut entry) => match entry.get_mut() {
1047                Recording::CompleteRecording { .. } => {
1048                    bail!("got DefineRecording message for already complete recording")
1049                }
1050                Recording::PartialRecording {
1051                    last_index,
1052                    commands: existing_commands,
1053                } => {
1054                    ensure!(
1055                        index == *last_index + 1,
1056                        "Got DefineRecording message with index = {:?}, but \
1057                            last seen index for recording is {:?}",
1058                        index,
1059                        last_index
1060                    );
1061                    *last_index = index;
1062                    existing_commands.extend(commands.into_iter());
1063                }
1064            },
1065        };
1066
1067        if index < ntotal_messages - 1 {
1068            return Ok(());
1069        }
1070        let commands = match self.recordings.remove(&result).unwrap() {
1071            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1072            Recording::PartialRecording { commands, .. } => {
1073                self.recordings.insert(
1074                    result,
1075                    Recording::CompleteRecording {
1076                        streams: HashSet::new(),
1077                    },
1078                );
1079                commands
1080            }
1081        };
1082
1083        for command in commands {
1084            WorkerMessageHandler::handle(self, cx, command).await?;
1085        }
1086
1087        match self.recordings.get(&result).unwrap() {
1088            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1089            Recording::CompleteRecording { streams, .. } => {
1090                for stream in streams {
1091                    self.try_get_stream(*stream)?
1092                        .finalize_recording(cx, result)
1093                        .await?;
1094                }
1095            }
1096        }
1097
1098        self.defining_recording = None;
1099        Ok(())
1100    }
1101
1102    async fn recording_formal(
1103        &mut self,
1104        cx: &hyperactor::Context<Self>,
1105        result: Ref,
1106        argument_index: usize,
1107        stream: StreamRef,
1108    ) -> Result<()> {
1109        ensure!(self.defining_recording.is_some());
1110        self.maybe_add_stream_to_recording(cx, stream).await?;
1111        self.try_get_stream(stream)?
1112            .recording_formal(cx, result, argument_index)
1113            .await
1114    }
1115
1116    async fn recording_result(
1117        &mut self,
1118        cx: &hyperactor::Context<Self>,
1119        result: Ref,
1120        output_index: usize,
1121        stream: StreamRef,
1122    ) -> Result<()> {
1123        ensure!(self.defining_recording.is_some());
1124        self.maybe_add_stream_to_recording(cx, stream).await?;
1125        self.try_get_stream(stream)?
1126            .recording_result(cx, result, output_index)
1127            .await
1128    }
1129
1130    async fn call_recording(
1131        &mut self,
1132        cx: &hyperactor::Context<Self>,
1133        seq: Seq,
1134        recording: Ref,
1135        results: Vec<Ref>,
1136        actuals: Vec<Ref>,
1137    ) -> Result<()> {
1138        ensure!(self.defining_recording.is_none());
1139        let recording_ref = recording;
1140        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1141            "could not find recording: {:#?}",
1142            recording
1143        ))?;
1144        match recording {
1145            Recording::PartialRecording { .. } => {
1146                bail!("cannot call recording because it is incomplete")
1147            }
1148            Recording::CompleteRecording { streams } => try_join_all(
1149                streams
1150                    .iter()
1151                    .map(|stream| self.try_get_stream(*stream))
1152                    .collect::<Result<Vec<_>>>()?
1153                    .into_iter()
1154                    .map(|stream| {
1155                        stream.call_recording(
1156                            cx,
1157                            seq,
1158                            recording_ref,
1159                            results.clone(),
1160                            actuals.clone(),
1161                        )
1162                    }),
1163            )
1164            .await
1165            .map(|_| ()),
1166        }
1167    }
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172    use std::assert_matches::assert_matches;
1173    use std::process::Stdio;
1174
1175    use anyhow::Result;
1176    use hyperactor::Instance;
1177    use hyperactor::Named;
1178    use hyperactor::WorldId;
1179    use hyperactor::actor::ActorStatus;
1180    use hyperactor::channel::ChannelAddr;
1181    use hyperactor::id;
1182    use hyperactor::mailbox::open_port;
1183    use hyperactor::proc::Proc;
1184    use hyperactor_multiprocess::System;
1185    use hyperactor_multiprocess::proc_actor::Environment;
1186    use hyperactor_multiprocess::proc_actor::ProcActor;
1187    use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1188    use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1189    use hyperactor_multiprocess::system_actor::Shape;
1190    use hyperactor_multiprocess::system_actor::SystemMessageClient;
1191    use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1192    use hyperactor_multiprocess::system_actor::WorldStatus;
1193    use monarch_messages::controller::ControllerMessage;
1194    use monarch_messages::controller::WorkerError;
1195    use monarch_messages::worker::WorkerMessageClient;
1196    use monarch_types::PickledPyObject;
1197    use monarch_types::PyTree;
1198    use pyo3::IntoPyObjectExt;
1199    use pyo3::Python;
1200    use pyo3::prelude::*;
1201    use pyo3::types::PyList;
1202    use pyo3::types::PyString;
1203    use rand::Rng;
1204    use rand::distributions::Alphanumeric;
1205    use timed_test::async_timed_test;
1206    use tokio::io::BufReader;
1207    use tokio::process::Command;
1208    use tokio_retry::Retry;
1209    use tokio_retry::strategy::FixedInterval;
1210    use torch_sys::Device;
1211    use torch_sys::DeviceIndex;
1212    use torch_sys::MemoryFormat;
1213
1214    use super::*;
1215    use crate::test_util::test_setup;
1216
1217    #[async_timed_test(timeout_secs = 60)]
1218    async fn basic_worker() -> Result<()> {
1219        test_setup()?;
1220
1221        let proc = Proc::local();
1222        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1223
1224        let worker_handle = proc
1225            .spawn::<WorkerActor>(
1226                "worker",
1227                WorkerParams {
1228                    world_size: 1,
1229                    rank: 0,
1230                    device_index: None,
1231                    controller_actor: controller_ref,
1232                },
1233            )
1234            .await
1235            .unwrap();
1236        worker_handle
1237            .command_group(
1238                &client,
1239                vec![
1240                    WorkerMessage::CreateStream {
1241                        id: 1.into(),
1242                        stream_creation: StreamCreationMode::UseDefaultStream,
1243                    },
1244                    WorkerMessage::CallFunction(CallFunctionParams {
1245                        seq: 0.into(),
1246                        results: vec![Some(0.into())],
1247                        mutates: vec![],
1248                        function: "torch.ops.aten.ones.default".into(),
1249                        args: vec![WireValue::IntList(vec![2, 3])],
1250                        kwargs: HashMap::new(),
1251                        stream: 1.into(),
1252                        remote_process_groups: vec![],
1253                    }),
1254                    WorkerMessage::CallFunction(CallFunctionParams {
1255                        seq: 2.into(),
1256                        results: vec![Some(Ref { id: 2 })],
1257                        mutates: vec![0.into()],
1258                        function: "torch.ops.aten.sub_.Scalar".into(),
1259                        args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1260                        kwargs: HashMap::new(),
1261                        stream: 1.into(),
1262                        remote_process_groups: vec![],
1263                    }),
1264                    WorkerMessage::CallFunction(CallFunctionParams {
1265                        seq: 3.into(),
1266                        results: vec![Some(Ref { id: 3 })],
1267                        mutates: vec![],
1268                        function: "torch.ops.aten.zeros.default".into(),
1269                        args: vec![WireValue::IntList(vec![2, 3])],
1270                        kwargs: HashMap::new(),
1271                        stream: 1.into(),
1272                        remote_process_groups: vec![],
1273                    }),
1274                    WorkerMessage::CallFunction(CallFunctionParams {
1275                        seq: 4.into(),
1276                        results: vec![Some(Ref { id: 4 })],
1277                        mutates: vec![],
1278                        function: "torch.ops.aten.allclose.default".into(),
1279                        args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1280                        kwargs: HashMap::new(),
1281                        stream: 1.into(),
1282                        remote_process_groups: vec![],
1283                    }),
1284                ],
1285            )
1286            .await
1287            .unwrap();
1288
1289        let result: bool = worker_handle
1290            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1291            .await
1292            .unwrap()
1293            .unwrap()
1294            .unwrap()
1295            .try_into()
1296            .unwrap();
1297        worker_handle.drain_and_stop().unwrap();
1298        worker_handle.await;
1299        let error_responses = controller_rx.drain();
1300        assert!(
1301            error_responses.is_empty(),
1302            "Expected no error responses, got: {:#?}",
1303            error_responses
1304        );
1305        assert!(result);
1306
1307        Ok(())
1308    }
1309
1310    #[async_timed_test(timeout_secs = 60)]
1311    async fn error_sends_response() -> Result<()> {
1312        test_setup()?;
1313
1314        let proc = Proc::local();
1315        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1316
1317        let worker_handle = proc
1318            .spawn::<WorkerActor>(
1319                "worker",
1320                WorkerParams {
1321                    world_size: 1,
1322                    rank: 0,
1323                    device_index: None,
1324                    controller_actor: controller_ref,
1325                },
1326            )
1327            .await
1328            .unwrap();
1329        worker_handle
1330            .command_group(
1331                &client,
1332                vec![
1333                    WorkerMessage::CreateStream {
1334                        id: 1.into(),
1335                        stream_creation: StreamCreationMode::UseDefaultStream,
1336                    },
1337                    WorkerMessage::CallFunction(CallFunctionParams {
1338                        seq: 0.into(),
1339                        results: vec![Some(0.into())],
1340                        mutates: vec![],
1341                        function: "torch.ops.aten.rand.default".into(),
1342                        args: vec![],
1343                        kwargs: HashMap::new(),
1344                        stream: 1.into(),
1345                        remote_process_groups: vec![],
1346                    }),
1347                    WorkerMessage::Exit { error: None },
1348                ],
1349            )
1350            .await
1351            .unwrap();
1352
1353        worker_handle.drain_and_stop().unwrap();
1354        worker_handle.await;
1355        let response_message = controller_rx.recv().await.unwrap();
1356        match response_message {
1357            ControllerMessage::RemoteFunctionFailed {
1358                seq,
1359                error: WorkerError { backtrace: msg, .. },
1360            } => {
1361                assert_eq!(seq, 0.into());
1362                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1363            }
1364            _ => panic!("unexpected response {:#?}", response_message),
1365        }
1366
1367        Ok(())
1368    }
1369
1370    #[async_timed_test(timeout_secs = 60)]
1371    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1372        test_setup()?;
1373
1374        let proc = Proc::local();
1375        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1376
1377        let worker_handle = proc
1378            .spawn::<WorkerActor>(
1379                "worker",
1380                WorkerParams {
1381                    world_size: 1,
1382                    rank: 0,
1383                    device_index: None,
1384                    controller_actor: controller_ref,
1385                },
1386            )
1387            .await
1388            .unwrap();
1389        worker_handle
1390            .command_group(
1391                &client,
1392                vec![
1393                    WorkerMessage::CreateStream {
1394                        id: 1.into(),
1395                        stream_creation: StreamCreationMode::UseDefaultStream,
1396                    },
1397                    WorkerMessage::SetRefUnitTestsOnly {
1398                        reference: 0.into(),
1399                        value: WireValue::Int(1),
1400                        stream: 1.into(),
1401                    },
1402                    WorkerMessage::CallFunction(CallFunctionParams {
1403                        seq: 0.into(),
1404                        results: vec![Some(Ref { id: 2 })],
1405                        mutates: vec![0.into()],
1406                        function: "i.dont.exist".into(),
1407                        args: vec![],
1408                        kwargs: HashMap::new(),
1409                        stream: 1.into(),
1410                        remote_process_groups: vec![],
1411                    }),
1412                ],
1413            )
1414            .await
1415            .unwrap();
1416
1417        let result = worker_handle
1418            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1419            .await?;
1420
1421        // Stop/drain worker before asserts to avoid hangs.
1422        worker_handle.drain_and_stop().unwrap();
1423        worker_handle.await;
1424
1425        let mutated_ref = result
1426            .context("no such ref")?
1427            .err()
1428            .context("expected error")?;
1429        assert!(mutated_ref.contains("failed to resolve function"));
1430
1431        let responses = controller_rx.drain();
1432        assert_eq!(
1433            responses.len(),
1434            1,
1435            "Expected one response, got: {:#?}",
1436            responses
1437        );
1438        Ok(())
1439    }
1440
1441    #[async_timed_test(timeout_secs = 60)]
1442    async fn accessing_errored_dependency() -> Result<()> {
1443        test_setup()?;
1444
1445        let proc = Proc::local();
1446        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1447
1448        let worker_handle = proc
1449            .spawn::<WorkerActor>(
1450                "worker",
1451                WorkerParams {
1452                    world_size: 1,
1453                    rank: 0,
1454                    device_index: None,
1455                    controller_actor: controller_ref,
1456                },
1457            )
1458            .await
1459            .unwrap();
1460        worker_handle
1461            .command_group(
1462                &client,
1463                vec![
1464                    WorkerMessage::CreateStream {
1465                        id: 1.into(),
1466                        stream_creation: StreamCreationMode::UseDefaultStream,
1467                    },
1468                    WorkerMessage::CallFunction(CallFunctionParams {
1469                        seq: 0.into(),
1470                        results: vec![Some(0.into())],
1471                        mutates: vec![],
1472                        function: "i.dont.exist".into(),
1473                        args: vec![],
1474                        kwargs: HashMap::new(),
1475                        stream: 1.into(),
1476                        remote_process_groups: vec![],
1477                    }),
1478                    WorkerMessage::CallFunction(CallFunctionParams {
1479                        seq: 1.into(),
1480                        results: vec![Some(1.into())],
1481                        mutates: vec![],
1482                        function: "torch.ops.aten.sub_.Scalar".into(),
1483                        args: vec![WireValue::Ref(0.into())],
1484                        kwargs: HashMap::new(),
1485                        stream: 1.into(),
1486                        remote_process_groups: vec![],
1487                    }),
1488                    WorkerMessage::Exit { error: None },
1489                ],
1490            )
1491            .await
1492            .unwrap();
1493
1494        worker_handle.drain_and_stop().unwrap();
1495        worker_handle.await;
1496
1497        let responses = controller_rx.drain();
1498        assert_eq!(
1499            responses.len(),
1500            1,
1501            "Expected one response, got: {:#?}",
1502            responses
1503        );
1504
1505        match &responses[0] {
1506            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1507                assert_eq!(seq, &0.into())
1508            }
1509            _ => panic!("unexpected response {:#?}", responses[0]),
1510        };
1511        Ok(())
1512    }
1513
1514    #[async_timed_test(timeout_secs = 60)]
1515    async fn py_remote_function_calls() -> Result<()> {
1516        test_setup()?;
1517
1518        let proc = Proc::local();
1519        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1520
1521        let worker_handle = proc
1522            .spawn::<WorkerActor>(
1523                "worker",
1524                WorkerParams {
1525                    world_size: 1,
1526                    rank: 0,
1527                    device_index: None,
1528                    controller_actor: controller_ref,
1529                },
1530            )
1531            .await
1532            .unwrap();
1533        let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1534            Python::with_gil(|py| {
1535                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1536                    .into_any()
1537                    .try_into()?;
1538                let sort_list: PickledPyObject =
1539                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1540                let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1541                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1542                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1543                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1544                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1545                let device: PickledPyObject = py
1546                    .import("torch")?
1547                    .getattr("device")?
1548                    .call1(("cuda:1",))?
1549                    .try_into()?;
1550                let memory_format: PickledPyObject = py
1551                    .import("torch")?
1552                    .getattr("contiguous_format")?
1553                    .try_into()?;
1554                PyResult::Ok((
1555                    split_arg,
1556                    sort_list,
1557                    mesh_ref,
1558                    dim,
1559                    layout,
1560                    none,
1561                    scalar,
1562                    device,
1563                    memory_format,
1564                ))
1565            })?;
1566
1567        worker_handle
1568            .command_group(
1569                &client,
1570                vec![
1571                    WorkerMessage::CreateStream {
1572                        id: 1.into(),
1573                        stream_creation: StreamCreationMode::UseDefaultStream,
1574                    },
1575                    WorkerMessage::CallFunction(CallFunctionParams {
1576                        seq: 0.into(),
1577                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1578                        mutates: vec![],
1579                        function: "os.path.split".into(),
1580                        args: vec![split_arg.into()],
1581                        kwargs: HashMap::new(),
1582                        stream: 1.into(),
1583                        remote_process_groups: vec![],
1584                    }),
1585                    WorkerMessage::CallFunction(CallFunctionParams {
1586                        seq: 2.into(),
1587                        results: vec![Some(4.into()), None, None, None, None],
1588                        mutates: vec![],
1589                        function: "builtins.sorted".into(),
1590                        args: vec![sort_list.into()],
1591                        kwargs: HashMap::new(),
1592                        stream: 1.into(),
1593                        remote_process_groups: vec![],
1594                    }),
1595                    WorkerMessage::CreateDeviceMesh {
1596                        result: 5.into(),
1597                        names: vec!["x".into()],
1598                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1599                    },
1600                    WorkerMessage::CallFunction(CallFunctionParams {
1601                        seq: 2.into(),
1602                        results: vec![Some(6.into())],
1603                        mutates: vec![],
1604                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1605                        args: vec![mesh_ref.into(), dim.into()],
1606                        kwargs: HashMap::new(),
1607                        stream: 1.into(),
1608                        remote_process_groups: vec![],
1609                    }),
1610                    WorkerMessage::CallFunction(CallFunctionParams {
1611                        seq: 4.into(),
1612                        results: vec![Some(7.into())],
1613                        mutates: vec![],
1614                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1615                            .into(),
1616                        args: vec![scalar.into()],
1617                        kwargs: HashMap::new(),
1618                        stream: 1.into(),
1619                        remote_process_groups: vec![],
1620                    }),
1621                    WorkerMessage::CallFunction(CallFunctionParams {
1622                        seq: 5.into(),
1623                        results: vec![Some(8.into())],
1624                        mutates: vec![],
1625                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1626                        args: vec![layout.into()],
1627                        kwargs: HashMap::new(),
1628                        stream: 1.into(),
1629                        remote_process_groups: vec![],
1630                    }),
1631                    WorkerMessage::CallFunction(CallFunctionParams {
1632                        seq: 6.into(),
1633                        results: vec![Some(9.into())],
1634                        mutates: vec![],
1635                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1636                        args: vec![none.into()],
1637                        kwargs: HashMap::new(),
1638                        stream: 1.into(),
1639                        remote_process_groups: vec![],
1640                    }),
1641                    // Verify that a function that returns `None` matches up with an
1642                    // empty result list.
1643                    WorkerMessage::CallFunction(CallFunctionParams {
1644                        seq: 7.into(),
1645                        results: vec![None],
1646                        mutates: vec![],
1647                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1648                        args: vec![],
1649                        kwargs: HashMap::new(),
1650                        stream: 1.into(),
1651                        remote_process_groups: vec![],
1652                    }),
1653                    WorkerMessage::CallFunction(CallFunctionParams {
1654                        seq: 8.into(),
1655                        results: vec![Some(10.into())],
1656                        mutates: vec![],
1657                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1658                        args: vec![device.into()],
1659                        kwargs: HashMap::new(),
1660                        stream: 1.into(),
1661                        remote_process_groups: vec![],
1662                    }),
1663                    WorkerMessage::CallFunction(CallFunctionParams {
1664                        seq: 9.into(),
1665                        results: vec![Some(11.into())],
1666                        mutates: vec![],
1667                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1668                            .into(),
1669                        args: vec![memory_format.into()],
1670                        kwargs: HashMap::new(),
1671                        stream: 1.into(),
1672                        remote_process_groups: vec![],
1673                    }),
1674                    // Test that list of tests can be passes correctly
1675                    WorkerMessage::CallFunction(CallFunctionParams {
1676                        seq: 10.into(),
1677                        results: vec![Some(12.into())],
1678                        mutates: vec![],
1679                        function: "torch.ops.aten.ones.default".into(),
1680                        args: vec![WireValue::IntList(vec![2, 3])],
1681                        kwargs: HashMap::new(),
1682                        stream: 1.into(),
1683                        remote_process_groups: vec![],
1684                    }),
1685                    WorkerMessage::CallFunction(CallFunctionParams {
1686                        seq: 11.into(),
1687                        results: vec![Some(13.into())],
1688                        mutates: vec![],
1689                        function: "torch.ops.aten.stack.default".into(),
1690                        args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1691                        kwargs: HashMap::new(),
1692                        stream: 1.into(),
1693                        remote_process_groups: vec![],
1694                    }),
1695                ],
1696            )
1697            .await
1698            .unwrap();
1699
1700        let result1: String = worker_handle
1701            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1702            .await
1703            .unwrap()
1704            .unwrap()
1705            .unwrap()
1706            .try_into()
1707            .unwrap();
1708        let result2: String = worker_handle
1709            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1710            .await
1711            .unwrap()
1712            .unwrap()
1713            .unwrap()
1714            .try_into()
1715            .unwrap();
1716        let result3: i64 = worker_handle
1717            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1718            .await
1719            .unwrap()
1720            .unwrap()
1721            .unwrap()
1722            .try_into()
1723            .unwrap();
1724        let result4: i64 = worker_handle
1725            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1726            .await
1727            .unwrap()
1728            .unwrap()
1729            .unwrap()
1730            .try_into()
1731            .unwrap();
1732        assert_eq!(
1733            ScalarType::Float,
1734            worker_handle
1735                .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1736                .await
1737                .unwrap()
1738                .unwrap()
1739                .unwrap()
1740                .try_into()
1741                .unwrap()
1742        );
1743        assert_eq!(
1744            Layout::Strided,
1745            worker_handle
1746                .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1747                .await
1748                .unwrap()
1749                .unwrap()
1750                .unwrap()
1751                .try_into()
1752                .unwrap()
1753        );
1754        assert_matches!(
1755            worker_handle
1756                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1757                .await
1758                .unwrap()
1759                .unwrap()
1760                .unwrap(),
1761            WireValue::None(()),
1762        );
1763        let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1764        assert_eq!(
1765            device,
1766            worker_handle
1767                .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1768                .await
1769                .unwrap()
1770                .unwrap()
1771                .unwrap()
1772                .try_into()
1773                .unwrap()
1774        );
1775        assert_matches!(
1776            worker_handle
1777                .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1778                .await
1779                .unwrap()
1780                .unwrap()
1781                .unwrap(),
1782            WireValue::MemoryFormat(MemoryFormat::Contiguous),
1783        );
1784
1785        worker_handle.drain_and_stop().unwrap();
1786        worker_handle.await;
1787        let error_responses = controller_rx.drain();
1788        assert!(
1789            error_responses.is_empty(),
1790            "Expected no error responses, got: {:#?}",
1791            error_responses
1792        );
1793
1794        assert_eq!(result1, "/fbs/fbc/foo");
1795        assert_eq!(result2, "bar");
1796        assert_eq!(result3, 1);
1797        assert_eq!(result4, 0);
1798
1799        Ok(())
1800    }
1801
1802    #[async_timed_test(timeout_secs = 60)]
1803    async fn delete_refs() -> Result<()> {
1804        test_setup()?;
1805
1806        let proc = Proc::local();
1807        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1808
1809        let worker_handle = proc
1810            .spawn::<WorkerActor>(
1811                "worker",
1812                WorkerParams {
1813                    world_size: 1,
1814                    rank: 0,
1815                    device_index: None,
1816                    controller_actor: controller_ref,
1817                },
1818            )
1819            .await
1820            .unwrap();
1821        worker_handle
1822            .command_group(
1823                &client,
1824                vec![
1825                    WorkerMessage::CreateStream {
1826                        id: 0.into(),
1827                        stream_creation: StreamCreationMode::CreateNewStream,
1828                    },
1829                    WorkerMessage::CreateStream {
1830                        id: 1.into(),
1831                        stream_creation: StreamCreationMode::CreateNewStream,
1832                    },
1833                    WorkerMessage::SetRefUnitTestsOnly {
1834                        reference: Ref { id: 2 },
1835                        value: WireValue::Bool(false),
1836                        stream: 0.into(),
1837                    },
1838                    WorkerMessage::SetRefUnitTestsOnly {
1839                        reference: Ref { id: 3 },
1840                        value: WireValue::Bool(true),
1841                        stream: 0.into(),
1842                    },
1843                    WorkerMessage::SetRefUnitTestsOnly {
1844                        reference: Ref { id: 4 },
1845                        value: WireValue::Int(0),
1846                        stream: 1.into(),
1847                    },
1848                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1849                ],
1850            )
1851            .await
1852            .unwrap();
1853
1854        let result: bool = worker_handle
1855            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1856            .await
1857            .unwrap()
1858            .unwrap()
1859            .unwrap()
1860            .try_into()
1861            .unwrap();
1862        let fail_result = worker_handle
1863            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1864            .await
1865            .unwrap();
1866
1867        worker_handle.drain_and_stop().unwrap();
1868        worker_handle.await;
1869
1870        assert!(result, "should be able to get a non-deleted ref");
1871        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1872
1873        Ok(())
1874    }
1875
1876    #[async_timed_test(timeout_secs = 60)]
1877    async fn request_status() -> Result<()> {
1878        test_setup()?;
1879
1880        let proc = Proc::local();
1881        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1882
1883        let worker_handle = proc
1884            .spawn::<WorkerActor>(
1885                "worker",
1886                WorkerParams {
1887                    world_size: 1,
1888                    rank: 0,
1889                    device_index: None,
1890                    controller_actor: controller_ref,
1891                },
1892            )
1893            .await
1894            .unwrap();
1895        worker_handle
1896            .command_group(
1897                &client,
1898                vec![
1899                    WorkerMessage::CreateStream {
1900                        id: 0.into(),
1901                        stream_creation: StreamCreationMode::CreateNewStream,
1902                    },
1903                    WorkerMessage::CreateStream {
1904                        id: 1.into(),
1905                        stream_creation: StreamCreationMode::CreateNewStream,
1906                    },
1907                ],
1908            )
1909            .await
1910            .unwrap();
1911
1912        for i in 0..100 {
1913            // call alternating functions on this stream.
1914            worker_handle
1915                .call_function(
1916                    &client,
1917                    CallFunctionParams {
1918                        seq: i.into(),
1919                        results: vec![Some(Ref { id: i + 2 })],
1920                        mutates: vec![],
1921                        function: "torch.ops.aten.ones.default".into(),
1922                        args: vec![WireValue::IntList(vec![2, 3])],
1923                        kwargs: HashMap::new(),
1924                        stream: (i % 2).into(),
1925                        remote_process_groups: vec![],
1926                    },
1927                )
1928                .await
1929                .unwrap();
1930        }
1931
1932        worker_handle
1933            .request_status(&client, 100.into(), false)
1934            .await
1935            .unwrap();
1936
1937        worker_handle.drain_and_stop().unwrap();
1938        worker_handle.await;
1939
1940        let mut responses = controller_rx.drain();
1941        assert_eq!(
1942            responses.len(),
1943            1,
1944            "Expected one response, got: {:#?}",
1945            responses
1946        );
1947
1948        let response = responses.pop().unwrap();
1949        match response {
1950            ControllerMessage::Status { seq, .. } => {
1951                assert_eq!(seq, 101.into())
1952            }
1953            _ => panic!("unexpected response {:#?}", response),
1954        };
1955
1956        Ok(())
1957    }
1958
1959    #[async_timed_test(timeout_secs = 60)]
1960    async fn backend_network_init() {
1961        let proc = Proc::local();
1962        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1963
1964        let worker_handle1 = proc
1965            .spawn::<WorkerActor>(
1966                "worker0",
1967                WorkerParams {
1968                    world_size: 2,
1969                    rank: 0,
1970                    device_index: Some(0),
1971                    controller_actor: controller_ref.clone(),
1972                },
1973            )
1974            .await
1975            .unwrap();
1976        let worker_handle2 = proc
1977            .spawn::<WorkerActor>(
1978                "worker1",
1979                WorkerParams {
1980                    world_size: 2,
1981                    rank: 1,
1982                    device_index: Some(1),
1983                    controller_actor: controller_ref,
1984                },
1985            )
1986            .await
1987            .unwrap();
1988
1989        let unique_id = UniqueId::new().unwrap();
1990        worker_handle1
1991            .backend_network_init(&client, unique_id.clone())
1992            .await
1993            .unwrap();
1994        worker_handle2
1995            .backend_network_init(&client, unique_id)
1996            .await
1997            .unwrap();
1998
1999        worker_handle1.drain_and_stop().unwrap();
2000        worker_handle1.await;
2001        worker_handle2.drain_and_stop().unwrap();
2002        worker_handle2.await;
2003    }
2004
2005    #[async_timed_test(timeout_secs = 60)]
2006    async fn send_value() -> Result<()> {
2007        test_setup()?;
2008
2009        let proc = Proc::local();
2010        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2011
2012        let worker_handle = proc
2013            .spawn::<WorkerActor>(
2014                "worker",
2015                WorkerParams {
2016                    world_size: 1,
2017                    rank: 0,
2018                    device_index: None,
2019                    controller_actor: controller_ref,
2020                },
2021            )
2022            .await
2023            .unwrap();
2024        worker_handle
2025            .command_group(
2026                &client,
2027                vec![
2028                    WorkerMessage::CreateStream {
2029                        id: 1.into(),
2030                        stream_creation: StreamCreationMode::UseDefaultStream,
2031                    },
2032                    WorkerMessage::CallFunction(CallFunctionParams {
2033                        seq: 0.into(),
2034                        results: vec![Some(0.into())],
2035                        mutates: vec![],
2036                        function: "torch.ops.aten.ones.default".into(),
2037                        args: vec![WireValue::IntList(vec![2, 3])],
2038                        kwargs: HashMap::new(),
2039                        stream: 1.into(),
2040                        remote_process_groups: vec![],
2041                    }),
2042                    WorkerMessage::SendValue {
2043                        seq: 1.into(),
2044                        destination: None,
2045                        mutates: vec![],
2046                        function: None,
2047                        args: vec![WireValue::Ref(0.into())],
2048                        kwargs: HashMap::new(),
2049                        stream: 1.into(),
2050                    },
2051                    WorkerMessage::SendValue {
2052                        seq: 2.into(),
2053                        destination: None,
2054                        mutates: vec![],
2055                        function: Some("torch.ops.aten.var_mean.default".into()),
2056                        args: vec![WireValue::Ref(0.into())],
2057                        kwargs: HashMap::new(),
2058                        stream: 1.into(),
2059                    },
2060                    WorkerMessage::Exit { error: None },
2061                ],
2062            )
2063            .await
2064            .unwrap();
2065
2066        worker_handle.drain_and_stop()?;
2067        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2068
2069        let mut responses = controller_rx.drain();
2070        assert_eq!(
2071            responses.len(),
2072            3,
2073            "Expected one response, got: {:#?}",
2074            responses
2075        );
2076
2077        match responses.pop().unwrap() {
2078            ControllerMessage::FetchResult { seq, value } => {
2079                assert_eq!(seq, 2.into());
2080                let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2081                assert_eq!(value.leaves().len(), 2);
2082            }
2083            resp => panic!("unexpected response {:#?}", resp),
2084        };
2085        match responses.pop().unwrap() {
2086            ControllerMessage::FetchResult { seq, .. } => {
2087                assert_eq!(seq, 1.into())
2088            }
2089            resp => panic!("unexpected response {:#?}", resp),
2090        };
2091        Ok(())
2092    }
2093
2094    #[async_timed_test(timeout_secs = 60)]
2095    async fn send_value_err_result() -> Result<()> {
2096        test_setup()?;
2097
2098        let proc = Proc::local();
2099        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2100
2101        let worker_handle = proc
2102            .spawn::<WorkerActor>(
2103                "worker",
2104                WorkerParams {
2105                    world_size: 1,
2106                    rank: 0,
2107                    device_index: None,
2108                    controller_actor: controller_ref,
2109                },
2110            )
2111            .await
2112            .unwrap();
2113
2114        let ref_arg: PickledPyObject =
2115            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2116
2117        worker_handle
2118            .command_group(
2119                &client,
2120                vec![
2121                    WorkerMessage::CreateStream {
2122                        id: 1.into(),
2123                        stream_creation: StreamCreationMode::UseDefaultStream,
2124                    },
2125                    WorkerMessage::SetRefUnitTestsOnly {
2126                        reference: Ref { id: 2 },
2127                        value: WireValue::Bool(false),
2128                        stream: 1.into(),
2129                    },
2130                    WorkerMessage::SendValue {
2131                        seq: 1.into(),
2132                        destination: None,
2133                        mutates: vec![Ref { id: 2 }],
2134                        function: Some("non.existent.function".into()),
2135                        args: vec![],
2136                        kwargs: HashMap::new(),
2137                        stream: 1.into(),
2138                    },
2139                    WorkerMessage::SendValue {
2140                        seq: 2.into(),
2141                        destination: None,
2142                        mutates: vec![],
2143                        function: None,
2144                        args: vec![ref_arg.into()],
2145                        kwargs: HashMap::new(),
2146                        stream: 1.into(),
2147                    },
2148                    WorkerMessage::Exit { error: None },
2149                ],
2150            )
2151            .await
2152            .unwrap();
2153
2154        worker_handle.drain_and_stop()?;
2155        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2156
2157        let mut responses = controller_rx.drain();
2158        assert_eq!(
2159            responses.len(),
2160            3,
2161            "Expected one response, got: {:#?}",
2162            responses
2163        );
2164        match responses.pop() {
2165            Some(ControllerMessage::FetchResult { seq, value }) => {
2166                assert_eq!(seq, 2.into());
2167                assert!(value.is_err());
2168                assert!(
2169                    value
2170                        .unwrap_err()
2171                        .backtrace
2172                        .contains("failed to resolve function")
2173                );
2174            }
2175            _ => panic!("unexpected response {:#?}", responses),
2176        }
2177        match responses.pop() {
2178            Some(ControllerMessage::FetchResult { seq, value }) => {
2179                assert_eq!(seq, 1.into());
2180                assert!(value.is_err());
2181                assert!(
2182                    value
2183                        .unwrap_err()
2184                        .backtrace
2185                        .contains("failed to resolve function")
2186                );
2187            }
2188            _ => panic!("unexpected response {:#?}", responses),
2189        }
2190        Ok(())
2191    }
2192
2193    #[async_timed_test(timeout_secs = 60)]
2194    async fn pipe_send_recv() -> Result<()> {
2195        test_setup()?;
2196
2197        let proc = Proc::local();
2198        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2199
2200        let handle = proc
2201            .spawn::<WorkerActor>(
2202                "worker",
2203                WorkerParams {
2204                    world_size: 1,
2205                    rank: 0,
2206                    device_index: None,
2207                    controller_actor: controller_ref,
2208                },
2209            )
2210            .await
2211            .unwrap();
2212        let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
2213            PickledPyObject,
2214            PickledPyObject,
2215            PickledPyObject,
2216        ) = Python::with_gil(|py| {
2217            PyResult::Ok((
2218                PyList::new(py, [2, 3])?.into_any().try_into()?,
2219                Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
2220                Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
2221            ))
2222        })?;
2223
2224        handle
2225            .command_group(
2226                &client,
2227                vec![
2228                    WorkerMessage::CreateStream {
2229                        id: 0.into(),
2230                        stream_creation: StreamCreationMode::UseDefaultStream,
2231                    },
2232                    WorkerMessage::CreateDeviceMesh {
2233                        result: 1.into(),
2234                        names: vec!["x".into()],
2235                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2236                    },
2237                    // Create a tensor value which we'll send through the pipe.
2238                    WorkerMessage::CallFunction(CallFunctionParams {
2239                        seq: 0.into(),
2240                        results: vec![Some(2.into())],
2241                        mutates: vec![],
2242                        function: "torch.ops.aten.ones.default".into(),
2243                        args: vec![WireValue::IntList(vec![2, 3])],
2244                        kwargs: HashMap::new(),
2245                        stream: 0.into(),
2246                        remote_process_groups: vec![],
2247                    }),
2248                    WorkerMessage::CreatePipe {
2249                        result: 3.into(),
2250                        key: "unused".into(),
2251                        function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
2252                        max_messages: 1,
2253                        mesh: 1.into(),
2254                        args: vec![],
2255                        kwargs: HashMap::new(),
2256                    },
2257                    WorkerMessage::SendValue {
2258                        seq: 1.into(),
2259                        destination: Some(3.into()),
2260                        mutates: vec![],
2261                        function: Some(
2262                            "monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
2263                        ),
2264                        args: vec![resolve_value_arg.into()],
2265                        kwargs: HashMap::new(),
2266                        stream: 0.into(),
2267                    },
2268                    WorkerMessage::PipeRecv {
2269                        seq: 2.into(),
2270                        results: vec![Some(4.into())],
2271                        pipe: 3.into(),
2272                        stream: 0.into(),
2273                    },
2274                    WorkerMessage::CallFunction(CallFunctionParams {
2275                        seq: 0.into(),
2276                        results: vec![Some(5.into())],
2277                        mutates: vec![],
2278                        function: "torch.equal".into(),
2279                        args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
2280                        kwargs: HashMap::new(),
2281                        stream: 0.into(),
2282                        remote_process_groups: vec![],
2283                    }),
2284                ],
2285            )
2286            .await
2287            .unwrap();
2288
2289        let matches: bool = handle
2290            .get_ref_unit_tests_only(&client, 5.into(), 0.into())
2291            .await
2292            .unwrap()
2293            .unwrap()
2294            .unwrap()
2295            .try_into()
2296            .unwrap();
2297        assert!(matches);
2298
2299        handle.drain_and_stop()?;
2300        assert_matches!(handle.await, ActorStatus::Stopped);
2301
2302        let responses = controller_rx.drain();
2303        assert_eq!(
2304            responses.len(),
2305            0,
2306            "Expected one response, got: {:#?}",
2307            responses
2308        );
2309
2310        Ok(())
2311    }
2312
2313    fn get_random_channel_addr() -> ChannelAddr {
2314        let random_string = rand::thread_rng()
2315            .sample_iter(&Alphanumeric)
2316            .take(24)
2317            .map(char::from)
2318            .collect::<String>();
2319        format!("unix!@{random_string}").parse().unwrap()
2320    }
2321
2322    async fn ensure_world_ready(client: &Instance<()>, world: WorldId) -> Result<()> {
2323        tracing::info!("checking whether world {world} is ready");
2324        let retry_strategy = FixedInterval::from_millis(1000).take(100);
2325        Retry::spawn(retry_strategy, async || {
2326            let snapshot = SYSTEM_ACTOR_REF
2327                .snapshot(&client, SystemSnapshotFilter::default())
2328                .await?;
2329            let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2330            tracing::info!("world status: {:?}", world_snapshot.status);
2331            match world_snapshot.status {
2332                WorldStatus::Live => Ok(()),
2333                _ => Err(anyhow!("world is not live")),
2334            }
2335        })
2336        .await?;
2337        Ok(())
2338    }
2339
2340    #[async_timed_test(timeout_secs = 60)]
2341    async fn remote_process_group() -> Result<()> {
2342        test_setup()?;
2343
2344        // Spin up a system to manage the test setup.
2345        let timeout: Duration = Duration::from_secs(10);
2346        let system_addr = get_random_channel_addr();
2347        let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2348
2349        // Create a fake controller for the workers to talk to.
2350        let client = System::new(system_addr.clone()).attach().await?;
2351        let (handle, mut controller_rx) = client.open_port::<ControllerMessage>();
2352        handle.bind_to(ControllerMessage::port());
2353        let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.self_id().clone());
2354
2355        // Create the worker world
2356        let world_size = 2;
2357        SYSTEM_ACTOR_REF
2358            .upsert_world(
2359                &client,
2360                id!(world),
2361                Shape::Definite(vec![world_size]),
2362                4,
2363                Environment::Local,
2364                HashMap::new(),
2365            )
2366            .await?;
2367
2368        // Bootstrap a proc for each worker
2369        let mut worker_process_handles = vec![];
2370        let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2371        for rank in 0..world_size {
2372            let world_id = "world".to_string();
2373            let proc_id = format!("{world_id}[{rank}]");
2374            worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2375
2376            let mut handle =
2377                Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2378                    anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2379                })?)
2380                .arg("worker")
2381                .arg(format!("--bootstrap-addr={system_addr}"))
2382                .arg(format!("--world-id={world_id}"))
2383                .arg(format!("--proc-id={proc_id}"))
2384                .stdout(Stdio::piped())
2385                .stdin(Stdio::piped())
2386                .kill_on_drop(true)
2387                .spawn()?;
2388
2389            let out = handle.stdout.take().unwrap();
2390            tokio::spawn(async move {
2391                let mut reader = BufReader::new(out);
2392                tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2393                    .await
2394                    .unwrap();
2395            });
2396            worker_process_handles.push(handle);
2397        }
2398
2399        // Wait for procs to initialize
2400
2401        ensure_world_ready(&client, id!(world)).await?;
2402
2403        // Spawn workers on each proc
2404        let (spawned_port, mut spawned_receiver) = open_port(&client);
2405        for (rank, worker_proc) in worker_procs.iter().enumerate() {
2406            let params = WorkerParams {
2407                world_size,
2408                rank,
2409                device_index: Some(rank.try_into().unwrap()),
2410                controller_actor: controller_ref.clone(),
2411            };
2412            worker_proc
2413                .spawn(
2414                    &client,
2415                    "monarch_tensor_worker::WorkerActor".to_owned(),
2416                    "worker".to_owned(),
2417                    bincode::serialize(&params)?,
2418                    spawned_port.bind(),
2419                )
2420                .await?;
2421        }
2422        let mut spawned = vec![];
2423        while spawned.len() < world_size {
2424            spawned.push(spawned_receiver.recv().await?);
2425        }
2426        tracing::info!("spawned {} worker actors", world_size);
2427        let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2428            .map(|rank| format!("world[{rank}].worker[0]"))
2429            .map(|name| ActorRef::attest(name.parse().unwrap()))
2430            .collect();
2431
2432        let remote_proc_grp_ref: PickledPyObject =
2433            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2434
2435        let unique_id = UniqueId::new()?;
2436        let messages = vec![
2437            WorkerMessage::CreateStream {
2438                id: 0.into(),
2439                stream_creation: StreamCreationMode::UseDefaultStream,
2440            },
2441            WorkerMessage::BackendNetworkInit(unique_id.clone()),
2442            WorkerMessage::CreateDeviceMesh {
2443                result: 1.into(),
2444                names: vec!["x".into()],
2445                ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2446            },
2447            WorkerMessage::CreateRemoteProcessGroup {
2448                result: 2.into(),
2449                device_mesh: 1.into(),
2450                dims: vec!["x".into()],
2451            },
2452            WorkerMessage::SplitCommForProcessGroup {
2453                remote_process_group: 2.into(),
2454                stream: 0.into(),
2455                config: None,
2456            },
2457            WorkerMessage::CallFunction(CallFunctionParams {
2458                seq: 0.into(),
2459                results: vec![Some(3.into())],
2460                mutates: vec![],
2461                function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2462                    .into(),
2463                args: vec![remote_proc_grp_ref.into()],
2464                kwargs: HashMap::new(),
2465                stream: 0.into(),
2466                remote_process_groups: vec![2.into()],
2467            }),
2468        ];
2469
2470        workers[0].command_group(&client, messages.clone()).await?;
2471        workers[1].command_group(&client, messages).await?;
2472
2473        let _ = workers[0]
2474            .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2475            .await?
2476            .unwrap()
2477            .unwrap();
2478
2479        let error_responses = controller_rx.drain();
2480        assert!(
2481            error_responses.is_empty(),
2482            "Expected no error responses, got: {:#?}",
2483            error_responses
2484        );
2485
2486        Ok(())
2487    }
2488}