monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::ActorRef;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::Named;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor::reference::ActorId;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::CallFunctionParams;
73use monarch_messages::worker::Factory;
74use monarch_messages::worker::Reduction;
75use monarch_messages::worker::Ref;
76use monarch_messages::worker::ResolvableFunction;
77use monarch_messages::worker::StreamCreationMode;
78use monarch_messages::worker::StreamRef;
79use monarch_messages::worker::WorkerMessage;
80use monarch_messages::worker::WorkerMessageHandler;
81use monarch_messages::worker::WorkerParams;
82use monarch_types::PyTree;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys::CudaDevice;
93use torch_sys::DeviceIndex;
94use torch_sys::Layout;
95use torch_sys::RValue;
96use torch_sys::ScalarType;
97use torch_sys::TensorCell;
98use torch_sys::factory_zeros;
99use torch_sys_cuda::nccl::NcclConfig;
100use torch_sys_cuda::nccl::ReduceOp;
101use torch_sys_cuda::nccl::UniqueId;
102
103#[derive(Debug)]
104struct RemoteProcessGroupState {
105    device_mesh_ref: Ref,
106    dims: SortedVec<String>,
107    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
108}
109
110impl RemoteProcessGroupState {
111    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
112        Self {
113            device_mesh_ref,
114            dims,
115            comms: HashMap::new(),
116        }
117    }
118}
119
120#[derive(Debug)]
121enum Recording {
122    // In the process of receiving DefineRecording messages for this
123    // recording.
124    PartialRecording {
125        // The index of the last DefineRecording message received.
126        last_index: usize,
127        // The list of commands seen so far for this recording.
128        commands: Vec<WorkerMessage>,
129    },
130
131    // The recording is ready to be run.
132    CompleteRecording {
133        // The list of streams on which this recording is defined.
134        streams: HashSet<StreamRef>,
135    },
136}
137
138/// A PyTorch runtime instance, operating on a single accelerator device,
139/// controlled via hyperactor messaging.
140///
141/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
142/// real work.
143///
144/// See [`WorkerMessage`] for what it can do!
145#[derive(Debug)]
146#[hyperactor::export(
147    spawn = true,
148    handlers = [
149        WorkerMessage {cast = true},
150        AssignRankMessage {cast = true},
151    ],
152)]
153pub struct WorkerActor {
154    device: Option<CudaDevice>,
155    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
156    /// Maps streams to the device mesh and a map of dim names to the concrete
157    /// communicator actor that represents the dimension for that stream.
158    device_meshes: HashMap<
159        Ref,
160        (
161            DeviceMesh,
162            // A map for comms for this mesh for a given pair of stream and dims.
163            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
164        ),
165    >,
166    world_size: usize,
167    rank: usize,
168    borrows: HashMap<u64, Borrow>,
169    comm: Option<ActorHandle<NcclCommActor>>,
170    controller_actor: ActorRef<ControllerActor>,
171    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
172    /// subsequent `CallFunction` calls, as this is where the actual allocation
173    /// will happen.
174    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
175    /// The comm actor for each pair of streams that need to send/recv tensors.
176    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
177    recordings: HashMap<Ref, Recording>,
178    defining_recording: Option<Ref>,
179    respond_with_python_message: bool,
180}
181
182impl WorkerActor {
183    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
184        self.streams
185            .get(&stream)
186            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
187    }
188
189    async fn maybe_add_stream_to_recording(
190        &mut self,
191        cx: &impl context::Actor,
192        stream: StreamRef,
193    ) -> Result<()> {
194        // If we're defining a recording, add the stream to the list of streams that
195        // this recording uses, and call define_recording on the stream.
196        if let Some(defining_recording) = self.defining_recording {
197            let recording = self.recordings.get_mut(&defining_recording).unwrap();
198            let fut = match recording {
199                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
200                Recording::CompleteRecording { streams } => {
201                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
202                        Ok(self
203                            .try_get_stream(stream)?
204                            .define_recording(cx, defining_recording))
205                    })
206                }
207            }
208            .transpose()?;
209            match fut {
210                Some(fut) => fut.await,
211                None => Ok(()),
212            }
213        } else {
214            Ok(())
215        }
216    }
217}
218
219#[async_trait]
220impl Actor for WorkerActor {
221    type Params = WorkerParams;
222
223    async fn new(
224        WorkerParams {
225            world_size,
226            rank,
227            device_index,
228            controller_actor,
229        }: Self::Params,
230    ) -> Result<Self> {
231        Ok(Self {
232            device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
233            streams: HashMap::new(),
234            device_meshes: HashMap::new(),
235            world_size,
236            rank,
237            borrows: HashMap::new(),
238            comm: None,
239            controller_actor,
240            remote_process_groups: HashMap::new(),
241            send_recv_comms: HashMap::new(),
242            recordings: HashMap::new(),
243            defining_recording: None,
244            respond_with_python_message: false,
245        })
246    }
247
248    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
249}
250
251#[async_trait]
252impl Handler<AssignRankMessage> for WorkerActor {
253    async fn handle(
254        &mut self,
255        cx: &hyperactor::Context<Self>,
256        _: AssignRankMessage,
257    ) -> anyhow::Result<()> {
258        let point = cx.cast_point();
259        self.rank = point.rank();
260        self.respond_with_python_message = true;
261        Python::with_gil(|py| {
262            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
263            let p: PyPoint = point.into();
264            mesh_controller
265                .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
266                .unwrap();
267        });
268        Ok(())
269    }
270}
271
272/// Worker messages. These define the observable behavior of the worker, so the
273/// documentations here
274#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
275pub enum AssignRankMessage {
276    AssignRank(),
277}
278
279#[async_trait]
280impl Handler<WorkerMessage> for WorkerActor {
281    async fn handle(
282        &mut self,
283        cx: &hyperactor::Context<Self>,
284        message: WorkerMessage,
285    ) -> anyhow::Result<()> {
286        <Self as WorkerMessageHandler>::handle(self, cx, message).await
287    }
288}
289
290#[async_trait]
291impl WorkerMessageHandler for WorkerActor {
292    async fn backend_network_init(
293        &mut self,
294        cx: &hyperactor::Context<Self>,
295        unique_id: UniqueId,
296    ) -> Result<()> {
297        let device = self
298            .device
299            .expect("tried to init backend network on a non-CUDA worker");
300        let comm = NcclCommActor::spawn(
301            cx,
302            CommParams::New {
303                device,
304                unique_id,
305                world_size: self.world_size.try_into().unwrap(),
306                rank: self.rank.try_into().unwrap(),
307            },
308        )
309        .await?;
310
311        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
312        let cell = TensorCell::new(tensor);
313
314        comm.all_reduce(
315            cx,
316            cell,
317            ReduceOp::Sum,
318            torch_sys_cuda::cuda::Stream::get_current_stream(),
319        )
320        .await?;
321
322        // TODO: this blocks forward progress of the the actor loop while we
323        // wait for the streams to catch up. Once we have a way of spawning
324        // tasks that the actor system can monitor in a non-blocking way, we
325        // should remove this.
326
327        // We need to be careful to initialize the streams in a consistent order
328        // across all workers to avoid NCCL deadlocks. Use the refs to provide
329        // this order, as a stream's ref will be the same across all workers.
330        let sorted_streams = self
331            .streams
332            .iter()
333            .sorted_by_key(|(k, _)| *k)
334            .map(|(_, v)| v.as_ref());
335
336        let mut splits = Vec::new();
337        for _ in 0..sorted_streams.len() {
338            // Do the split in this event loop, to provide a deterministic
339            // order.
340            splits.push(comm.split_all(cx, None).await?);
341        }
342        let _: Vec<()> = try_join_all(
343            sorted_streams
344                .into_iter()
345                .zip(splits.into_iter())
346                .map(|(stream, split)| stream.init_comm(cx, split)),
347        )
348        .await?;
349
350        self.comm = Some(comm);
351
352        Ok(())
353    }
354
355    async fn backend_network_point_to_point_init(
356        &mut self,
357        cx: &hyperactor::Context<Self>,
358        from_stream: StreamRef,
359        to_stream: StreamRef,
360    ) -> Result<()> {
361        if !self.streams.contains_key(&from_stream) {
362            bail!("invalid from_stream id: {:#?}", from_stream);
363        }
364        if !self.streams.contains_key(&to_stream) {
365            bail!("invalid to_stream id: {:#?}", to_stream);
366        }
367        let global_comm = self
368            .comm
369            .as_ref()
370            .context("tried to call Reduce before BackendNetworkInit")?;
371        let comm = global_comm.split_all(cx, None).await?;
372        self.send_recv_comms
373            .insert((from_stream, to_stream), Arc::new(comm));
374        Ok(())
375    }
376
377    async fn call_function(
378        &mut self,
379        cx: &hyperactor::Context<Self>,
380        params: CallFunctionParams,
381    ) -> Result<()> {
382        let stream = self.try_get_stream(params.stream)?.clone();
383        self.maybe_add_stream_to_recording(cx, params.stream)
384            .await?;
385
386        let device_meshes = if params.function.as_torch_op().is_some() {
387            HashMap::new()
388        } else {
389            self.device_meshes
390                .iter()
391                .map(|(k, v)| (k.clone(), v.0.clone()))
392                .collect()
393        };
394
395        let mut remote_process_groups = HashMap::new();
396        for remote_process_group_ref in &params.remote_process_groups {
397            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
398                let dims_vec = state.dims.iter().cloned().collect();
399                let (device_mesh, _) = self
400                    .device_meshes
401                    .get(&state.device_mesh_ref)
402                    .ok_or_else(|| {
403                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
404                    })?
405                    .clone();
406                let comm = state.comms
407                    .get(&params.stream)
408                    .ok_or_else(|| {
409                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
410                    })?
411                    .clone();
412                remote_process_groups.insert(
413                    remote_process_group_ref.clone(),
414                    (device_mesh, dims_vec, comm),
415                );
416            }
417        }
418
419        stream
420            .call_function(cx, params, device_meshes, remote_process_groups)
421            .await?;
422
423        Ok(())
424    }
425
426    async fn command_group(
427        &mut self,
428        cx: &hyperactor::Context<Self>,
429        params: Vec<WorkerMessage>,
430    ) -> Result<()> {
431        for msg in params {
432            WorkerMessageHandler::handle(self, cx, msg).await?;
433        }
434        Ok(())
435    }
436
437    async fn create_stream(
438        &mut self,
439        cx: &hyperactor::Context<Self>,
440        result: StreamRef,
441        creation_mode: StreamCreationMode,
442    ) -> Result<()> {
443        let handle: ActorHandle<StreamActor> = StreamActor::spawn(
444            cx,
445            StreamParams {
446                world_size: self.world_size,
447                rank: self.rank,
448                creation_mode,
449                id: result,
450                device: self.device,
451                controller_actor: self.controller_actor.clone(),
452                respond_with_python_message: self.respond_with_python_message,
453            },
454        )
455        .await?;
456        self.streams.insert(result, Arc::new(handle));
457        Ok(())
458    }
459
460    async fn create_device_mesh(
461        &mut self,
462        _cx: &hyperactor::Context<Self>,
463        result: Ref,
464        names: Vec<String>,
465        ranks: Slice,
466    ) -> Result<()> {
467        self.device_meshes.insert(
468            result,
469            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
470        );
471        Ok(())
472    }
473
474    async fn create_remote_process_group(
475        &mut self,
476        _cx: &hyperactor::Context<Self>,
477        result: Ref,
478        device_mesh: Ref,
479        dims: Vec<String>,
480    ) -> Result<()> {
481        self.device_meshes
482            .get(&device_mesh)
483            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
484        match self.remote_process_groups.entry(result) {
485            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
486                device_mesh,
487                SortedVec::from_unsorted(dims),
488            )),
489            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
490        };
491        Ok(())
492    }
493
494    async fn borrow_create(
495        &mut self,
496        cx: &hyperactor::Context<Self>,
497        result: Ref,
498        borrow_id: u64,
499        tensor_ref: Ref,
500        from_stream: StreamRef,
501        to_stream: StreamRef,
502    ) -> Result<()> {
503        self.maybe_add_stream_to_recording(cx, from_stream).await?;
504        self.maybe_add_stream_to_recording(cx, to_stream).await?;
505        let from_stream = self.try_get_stream(from_stream)?.clone();
506        let to_stream = self.try_get_stream(to_stream)?.clone();
507
508        let borrow =
509            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
510        self.borrows.insert(borrow_id, borrow);
511        Ok(())
512    }
513
514    async fn borrow_first_use(
515        &mut self,
516        cx: &hyperactor::Context<Self>,
517        borrow: u64,
518    ) -> Result<()> {
519        let borrow = self
520            .borrows
521            .get_mut(&borrow)
522            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
523
524        borrow.first_use(cx).await?;
525        Ok(())
526    }
527
528    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
529        let borrow = self
530            .borrows
531            .get_mut(&borrow)
532            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
533
534        borrow.last_use(cx).await?;
535        Ok(())
536    }
537
538    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
539        let borrow = self
540            .borrows
541            .get_mut(&borrow_id)
542            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
543
544        borrow.drop(cx).await?;
545        self.borrows.remove(&borrow_id);
546        Ok(())
547    }
548
549    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
550        // Fan the delete message to all streams.
551        // Check for errors.
552        // TODO: this blocks forward progress of the the actor loop while we
553        // wait for the streams to catch up. Once we have a way of spawning
554        // tasks that the actor system can monitor in a non-blocking way, we
555        // should remove this.
556        let _: Vec<()> = try_join_all(
557            self.streams
558                .values()
559                .map(|s| s.delete_refs(cx, refs.clone())),
560        )
561        .await?;
562        Ok(())
563    }
564
565    async fn request_status(
566        &mut self,
567        cx: &hyperactor::Context<Self>,
568        seq: Seq,
569        controller: bool,
570    ) -> Result<()> {
571        // TODO: this blocks forward progress of the the actor loop while we
572        // wait for the streams to catch up. Once we have a way of spawning
573        // tasks that the actor system can monitor in a non-blocking way, we
574        // should remove this.
575        let _: Vec<()> = try_join_all(
576            self.streams
577                .values()
578                .map(|stream| stream.request_status(cx)),
579        )
580        .await?;
581
582        ControllerMessageClient::status(
583            &self.controller_actor,
584            cx,
585            seq.next(),
586            cx.self_id().clone(),
587            controller,
588        )
589        .await?;
590        Ok(())
591    }
592
593    async fn reduce(
594        &mut self,
595        cx: &hyperactor::Context<Self>,
596        result: Ref,
597        local_tensor: Ref,
598        factory: Factory,
599        source_mesh: Ref,
600        stream_ref: StreamRef,
601        dims: Vec<String>,
602        reduction: Reduction,
603        scatter: bool,
604        in_place: bool,
605        out: Option<Ref>,
606    ) -> Result<()> {
607        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
608
609        // Sort for stable indexing.
610        let dims = SortedVec::from_unsorted(dims);
611        let stream = self.try_get_stream(stream_ref)?.clone();
612
613        let (_, comm_map) = self
614            .device_meshes
615            .get_mut(&source_mesh)
616            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
617
618        let (size, comm) = comm_map
619            .get(&(stream_ref, dims.clone()))
620            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
621            .clone();
622
623        stream
624            .reduce(
625                cx,
626                comm,
627                size.try_into()?,
628                result,
629                local_tensor,
630                factory,
631                reduction,
632                scatter,
633                in_place,
634                out,
635            )
636            .await?;
637
638        Ok(())
639    }
640
641    async fn create_pipe(
642        &mut self,
643        _cx: &hyperactor::Context<Self>,
644        _result: Ref,
645        // TODO(agallagher): This is used in the python impl to name the socket
646        // path to use for comms, but we don't currently use a named socket.
647        _key: String,
648        _function: ResolvableFunction,
649        _max_messages: i64,
650        _device_mesh: Ref,
651        _args: Vec<WireValue>,
652        _kwargs: HashMap<String, WireValue>,
653    ) -> Result<()> {
654        panic!("create_pipe is no longer implemented")
655    }
656
657    async fn send_tensor(
658        &mut self,
659        cx: &hyperactor::Context<Self>,
660        result: Ref,
661        from_ranks: Slice,
662        to_ranks: Slice,
663        tensor: Ref,
664        factory: Factory,
665        from_stream: StreamRef,
666        to_stream: StreamRef,
667    ) -> Result<()> {
668        let comm = self
669            .send_recv_comms
670            .get(&(from_stream, to_stream))
671            .ok_or_else(|| {
672                anyhow::anyhow!(
673                    "could not find stream to stream comm for: {:#?}",
674                    (from_stream, to_stream)
675                )
676            })?
677            .clone();
678
679        let to_rank = from_ranks
680            .index(self.rank)
681            .map(|index| to_ranks.get(index).ok())
682            .ok()
683            .flatten();
684        let from_rank = to_ranks
685            .index(self.rank)
686            .map(|index| from_ranks.get(index).ok())
687            .ok()
688            .flatten();
689
690        let (stream, stream_ref) = if to_rank.is_none() {
691            (self.try_get_stream(to_stream)?.clone(), to_stream)
692        } else if from_rank.is_none() || from_stream == to_stream {
693            (self.try_get_stream(from_stream)?.clone(), from_stream)
694        } else {
695            unimplemented!(
696                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
697                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
698                Then the send stream would do the nccl op, and then sync with sending stream again."
699            );
700        };
701
702        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
703
704        stream
705            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
706            .await?;
707
708        Ok(())
709    }
710
711    async fn exit(
712        &mut self,
713        cx: &hyperactor::Context<Self>,
714        error: Option<(Option<ActorId>, String)>,
715    ) -> Result<()> {
716        for (_, stream) in self.streams.drain() {
717            stream.drain_and_stop()?;
718            Arc::into_inner(stream)
719                .expect("there should be no owners of this stream handle except the worker stream table")
720                .await;
721        }
722
723        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
724            .ok()
725            .and_then(|val| val.parse::<i32>().ok())
726            .unwrap_or(1);
727        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
728            .ok()
729            .and_then(|val| val.parse::<i32>().ok())
730            .unwrap_or(1);
731
732        // Exit the worker process if there is an error.
733        let exit_code = match error {
734            Some((Some(actor_id), reason)) => {
735                tracing::error!(
736                    "stopping the worker, actor {} failed with error: {}",
737                    actor_id,
738                    reason
739                );
740                if *cx.self_id() == actor_id {
741                    self_error_exit_code
742                } else {
743                    peer_error_exit_code
744                }
745            }
746            Some((None, reason)) => {
747                tracing::error!("stopping the worker, reason: {}", reason);
748                1
749            }
750            None => 0,
751        };
752
753        if exit_code != 0 {
754            tracing::info!("stopping the worker process, exit code: {}", exit_code);
755            std::process::exit(exit_code);
756        }
757        cx.stop()?;
758        Ok(())
759    }
760
761    async fn send_value(
762        &mut self,
763        cx: &hyperactor::Context<Self>,
764        seq: Seq,
765        destination: Option<Ref>,
766        mutates: Vec<Ref>,
767        function: Option<ResolvableFunction>,
768        args: Vec<WireValue>,
769        kwargs: HashMap<String, WireValue>,
770        stream: StreamRef,
771    ) -> Result<()> {
772        // Resolve the stream.
773        let stream = self.try_get_stream(stream)?;
774
775        let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
776            HashMap::new()
777        } else {
778            self.device_meshes
779                .iter()
780                .map(|(k, v)| (k.clone(), v.0.clone()))
781                .collect()
782        };
783
784        if destination.is_some() {
785            panic!("send_value with pipe destination is no longer implemented")
786        }
787
788        // Resolve the value on the stream, then send the value back to the controller.
789        stream
790            .send_value(
791                cx,
792                seq,
793                cx.self_id().clone(),
794                mutates,
795                function,
796                args,
797                kwargs,
798                device_meshes,
799            )
800            .await
801    }
802
803    async fn send_result_of_actor_call(
804        &mut self,
805        cx: &hyperactor::Context<Self>,
806        params: ActorCallParams,
807    ) -> Result<()> {
808        let stream = self.try_get_stream(params.stream)?;
809        stream.send_result_of_actor_call(cx, params).await?;
810        Ok(())
811    }
812    async fn call_actor_method(
813        &mut self,
814        cx: &hyperactor::Context<Self>,
815        params: ActorMethodParams,
816    ) -> Result<()> {
817        let stream = self.try_get_stream(params.call.stream)?;
818        stream.call_actor_method(cx, params).await?;
819        Ok(())
820    }
821    async fn split_comm(
822        &mut self,
823        cx: &hyperactor::Context<Self>,
824        dims: Vec<String>,
825        device_mesh: Ref,
826        stream_ref: StreamRef,
827        config: Option<NcclConfig>,
828    ) -> Result<()> {
829        let global_comm = self
830            .comm
831            .as_ref()
832            .context("tried to call SplitComm before BackendNetworkInit")?;
833        match self.device_meshes.get_mut(&device_mesh) {
834            Some((device_mesh, comm_map)) => {
835                // This rank is in the group to be split off. Split a new
836                // communicator for it off from the global communicator.
837                let stream = self
838                    .streams
839                    .get(&stream_ref)
840                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
841
842                let dims = SortedVec::from_unsorted(dims);
843
844                anyhow::ensure!(
845                    !comm_map.contains_key(&(stream_ref, dims.clone())),
846                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
847                );
848                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
849                let size = ranks_for_group.len();
850                let split_comm = global_comm
851                    .split_from(
852                        cx,
853                        ranks_for_group
854                            .into_iter()
855                            .map(|v| v.clone().try_into())
856                            .collect::<Result<Vec<_>, _>>()?,
857                        config,
858                    )
859                    .await?
860                    .context("split comm should include self rank")?;
861                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
862            }
863            None => {
864                // This rank is not in the group to be split off. We still need to
865                // participate in the commSplit call, however.
866                global_comm.split_from(cx, vec![], config).await?;
867            }
868        }
869        Ok(())
870    }
871
872    async fn split_comm_for_process_group(
873        &mut self,
874        cx: &hyperactor::Context<Self>,
875        remote_process_group_ref: Ref,
876        stream_ref: StreamRef,
877        config: Option<NcclConfig>,
878    ) -> Result<()> {
879        ensure!(
880            self.streams.contains_key(&stream_ref),
881            "invalid stream id: {:#?}",
882            stream_ref
883        );
884        let global_comm = self
885            .comm
886            .as_ref()
887            .context("tried to call SplitComm before BackendNetworkInit")?;
888        let state = self
889            .remote_process_groups
890            .get_mut(&remote_process_group_ref)
891            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
892        match self.device_meshes.get_mut(&state.device_mesh_ref) {
893            Some((device_mesh, _)) => {
894                // This rank is in the group to be split off. Split a new
895                // communicator for it off from the global communicator.
896                let entry = match state.comms.entry(stream_ref) {
897                    Entry::Vacant(entry) => entry,
898                    Entry::Occupied(_) => bail!(
899                        "comm already exists for remote process group {:#?} on stream {:#?}",
900                        remote_process_group_ref,
901                        stream_ref,
902                    ),
903                };
904                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
905                let split_comm = global_comm
906                    .split_from(
907                        cx,
908                        ranks_for_group
909                            .into_iter()
910                            .map(|v| v.clone().try_into())
911                            .collect::<Result<Vec<_>, _>>()?,
912                        config,
913                    )
914                    .await?
915                    .context("split comm should include self rank")?;
916                entry.insert(Arc::new(split_comm));
917            }
918            None => {
919                // This rank is not in the group to be split off. We still need to
920                // participate in the commSplit call, however.
921                global_comm.split_from(cx, vec![], config).await?;
922            }
923        }
924        Ok(())
925    }
926
927    async fn pipe_recv(
928        &mut self,
929        _cx: &hyperactor::Context<Self>,
930        _seq: Seq,
931        _results: Vec<Option<Ref>>,
932        _pipe: Ref,
933        _stream: StreamRef,
934    ) -> Result<()> {
935        panic!("pipe_recv is no longer implemented")
936    }
937
938    async fn set_ref_unit_tests_only(
939        &mut self,
940        cx: &hyperactor::Context<Self>,
941        reference: Ref,
942        value: WireValue,
943        stream: StreamRef,
944    ) -> Result<()> {
945        let stream = self.try_get_stream(stream)?;
946
947        stream.set_ref_unit_tests_only(cx, reference, value).await
948    }
949
950    async fn get_ref_unit_tests_only(
951        &mut self,
952        cx: &hyperactor::Context<Self>,
953        ref_id: Ref,
954        stream: StreamRef,
955    ) -> Result<Option<Result<WireValue, String>>> {
956        let stream = self.try_get_stream(stream)?;
957        Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
958    }
959
960    async fn define_recording(
961        &mut self,
962        cx: &hyperactor::Context<Self>,
963        result: Ref,
964        _nresults: usize,
965        _nformals: usize,
966        commands: Vec<WorkerMessage>,
967        ntotal_messages: usize,
968        index: usize,
969    ) -> Result<()> {
970        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
971            bail!("already defining a different recording");
972        }
973        self.defining_recording = Some(result);
974
975        match self.recordings.entry(result) {
976            Entry::Vacant(entry) => {
977                ensure!(
978                    index == 0,
979                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
980                    index
981                );
982                entry.insert(Recording::PartialRecording {
983                    last_index: 0,
984                    commands,
985                });
986            }
987            Entry::Occupied(mut entry) => match entry.get_mut() {
988                Recording::CompleteRecording { .. } => {
989                    bail!("got DefineRecording message for already complete recording")
990                }
991                Recording::PartialRecording {
992                    last_index,
993                    commands: existing_commands,
994                } => {
995                    ensure!(
996                        index == *last_index + 1,
997                        "Got DefineRecording message with index = {:?}, but \
998                            last seen index for recording is {:?}",
999                        index,
1000                        last_index
1001                    );
1002                    *last_index = index;
1003                    existing_commands.extend(commands.into_iter());
1004                }
1005            },
1006        };
1007
1008        if index < ntotal_messages - 1 {
1009            return Ok(());
1010        }
1011        let commands = match self.recordings.remove(&result).unwrap() {
1012            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1013            Recording::PartialRecording { commands, .. } => {
1014                self.recordings.insert(
1015                    result,
1016                    Recording::CompleteRecording {
1017                        streams: HashSet::new(),
1018                    },
1019                );
1020                commands
1021            }
1022        };
1023
1024        for command in commands {
1025            WorkerMessageHandler::handle(self, cx, command).await?;
1026        }
1027
1028        match self.recordings.get(&result).unwrap() {
1029            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1030            Recording::CompleteRecording { streams, .. } => {
1031                for stream in streams {
1032                    self.try_get_stream(*stream)?
1033                        .finalize_recording(cx, result)
1034                        .await?;
1035                }
1036            }
1037        }
1038
1039        self.defining_recording = None;
1040        Ok(())
1041    }
1042
1043    async fn recording_formal(
1044        &mut self,
1045        cx: &hyperactor::Context<Self>,
1046        result: Ref,
1047        argument_index: usize,
1048        stream: StreamRef,
1049    ) -> Result<()> {
1050        ensure!(self.defining_recording.is_some());
1051        self.maybe_add_stream_to_recording(cx, stream).await?;
1052        self.try_get_stream(stream)?
1053            .recording_formal(cx, result, argument_index)
1054            .await
1055    }
1056
1057    async fn recording_result(
1058        &mut self,
1059        cx: &hyperactor::Context<Self>,
1060        result: Ref,
1061        output_index: usize,
1062        stream: StreamRef,
1063    ) -> Result<()> {
1064        ensure!(self.defining_recording.is_some());
1065        self.maybe_add_stream_to_recording(cx, stream).await?;
1066        self.try_get_stream(stream)?
1067            .recording_result(cx, result, output_index)
1068            .await
1069    }
1070
1071    async fn call_recording(
1072        &mut self,
1073        cx: &hyperactor::Context<Self>,
1074        seq: Seq,
1075        recording: Ref,
1076        results: Vec<Ref>,
1077        actuals: Vec<Ref>,
1078    ) -> Result<()> {
1079        ensure!(self.defining_recording.is_none());
1080        let recording_ref = recording;
1081        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1082            "could not find recording: {:#?}",
1083            recording
1084        ))?;
1085        match recording {
1086            Recording::PartialRecording { .. } => {
1087                bail!("cannot call recording because it is incomplete")
1088            }
1089            Recording::CompleteRecording { streams } => try_join_all(
1090                streams
1091                    .iter()
1092                    .map(|stream| self.try_get_stream(*stream))
1093                    .collect::<Result<Vec<_>>>()?
1094                    .into_iter()
1095                    .map(|stream| {
1096                        stream.call_recording(
1097                            cx,
1098                            seq,
1099                            recording_ref,
1100                            results.clone(),
1101                            actuals.clone(),
1102                        )
1103                    }),
1104            )
1105            .await
1106            .map(|_| ()),
1107        }
1108    }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use std::assert_matches::assert_matches;
1114    use std::process::Stdio;
1115
1116    use anyhow::Result;
1117    use hyperactor::Instance;
1118    use hyperactor::WorldId;
1119    use hyperactor::actor::ActorStatus;
1120    use hyperactor::channel::ChannelAddr;
1121    use hyperactor::id;
1122    use hyperactor::mailbox::open_port;
1123    use hyperactor::proc::Proc;
1124    use hyperactor_multiprocess::System;
1125    use hyperactor_multiprocess::proc_actor::Environment;
1126    use hyperactor_multiprocess::proc_actor::ProcActor;
1127    use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1128    use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1129    use hyperactor_multiprocess::system_actor::Shape;
1130    use hyperactor_multiprocess::system_actor::SystemMessageClient;
1131    use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1132    use hyperactor_multiprocess::system_actor::WorldStatus;
1133    use monarch_messages::controller::ControllerMessage;
1134    use monarch_messages::controller::WorkerError;
1135    use monarch_messages::worker::WorkerMessageClient;
1136    use monarch_types::PickledPyObject;
1137    use monarch_types::PyTree;
1138    use pyo3::IntoPyObjectExt;
1139    use pyo3::Python;
1140    use pyo3::prelude::*;
1141    use pyo3::types::PyList;
1142    use pyo3::types::PyString;
1143    use rand::Rng;
1144    use rand::distributions::Alphanumeric;
1145    use timed_test::async_timed_test;
1146    use tokio::io::BufReader;
1147    use tokio::process::Command;
1148    use tokio_retry::Retry;
1149    use tokio_retry::strategy::FixedInterval;
1150    use torch_sys::Device;
1151    use torch_sys::DeviceIndex;
1152    use torch_sys::MemoryFormat;
1153
1154    use super::*;
1155    use crate::test_util::test_setup;
1156
1157    #[async_timed_test(timeout_secs = 60)]
1158    async fn basic_worker() -> Result<()> {
1159        test_setup()?;
1160
1161        let proc = Proc::local();
1162        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1163
1164        let worker_handle = proc
1165            .spawn::<WorkerActor>(
1166                "worker",
1167                WorkerParams {
1168                    world_size: 1,
1169                    rank: 0,
1170                    device_index: None,
1171                    controller_actor: controller_ref,
1172                },
1173            )
1174            .await
1175            .unwrap();
1176        worker_handle
1177            .command_group(
1178                &client,
1179                vec![
1180                    WorkerMessage::CreateStream {
1181                        id: 1.into(),
1182                        stream_creation: StreamCreationMode::UseDefaultStream,
1183                    },
1184                    WorkerMessage::CallFunction(CallFunctionParams {
1185                        seq: 0.into(),
1186                        results: vec![Some(0.into())],
1187                        mutates: vec![],
1188                        function: "torch.ops.aten.ones.default".into(),
1189                        args: vec![WireValue::IntList(vec![2, 3])],
1190                        kwargs: HashMap::new(),
1191                        stream: 1.into(),
1192                        remote_process_groups: vec![],
1193                    }),
1194                    WorkerMessage::CallFunction(CallFunctionParams {
1195                        seq: 2.into(),
1196                        results: vec![Some(Ref { id: 2 })],
1197                        mutates: vec![0.into()],
1198                        function: "torch.ops.aten.sub_.Scalar".into(),
1199                        args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1200                        kwargs: HashMap::new(),
1201                        stream: 1.into(),
1202                        remote_process_groups: vec![],
1203                    }),
1204                    WorkerMessage::CallFunction(CallFunctionParams {
1205                        seq: 3.into(),
1206                        results: vec![Some(Ref { id: 3 })],
1207                        mutates: vec![],
1208                        function: "torch.ops.aten.zeros.default".into(),
1209                        args: vec![WireValue::IntList(vec![2, 3])],
1210                        kwargs: HashMap::new(),
1211                        stream: 1.into(),
1212                        remote_process_groups: vec![],
1213                    }),
1214                    WorkerMessage::CallFunction(CallFunctionParams {
1215                        seq: 4.into(),
1216                        results: vec![Some(Ref { id: 4 })],
1217                        mutates: vec![],
1218                        function: "torch.ops.aten.allclose.default".into(),
1219                        args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1220                        kwargs: HashMap::new(),
1221                        stream: 1.into(),
1222                        remote_process_groups: vec![],
1223                    }),
1224                ],
1225            )
1226            .await
1227            .unwrap();
1228
1229        let result: bool = worker_handle
1230            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1231            .await
1232            .unwrap()
1233            .unwrap()
1234            .unwrap()
1235            .try_into()
1236            .unwrap();
1237        worker_handle.drain_and_stop().unwrap();
1238        worker_handle.await;
1239        let error_responses = controller_rx.drain();
1240        assert!(
1241            error_responses.is_empty(),
1242            "Expected no error responses, got: {:#?}",
1243            error_responses
1244        );
1245        assert!(result);
1246
1247        Ok(())
1248    }
1249
1250    #[async_timed_test(timeout_secs = 60)]
1251    async fn error_sends_response() -> Result<()> {
1252        test_setup()?;
1253
1254        let proc = Proc::local();
1255        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1256
1257        let worker_handle = proc
1258            .spawn::<WorkerActor>(
1259                "worker",
1260                WorkerParams {
1261                    world_size: 1,
1262                    rank: 0,
1263                    device_index: None,
1264                    controller_actor: controller_ref,
1265                },
1266            )
1267            .await
1268            .unwrap();
1269        worker_handle
1270            .command_group(
1271                &client,
1272                vec![
1273                    WorkerMessage::CreateStream {
1274                        id: 1.into(),
1275                        stream_creation: StreamCreationMode::UseDefaultStream,
1276                    },
1277                    WorkerMessage::CallFunction(CallFunctionParams {
1278                        seq: 0.into(),
1279                        results: vec![Some(0.into())],
1280                        mutates: vec![],
1281                        function: "torch.ops.aten.rand.default".into(),
1282                        args: vec![],
1283                        kwargs: HashMap::new(),
1284                        stream: 1.into(),
1285                        remote_process_groups: vec![],
1286                    }),
1287                    WorkerMessage::Exit { error: None },
1288                ],
1289            )
1290            .await
1291            .unwrap();
1292
1293        worker_handle.drain_and_stop().unwrap();
1294        worker_handle.await;
1295        let response_message = controller_rx.recv().await.unwrap();
1296        match response_message {
1297            ControllerMessage::RemoteFunctionFailed {
1298                seq,
1299                error: WorkerError { backtrace: msg, .. },
1300            } => {
1301                assert_eq!(seq, 0.into());
1302                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1303            }
1304            _ => panic!("unexpected response {:#?}", response_message),
1305        }
1306
1307        Ok(())
1308    }
1309
1310    #[async_timed_test(timeout_secs = 60)]
1311    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1312        test_setup()?;
1313
1314        let proc = Proc::local();
1315        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1316
1317        let worker_handle = proc
1318            .spawn::<WorkerActor>(
1319                "worker",
1320                WorkerParams {
1321                    world_size: 1,
1322                    rank: 0,
1323                    device_index: None,
1324                    controller_actor: controller_ref,
1325                },
1326            )
1327            .await
1328            .unwrap();
1329        worker_handle
1330            .command_group(
1331                &client,
1332                vec![
1333                    WorkerMessage::CreateStream {
1334                        id: 1.into(),
1335                        stream_creation: StreamCreationMode::UseDefaultStream,
1336                    },
1337                    WorkerMessage::SetRefUnitTestsOnly {
1338                        reference: 0.into(),
1339                        value: WireValue::Int(1),
1340                        stream: 1.into(),
1341                    },
1342                    WorkerMessage::CallFunction(CallFunctionParams {
1343                        seq: 0.into(),
1344                        results: vec![Some(Ref { id: 2 })],
1345                        mutates: vec![0.into()],
1346                        function: "i.dont.exist".into(),
1347                        args: vec![],
1348                        kwargs: HashMap::new(),
1349                        stream: 1.into(),
1350                        remote_process_groups: vec![],
1351                    }),
1352                ],
1353            )
1354            .await
1355            .unwrap();
1356
1357        let result = worker_handle
1358            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1359            .await?;
1360
1361        // Stop/drain worker before asserts to avoid hangs.
1362        worker_handle.drain_and_stop().unwrap();
1363        worker_handle.await;
1364
1365        let mutated_ref = result
1366            .context("no such ref")?
1367            .err()
1368            .context("expected error")?;
1369        assert!(mutated_ref.contains("failed to resolve function"));
1370
1371        let responses = controller_rx.drain();
1372        assert_eq!(
1373            responses.len(),
1374            1,
1375            "Expected one response, got: {:#?}",
1376            responses
1377        );
1378        Ok(())
1379    }
1380
1381    #[async_timed_test(timeout_secs = 60)]
1382    async fn accessing_errored_dependency() -> Result<()> {
1383        test_setup()?;
1384
1385        let proc = Proc::local();
1386        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1387
1388        let worker_handle = proc
1389            .spawn::<WorkerActor>(
1390                "worker",
1391                WorkerParams {
1392                    world_size: 1,
1393                    rank: 0,
1394                    device_index: None,
1395                    controller_actor: controller_ref,
1396                },
1397            )
1398            .await
1399            .unwrap();
1400        worker_handle
1401            .command_group(
1402                &client,
1403                vec![
1404                    WorkerMessage::CreateStream {
1405                        id: 1.into(),
1406                        stream_creation: StreamCreationMode::UseDefaultStream,
1407                    },
1408                    WorkerMessage::CallFunction(CallFunctionParams {
1409                        seq: 0.into(),
1410                        results: vec![Some(0.into())],
1411                        mutates: vec![],
1412                        function: "i.dont.exist".into(),
1413                        args: vec![],
1414                        kwargs: HashMap::new(),
1415                        stream: 1.into(),
1416                        remote_process_groups: vec![],
1417                    }),
1418                    WorkerMessage::CallFunction(CallFunctionParams {
1419                        seq: 1.into(),
1420                        results: vec![Some(1.into())],
1421                        mutates: vec![],
1422                        function: "torch.ops.aten.sub_.Scalar".into(),
1423                        args: vec![WireValue::Ref(0.into())],
1424                        kwargs: HashMap::new(),
1425                        stream: 1.into(),
1426                        remote_process_groups: vec![],
1427                    }),
1428                    WorkerMessage::Exit { error: None },
1429                ],
1430            )
1431            .await
1432            .unwrap();
1433
1434        worker_handle.drain_and_stop().unwrap();
1435        worker_handle.await;
1436
1437        let responses = controller_rx.drain();
1438        assert_eq!(
1439            responses.len(),
1440            1,
1441            "Expected one response, got: {:#?}",
1442            responses
1443        );
1444
1445        match &responses[0] {
1446            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1447                assert_eq!(seq, &0.into())
1448            }
1449            _ => panic!("unexpected response {:#?}", responses[0]),
1450        };
1451        Ok(())
1452    }
1453
1454    #[async_timed_test(timeout_secs = 60)]
1455    async fn py_remote_function_calls() -> Result<()> {
1456        test_setup()?;
1457
1458        let proc = Proc::local();
1459        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1460
1461        let worker_handle = proc
1462            .spawn::<WorkerActor>(
1463                "worker",
1464                WorkerParams {
1465                    world_size: 1,
1466                    rank: 0,
1467                    device_index: None,
1468                    controller_actor: controller_ref,
1469                },
1470            )
1471            .await
1472            .unwrap();
1473        let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1474            Python::with_gil(|py| {
1475                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1476                    .into_any()
1477                    .try_into()?;
1478                let sort_list: PickledPyObject =
1479                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1480                let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1481                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1482                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1483                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1484                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1485                let device: PickledPyObject = py
1486                    .import("torch")?
1487                    .getattr("device")?
1488                    .call1(("cuda:1",))?
1489                    .try_into()?;
1490                let memory_format: PickledPyObject = py
1491                    .import("torch")?
1492                    .getattr("contiguous_format")?
1493                    .try_into()?;
1494                PyResult::Ok((
1495                    split_arg,
1496                    sort_list,
1497                    mesh_ref,
1498                    dim,
1499                    layout,
1500                    none,
1501                    scalar,
1502                    device,
1503                    memory_format,
1504                ))
1505            })?;
1506
1507        worker_handle
1508            .command_group(
1509                &client,
1510                vec![
1511                    WorkerMessage::CreateStream {
1512                        id: 1.into(),
1513                        stream_creation: StreamCreationMode::UseDefaultStream,
1514                    },
1515                    WorkerMessage::CallFunction(CallFunctionParams {
1516                        seq: 0.into(),
1517                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1518                        mutates: vec![],
1519                        function: "os.path.split".into(),
1520                        args: vec![split_arg.into()],
1521                        kwargs: HashMap::new(),
1522                        stream: 1.into(),
1523                        remote_process_groups: vec![],
1524                    }),
1525                    WorkerMessage::CallFunction(CallFunctionParams {
1526                        seq: 2.into(),
1527                        results: vec![Some(4.into()), None, None, None, None],
1528                        mutates: vec![],
1529                        function: "builtins.sorted".into(),
1530                        args: vec![sort_list.into()],
1531                        kwargs: HashMap::new(),
1532                        stream: 1.into(),
1533                        remote_process_groups: vec![],
1534                    }),
1535                    WorkerMessage::CreateDeviceMesh {
1536                        result: 5.into(),
1537                        names: vec!["x".into()],
1538                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1539                    },
1540                    WorkerMessage::CallFunction(CallFunctionParams {
1541                        seq: 2.into(),
1542                        results: vec![Some(6.into())],
1543                        mutates: vec![],
1544                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1545                        args: vec![mesh_ref.into(), dim.into()],
1546                        kwargs: HashMap::new(),
1547                        stream: 1.into(),
1548                        remote_process_groups: vec![],
1549                    }),
1550                    WorkerMessage::CallFunction(CallFunctionParams {
1551                        seq: 4.into(),
1552                        results: vec![Some(7.into())],
1553                        mutates: vec![],
1554                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1555                            .into(),
1556                        args: vec![scalar.into()],
1557                        kwargs: HashMap::new(),
1558                        stream: 1.into(),
1559                        remote_process_groups: vec![],
1560                    }),
1561                    WorkerMessage::CallFunction(CallFunctionParams {
1562                        seq: 5.into(),
1563                        results: vec![Some(8.into())],
1564                        mutates: vec![],
1565                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1566                        args: vec![layout.into()],
1567                        kwargs: HashMap::new(),
1568                        stream: 1.into(),
1569                        remote_process_groups: vec![],
1570                    }),
1571                    WorkerMessage::CallFunction(CallFunctionParams {
1572                        seq: 6.into(),
1573                        results: vec![Some(9.into())],
1574                        mutates: vec![],
1575                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1576                        args: vec![none.into()],
1577                        kwargs: HashMap::new(),
1578                        stream: 1.into(),
1579                        remote_process_groups: vec![],
1580                    }),
1581                    // Verify that a function that returns `None` matches up with an
1582                    // empty result list.
1583                    WorkerMessage::CallFunction(CallFunctionParams {
1584                        seq: 7.into(),
1585                        results: vec![None],
1586                        mutates: vec![],
1587                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1588                        args: vec![],
1589                        kwargs: HashMap::new(),
1590                        stream: 1.into(),
1591                        remote_process_groups: vec![],
1592                    }),
1593                    WorkerMessage::CallFunction(CallFunctionParams {
1594                        seq: 8.into(),
1595                        results: vec![Some(10.into())],
1596                        mutates: vec![],
1597                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1598                        args: vec![device.into()],
1599                        kwargs: HashMap::new(),
1600                        stream: 1.into(),
1601                        remote_process_groups: vec![],
1602                    }),
1603                    WorkerMessage::CallFunction(CallFunctionParams {
1604                        seq: 9.into(),
1605                        results: vec![Some(11.into())],
1606                        mutates: vec![],
1607                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1608                            .into(),
1609                        args: vec![memory_format.into()],
1610                        kwargs: HashMap::new(),
1611                        stream: 1.into(),
1612                        remote_process_groups: vec![],
1613                    }),
1614                    // Test that list of tests can be passes correctly
1615                    WorkerMessage::CallFunction(CallFunctionParams {
1616                        seq: 10.into(),
1617                        results: vec![Some(12.into())],
1618                        mutates: vec![],
1619                        function: "torch.ops.aten.ones.default".into(),
1620                        args: vec![WireValue::IntList(vec![2, 3])],
1621                        kwargs: HashMap::new(),
1622                        stream: 1.into(),
1623                        remote_process_groups: vec![],
1624                    }),
1625                    WorkerMessage::CallFunction(CallFunctionParams {
1626                        seq: 11.into(),
1627                        results: vec![Some(13.into())],
1628                        mutates: vec![],
1629                        function: "torch.ops.aten.stack.default".into(),
1630                        args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1631                        kwargs: HashMap::new(),
1632                        stream: 1.into(),
1633                        remote_process_groups: vec![],
1634                    }),
1635                ],
1636            )
1637            .await
1638            .unwrap();
1639
1640        let result1: String = worker_handle
1641            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1642            .await
1643            .unwrap()
1644            .unwrap()
1645            .unwrap()
1646            .try_into()
1647            .unwrap();
1648        let result2: String = worker_handle
1649            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1650            .await
1651            .unwrap()
1652            .unwrap()
1653            .unwrap()
1654            .try_into()
1655            .unwrap();
1656        let result3: i64 = worker_handle
1657            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1658            .await
1659            .unwrap()
1660            .unwrap()
1661            .unwrap()
1662            .try_into()
1663            .unwrap();
1664        let result4: i64 = worker_handle
1665            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1666            .await
1667            .unwrap()
1668            .unwrap()
1669            .unwrap()
1670            .try_into()
1671            .unwrap();
1672        assert_eq!(
1673            ScalarType::Float,
1674            worker_handle
1675                .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1676                .await
1677                .unwrap()
1678                .unwrap()
1679                .unwrap()
1680                .try_into()
1681                .unwrap()
1682        );
1683        assert_eq!(
1684            Layout::Strided,
1685            worker_handle
1686                .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1687                .await
1688                .unwrap()
1689                .unwrap()
1690                .unwrap()
1691                .try_into()
1692                .unwrap()
1693        );
1694        assert_matches!(
1695            worker_handle
1696                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1697                .await
1698                .unwrap()
1699                .unwrap()
1700                .unwrap(),
1701            WireValue::None(()),
1702        );
1703        let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1704        assert_eq!(
1705            device,
1706            worker_handle
1707                .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1708                .await
1709                .unwrap()
1710                .unwrap()
1711                .unwrap()
1712                .try_into()
1713                .unwrap()
1714        );
1715        assert_matches!(
1716            worker_handle
1717                .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1718                .await
1719                .unwrap()
1720                .unwrap()
1721                .unwrap(),
1722            WireValue::MemoryFormat(MemoryFormat::Contiguous),
1723        );
1724
1725        worker_handle.drain_and_stop().unwrap();
1726        worker_handle.await;
1727        let error_responses = controller_rx.drain();
1728        assert!(
1729            error_responses.is_empty(),
1730            "Expected no error responses, got: {:#?}",
1731            error_responses
1732        );
1733
1734        assert_eq!(result1, "/fbs/fbc/foo");
1735        assert_eq!(result2, "bar");
1736        assert_eq!(result3, 1);
1737        assert_eq!(result4, 0);
1738
1739        Ok(())
1740    }
1741
1742    #[async_timed_test(timeout_secs = 60)]
1743    async fn delete_refs() -> Result<()> {
1744        test_setup()?;
1745
1746        let proc = Proc::local();
1747        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1748
1749        let worker_handle = proc
1750            .spawn::<WorkerActor>(
1751                "worker",
1752                WorkerParams {
1753                    world_size: 1,
1754                    rank: 0,
1755                    device_index: None,
1756                    controller_actor: controller_ref,
1757                },
1758            )
1759            .await
1760            .unwrap();
1761        worker_handle
1762            .command_group(
1763                &client,
1764                vec![
1765                    WorkerMessage::CreateStream {
1766                        id: 0.into(),
1767                        stream_creation: StreamCreationMode::CreateNewStream,
1768                    },
1769                    WorkerMessage::CreateStream {
1770                        id: 1.into(),
1771                        stream_creation: StreamCreationMode::CreateNewStream,
1772                    },
1773                    WorkerMessage::SetRefUnitTestsOnly {
1774                        reference: Ref { id: 2 },
1775                        value: WireValue::Bool(false),
1776                        stream: 0.into(),
1777                    },
1778                    WorkerMessage::SetRefUnitTestsOnly {
1779                        reference: Ref { id: 3 },
1780                        value: WireValue::Bool(true),
1781                        stream: 0.into(),
1782                    },
1783                    WorkerMessage::SetRefUnitTestsOnly {
1784                        reference: Ref { id: 4 },
1785                        value: WireValue::Int(0),
1786                        stream: 1.into(),
1787                    },
1788                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1789                ],
1790            )
1791            .await
1792            .unwrap();
1793
1794        let result: bool = worker_handle
1795            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1796            .await
1797            .unwrap()
1798            .unwrap()
1799            .unwrap()
1800            .try_into()
1801            .unwrap();
1802        let fail_result = worker_handle
1803            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1804            .await
1805            .unwrap();
1806
1807        worker_handle.drain_and_stop().unwrap();
1808        worker_handle.await;
1809
1810        assert!(result, "should be able to get a non-deleted ref");
1811        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1812
1813        Ok(())
1814    }
1815
1816    #[async_timed_test(timeout_secs = 60)]
1817    async fn request_status() -> Result<()> {
1818        test_setup()?;
1819
1820        let proc = Proc::local();
1821        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1822
1823        let worker_handle = proc
1824            .spawn::<WorkerActor>(
1825                "worker",
1826                WorkerParams {
1827                    world_size: 1,
1828                    rank: 0,
1829                    device_index: None,
1830                    controller_actor: controller_ref,
1831                },
1832            )
1833            .await
1834            .unwrap();
1835        worker_handle
1836            .command_group(
1837                &client,
1838                vec![
1839                    WorkerMessage::CreateStream {
1840                        id: 0.into(),
1841                        stream_creation: StreamCreationMode::CreateNewStream,
1842                    },
1843                    WorkerMessage::CreateStream {
1844                        id: 1.into(),
1845                        stream_creation: StreamCreationMode::CreateNewStream,
1846                    },
1847                ],
1848            )
1849            .await
1850            .unwrap();
1851
1852        for i in 0..100 {
1853            // call alternating functions on this stream.
1854            worker_handle
1855                .call_function(
1856                    &client,
1857                    CallFunctionParams {
1858                        seq: i.into(),
1859                        results: vec![Some(Ref { id: i + 2 })],
1860                        mutates: vec![],
1861                        function: "torch.ops.aten.ones.default".into(),
1862                        args: vec![WireValue::IntList(vec![2, 3])],
1863                        kwargs: HashMap::new(),
1864                        stream: (i % 2).into(),
1865                        remote_process_groups: vec![],
1866                    },
1867                )
1868                .await
1869                .unwrap();
1870        }
1871
1872        worker_handle
1873            .request_status(&client, 100.into(), false)
1874            .await
1875            .unwrap();
1876
1877        worker_handle.drain_and_stop().unwrap();
1878        worker_handle.await;
1879
1880        let mut responses = controller_rx.drain();
1881        assert_eq!(
1882            responses.len(),
1883            1,
1884            "Expected one response, got: {:#?}",
1885            responses
1886        );
1887
1888        let response = responses.pop().unwrap();
1889        match response {
1890            ControllerMessage::Status { seq, .. } => {
1891                assert_eq!(seq, 101.into())
1892            }
1893            _ => panic!("unexpected response {:#?}", response),
1894        };
1895
1896        Ok(())
1897    }
1898
1899    #[async_timed_test(timeout_secs = 60)]
1900    async fn backend_network_init() {
1901        let proc = Proc::local();
1902        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1903
1904        let worker_handle1 = proc
1905            .spawn::<WorkerActor>(
1906                "worker0",
1907                WorkerParams {
1908                    world_size: 2,
1909                    rank: 0,
1910                    device_index: Some(0),
1911                    controller_actor: controller_ref.clone(),
1912                },
1913            )
1914            .await
1915            .unwrap();
1916        let worker_handle2 = proc
1917            .spawn::<WorkerActor>(
1918                "worker1",
1919                WorkerParams {
1920                    world_size: 2,
1921                    rank: 1,
1922                    device_index: Some(1),
1923                    controller_actor: controller_ref,
1924                },
1925            )
1926            .await
1927            .unwrap();
1928
1929        let unique_id = UniqueId::new().unwrap();
1930        worker_handle1
1931            .backend_network_init(&client, unique_id.clone())
1932            .await
1933            .unwrap();
1934        worker_handle2
1935            .backend_network_init(&client, unique_id)
1936            .await
1937            .unwrap();
1938
1939        worker_handle1.drain_and_stop().unwrap();
1940        worker_handle1.await;
1941        worker_handle2.drain_and_stop().unwrap();
1942        worker_handle2.await;
1943    }
1944
1945    #[async_timed_test(timeout_secs = 60)]
1946    async fn send_value() -> Result<()> {
1947        test_setup()?;
1948
1949        let proc = Proc::local();
1950        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1951
1952        let worker_handle = proc
1953            .spawn::<WorkerActor>(
1954                "worker",
1955                WorkerParams {
1956                    world_size: 1,
1957                    rank: 0,
1958                    device_index: None,
1959                    controller_actor: controller_ref,
1960                },
1961            )
1962            .await
1963            .unwrap();
1964        worker_handle
1965            .command_group(
1966                &client,
1967                vec![
1968                    WorkerMessage::CreateStream {
1969                        id: 1.into(),
1970                        stream_creation: StreamCreationMode::UseDefaultStream,
1971                    },
1972                    WorkerMessage::CallFunction(CallFunctionParams {
1973                        seq: 0.into(),
1974                        results: vec![Some(0.into())],
1975                        mutates: vec![],
1976                        function: "torch.ops.aten.ones.default".into(),
1977                        args: vec![WireValue::IntList(vec![2, 3])],
1978                        kwargs: HashMap::new(),
1979                        stream: 1.into(),
1980                        remote_process_groups: vec![],
1981                    }),
1982                    WorkerMessage::SendValue {
1983                        seq: 1.into(),
1984                        destination: None,
1985                        mutates: vec![],
1986                        function: None,
1987                        args: vec![WireValue::Ref(0.into())],
1988                        kwargs: HashMap::new(),
1989                        stream: 1.into(),
1990                    },
1991                    WorkerMessage::SendValue {
1992                        seq: 2.into(),
1993                        destination: None,
1994                        mutates: vec![],
1995                        function: Some("torch.ops.aten.var_mean.default".into()),
1996                        args: vec![WireValue::Ref(0.into())],
1997                        kwargs: HashMap::new(),
1998                        stream: 1.into(),
1999                    },
2000                    WorkerMessage::Exit { error: None },
2001                ],
2002            )
2003            .await
2004            .unwrap();
2005
2006        worker_handle.drain_and_stop()?;
2007        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2008
2009        let mut responses = controller_rx.drain();
2010        assert_eq!(
2011            responses.len(),
2012            3,
2013            "Expected one response, got: {:#?}",
2014            responses
2015        );
2016
2017        match responses.pop().unwrap() {
2018            ControllerMessage::FetchResult { seq, value } => {
2019                assert_eq!(seq, 2.into());
2020                let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2021                assert_eq!(value.leaves().len(), 2);
2022            }
2023            resp => panic!("unexpected response {:#?}", resp),
2024        };
2025        match responses.pop().unwrap() {
2026            ControllerMessage::FetchResult { seq, .. } => {
2027                assert_eq!(seq, 1.into())
2028            }
2029            resp => panic!("unexpected response {:#?}", resp),
2030        };
2031        Ok(())
2032    }
2033
2034    #[async_timed_test(timeout_secs = 60)]
2035    async fn send_value_err_result() -> Result<()> {
2036        test_setup()?;
2037
2038        let proc = Proc::local();
2039        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2040
2041        let worker_handle = proc
2042            .spawn::<WorkerActor>(
2043                "worker",
2044                WorkerParams {
2045                    world_size: 1,
2046                    rank: 0,
2047                    device_index: None,
2048                    controller_actor: controller_ref,
2049                },
2050            )
2051            .await
2052            .unwrap();
2053
2054        let ref_arg: PickledPyObject =
2055            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2056
2057        worker_handle
2058            .command_group(
2059                &client,
2060                vec![
2061                    WorkerMessage::CreateStream {
2062                        id: 1.into(),
2063                        stream_creation: StreamCreationMode::UseDefaultStream,
2064                    },
2065                    WorkerMessage::SetRefUnitTestsOnly {
2066                        reference: Ref { id: 2 },
2067                        value: WireValue::Bool(false),
2068                        stream: 1.into(),
2069                    },
2070                    WorkerMessage::SendValue {
2071                        seq: 1.into(),
2072                        destination: None,
2073                        mutates: vec![Ref { id: 2 }],
2074                        function: Some("non.existent.function".into()),
2075                        args: vec![],
2076                        kwargs: HashMap::new(),
2077                        stream: 1.into(),
2078                    },
2079                    WorkerMessage::SendValue {
2080                        seq: 2.into(),
2081                        destination: None,
2082                        mutates: vec![],
2083                        function: None,
2084                        args: vec![ref_arg.into()],
2085                        kwargs: HashMap::new(),
2086                        stream: 1.into(),
2087                    },
2088                    WorkerMessage::Exit { error: None },
2089                ],
2090            )
2091            .await
2092            .unwrap();
2093
2094        worker_handle.drain_and_stop()?;
2095        assert_matches!(worker_handle.await, ActorStatus::Stopped);
2096
2097        let mut responses = controller_rx.drain();
2098        assert_eq!(
2099            responses.len(),
2100            3,
2101            "Expected one response, got: {:#?}",
2102            responses
2103        );
2104        match responses.pop() {
2105            Some(ControllerMessage::FetchResult { seq, value }) => {
2106                assert_eq!(seq, 2.into());
2107                assert!(value.is_err());
2108                assert!(
2109                    value
2110                        .unwrap_err()
2111                        .backtrace
2112                        .contains("failed to resolve function")
2113                );
2114            }
2115            _ => panic!("unexpected response {:#?}", responses),
2116        }
2117        match responses.pop() {
2118            Some(ControllerMessage::FetchResult { seq, value }) => {
2119                assert_eq!(seq, 1.into());
2120                assert!(value.is_err());
2121                assert!(
2122                    value
2123                        .unwrap_err()
2124                        .backtrace
2125                        .contains("failed to resolve function")
2126                );
2127            }
2128            _ => panic!("unexpected response {:#?}", responses),
2129        }
2130        Ok(())
2131    }
2132
2133    fn get_random_channel_addr() -> ChannelAddr {
2134        let random_string = rand::thread_rng()
2135            .sample_iter(&Alphanumeric)
2136            .take(24)
2137            .map(char::from)
2138            .collect::<String>();
2139        format!("unix!@{random_string}").parse().unwrap()
2140    }
2141
2142    async fn ensure_world_ready(client: &Instance<()>, world: WorldId) -> Result<()> {
2143        tracing::info!("checking whether world {world} is ready");
2144        let retry_strategy = FixedInterval::from_millis(1000).take(100);
2145        Retry::spawn(retry_strategy, async || {
2146            let snapshot = SYSTEM_ACTOR_REF
2147                .snapshot(&client, SystemSnapshotFilter::default())
2148                .await?;
2149            let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2150            tracing::info!("world status: {:?}", world_snapshot.status);
2151            match world_snapshot.status {
2152                WorldStatus::Live => Ok(()),
2153                _ => Err(anyhow!("world is not live")),
2154            }
2155        })
2156        .await?;
2157        Ok(())
2158    }
2159
2160    #[async_timed_test(timeout_secs = 60)]
2161    async fn remote_process_group() -> Result<()> {
2162        test_setup()?;
2163
2164        // Spin up a system to manage the test setup.
2165        let timeout: Duration = Duration::from_secs(10);
2166        let system_addr = get_random_channel_addr();
2167        let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2168
2169        // Create a fake controller for the workers to talk to.
2170        let client = System::new(system_addr.clone()).attach().await?;
2171        let (_handle, mut controller_rx) = client.bind_actor_port::<ControllerMessage>();
2172        let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.self_id().clone());
2173
2174        // Create the worker world
2175        let world_size = 2;
2176        SYSTEM_ACTOR_REF
2177            .upsert_world(
2178                &client,
2179                id!(world),
2180                Shape::Definite(vec![world_size]),
2181                4,
2182                Environment::Local,
2183                HashMap::new(),
2184            )
2185            .await?;
2186
2187        // Bootstrap a proc for each worker
2188        let mut worker_process_handles = vec![];
2189        let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2190        for rank in 0..world_size {
2191            let world_id = "world".to_string();
2192            let proc_id = format!("{world_id}[{rank}]");
2193            worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2194
2195            let mut handle =
2196                Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2197                    anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2198                })?)
2199                .arg("worker")
2200                .arg(format!("--bootstrap-addr={system_addr}"))
2201                .arg(format!("--world-id={world_id}"))
2202                .arg(format!("--proc-id={proc_id}"))
2203                .stdout(Stdio::piped())
2204                .stdin(Stdio::piped())
2205                .kill_on_drop(true)
2206                .spawn()?;
2207
2208            let out = handle.stdout.take().unwrap();
2209            tokio::spawn(async move {
2210                let mut reader = BufReader::new(out);
2211                tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2212                    .await
2213                    .unwrap();
2214            });
2215            worker_process_handles.push(handle);
2216        }
2217
2218        // Wait for procs to initialize
2219
2220        ensure_world_ready(&client, id!(world)).await?;
2221
2222        // Spawn workers on each proc
2223        let (spawned_port, mut spawned_receiver) = open_port(&client);
2224        for (rank, worker_proc) in worker_procs.iter().enumerate() {
2225            let params = WorkerParams {
2226                world_size,
2227                rank,
2228                device_index: Some(rank.try_into().unwrap()),
2229                controller_actor: controller_ref.clone(),
2230            };
2231            worker_proc
2232                .spawn(
2233                    &client,
2234                    "monarch_tensor_worker::WorkerActor".to_owned(),
2235                    "worker".to_owned(),
2236                    bincode::serialize(&params)?,
2237                    spawned_port.bind(),
2238                )
2239                .await?;
2240        }
2241        let mut spawned = vec![];
2242        while spawned.len() < world_size {
2243            spawned.push(spawned_receiver.recv().await?);
2244        }
2245        tracing::info!("spawned {} worker actors", world_size);
2246        let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2247            .map(|rank| format!("world[{rank}].worker[0]"))
2248            .map(|name| ActorRef::attest(name.parse().unwrap()))
2249            .collect();
2250
2251        let remote_proc_grp_ref: PickledPyObject =
2252            Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2253
2254        let unique_id = UniqueId::new()?;
2255        let messages = vec![
2256            WorkerMessage::CreateStream {
2257                id: 0.into(),
2258                stream_creation: StreamCreationMode::UseDefaultStream,
2259            },
2260            WorkerMessage::BackendNetworkInit(unique_id.clone()),
2261            WorkerMessage::CreateDeviceMesh {
2262                result: 1.into(),
2263                names: vec!["x".into()],
2264                ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2265            },
2266            WorkerMessage::CreateRemoteProcessGroup {
2267                result: 2.into(),
2268                device_mesh: 1.into(),
2269                dims: vec!["x".into()],
2270            },
2271            WorkerMessage::SplitCommForProcessGroup {
2272                remote_process_group: 2.into(),
2273                stream: 0.into(),
2274                config: None,
2275            },
2276            WorkerMessage::CallFunction(CallFunctionParams {
2277                seq: 0.into(),
2278                results: vec![Some(3.into())],
2279                mutates: vec![],
2280                function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2281                    .into(),
2282                args: vec![remote_proc_grp_ref.into()],
2283                kwargs: HashMap::new(),
2284                stream: 0.into(),
2285                remote_process_groups: vec![2.into()],
2286            }),
2287        ];
2288
2289        workers[0].command_group(&client, messages.clone()).await?;
2290        workers[1].command_group(&client, messages).await?;
2291
2292        let _ = workers[0]
2293            .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2294            .await?
2295            .unwrap()
2296            .unwrap();
2297
2298        let error_responses = controller_rx.drain();
2299        assert!(
2300            error_responses.is_empty(),
2301            "Expected no error responses, got: {:#?}",
2302            error_responses
2303        );
2304
2305        Ok(())
2306    }
2307}