Skip to main content

monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor as reference;
55use hyperactor::Actor;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::RemoteSpawn;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor_config::Flattrs;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use monarch_types::ReduceOp;
84use monarch_types::UniqueId;
85use ndslice::Slice;
86use pyo3::Python;
87use pyo3::types::PyAnyMethods;
88use serde::Deserialize;
89use serde::Serialize;
90use sorted_vec::SortedVec;
91use stream::StreamActor;
92use stream::StreamMessageClient;
93use stream::StreamParams;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104    device_mesh_ref: Ref,
105    dims: SortedVec<String>,
106    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111        Self {
112            device_mesh_ref,
113            dims,
114            comms: HashMap::new(),
115        }
116    }
117}
118
119#[derive(Debug)]
120enum Recording {
121    // In the process of receiving DefineRecording messages for this
122    // recording.
123    PartialRecording {
124        // The index of the last DefineRecording message received.
125        last_index: usize,
126        // The list of commands seen so far for this recording.
127        commands: Vec<WorkerMessage>,
128    },
129
130    // The recording is ready to be run.
131    CompleteRecording {
132        // The list of streams on which this recording is defined.
133        streams: HashSet<StreamRef>,
134    },
135}
136
137/// A PyTorch runtime instance, operating on a single accelerator device,
138/// controlled via hyperactor messaging.
139///
140/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
141/// real work.
142///
143/// See [`WorkerMessage`] for what it can do!
144#[derive(Debug)]
145#[hyperactor::spawnable]
146#[hyperactor::export(
147    handlers = [
148        WorkerMessage {cast = true},
149        AssignRankMessage {cast = true},
150    ],
151)]
152pub struct WorkerActor {
153    device: Option<CudaDevice>,
154    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155    /// Maps streams to the device mesh and a map of dim names to the concrete
156    /// communicator actor that represents the dimension for that stream.
157    device_meshes: HashMap<
158        Ref,
159        (
160            DeviceMesh,
161            // A map for comms for this mesh for a given pair of stream and dims.
162            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163        ),
164    >,
165    world_size: usize,
166    rank: usize,
167    borrows: HashMap<u64, Borrow>,
168    comm: Option<ActorHandle<NcclCommActor>>,
169    controller_actor: reference::ActorRef<ControllerActor>,
170    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
171    /// subsequent `CallFunction` calls, as this is where the actual allocation
172    /// will happen.
173    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174    /// The comm actor for each pair of streams that need to send/recv tensors.
175    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176    recordings: HashMap<Ref, Recording>,
177    defining_recording: Option<Ref>,
178    respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182    fn runtime_has_cuda() -> bool {
183        Python::attach(|py| {
184            py.import("torch")
185                .expect("torch must be importable in a worker")
186                .getattr("cuda")
187                .expect("torch.cuda attribute must exist")
188                .call_method0("is_available")
189                .expect("torch.cuda.is_available() must be callable")
190                .extract::<bool>()
191                .expect("torch.cuda.is_available() must return bool")
192        })
193    }
194
195    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
196        self.streams
197            .get(&stream)
198            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
199    }
200
201    async fn maybe_add_stream_to_recording(
202        &mut self,
203        cx: &impl context::Actor,
204        stream: StreamRef,
205    ) -> Result<()> {
206        // If we're defining a recording, add the stream to the list of streams that
207        // this recording uses, and call define_recording on the stream.
208        if let Some(defining_recording) = self.defining_recording {
209            let recording = self.recordings.get_mut(&defining_recording).unwrap();
210            let fut = match recording {
211                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
212                Recording::CompleteRecording { streams } => {
213                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
214                        Ok(self
215                            .try_get_stream(stream)?
216                            .define_recording(cx, defining_recording))
217                    })
218                }
219            }
220            .transpose()?;
221            match fut {
222                Some(fut) => fut.await,
223                None => Ok(()),
224            }
225        } else {
226            Ok(())
227        }
228    }
229}
230
231impl Actor for WorkerActor {}
232
233#[async_trait]
234impl RemoteSpawn for WorkerActor {
235    type Params = WorkerParams;
236
237    async fn new(
238        WorkerParams {
239            world_size,
240            rank,
241            device_index,
242            controller_actor,
243        }: Self::Params,
244        _environment: Flattrs,
245    ) -> Result<Self> {
246        Python::attach(|py| {
247            py.import("monarch.safe_torch").unwrap();
248        });
249        let device = if Self::runtime_has_cuda() {
250            device_index.map(|i| CudaDevice::new(DeviceIndex(i)))
251        } else {
252            None
253        };
254        Ok(Self {
255            device,
256            streams: HashMap::new(),
257            device_meshes: HashMap::new(),
258            world_size,
259            rank,
260            borrows: HashMap::new(),
261            comm: None,
262            controller_actor,
263            remote_process_groups: HashMap::new(),
264            send_recv_comms: HashMap::new(),
265            recordings: HashMap::new(),
266            defining_recording: None,
267            respond_with_python_message: false,
268        })
269    }
270
271    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
272}
273
274#[async_trait]
275impl Handler<AssignRankMessage> for WorkerActor {
276    async fn handle(
277        &mut self,
278        cx: &hyperactor::Context<Self>,
279        _: AssignRankMessage,
280    ) -> anyhow::Result<()> {
281        let point = cx.cast_point();
282        self.rank = point.rank();
283        self.respond_with_python_message = true;
284        Python::attach(|py| {
285            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
286            let p: PyPoint = point.into();
287            mesh_controller
288                .call_method1("_initialize_env", (p, cx.proc().proc_addr().to_string()))
289                .unwrap();
290        });
291        Ok(())
292    }
293}
294
295/// Worker messages. These define the observable behavior of the worker, so the
296/// documentations here
297#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
298pub enum AssignRankMessage {
299    AssignRank(),
300}
301wirevalue::register_type!(AssignRankMessage);
302
303#[async_trait]
304impl Handler<WorkerMessage> for WorkerActor {
305    async fn handle(
306        &mut self,
307        cx: &hyperactor::Context<Self>,
308        message: WorkerMessage,
309    ) -> anyhow::Result<()> {
310        <Self as WorkerMessageHandler>::handle(self, cx, message).await
311    }
312}
313
314#[async_trait]
315impl WorkerMessageHandler for WorkerActor {
316    async fn backend_network_init(
317        &mut self,
318        cx: &hyperactor::Context<Self>,
319        unique_id: UniqueId,
320    ) -> Result<()> {
321        let Some(device) = self.device else {
322            return Ok(());
323        };
324        let comm = NcclCommActor::new(CommParams::New {
325            device,
326            unique_id,
327            world_size: self.world_size.try_into().unwrap(),
328            rank: self.rank.try_into().unwrap(),
329        })
330        .await?
331        .spawn(cx)?;
332
333        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
334        let cell = TensorCell::new(tensor);
335
336        comm.all_reduce(
337            cx,
338            cell,
339            ReduceOp::Sum,
340            torch_sys_cuda::cuda::Stream::get_current_stream(),
341        )
342        .await?;
343
344        // TODO: this blocks forward progress of the the actor loop while we
345        // wait for the streams to catch up. Once we have a way of spawning
346        // tasks that the actor system can monitor in a non-blocking way, we
347        // should remove this.
348
349        // We need to be careful to initialize the streams in a consistent order
350        // across all workers to avoid NCCL deadlocks. Use the refs to provide
351        // this order, as a stream's ref will be the same across all workers.
352        let sorted_streams = self
353            .streams
354            .iter()
355            .sorted_by_key(|(k, _)| *k)
356            .map(|(_, v)| v.as_ref());
357
358        let mut splits = Vec::new();
359        for _ in 0..sorted_streams.len() {
360            // Do the split in this event loop, to provide a deterministic
361            // order.
362            splits.push(comm.split_all(cx).await?);
363        }
364        let _: Vec<()> = try_join_all(
365            sorted_streams
366                .into_iter()
367                .zip(splits)
368                .map(|(stream, split)| stream.init_comm(cx, split)),
369        )
370        .await?;
371
372        self.comm = Some(comm);
373
374        Ok(())
375    }
376
377    async fn backend_network_point_to_point_init(
378        &mut self,
379        cx: &hyperactor::Context<Self>,
380        from_stream: StreamRef,
381        to_stream: StreamRef,
382    ) -> Result<()> {
383        if !self.streams.contains_key(&from_stream) {
384            bail!("invalid from_stream id: {:#?}", from_stream);
385        }
386        if !self.streams.contains_key(&to_stream) {
387            bail!("invalid to_stream id: {:#?}", to_stream);
388        }
389        let Some(global_comm) = self.comm.as_ref() else {
390            return Ok(());
391        };
392        let comm = global_comm.split_all(cx).await?;
393        self.send_recv_comms
394            .insert((from_stream, to_stream), Arc::new(comm));
395        Ok(())
396    }
397
398    async fn call_function(
399        &mut self,
400        cx: &hyperactor::Context<Self>,
401        params: CallFunctionParams,
402    ) -> Result<()> {
403        let stream = self.try_get_stream(params.stream)?.clone();
404        self.maybe_add_stream_to_recording(cx, params.stream)
405            .await?;
406
407        let device_meshes = self
408            .device_meshes
409            .iter()
410            .map(|(k, v)| (k.clone(), v.0.clone()))
411            .collect();
412
413        let mut remote_process_groups = HashMap::new();
414        for remote_process_group_ref in &params.remote_process_groups {
415            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
416                let dims_vec = state.dims.iter().cloned().collect();
417                let (device_mesh, _) = self
418                    .device_meshes
419                    .get(&state.device_mesh_ref)
420                    .ok_or_else(|| {
421                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
422                    })?
423                    .clone();
424                let comm = state.comms
425                    .get(&params.stream)
426                    .ok_or_else(|| {
427                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
428                    })?
429                    .clone();
430                remote_process_groups.insert(
431                    remote_process_group_ref.clone(),
432                    (device_mesh, dims_vec, comm),
433                );
434            }
435        }
436
437        stream
438            .call_function(cx, params, device_meshes, remote_process_groups)
439            .await?;
440
441        Ok(())
442    }
443
444    async fn command_group(
445        &mut self,
446        cx: &hyperactor::Context<Self>,
447        params: Vec<WorkerMessage>,
448    ) -> Result<()> {
449        for msg in params {
450            WorkerMessageHandler::handle(self, cx, msg).await?;
451        }
452        Ok(())
453    }
454
455    async fn create_stream(
456        &mut self,
457        cx: &hyperactor::Context<Self>,
458        result: StreamRef,
459        creation_mode: StreamCreationMode,
460    ) -> Result<()> {
461        let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
462            world_size: self.world_size,
463            rank: self.rank,
464            creation_mode,
465            id: result,
466            device: self.device,
467            controller_actor: self.controller_actor.clone(),
468            respond_with_python_message: self.respond_with_python_message,
469        })
470        .spawn(cx)?;
471        self.streams.insert(result, Arc::new(handle));
472        Ok(())
473    }
474
475    async fn create_device_mesh(
476        &mut self,
477        _cx: &hyperactor::Context<Self>,
478        result: Ref,
479        names: Vec<String>,
480        ranks: Slice,
481    ) -> Result<()> {
482        self.device_meshes.insert(
483            result,
484            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
485        );
486        Ok(())
487    }
488
489    async fn create_remote_process_group(
490        &mut self,
491        _cx: &hyperactor::Context<Self>,
492        result: Ref,
493        device_mesh: Ref,
494        dims: Vec<String>,
495    ) -> Result<()> {
496        self.device_meshes
497            .get(&device_mesh)
498            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
499        match self.remote_process_groups.entry(result) {
500            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
501                device_mesh,
502                SortedVec::from_unsorted(dims),
503            )),
504            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
505        };
506        Ok(())
507    }
508
509    async fn borrow_create(
510        &mut self,
511        cx: &hyperactor::Context<Self>,
512        result: Ref,
513        borrow_id: u64,
514        tensor_ref: Ref,
515        from_stream: StreamRef,
516        to_stream: StreamRef,
517    ) -> Result<()> {
518        self.maybe_add_stream_to_recording(cx, from_stream).await?;
519        self.maybe_add_stream_to_recording(cx, to_stream).await?;
520        let from_stream = self.try_get_stream(from_stream)?.clone();
521        let to_stream = self.try_get_stream(to_stream)?.clone();
522
523        let borrow =
524            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
525        self.borrows.insert(borrow_id, borrow);
526        Ok(())
527    }
528
529    async fn borrow_first_use(
530        &mut self,
531        cx: &hyperactor::Context<Self>,
532        borrow: u64,
533    ) -> Result<()> {
534        let borrow = self
535            .borrows
536            .get_mut(&borrow)
537            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
538
539        borrow.first_use(cx).await?;
540        Ok(())
541    }
542
543    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
544        let borrow = self
545            .borrows
546            .get_mut(&borrow)
547            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
548
549        borrow.last_use(cx).await?;
550        Ok(())
551    }
552
553    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
554        let borrow = self
555            .borrows
556            .get_mut(&borrow_id)
557            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
558
559        borrow.drop(cx).await?;
560        self.borrows.remove(&borrow_id);
561        Ok(())
562    }
563
564    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
565        // Fan the delete message to all streams.
566        // Check for errors.
567        // TODO: this blocks forward progress of the the actor loop while we
568        // wait for the streams to catch up. Once we have a way of spawning
569        // tasks that the actor system can monitor in a non-blocking way, we
570        // should remove this.
571        let _: Vec<()> = try_join_all(
572            self.streams
573                .values()
574                .map(|s| s.delete_refs(cx, refs.clone())),
575        )
576        .await?;
577        Ok(())
578    }
579
580    async fn request_status(
581        &mut self,
582        cx: &hyperactor::Context<Self>,
583        seq: Seq,
584        controller: bool,
585    ) -> Result<()> {
586        // TODO: this blocks forward progress of the the actor loop while we
587        // wait for the streams to catch up. Once we have a way of spawning
588        // tasks that the actor system can monitor in a non-blocking way, we
589        // should remove this.
590        let _: Vec<()> = try_join_all(
591            self.streams
592                .values()
593                .map(|stream| stream.request_status(cx)),
594        )
595        .await?;
596
597        ControllerMessageClient::status(
598            &self.controller_actor,
599            cx,
600            seq.next(),
601            cx.self_addr().clone(),
602            controller,
603        )
604        .await?;
605        Ok(())
606    }
607
608    async fn reduce(
609        &mut self,
610        cx: &hyperactor::Context<Self>,
611        result: Ref,
612        local_tensor: Ref,
613        factory: Factory,
614        source_mesh: Ref,
615        stream_ref: StreamRef,
616        dims: Vec<String>,
617        reduction: Reduction,
618        scatter: bool,
619        in_place: bool,
620        out: Option<Ref>,
621    ) -> Result<()> {
622        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
623
624        // Sort for stable indexing.
625        let dims = SortedVec::from_unsorted(dims);
626        let stream = self.try_get_stream(stream_ref)?.clone();
627
628        let (_, comm_map) = self
629            .device_meshes
630            .get_mut(&source_mesh)
631            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
632
633        let (size, comm) = comm_map
634            .get(&(stream_ref, dims.clone()))
635            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
636            .clone();
637
638        stream
639            .reduce(
640                cx,
641                comm,
642                size.try_into()?,
643                result,
644                local_tensor,
645                factory,
646                reduction,
647                scatter,
648                in_place,
649                out,
650            )
651            .await?;
652
653        Ok(())
654    }
655
656    async fn send_tensor(
657        &mut self,
658        cx: &hyperactor::Context<Self>,
659        result: Ref,
660        from_ranks: Slice,
661        to_ranks: Slice,
662        tensor: Ref,
663        factory: Factory,
664        from_stream: StreamRef,
665        to_stream: StreamRef,
666    ) -> Result<()> {
667        let to_rank = from_ranks
668            .index(self.rank)
669            .map(|index| to_ranks.get(index).ok())
670            .ok()
671            .flatten();
672        let from_rank = to_ranks
673            .index(self.rank)
674            .map(|index| from_ranks.get(index).ok())
675            .ok()
676            .flatten();
677
678        let (stream, stream_ref) = if to_rank.is_none() {
679            (self.try_get_stream(to_stream)?.clone(), to_stream)
680        } else if from_rank.is_none() || from_stream == to_stream {
681            (self.try_get_stream(from_stream)?.clone(), from_stream)
682        } else {
683            unimplemented!(
684                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
685                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
686                Then the send stream would do the nccl op, and then sync with sending stream again."
687            );
688        };
689        let comm = if from_rank == to_rank {
690            None
691        } else {
692            Some(
693                self.send_recv_comms
694                    .get(&(from_stream, to_stream))
695                    .ok_or_else(|| {
696                        anyhow::anyhow!(
697                            "could not find stream to stream comm for: {:#?}",
698                            (from_stream, to_stream)
699                        )
700                    })?
701                    .clone(),
702            )
703        };
704
705        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
706
707        stream
708            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
709            .await?;
710
711        Ok(())
712    }
713
714    async fn exit(
715        &mut self,
716        cx: &hyperactor::Context<Self>,
717        error: Option<(Option<reference::ActorAddr>, String)>,
718    ) -> Result<()> {
719        for (_, stream) in self.streams.drain() {
720            stream.drain_and_stop("tensor worker exit cleanup")?;
721            Arc::into_inner(stream)
722                .expect("there should be no owners of this stream handle except the worker stream table")
723                .await;
724        }
725
726        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
727            .ok()
728            .and_then(|val| val.parse::<i32>().ok())
729            .unwrap_or(1);
730        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
731            .ok()
732            .and_then(|val| val.parse::<i32>().ok())
733            .unwrap_or(1);
734
735        // Exit the worker process if there is an error.
736        let exit_code = match error {
737            Some((Some(actor_id), reason)) => {
738                tracing::error!(
739                    "stopping the worker, actor {} failed with error: {}",
740                    actor_id,
741                    reason
742                );
743                if cx.self_addr() == &actor_id {
744                    self_error_exit_code
745                } else {
746                    peer_error_exit_code
747                }
748            }
749            Some((None, reason)) => {
750                tracing::error!("stopping the worker, reason: {}", reason);
751                1
752            }
753            None => 0,
754        };
755
756        if exit_code != 0 {
757            tracing::info!("stopping the worker process, exit code: {}", exit_code);
758            std::process::exit(exit_code);
759        }
760        cx.stop("tensor worker exit")?;
761        Ok(())
762    }
763
764    async fn send_value(
765        &mut self,
766        cx: &hyperactor::Context<Self>,
767        seq: Seq,
768        destination: Option<Ref>,
769        mutates: Vec<Ref>,
770        function: Option<ResolvableFunction>,
771        args_kwargs: ArgsKwargs,
772        stream: StreamRef,
773    ) -> Result<()> {
774        // Resolve the stream.
775        let stream = self.try_get_stream(stream)?;
776
777        let device_meshes = if function.is_none() {
778            HashMap::new()
779        } else {
780            self.device_meshes
781                .iter()
782                .map(|(k, v)| (k.clone(), v.0.clone()))
783                .collect()
784        };
785
786        if destination.is_some() {
787            panic!("send_value with pipe destination is no longer implemented")
788        }
789
790        // Resolve the value on the stream, then send the value back to the controller.
791        stream
792            .send_value(
793                cx,
794                seq,
795                cx.self_addr().clone(),
796                mutates,
797                function,
798                args_kwargs,
799                device_meshes,
800            )
801            .await
802    }
803
804    async fn send_result_of_actor_call(
805        &mut self,
806        cx: &hyperactor::Context<Self>,
807        params: ActorCallParams,
808    ) -> Result<()> {
809        let stream = self.try_get_stream(params.stream)?;
810        stream.send_result_of_actor_call(cx, params).await?;
811        Ok(())
812    }
813    async fn call_actor_method(
814        &mut self,
815        cx: &hyperactor::Context<Self>,
816        params: ActorMethodParams,
817    ) -> Result<()> {
818        let stream = self.try_get_stream(params.call.stream)?;
819        stream.call_actor_method(cx, params).await?;
820        Ok(())
821    }
822    async fn split_comm(
823        &mut self,
824        cx: &hyperactor::Context<Self>,
825        dims: Vec<String>,
826        device_mesh: Ref,
827        stream_ref: StreamRef,
828    ) -> Result<()> {
829        let Some(global_comm) = self.comm.as_ref() else {
830            return Ok(());
831        };
832        match self.device_meshes.get_mut(&device_mesh) {
833            Some((device_mesh, comm_map)) => {
834                // This rank is in the group to be split off. Split a new
835                // communicator for it off from the global communicator.
836                let stream = self
837                    .streams
838                    .get(&stream_ref)
839                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
840
841                let dims = SortedVec::from_unsorted(dims);
842
843                anyhow::ensure!(
844                    !comm_map.contains_key(&(stream_ref, dims.clone())),
845                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
846                );
847                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
848                let size = ranks_for_group.len();
849                let split_comm = global_comm
850                    .split_from(
851                        cx,
852                        ranks_for_group
853                            .into_iter()
854                            .map(|v| v.clone().try_into())
855                            .collect::<Result<Vec<_>, _>>()?,
856                    )
857                    .await?
858                    .context("split comm should include self rank")?;
859                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
860            }
861            None => {
862                // This rank is not in the group to be split off. We still need to
863                // participate in the commSplit call, however.
864                global_comm.split_from(cx, vec![]).await?;
865            }
866        }
867        Ok(())
868    }
869
870    async fn split_comm_for_process_group(
871        &mut self,
872        cx: &hyperactor::Context<Self>,
873        remote_process_group_ref: Ref,
874        stream_ref: StreamRef,
875    ) -> Result<()> {
876        ensure!(
877            self.streams.contains_key(&stream_ref),
878            "invalid stream id: {:#?}",
879            stream_ref
880        );
881        let Some(global_comm) = self.comm.as_ref() else {
882            return Ok(());
883        };
884        let state = self
885            .remote_process_groups
886            .get_mut(&remote_process_group_ref)
887            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
888        match self.device_meshes.get_mut(&state.device_mesh_ref) {
889            Some((device_mesh, _)) => {
890                // This rank is in the group to be split off. Split a new
891                // communicator for it off from the global communicator.
892                let entry = match state.comms.entry(stream_ref) {
893                    Entry::Vacant(entry) => entry,
894                    Entry::Occupied(_) => bail!(
895                        "comm already exists for remote process group {:#?} on stream {:#?}",
896                        remote_process_group_ref,
897                        stream_ref,
898                    ),
899                };
900                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
901                let split_comm = global_comm
902                    .split_from(
903                        cx,
904                        ranks_for_group
905                            .into_iter()
906                            .map(|v| v.clone().try_into())
907                            .collect::<Result<Vec<_>, _>>()?,
908                    )
909                    .await?
910                    .context("split comm should include self rank")?;
911                entry.insert(Arc::new(split_comm));
912            }
913            None => {
914                // This rank is not in the group to be split off. We still need to
915                // participate in the commSplit call, however.
916                global_comm.split_from(cx, vec![]).await?;
917            }
918        }
919        Ok(())
920    }
921
922    async fn pipe_recv(
923        &mut self,
924        _cx: &hyperactor::Context<Self>,
925        _seq: Seq,
926        _results: Vec<Option<Ref>>,
927        _pipe: Ref,
928        _stream: StreamRef,
929    ) -> Result<()> {
930        panic!("pipe_recv is no longer implemented")
931    }
932
933    async fn set_ref_unit_tests_only(
934        &mut self,
935        cx: &hyperactor::Context<Self>,
936        reference: Ref,
937        value: WireValue,
938        stream: StreamRef,
939    ) -> Result<()> {
940        let stream = self.try_get_stream(stream)?;
941
942        stream.set_ref_unit_tests_only(cx, reference, value).await
943    }
944
945    async fn get_ref_unit_tests_only(
946        &mut self,
947        cx: &hyperactor::Context<Self>,
948        ref_id: Ref,
949        stream: StreamRef,
950    ) -> Result<Option<Result<WireValue, String>>> {
951        let stream = self.try_get_stream(stream)?;
952        Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
953    }
954
955    async fn define_recording(
956        &mut self,
957        cx: &hyperactor::Context<Self>,
958        result: Ref,
959        _nresults: usize,
960        _nformals: usize,
961        commands: Vec<WorkerMessage>,
962        ntotal_messages: usize,
963        index: usize,
964    ) -> Result<()> {
965        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
966            bail!("already defining a different recording");
967        }
968        self.defining_recording = Some(result);
969
970        match self.recordings.entry(result) {
971            Entry::Vacant(entry) => {
972                ensure!(
973                    index == 0,
974                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
975                    index
976                );
977                entry.insert(Recording::PartialRecording {
978                    last_index: 0,
979                    commands,
980                });
981            }
982            Entry::Occupied(mut entry) => match entry.get_mut() {
983                Recording::CompleteRecording { .. } => {
984                    bail!("got DefineRecording message for already complete recording")
985                }
986                Recording::PartialRecording {
987                    last_index,
988                    commands: existing_commands,
989                } => {
990                    ensure!(
991                        index == *last_index + 1,
992                        "Got DefineRecording message with index = {:?}, but \
993                            last seen index for recording is {:?}",
994                        index,
995                        last_index
996                    );
997                    *last_index = index;
998                    existing_commands.extend(commands);
999                }
1000            },
1001        };
1002
1003        if index < ntotal_messages - 1 {
1004            return Ok(());
1005        }
1006        let commands = match self.recordings.remove(&result).unwrap() {
1007            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1008            Recording::PartialRecording { commands, .. } => {
1009                self.recordings.insert(
1010                    result,
1011                    Recording::CompleteRecording {
1012                        streams: HashSet::new(),
1013                    },
1014                );
1015                commands
1016            }
1017        };
1018
1019        for command in commands {
1020            WorkerMessageHandler::handle(self, cx, command).await?;
1021        }
1022
1023        match self.recordings.get(&result).unwrap() {
1024            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1025            Recording::CompleteRecording { streams, .. } => {
1026                for stream in streams {
1027                    self.try_get_stream(*stream)?
1028                        .finalize_recording(cx, result)
1029                        .await?;
1030                }
1031            }
1032        }
1033
1034        self.defining_recording = None;
1035        Ok(())
1036    }
1037
1038    async fn recording_formal(
1039        &mut self,
1040        cx: &hyperactor::Context<Self>,
1041        result: Ref,
1042        argument_index: usize,
1043        stream: StreamRef,
1044    ) -> Result<()> {
1045        ensure!(self.defining_recording.is_some());
1046        self.maybe_add_stream_to_recording(cx, stream).await?;
1047        self.try_get_stream(stream)?
1048            .recording_formal(cx, result, argument_index)
1049            .await
1050    }
1051
1052    async fn recording_result(
1053        &mut self,
1054        cx: &hyperactor::Context<Self>,
1055        result: Ref,
1056        output_index: usize,
1057        stream: StreamRef,
1058    ) -> Result<()> {
1059        ensure!(self.defining_recording.is_some());
1060        self.maybe_add_stream_to_recording(cx, stream).await?;
1061        self.try_get_stream(stream)?
1062            .recording_result(cx, result, output_index)
1063            .await
1064    }
1065
1066    async fn call_recording(
1067        &mut self,
1068        cx: &hyperactor::Context<Self>,
1069        seq: Seq,
1070        recording: Ref,
1071        results: Vec<Ref>,
1072        actuals: Vec<Ref>,
1073    ) -> Result<()> {
1074        ensure!(self.defining_recording.is_none());
1075        let recording_ref = recording;
1076        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1077            "could not find recording: {:#?}",
1078            recording
1079        ))?;
1080        match recording {
1081            Recording::PartialRecording { .. } => {
1082                bail!("cannot call recording because it is incomplete")
1083            }
1084            Recording::CompleteRecording { streams } => try_join_all(
1085                streams
1086                    .iter()
1087                    .map(|stream| self.try_get_stream(*stream))
1088                    .collect::<Result<Vec<_>>>()?
1089                    .into_iter()
1090                    .map(|stream| {
1091                        stream.call_recording(
1092                            cx,
1093                            seq,
1094                            recording_ref,
1095                            results.clone(),
1096                            actuals.clone(),
1097                        )
1098                    }),
1099            )
1100            .await
1101            .map(|_| ()),
1102        }
1103    }
1104}
1105
1106#[cfg(all(test, fbcode_build))]
1107mod tests {
1108    use std::assert_matches;
1109
1110    use anyhow::Result;
1111    use hyperactor::RemoteSpawn;
1112    use hyperactor::channel::ChannelAddr;
1113    use hyperactor::proc::Proc;
1114    use monarch_messages::controller::ControllerMessage;
1115    use monarch_messages::controller::WorkerError;
1116    use monarch_messages::worker::WorkerMessageClient;
1117    use monarch_types::PickledPyObject;
1118    use pyo3::Python;
1119    use pyo3::prelude::*;
1120    use pyo3::types::PyList;
1121    use pyo3::types::PyString;
1122    use rand::RngExt as _;
1123    use rand::distr::Alphanumeric;
1124    use timed_test::async_timed_test;
1125    use torch_sys_cuda::nccl::UniqueIdExt;
1126
1127    use super::*;
1128    use crate::test_util::test_setup;
1129
1130    #[async_timed_test(timeout_secs = 60)]
1131    async fn basic_worker() -> Result<()> {
1132        test_setup()?;
1133
1134        let proc = Proc::isolated();
1135        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1136
1137        let worker_handle = proc
1138            .spawn(
1139                "worker",
1140                WorkerActor::new(
1141                    WorkerParams {
1142                        world_size: 1,
1143                        rank: 0,
1144                        device_index: None,
1145                        controller_actor: controller_ref,
1146                    },
1147                    Flattrs::default(),
1148                )
1149                .await
1150                .unwrap(),
1151            )
1152            .unwrap();
1153        worker_handle
1154            .command_group(
1155                &client,
1156                vec![
1157                    WorkerMessage::CreateStream {
1158                        id: 1.into(),
1159                        stream_creation: StreamCreationMode::UseDefaultStream,
1160                    },
1161                    WorkerMessage::CallFunction(CallFunctionParams {
1162                        seq: 0.into(),
1163                        results: vec![Some(0.into())],
1164                        mutates: vec![],
1165                        function: "torch.ops.aten.ones.default".into(),
1166                        args_kwargs: ArgsKwargs::from_wire_values(
1167                            vec![WireValue::IntList(vec![2, 3])],
1168                            HashMap::new(),
1169                        )
1170                        .unwrap(),
1171                        stream: 1.into(),
1172                        remote_process_groups: vec![],
1173                    }),
1174                    WorkerMessage::CallFunction(CallFunctionParams {
1175                        seq: 2.into(),
1176                        results: vec![Some(Ref { id: 2 })],
1177                        mutates: vec![0.into()],
1178                        function: "torch.ops.aten.sub_.Scalar".into(),
1179                        args_kwargs: ArgsKwargs::from_wire_values(
1180                            vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1181                            HashMap::new(),
1182                        )
1183                        .unwrap(),
1184                        stream: 1.into(),
1185                        remote_process_groups: vec![],
1186                    }),
1187                    WorkerMessage::CallFunction(CallFunctionParams {
1188                        seq: 3.into(),
1189                        results: vec![Some(Ref { id: 3 })],
1190                        mutates: vec![],
1191                        function: "torch.ops.aten.zeros.default".into(),
1192                        args_kwargs: ArgsKwargs::from_wire_values(
1193                            vec![WireValue::IntList(vec![2, 3])],
1194                            HashMap::new(),
1195                        )
1196                        .unwrap(),
1197                        stream: 1.into(),
1198                        remote_process_groups: vec![],
1199                    }),
1200                    WorkerMessage::CallFunction(CallFunctionParams {
1201                        seq: 4.into(),
1202                        results: vec![Some(Ref { id: 4 })],
1203                        mutates: vec![],
1204                        function: "torch.ops.aten.allclose.default".into(),
1205                        args_kwargs: ArgsKwargs::from_wire_values(
1206                            vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1207                            HashMap::new(),
1208                        )
1209                        .unwrap(),
1210                        stream: 1.into(),
1211                        remote_process_groups: vec![],
1212                    }),
1213                ],
1214            )
1215            .await
1216            .unwrap();
1217
1218        let result: bool = worker_handle
1219            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1220            .await
1221            .unwrap()
1222            .unwrap()
1223            .unwrap()
1224            .try_into()
1225            .unwrap();
1226        worker_handle.drain_and_stop("test").unwrap();
1227        worker_handle.await;
1228        let error_responses = controller_rx.drain();
1229        assert!(
1230            error_responses.is_empty(),
1231            "Expected no error responses, got: {:#?}",
1232            error_responses
1233        );
1234        assert!(result);
1235
1236        Ok(())
1237    }
1238
1239    #[async_timed_test(timeout_secs = 60)]
1240    async fn error_sends_response() -> Result<()> {
1241        test_setup()?;
1242
1243        let proc = Proc::isolated();
1244        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1245
1246        let worker_handle = proc
1247            .spawn(
1248                "worker",
1249                WorkerActor::new(
1250                    WorkerParams {
1251                        world_size: 1,
1252                        rank: 0,
1253                        device_index: None,
1254                        controller_actor: controller_ref,
1255                    },
1256                    Flattrs::default(),
1257                )
1258                .await
1259                .unwrap(),
1260            )
1261            .unwrap();
1262        worker_handle
1263            .command_group(
1264                &client,
1265                vec![
1266                    WorkerMessage::CreateStream {
1267                        id: 1.into(),
1268                        stream_creation: StreamCreationMode::UseDefaultStream,
1269                    },
1270                    WorkerMessage::CallFunction(CallFunctionParams {
1271                        seq: 0.into(),
1272                        results: vec![Some(0.into())],
1273                        mutates: vec![],
1274                        function: "torch.ops.aten.rand.default".into(),
1275                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1276                        stream: 1.into(),
1277                        remote_process_groups: vec![],
1278                    }),
1279                    WorkerMessage::Exit { error: None },
1280                ],
1281            )
1282            .await
1283            .unwrap();
1284
1285        worker_handle.drain_and_stop("test").unwrap();
1286        worker_handle.await;
1287        let response_message = controller_rx.recv().await.unwrap();
1288        match response_message {
1289            ControllerMessage::RemoteFunctionFailed {
1290                seq,
1291                error: WorkerError { backtrace: msg, .. },
1292            } => {
1293                assert_eq!(seq, 0.into());
1294                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1295            }
1296            _ => panic!("unexpected response {:#?}", response_message),
1297        }
1298
1299        Ok(())
1300    }
1301
1302    #[async_timed_test(timeout_secs = 60)]
1303    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1304        test_setup()?;
1305
1306        let proc = Proc::isolated();
1307        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1308
1309        let worker_handle = proc
1310            .spawn(
1311                "worker",
1312                WorkerActor::new(
1313                    WorkerParams {
1314                        world_size: 1,
1315                        rank: 0,
1316                        device_index: None,
1317                        controller_actor: controller_ref,
1318                    },
1319                    Flattrs::default(),
1320                )
1321                .await
1322                .unwrap(),
1323            )
1324            .unwrap();
1325        worker_handle
1326            .command_group(
1327                &client,
1328                vec![
1329                    WorkerMessage::CreateStream {
1330                        id: 1.into(),
1331                        stream_creation: StreamCreationMode::UseDefaultStream,
1332                    },
1333                    WorkerMessage::SetRefUnitTestsOnly {
1334                        reference: 0.into(),
1335                        value: WireValue::Int(1),
1336                        stream: 1.into(),
1337                    },
1338                    WorkerMessage::CallFunction(CallFunctionParams {
1339                        seq: 0.into(),
1340                        results: vec![Some(Ref { id: 2 })],
1341                        mutates: vec![0.into()],
1342                        function: "i.dont.exist".into(),
1343                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1344                        stream: 1.into(),
1345                        remote_process_groups: vec![],
1346                    }),
1347                ],
1348            )
1349            .await
1350            .unwrap();
1351
1352        let result = worker_handle
1353            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1354            .await?;
1355
1356        // Stop/drain worker before asserts to avoid hangs.
1357        worker_handle.drain_and_stop("test").unwrap();
1358        worker_handle.await;
1359
1360        let mutated_ref = result
1361            .context("no such ref")?
1362            .err()
1363            .context("expected error")?;
1364        assert!(mutated_ref.contains("failed to resolve function"));
1365
1366        let responses = controller_rx.drain();
1367        assert_eq!(
1368            responses.len(),
1369            1,
1370            "Expected one response, got: {:#?}",
1371            responses
1372        );
1373        Ok(())
1374    }
1375
1376    #[async_timed_test(timeout_secs = 60)]
1377    async fn accessing_errored_dependency() -> Result<()> {
1378        test_setup()?;
1379
1380        let proc = Proc::isolated();
1381        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1382
1383        let worker_handle = proc
1384            .spawn(
1385                "worker",
1386                WorkerActor::new(
1387                    WorkerParams {
1388                        world_size: 1,
1389                        rank: 0,
1390                        device_index: None,
1391                        controller_actor: controller_ref,
1392                    },
1393                    Flattrs::default(),
1394                )
1395                .await
1396                .unwrap(),
1397            )
1398            .unwrap();
1399        worker_handle
1400            .command_group(
1401                &client,
1402                vec![
1403                    WorkerMessage::CreateStream {
1404                        id: 1.into(),
1405                        stream_creation: StreamCreationMode::UseDefaultStream,
1406                    },
1407                    WorkerMessage::CallFunction(CallFunctionParams {
1408                        seq: 0.into(),
1409                        results: vec![Some(0.into())],
1410                        mutates: vec![],
1411                        function: "i.dont.exist".into(),
1412                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1413                        stream: 1.into(),
1414                        remote_process_groups: vec![],
1415                    }),
1416                    WorkerMessage::CallFunction(CallFunctionParams {
1417                        seq: 1.into(),
1418                        results: vec![Some(1.into())],
1419                        mutates: vec![],
1420                        function: "torch.ops.aten.sub_.Scalar".into(),
1421                        args_kwargs: ArgsKwargs::from_wire_values(
1422                            vec![WireValue::Ref(0.into())],
1423                            HashMap::new(),
1424                        )
1425                        .unwrap(),
1426                        stream: 1.into(),
1427                        remote_process_groups: vec![],
1428                    }),
1429                    WorkerMessage::Exit { error: None },
1430                ],
1431            )
1432            .await
1433            .unwrap();
1434
1435        worker_handle.drain_and_stop("test").unwrap();
1436        worker_handle.await;
1437
1438        let responses = controller_rx.drain();
1439        assert_eq!(
1440            responses.len(),
1441            1,
1442            "Expected one response, got: {:#?}",
1443            responses
1444        );
1445
1446        match &responses[0] {
1447            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1448                assert_eq!(seq, &0.into())
1449            }
1450            _ => panic!("unexpected response {:#?}", responses[0]),
1451        };
1452        Ok(())
1453    }
1454
1455    #[async_timed_test(timeout_secs = 60)]
1456    async fn py_remote_function_calls() -> Result<()> {
1457        test_setup()?;
1458
1459        let proc = Proc::isolated();
1460        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1461
1462        let worker_handle = proc
1463            .spawn(
1464                "worker",
1465                WorkerActor::new(
1466                    WorkerParams {
1467                        world_size: 1,
1468                        rank: 0,
1469                        device_index: None,
1470                        controller_actor: controller_ref,
1471                    },
1472                    Flattrs::default(),
1473                )
1474                .await
1475                .unwrap(),
1476            )
1477            .unwrap();
1478        let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1479            Python::attach(|py| {
1480                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1481                    .into_any()
1482                    .try_into()?;
1483                let sort_list: PickledPyObject =
1484                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1485                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1486                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1487                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1488                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1489                let device: PickledPyObject = py
1490                    .import("torch")?
1491                    .getattr("device")?
1492                    .call1(("cuda:1",))?
1493                    .try_into()?;
1494                let memory_format: PickledPyObject = py
1495                    .import("torch")?
1496                    .getattr("contiguous_format")?
1497                    .try_into()?;
1498                PyResult::Ok((
1499                    split_arg,
1500                    sort_list,
1501                    dim,
1502                    layout,
1503                    none,
1504                    scalar,
1505                    device,
1506                    memory_format,
1507                ))
1508            })?;
1509
1510        worker_handle
1511            .command_group(
1512                &client,
1513                vec![
1514                    WorkerMessage::CreateStream {
1515                        id: 1.into(),
1516                        stream_creation: StreamCreationMode::UseDefaultStream,
1517                    },
1518                    WorkerMessage::CallFunction(CallFunctionParams {
1519                        seq: 0.into(),
1520                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1521                        mutates: vec![],
1522                        function: "os.path.split".into(),
1523                        args_kwargs: ArgsKwargs::from_wire_values(
1524                            vec![split_arg.into()],
1525                            HashMap::new(),
1526                        )
1527                        .unwrap(),
1528                        stream: 1.into(),
1529                        remote_process_groups: vec![],
1530                    }),
1531                    WorkerMessage::CallFunction(CallFunctionParams {
1532                        seq: 2.into(),
1533                        results: vec![Some(4.into()), None, None, None, None],
1534                        mutates: vec![],
1535                        function: "builtins.sorted".into(),
1536                        args_kwargs: ArgsKwargs::from_wire_values(
1537                            vec![sort_list.into()],
1538                            HashMap::new(),
1539                        )
1540                        .unwrap(),
1541                        stream: 1.into(),
1542                        remote_process_groups: vec![],
1543                    }),
1544                    WorkerMessage::CreateDeviceMesh {
1545                        result: 5.into(),
1546                        names: vec!["x".into()],
1547                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1548                    },
1549                    WorkerMessage::CallFunction(CallFunctionParams {
1550                        seq: 2.into(),
1551                        results: vec![Some(6.into())],
1552                        mutates: vec![],
1553                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1554                        args_kwargs: ArgsKwargs::from_wire_values(
1555                            vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1556                            HashMap::new(),
1557                        )
1558                        .unwrap(),
1559                        stream: 1.into(),
1560                        remote_process_groups: vec![],
1561                    }),
1562                    WorkerMessage::CallFunction(CallFunctionParams {
1563                        seq: 4.into(),
1564                        results: vec![Some(7.into())],
1565                        mutates: vec![],
1566                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1567                            .into(),
1568                        args_kwargs: ArgsKwargs::from_wire_values(
1569                            vec![scalar.into()],
1570                            HashMap::new(),
1571                        )
1572                        .unwrap(),
1573                        stream: 1.into(),
1574                        remote_process_groups: vec![],
1575                    }),
1576                    WorkerMessage::CallFunction(CallFunctionParams {
1577                        seq: 5.into(),
1578                        results: vec![Some(8.into())],
1579                        mutates: vec![],
1580                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1581                        args_kwargs: ArgsKwargs::from_wire_values(
1582                            vec![layout.into()],
1583                            HashMap::new(),
1584                        )
1585                        .unwrap(),
1586                        stream: 1.into(),
1587                        remote_process_groups: vec![],
1588                    }),
1589                    WorkerMessage::CallFunction(CallFunctionParams {
1590                        seq: 6.into(),
1591                        results: vec![Some(9.into())],
1592                        mutates: vec![],
1593                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1594                        args_kwargs: ArgsKwargs::from_wire_values(
1595                            vec![none.into()],
1596                            HashMap::new(),
1597                        )
1598                        .unwrap(),
1599                        stream: 1.into(),
1600                        remote_process_groups: vec![],
1601                    }),
1602                    // Verify that a function that returns `None` matches up with an
1603                    // empty result list.
1604                    WorkerMessage::CallFunction(CallFunctionParams {
1605                        seq: 7.into(),
1606                        results: vec![None],
1607                        mutates: vec![],
1608                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1609                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1610                        stream: 1.into(),
1611                        remote_process_groups: vec![],
1612                    }),
1613                    WorkerMessage::CallFunction(CallFunctionParams {
1614                        seq: 8.into(),
1615                        results: vec![Some(10.into())],
1616                        mutates: vec![],
1617                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1618                        args_kwargs: ArgsKwargs::from_wire_values(
1619                            vec![device.into()],
1620                            HashMap::new(),
1621                        )
1622                        .unwrap(),
1623                        stream: 1.into(),
1624                        remote_process_groups: vec![],
1625                    }),
1626                    WorkerMessage::CallFunction(CallFunctionParams {
1627                        seq: 9.into(),
1628                        results: vec![Some(11.into())],
1629                        mutates: vec![],
1630                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1631                            .into(),
1632                        args_kwargs: ArgsKwargs::from_wire_values(
1633                            vec![memory_format.into()],
1634                            HashMap::new(),
1635                        )
1636                        .unwrap(),
1637                        stream: 1.into(),
1638                        remote_process_groups: vec![],
1639                    }),
1640                    // Test that list of tests can be passes correctly
1641                    WorkerMessage::CallFunction(CallFunctionParams {
1642                        seq: 10.into(),
1643                        results: vec![Some(12.into())],
1644                        mutates: vec![],
1645                        function: "torch.ops.aten.ones.default".into(),
1646                        args_kwargs: ArgsKwargs::from_wire_values(
1647                            vec![WireValue::IntList(vec![2, 3])],
1648                            HashMap::new(),
1649                        )
1650                        .unwrap(),
1651                        stream: 1.into(),
1652                        remote_process_groups: vec![],
1653                    }),
1654                    WorkerMessage::CallFunction(CallFunctionParams {
1655                        seq: 11.into(),
1656                        results: vec![Some(13.into())],
1657                        mutates: vec![],
1658                        function: "torch.ops.aten.stack.default".into(),
1659                        args_kwargs: ArgsKwargs::from_wire_values(
1660                            vec![WireValue::RefList(vec![12.into(), 12.into()])],
1661                            HashMap::new(),
1662                        )
1663                        .unwrap(),
1664                        stream: 1.into(),
1665                        remote_process_groups: vec![],
1666                    }),
1667                ],
1668            )
1669            .await
1670            .unwrap();
1671
1672        let result1: String = worker_handle
1673            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1674            .await
1675            .unwrap()
1676            .unwrap()
1677            .unwrap()
1678            .try_into()
1679            .unwrap();
1680        let result2: String = worker_handle
1681            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1682            .await
1683            .unwrap()
1684            .unwrap()
1685            .unwrap()
1686            .try_into()
1687            .unwrap();
1688        let result3: i64 = worker_handle
1689            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1690            .await
1691            .unwrap()
1692            .unwrap()
1693            .unwrap()
1694            .try_into()
1695            .unwrap();
1696        let result4: i64 = worker_handle
1697            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1698            .await
1699            .unwrap()
1700            .unwrap()
1701            .unwrap()
1702            .try_into()
1703            .unwrap();
1704        worker_handle
1705            .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1706            .await
1707            .unwrap()
1708            .unwrap()
1709            .unwrap();
1710
1711        worker_handle
1712            .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1713            .await
1714            .unwrap()
1715            .unwrap()
1716            .unwrap();
1717
1718        assert_matches!(
1719            worker_handle
1720                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1721                .await
1722                .unwrap()
1723                .unwrap()
1724                .unwrap(),
1725            WireValue::None(()),
1726        );
1727        worker_handle
1728            .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1729            .await
1730            .unwrap()
1731            .unwrap()
1732            .unwrap();
1733        worker_handle
1734            .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1735            .await
1736            .unwrap()
1737            .unwrap()
1738            .unwrap();
1739
1740        worker_handle.drain_and_stop("test").unwrap();
1741        worker_handle.await;
1742        let error_responses = controller_rx.drain();
1743        assert!(
1744            error_responses.is_empty(),
1745            "Expected no error responses, got: {:#?}",
1746            error_responses
1747        );
1748
1749        assert_eq!(result1, "/fbs/fbc/foo");
1750        assert_eq!(result2, "bar");
1751        assert_eq!(result3, 1);
1752        assert_eq!(result4, 0);
1753
1754        Ok(())
1755    }
1756
1757    #[async_timed_test(timeout_secs = 60)]
1758    async fn delete_refs() -> Result<()> {
1759        test_setup()?;
1760
1761        let proc = Proc::isolated();
1762        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1763
1764        let worker_handle = proc
1765            .spawn(
1766                "worker",
1767                WorkerActor::new(
1768                    WorkerParams {
1769                        world_size: 1,
1770                        rank: 0,
1771                        device_index: None,
1772                        controller_actor: controller_ref,
1773                    },
1774                    Flattrs::default(),
1775                )
1776                .await
1777                .unwrap(),
1778            )
1779            .unwrap();
1780        worker_handle
1781            .command_group(
1782                &client,
1783                vec![
1784                    WorkerMessage::CreateStream {
1785                        id: 0.into(),
1786                        stream_creation: StreamCreationMode::CreateNewStream,
1787                    },
1788                    WorkerMessage::CreateStream {
1789                        id: 1.into(),
1790                        stream_creation: StreamCreationMode::CreateNewStream,
1791                    },
1792                    WorkerMessage::SetRefUnitTestsOnly {
1793                        reference: Ref { id: 2 },
1794                        value: WireValue::Bool(false),
1795                        stream: 0.into(),
1796                    },
1797                    WorkerMessage::SetRefUnitTestsOnly {
1798                        reference: Ref { id: 3 },
1799                        value: WireValue::Bool(true),
1800                        stream: 0.into(),
1801                    },
1802                    WorkerMessage::SetRefUnitTestsOnly {
1803                        reference: Ref { id: 4 },
1804                        value: WireValue::Int(0),
1805                        stream: 1.into(),
1806                    },
1807                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1808                ],
1809            )
1810            .await
1811            .unwrap();
1812
1813        let result: bool = worker_handle
1814            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1815            .await
1816            .unwrap()
1817            .unwrap()
1818            .unwrap()
1819            .try_into()
1820            .unwrap();
1821        let fail_result = worker_handle
1822            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1823            .await
1824            .unwrap();
1825
1826        worker_handle.drain_and_stop("test").unwrap();
1827        worker_handle.await;
1828
1829        assert!(result, "should be able to get a non-deleted ref");
1830        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1831
1832        Ok(())
1833    }
1834
1835    #[async_timed_test(timeout_secs = 60)]
1836    async fn request_status() -> Result<()> {
1837        test_setup()?;
1838
1839        let proc = Proc::isolated();
1840        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1841
1842        let worker_handle = proc
1843            .spawn(
1844                "worker",
1845                WorkerActor::new(
1846                    WorkerParams {
1847                        world_size: 1,
1848                        rank: 0,
1849                        device_index: None,
1850                        controller_actor: controller_ref,
1851                    },
1852                    Flattrs::default(),
1853                )
1854                .await
1855                .unwrap(),
1856            )
1857            .unwrap();
1858        worker_handle
1859            .command_group(
1860                &client,
1861                vec![
1862                    WorkerMessage::CreateStream {
1863                        id: 0.into(),
1864                        stream_creation: StreamCreationMode::CreateNewStream,
1865                    },
1866                    WorkerMessage::CreateStream {
1867                        id: 1.into(),
1868                        stream_creation: StreamCreationMode::CreateNewStream,
1869                    },
1870                ],
1871            )
1872            .await
1873            .unwrap();
1874
1875        for i in 0..100 {
1876            // call alternating functions on this stream.
1877            worker_handle
1878                .call_function(
1879                    &client,
1880                    CallFunctionParams {
1881                        seq: i.into(),
1882                        results: vec![Some(Ref { id: i + 2 })],
1883                        mutates: vec![],
1884                        function: "torch.ops.aten.ones.default".into(),
1885                        args_kwargs: ArgsKwargs::from_wire_values(
1886                            vec![WireValue::IntList(vec![2, 3])],
1887                            HashMap::new(),
1888                        )
1889                        .unwrap(),
1890                        stream: (i % 2).into(),
1891                        remote_process_groups: vec![],
1892                    },
1893                )
1894                .await
1895                .unwrap();
1896        }
1897
1898        worker_handle
1899            .request_status(&client, 100.into(), false)
1900            .await
1901            .unwrap();
1902
1903        worker_handle.drain_and_stop("test").unwrap();
1904        worker_handle.await;
1905
1906        let mut responses = controller_rx.drain();
1907        assert_eq!(
1908            responses.len(),
1909            1,
1910            "Expected one response, got: {:#?}",
1911            responses
1912        );
1913
1914        let response = responses.pop().unwrap();
1915        match response {
1916            ControllerMessage::Status { seq, .. } => {
1917                assert_eq!(seq, 101.into())
1918            }
1919            _ => panic!("unexpected response {:#?}", response),
1920        };
1921
1922        Ok(())
1923    }
1924
1925    #[async_timed_test(timeout_secs = 60)]
1926    async fn backend_network_init() {
1927        test_setup().unwrap();
1928        let proc = Proc::isolated();
1929        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1930
1931        let worker_handle1 = proc
1932            .spawn(
1933                "worker0",
1934                WorkerActor::new(
1935                    WorkerParams {
1936                        world_size: 2,
1937                        rank: 0,
1938                        device_index: Some(0),
1939                        controller_actor: controller_ref.clone(),
1940                    },
1941                    Flattrs::default(),
1942                )
1943                .await
1944                .unwrap(),
1945            )
1946            .unwrap();
1947        let worker_handle2 = proc
1948            .spawn(
1949                "worker1",
1950                WorkerActor::new(
1951                    WorkerParams {
1952                        world_size: 2,
1953                        rank: 1,
1954                        device_index: Some(1),
1955                        controller_actor: controller_ref,
1956                    },
1957                    Flattrs::default(),
1958                )
1959                .await
1960                .unwrap(),
1961            )
1962            .unwrap();
1963
1964        let unique_id = UniqueId::new_nccl().unwrap();
1965        worker_handle1
1966            .backend_network_init(&client, unique_id.clone())
1967            .await
1968            .unwrap();
1969        worker_handle2
1970            .backend_network_init(&client, unique_id)
1971            .await
1972            .unwrap();
1973
1974        worker_handle1.drain_and_stop("test").unwrap();
1975        worker_handle1.await;
1976        worker_handle2.drain_and_stop("test").unwrap();
1977        worker_handle2.await;
1978    }
1979
1980    #[allow(dead_code)]
1981    fn get_random_channel_addr() -> ChannelAddr {
1982        let random_string = rand::rng()
1983            .sample_iter(&Alphanumeric)
1984            .take(24)
1985            .map(char::from)
1986            .collect::<String>();
1987        format!("unix!@{random_string}").parse().unwrap()
1988    }
1989}