monarch_tensor_worker/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
13// and unsafe-op-in-unsafe-fn warnings.
14#![allow(unsafe_op_in_unsafe_fn)]
15
16//! A `hyperactor`-based implementation of a PyTorch worker actor.
17//!
18//! The worker is responsible for executing PyTorch operations on a local
19//! device. It assumes it has exclusive access to device resources, and manages
20//! concurrency internally via device-specific constructs (CUDA stream, threads,
21//! etc.).
22//!
23//! This is a port of `monarch/python/controller/worker.py` but does have gaps due
24//! to drift that needs to be reconciled.
25//! This mainly includes:
26//! - Support for record and replay
27//! - debugger support
28//! - general drift in exisitng messages
29
30mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::ActorRef;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::RemoteSpawn;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor::reference::ActorId;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys_cuda::nccl::ReduceOp;
93use torch_sys_cuda::nccl::UniqueId;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104    device_mesh_ref: Ref,
105    dims: SortedVec<String>,
106    comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110    fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111        Self {
112            device_mesh_ref,
113            dims,
114            comms: HashMap::new(),
115        }
116    }
117}
118
119#[derive(Debug)]
120enum Recording {
121    // In the process of receiving DefineRecording messages for this
122    // recording.
123    PartialRecording {
124        // The index of the last DefineRecording message received.
125        last_index: usize,
126        // The list of commands seen so far for this recording.
127        commands: Vec<WorkerMessage>,
128    },
129
130    // The recording is ready to be run.
131    CompleteRecording {
132        // The list of streams on which this recording is defined.
133        streams: HashSet<StreamRef>,
134    },
135}
136
137/// A PyTorch runtime instance, operating on a single accelerator device,
138/// controlled via hyperactor messaging.
139///
140/// Generally this is a thin multiplexer over a set of [`Stream`]s that do the
141/// real work.
142///
143/// See [`WorkerMessage`] for what it can do!
144#[derive(Debug)]
145#[hyperactor::export(
146    spawn = true,
147    handlers = [
148        WorkerMessage {cast = true},
149        AssignRankMessage {cast = true},
150    ],
151)]
152pub struct WorkerActor {
153    device: Option<CudaDevice>,
154    streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155    /// Maps streams to the device mesh and a map of dim names to the concrete
156    /// communicator actor that represents the dimension for that stream.
157    device_meshes: HashMap<
158        Ref,
159        (
160            DeviceMesh,
161            // A map for comms for this mesh for a given pair of stream and dims.
162            HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163        ),
164    >,
165    world_size: usize,
166    rank: usize,
167    borrows: HashMap<u64, Borrow>,
168    comm: Option<ActorHandle<NcclCommActor>>,
169    controller_actor: ActorRef<ControllerActor>,
170    /// Remember the process groups "created" via `CreateRemoteProcessGroup` for
171    /// subsequent `CallFunction` calls, as this is where the actual allocation
172    /// will happen.
173    remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174    /// The comm actor for each pair of streams that need to send/recv tensors.
175    send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176    recordings: HashMap<Ref, Recording>,
177    defining_recording: Option<Ref>,
178    respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182    fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
183        self.streams
184            .get(&stream)
185            .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
186    }
187
188    async fn maybe_add_stream_to_recording(
189        &mut self,
190        cx: &impl context::Actor,
191        stream: StreamRef,
192    ) -> Result<()> {
193        // If we're defining a recording, add the stream to the list of streams that
194        // this recording uses, and call define_recording on the stream.
195        if let Some(defining_recording) = self.defining_recording {
196            let recording = self.recordings.get_mut(&defining_recording).unwrap();
197            let fut = match recording {
198                Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
199                Recording::CompleteRecording { streams } => {
200                    streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
201                        Ok(self
202                            .try_get_stream(stream)?
203                            .define_recording(cx, defining_recording))
204                    })
205                }
206            }
207            .transpose()?;
208            match fut {
209                Some(fut) => fut.await,
210                None => Ok(()),
211            }
212        } else {
213            Ok(())
214        }
215    }
216}
217
218impl Actor for WorkerActor {}
219
220#[async_trait]
221impl RemoteSpawn for WorkerActor {
222    type Params = WorkerParams;
223
224    async fn new(
225        WorkerParams {
226            world_size,
227            rank,
228            device_index,
229            controller_actor,
230        }: Self::Params,
231    ) -> Result<Self> {
232        Python::with_gil(|py| {
233            py.import("monarch.safe_torch").unwrap();
234        });
235        Ok(Self {
236            device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
237            streams: HashMap::new(),
238            device_meshes: HashMap::new(),
239            world_size,
240            rank,
241            borrows: HashMap::new(),
242            comm: None,
243            controller_actor,
244            remote_process_groups: HashMap::new(),
245            send_recv_comms: HashMap::new(),
246            recordings: HashMap::new(),
247            defining_recording: None,
248            respond_with_python_message: false,
249        })
250    }
251
252    // TODO: Exit the worker directly on any worker actor errors, with error exit code.
253}
254
255#[async_trait]
256impl Handler<AssignRankMessage> for WorkerActor {
257    async fn handle(
258        &mut self,
259        cx: &hyperactor::Context<Self>,
260        _: AssignRankMessage,
261    ) -> anyhow::Result<()> {
262        let point = cx.cast_point();
263        self.rank = point.rank();
264        self.respond_with_python_message = true;
265        Python::with_gil(|py| {
266            let mesh_controller = py.import("monarch.mesh_controller").unwrap();
267            let p: PyPoint = point.into();
268            mesh_controller
269                .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
270                .unwrap();
271        });
272        Ok(())
273    }
274}
275
276/// Worker messages. These define the observable behavior of the worker, so the
277/// documentations here
278#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
279pub enum AssignRankMessage {
280    AssignRank(),
281}
282wirevalue::register_type!(AssignRankMessage);
283
284#[async_trait]
285impl Handler<WorkerMessage> for WorkerActor {
286    async fn handle(
287        &mut self,
288        cx: &hyperactor::Context<Self>,
289        message: WorkerMessage,
290    ) -> anyhow::Result<()> {
291        <Self as WorkerMessageHandler>::handle(self, cx, message).await
292    }
293}
294
295#[async_trait]
296impl WorkerMessageHandler for WorkerActor {
297    async fn backend_network_init(
298        &mut self,
299        cx: &hyperactor::Context<Self>,
300        unique_id: UniqueId,
301    ) -> Result<()> {
302        let device = self
303            .device
304            .expect("tried to init backend network on a non-CUDA worker");
305        let comm = NcclCommActor::new(CommParams::New {
306            device,
307            unique_id,
308            world_size: self.world_size.try_into().unwrap(),
309            rank: self.rank.try_into().unwrap(),
310        })
311        .await?
312        .spawn(cx)?;
313
314        let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
315        let cell = TensorCell::new(tensor);
316
317        comm.all_reduce(
318            cx,
319            cell,
320            ReduceOp::Sum,
321            torch_sys_cuda::cuda::Stream::get_current_stream(),
322        )
323        .await?;
324
325        // TODO: this blocks forward progress of the the actor loop while we
326        // wait for the streams to catch up. Once we have a way of spawning
327        // tasks that the actor system can monitor in a non-blocking way, we
328        // should remove this.
329
330        // We need to be careful to initialize the streams in a consistent order
331        // across all workers to avoid NCCL deadlocks. Use the refs to provide
332        // this order, as a stream's ref will be the same across all workers.
333        let sorted_streams = self
334            .streams
335            .iter()
336            .sorted_by_key(|(k, _)| *k)
337            .map(|(_, v)| v.as_ref());
338
339        let mut splits = Vec::new();
340        for _ in 0..sorted_streams.len() {
341            // Do the split in this event loop, to provide a deterministic
342            // order.
343            splits.push(comm.split_all(cx).await?);
344        }
345        let _: Vec<()> = try_join_all(
346            sorted_streams
347                .into_iter()
348                .zip(splits.into_iter())
349                .map(|(stream, split)| stream.init_comm(cx, split)),
350        )
351        .await?;
352
353        self.comm = Some(comm);
354
355        Ok(())
356    }
357
358    async fn backend_network_point_to_point_init(
359        &mut self,
360        cx: &hyperactor::Context<Self>,
361        from_stream: StreamRef,
362        to_stream: StreamRef,
363    ) -> Result<()> {
364        if !self.streams.contains_key(&from_stream) {
365            bail!("invalid from_stream id: {:#?}", from_stream);
366        }
367        if !self.streams.contains_key(&to_stream) {
368            bail!("invalid to_stream id: {:#?}", to_stream);
369        }
370        let global_comm = self
371            .comm
372            .as_ref()
373            .context("tried to call Reduce before BackendNetworkInit")?;
374        let comm = global_comm.split_all(cx).await?;
375        self.send_recv_comms
376            .insert((from_stream, to_stream), Arc::new(comm));
377        Ok(())
378    }
379
380    async fn call_function(
381        &mut self,
382        cx: &hyperactor::Context<Self>,
383        params: CallFunctionParams,
384    ) -> Result<()> {
385        let stream = self.try_get_stream(params.stream)?.clone();
386        self.maybe_add_stream_to_recording(cx, params.stream)
387            .await?;
388
389        let device_meshes = self
390            .device_meshes
391            .iter()
392            .map(|(k, v)| (k.clone(), v.0.clone()))
393            .collect();
394
395        let mut remote_process_groups = HashMap::new();
396        for remote_process_group_ref in &params.remote_process_groups {
397            if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
398                let dims_vec = state.dims.iter().cloned().collect();
399                let (device_mesh, _) = self
400                    .device_meshes
401                    .get(&state.device_mesh_ref)
402                    .ok_or_else(|| {
403                        anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
404                    })?
405                    .clone();
406                let comm = state.comms
407                    .get(&params.stream)
408                    .ok_or_else(|| {
409                        anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
410                    })?
411                    .clone();
412                remote_process_groups.insert(
413                    remote_process_group_ref.clone(),
414                    (device_mesh, dims_vec, comm),
415                );
416            }
417        }
418
419        stream
420            .call_function(cx, params, device_meshes, remote_process_groups)
421            .await?;
422
423        Ok(())
424    }
425
426    async fn command_group(
427        &mut self,
428        cx: &hyperactor::Context<Self>,
429        params: Vec<WorkerMessage>,
430    ) -> Result<()> {
431        for msg in params {
432            WorkerMessageHandler::handle(self, cx, msg).await?;
433        }
434        Ok(())
435    }
436
437    async fn create_stream(
438        &mut self,
439        cx: &hyperactor::Context<Self>,
440        result: StreamRef,
441        creation_mode: StreamCreationMode,
442    ) -> Result<()> {
443        let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
444            world_size: self.world_size,
445            rank: self.rank,
446            creation_mode,
447            id: result,
448            device: self.device,
449            controller_actor: self.controller_actor.clone(),
450            respond_with_python_message: self.respond_with_python_message,
451        })
452        .spawn(cx)?;
453        self.streams.insert(result, Arc::new(handle));
454        Ok(())
455    }
456
457    async fn create_device_mesh(
458        &mut self,
459        _cx: &hyperactor::Context<Self>,
460        result: Ref,
461        names: Vec<String>,
462        ranks: Slice,
463    ) -> Result<()> {
464        self.device_meshes.insert(
465            result,
466            (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
467        );
468        Ok(())
469    }
470
471    async fn create_remote_process_group(
472        &mut self,
473        _cx: &hyperactor::Context<Self>,
474        result: Ref,
475        device_mesh: Ref,
476        dims: Vec<String>,
477    ) -> Result<()> {
478        self.device_meshes
479            .get(&device_mesh)
480            .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
481        match self.remote_process_groups.entry(result) {
482            Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
483                device_mesh,
484                SortedVec::from_unsorted(dims),
485            )),
486            Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
487        };
488        Ok(())
489    }
490
491    async fn borrow_create(
492        &mut self,
493        cx: &hyperactor::Context<Self>,
494        result: Ref,
495        borrow_id: u64,
496        tensor_ref: Ref,
497        from_stream: StreamRef,
498        to_stream: StreamRef,
499    ) -> Result<()> {
500        self.maybe_add_stream_to_recording(cx, from_stream).await?;
501        self.maybe_add_stream_to_recording(cx, to_stream).await?;
502        let from_stream = self.try_get_stream(from_stream)?.clone();
503        let to_stream = self.try_get_stream(to_stream)?.clone();
504
505        let borrow =
506            Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
507        self.borrows.insert(borrow_id, borrow);
508        Ok(())
509    }
510
511    async fn borrow_first_use(
512        &mut self,
513        cx: &hyperactor::Context<Self>,
514        borrow: u64,
515    ) -> Result<()> {
516        let borrow = self
517            .borrows
518            .get_mut(&borrow)
519            .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
520
521        borrow.first_use(cx).await?;
522        Ok(())
523    }
524
525    async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
526        let borrow = self
527            .borrows
528            .get_mut(&borrow)
529            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
530
531        borrow.last_use(cx).await?;
532        Ok(())
533    }
534
535    async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
536        let borrow = self
537            .borrows
538            .get_mut(&borrow_id)
539            .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
540
541        borrow.drop(cx).await?;
542        self.borrows.remove(&borrow_id);
543        Ok(())
544    }
545
546    async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
547        // Fan the delete message to all streams.
548        // Check for errors.
549        // TODO: this blocks forward progress of the the actor loop while we
550        // wait for the streams to catch up. Once we have a way of spawning
551        // tasks that the actor system can monitor in a non-blocking way, we
552        // should remove this.
553        let _: Vec<()> = try_join_all(
554            self.streams
555                .values()
556                .map(|s| s.delete_refs(cx, refs.clone())),
557        )
558        .await?;
559        Ok(())
560    }
561
562    async fn request_status(
563        &mut self,
564        cx: &hyperactor::Context<Self>,
565        seq: Seq,
566        controller: bool,
567    ) -> Result<()> {
568        // TODO: this blocks forward progress of the the actor loop while we
569        // wait for the streams to catch up. Once we have a way of spawning
570        // tasks that the actor system can monitor in a non-blocking way, we
571        // should remove this.
572        let _: Vec<()> = try_join_all(
573            self.streams
574                .values()
575                .map(|stream| stream.request_status(cx)),
576        )
577        .await?;
578
579        ControllerMessageClient::status(
580            &self.controller_actor,
581            cx,
582            seq.next(),
583            cx.self_id().clone(),
584            controller,
585        )
586        .await?;
587        Ok(())
588    }
589
590    async fn reduce(
591        &mut self,
592        cx: &hyperactor::Context<Self>,
593        result: Ref,
594        local_tensor: Ref,
595        factory: Factory,
596        source_mesh: Ref,
597        stream_ref: StreamRef,
598        dims: Vec<String>,
599        reduction: Reduction,
600        scatter: bool,
601        in_place: bool,
602        out: Option<Ref>,
603    ) -> Result<()> {
604        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
605
606        // Sort for stable indexing.
607        let dims = SortedVec::from_unsorted(dims);
608        let stream = self.try_get_stream(stream_ref)?.clone();
609
610        let (_, comm_map) = self
611            .device_meshes
612            .get_mut(&source_mesh)
613            .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
614
615        let (size, comm) = comm_map
616            .get(&(stream_ref, dims.clone()))
617            .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
618            .clone();
619
620        stream
621            .reduce(
622                cx,
623                comm,
624                size.try_into()?,
625                result,
626                local_tensor,
627                factory,
628                reduction,
629                scatter,
630                in_place,
631                out,
632            )
633            .await?;
634
635        Ok(())
636    }
637
638    async fn send_tensor(
639        &mut self,
640        cx: &hyperactor::Context<Self>,
641        result: Ref,
642        from_ranks: Slice,
643        to_ranks: Slice,
644        tensor: Ref,
645        factory: Factory,
646        from_stream: StreamRef,
647        to_stream: StreamRef,
648    ) -> Result<()> {
649        let comm = self
650            .send_recv_comms
651            .get(&(from_stream, to_stream))
652            .ok_or_else(|| {
653                anyhow::anyhow!(
654                    "could not find stream to stream comm for: {:#?}",
655                    (from_stream, to_stream)
656                )
657            })?
658            .clone();
659
660        let to_rank = from_ranks
661            .index(self.rank)
662            .map(|index| to_ranks.get(index).ok())
663            .ok()
664            .flatten();
665        let from_rank = to_ranks
666            .index(self.rank)
667            .map(|index| from_ranks.get(index).ok())
668            .ok()
669            .flatten();
670
671        let (stream, stream_ref) = if to_rank.is_none() {
672            (self.try_get_stream(to_stream)?.clone(), to_stream)
673        } else if from_rank.is_none() || from_stream == to_stream {
674            (self.try_get_stream(from_stream)?.clone(), from_stream)
675        } else {
676            unimplemented!(
677                "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
678                It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
679                Then the send stream would do the nccl op, and then sync with sending stream again."
680            );
681        };
682
683        self.maybe_add_stream_to_recording(cx, stream_ref).await?;
684
685        stream
686            .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
687            .await?;
688
689        Ok(())
690    }
691
692    async fn exit(
693        &mut self,
694        cx: &hyperactor::Context<Self>,
695        error: Option<(Option<ActorId>, String)>,
696    ) -> Result<()> {
697        for (_, stream) in self.streams.drain() {
698            stream.drain_and_stop()?;
699            Arc::into_inner(stream)
700                .expect("there should be no owners of this stream handle except the worker stream table")
701                .await;
702        }
703
704        let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
705            .ok()
706            .and_then(|val| val.parse::<i32>().ok())
707            .unwrap_or(1);
708        let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
709            .ok()
710            .and_then(|val| val.parse::<i32>().ok())
711            .unwrap_or(1);
712
713        // Exit the worker process if there is an error.
714        let exit_code = match error {
715            Some((Some(actor_id), reason)) => {
716                tracing::error!(
717                    "stopping the worker, actor {} failed with error: {}",
718                    actor_id,
719                    reason
720                );
721                if *cx.self_id() == actor_id {
722                    self_error_exit_code
723                } else {
724                    peer_error_exit_code
725                }
726            }
727            Some((None, reason)) => {
728                tracing::error!("stopping the worker, reason: {}", reason);
729                1
730            }
731            None => 0,
732        };
733
734        if exit_code != 0 {
735            tracing::info!("stopping the worker process, exit code: {}", exit_code);
736            std::process::exit(exit_code);
737        }
738        cx.stop()?;
739        Ok(())
740    }
741
742    async fn send_value(
743        &mut self,
744        cx: &hyperactor::Context<Self>,
745        seq: Seq,
746        destination: Option<Ref>,
747        mutates: Vec<Ref>,
748        function: Option<ResolvableFunction>,
749        args_kwargs: ArgsKwargs,
750        stream: StreamRef,
751    ) -> Result<()> {
752        // Resolve the stream.
753        let stream = self.try_get_stream(stream)?;
754
755        let device_meshes = if function.is_none() {
756            HashMap::new()
757        } else {
758            self.device_meshes
759                .iter()
760                .map(|(k, v)| (k.clone(), v.0.clone()))
761                .collect()
762        };
763
764        if destination.is_some() {
765            panic!("send_value with pipe destination is no longer implemented")
766        }
767
768        // Resolve the value on the stream, then send the value back to the controller.
769        stream
770            .send_value(
771                cx,
772                seq,
773                cx.self_id().clone(),
774                mutates,
775                function,
776                args_kwargs,
777                device_meshes,
778            )
779            .await
780    }
781
782    async fn send_result_of_actor_call(
783        &mut self,
784        cx: &hyperactor::Context<Self>,
785        params: ActorCallParams,
786    ) -> Result<()> {
787        let stream = self.try_get_stream(params.stream)?;
788        stream.send_result_of_actor_call(cx, params).await?;
789        Ok(())
790    }
791    async fn call_actor_method(
792        &mut self,
793        cx: &hyperactor::Context<Self>,
794        params: ActorMethodParams,
795    ) -> Result<()> {
796        let stream = self.try_get_stream(params.call.stream)?;
797        stream.call_actor_method(cx, params).await?;
798        Ok(())
799    }
800    async fn split_comm(
801        &mut self,
802        cx: &hyperactor::Context<Self>,
803        dims: Vec<String>,
804        device_mesh: Ref,
805        stream_ref: StreamRef,
806    ) -> Result<()> {
807        let global_comm = self
808            .comm
809            .as_ref()
810            .context("tried to call SplitComm before BackendNetworkInit")?;
811        match self.device_meshes.get_mut(&device_mesh) {
812            Some((device_mesh, comm_map)) => {
813                // This rank is in the group to be split off. Split a new
814                // communicator for it off from the global communicator.
815                let stream = self
816                    .streams
817                    .get(&stream_ref)
818                    .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
819
820                let dims = SortedVec::from_unsorted(dims);
821
822                anyhow::ensure!(
823                    !comm_map.contains_key(&(stream_ref, dims.clone())),
824                    "comm already exists for stream {stream:#?}, dims {dims:#?}"
825                );
826                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
827                let size = ranks_for_group.len();
828                let split_comm = global_comm
829                    .split_from(
830                        cx,
831                        ranks_for_group
832                            .into_iter()
833                            .map(|v| v.clone().try_into())
834                            .collect::<Result<Vec<_>, _>>()?,
835                    )
836                    .await?
837                    .context("split comm should include self rank")?;
838                comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
839            }
840            None => {
841                // This rank is not in the group to be split off. We still need to
842                // participate in the commSplit call, however.
843                global_comm.split_from(cx, vec![]).await?;
844            }
845        }
846        Ok(())
847    }
848
849    async fn split_comm_for_process_group(
850        &mut self,
851        cx: &hyperactor::Context<Self>,
852        remote_process_group_ref: Ref,
853        stream_ref: StreamRef,
854    ) -> Result<()> {
855        ensure!(
856            self.streams.contains_key(&stream_ref),
857            "invalid stream id: {:#?}",
858            stream_ref
859        );
860        let global_comm = self
861            .comm
862            .as_ref()
863            .context("tried to call SplitComm before BackendNetworkInit")?;
864        let state = self
865            .remote_process_groups
866            .get_mut(&remote_process_group_ref)
867            .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
868        match self.device_meshes.get_mut(&state.device_mesh_ref) {
869            Some((device_mesh, _)) => {
870                // This rank is in the group to be split off. Split a new
871                // communicator for it off from the global communicator.
872                let entry = match state.comms.entry(stream_ref) {
873                    Entry::Vacant(entry) => entry,
874                    Entry::Occupied(_) => bail!(
875                        "comm already exists for remote process group {:#?} on stream {:#?}",
876                        remote_process_group_ref,
877                        stream_ref,
878                    ),
879                };
880                let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
881                let split_comm = global_comm
882                    .split_from(
883                        cx,
884                        ranks_for_group
885                            .into_iter()
886                            .map(|v| v.clone().try_into())
887                            .collect::<Result<Vec<_>, _>>()?,
888                    )
889                    .await?
890                    .context("split comm should include self rank")?;
891                entry.insert(Arc::new(split_comm));
892            }
893            None => {
894                // This rank is not in the group to be split off. We still need to
895                // participate in the commSplit call, however.
896                global_comm.split_from(cx, vec![]).await?;
897            }
898        }
899        Ok(())
900    }
901
902    async fn pipe_recv(
903        &mut self,
904        _cx: &hyperactor::Context<Self>,
905        _seq: Seq,
906        _results: Vec<Option<Ref>>,
907        _pipe: Ref,
908        _stream: StreamRef,
909    ) -> Result<()> {
910        panic!("pipe_recv is no longer implemented")
911    }
912
913    async fn set_ref_unit_tests_only(
914        &mut self,
915        cx: &hyperactor::Context<Self>,
916        reference: Ref,
917        value: WireValue,
918        stream: StreamRef,
919    ) -> Result<()> {
920        let stream = self.try_get_stream(stream)?;
921
922        stream.set_ref_unit_tests_only(cx, reference, value).await
923    }
924
925    async fn get_ref_unit_tests_only(
926        &mut self,
927        cx: &hyperactor::Context<Self>,
928        ref_id: Ref,
929        stream: StreamRef,
930    ) -> Result<Option<Result<WireValue, String>>> {
931        let stream = self.try_get_stream(stream)?;
932        Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
933    }
934
935    async fn define_recording(
936        &mut self,
937        cx: &hyperactor::Context<Self>,
938        result: Ref,
939        _nresults: usize,
940        _nformals: usize,
941        commands: Vec<WorkerMessage>,
942        ntotal_messages: usize,
943        index: usize,
944    ) -> Result<()> {
945        if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
946            bail!("already defining a different recording");
947        }
948        self.defining_recording = Some(result);
949
950        match self.recordings.entry(result) {
951            Entry::Vacant(entry) => {
952                ensure!(
953                    index == 0,
954                    "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
955                    index
956                );
957                entry.insert(Recording::PartialRecording {
958                    last_index: 0,
959                    commands,
960                });
961            }
962            Entry::Occupied(mut entry) => match entry.get_mut() {
963                Recording::CompleteRecording { .. } => {
964                    bail!("got DefineRecording message for already complete recording")
965                }
966                Recording::PartialRecording {
967                    last_index,
968                    commands: existing_commands,
969                } => {
970                    ensure!(
971                        index == *last_index + 1,
972                        "Got DefineRecording message with index = {:?}, but \
973                            last seen index for recording is {:?}",
974                        index,
975                        last_index
976                    );
977                    *last_index = index;
978                    existing_commands.extend(commands);
979                }
980            },
981        };
982
983        if index < ntotal_messages - 1 {
984            return Ok(());
985        }
986        let commands = match self.recordings.remove(&result).unwrap() {
987            Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
988            Recording::PartialRecording { commands, .. } => {
989                self.recordings.insert(
990                    result,
991                    Recording::CompleteRecording {
992                        streams: HashSet::new(),
993                    },
994                );
995                commands
996            }
997        };
998
999        for command in commands {
1000            WorkerMessageHandler::handle(self, cx, command).await?;
1001        }
1002
1003        match self.recordings.get(&result).unwrap() {
1004            Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1005            Recording::CompleteRecording { streams, .. } => {
1006                for stream in streams {
1007                    self.try_get_stream(*stream)?
1008                        .finalize_recording(cx, result)
1009                        .await?;
1010                }
1011            }
1012        }
1013
1014        self.defining_recording = None;
1015        Ok(())
1016    }
1017
1018    async fn recording_formal(
1019        &mut self,
1020        cx: &hyperactor::Context<Self>,
1021        result: Ref,
1022        argument_index: usize,
1023        stream: StreamRef,
1024    ) -> Result<()> {
1025        ensure!(self.defining_recording.is_some());
1026        self.maybe_add_stream_to_recording(cx, stream).await?;
1027        self.try_get_stream(stream)?
1028            .recording_formal(cx, result, argument_index)
1029            .await
1030    }
1031
1032    async fn recording_result(
1033        &mut self,
1034        cx: &hyperactor::Context<Self>,
1035        result: Ref,
1036        output_index: usize,
1037        stream: StreamRef,
1038    ) -> Result<()> {
1039        ensure!(self.defining_recording.is_some());
1040        self.maybe_add_stream_to_recording(cx, stream).await?;
1041        self.try_get_stream(stream)?
1042            .recording_result(cx, result, output_index)
1043            .await
1044    }
1045
1046    async fn call_recording(
1047        &mut self,
1048        cx: &hyperactor::Context<Self>,
1049        seq: Seq,
1050        recording: Ref,
1051        results: Vec<Ref>,
1052        actuals: Vec<Ref>,
1053    ) -> Result<()> {
1054        ensure!(self.defining_recording.is_none());
1055        let recording_ref = recording;
1056        let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1057            "could not find recording: {:#?}",
1058            recording
1059        ))?;
1060        match recording {
1061            Recording::PartialRecording { .. } => {
1062                bail!("cannot call recording because it is incomplete")
1063            }
1064            Recording::CompleteRecording { streams } => try_join_all(
1065                streams
1066                    .iter()
1067                    .map(|stream| self.try_get_stream(*stream))
1068                    .collect::<Result<Vec<_>>>()?
1069                    .into_iter()
1070                    .map(|stream| {
1071                        stream.call_recording(
1072                            cx,
1073                            seq,
1074                            recording_ref,
1075                            results.clone(),
1076                            actuals.clone(),
1077                        )
1078                    }),
1079            )
1080            .await
1081            .map(|_| ()),
1082        }
1083    }
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088    use std::assert_matches::assert_matches;
1089
1090    use anyhow::Result;
1091    use hyperactor::RemoteSpawn;
1092    use hyperactor::channel::ChannelAddr;
1093    use hyperactor::proc::Proc;
1094    use monarch_messages::controller::ControllerMessage;
1095    use monarch_messages::controller::WorkerError;
1096    use monarch_messages::worker::WorkerMessageClient;
1097    use monarch_types::PickledPyObject;
1098    use pyo3::Python;
1099    use pyo3::prelude::*;
1100    use pyo3::types::PyList;
1101    use pyo3::types::PyString;
1102    use rand::Rng;
1103    use rand::distributions::Alphanumeric;
1104    use timed_test::async_timed_test;
1105
1106    use super::*;
1107    use crate::test_util::test_setup;
1108
1109    #[async_timed_test(timeout_secs = 60)]
1110    async fn basic_worker() -> Result<()> {
1111        test_setup()?;
1112
1113        let proc = Proc::local();
1114        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1115
1116        let worker_handle = proc
1117            .spawn(
1118                "worker",
1119                WorkerActor::new(WorkerParams {
1120                    world_size: 1,
1121                    rank: 0,
1122                    device_index: None,
1123                    controller_actor: controller_ref,
1124                })
1125                .await
1126                .unwrap(),
1127            )
1128            .unwrap();
1129        worker_handle
1130            .command_group(
1131                &client,
1132                vec![
1133                    WorkerMessage::CreateStream {
1134                        id: 1.into(),
1135                        stream_creation: StreamCreationMode::UseDefaultStream,
1136                    },
1137                    WorkerMessage::CallFunction(CallFunctionParams {
1138                        seq: 0.into(),
1139                        results: vec![Some(0.into())],
1140                        mutates: vec![],
1141                        function: "torch.ops.aten.ones.default".into(),
1142                        args_kwargs: ArgsKwargs::from_wire_values(
1143                            vec![WireValue::IntList(vec![2, 3])],
1144                            HashMap::new(),
1145                        )
1146                        .unwrap(),
1147                        stream: 1.into(),
1148                        remote_process_groups: vec![],
1149                    }),
1150                    WorkerMessage::CallFunction(CallFunctionParams {
1151                        seq: 2.into(),
1152                        results: vec![Some(Ref { id: 2 })],
1153                        mutates: vec![0.into()],
1154                        function: "torch.ops.aten.sub_.Scalar".into(),
1155                        args_kwargs: ArgsKwargs::from_wire_values(
1156                            vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1157                            HashMap::new(),
1158                        )
1159                        .unwrap(),
1160                        stream: 1.into(),
1161                        remote_process_groups: vec![],
1162                    }),
1163                    WorkerMessage::CallFunction(CallFunctionParams {
1164                        seq: 3.into(),
1165                        results: vec![Some(Ref { id: 3 })],
1166                        mutates: vec![],
1167                        function: "torch.ops.aten.zeros.default".into(),
1168                        args_kwargs: ArgsKwargs::from_wire_values(
1169                            vec![WireValue::IntList(vec![2, 3])],
1170                            HashMap::new(),
1171                        )
1172                        .unwrap(),
1173                        stream: 1.into(),
1174                        remote_process_groups: vec![],
1175                    }),
1176                    WorkerMessage::CallFunction(CallFunctionParams {
1177                        seq: 4.into(),
1178                        results: vec![Some(Ref { id: 4 })],
1179                        mutates: vec![],
1180                        function: "torch.ops.aten.allclose.default".into(),
1181                        args_kwargs: ArgsKwargs::from_wire_values(
1182                            vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1183                            HashMap::new(),
1184                        )
1185                        .unwrap(),
1186                        stream: 1.into(),
1187                        remote_process_groups: vec![],
1188                    }),
1189                ],
1190            )
1191            .await
1192            .unwrap();
1193
1194        let result: bool = worker_handle
1195            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1196            .await
1197            .unwrap()
1198            .unwrap()
1199            .unwrap()
1200            .try_into()
1201            .unwrap();
1202        worker_handle.drain_and_stop().unwrap();
1203        worker_handle.await;
1204        let error_responses = controller_rx.drain();
1205        assert!(
1206            error_responses.is_empty(),
1207            "Expected no error responses, got: {:#?}",
1208            error_responses
1209        );
1210        assert!(result);
1211
1212        Ok(())
1213    }
1214
1215    #[async_timed_test(timeout_secs = 60)]
1216    async fn error_sends_response() -> Result<()> {
1217        test_setup()?;
1218
1219        let proc = Proc::local();
1220        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1221
1222        let worker_handle = proc
1223            .spawn(
1224                "worker",
1225                WorkerActor::new(WorkerParams {
1226                    world_size: 1,
1227                    rank: 0,
1228                    device_index: None,
1229                    controller_actor: controller_ref,
1230                })
1231                .await
1232                .unwrap(),
1233            )
1234            .unwrap();
1235        worker_handle
1236            .command_group(
1237                &client,
1238                vec![
1239                    WorkerMessage::CreateStream {
1240                        id: 1.into(),
1241                        stream_creation: StreamCreationMode::UseDefaultStream,
1242                    },
1243                    WorkerMessage::CallFunction(CallFunctionParams {
1244                        seq: 0.into(),
1245                        results: vec![Some(0.into())],
1246                        mutates: vec![],
1247                        function: "torch.ops.aten.rand.default".into(),
1248                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1249                        stream: 1.into(),
1250                        remote_process_groups: vec![],
1251                    }),
1252                    WorkerMessage::Exit { error: None },
1253                ],
1254            )
1255            .await
1256            .unwrap();
1257
1258        worker_handle.drain_and_stop().unwrap();
1259        worker_handle.await;
1260        let response_message = controller_rx.recv().await.unwrap();
1261        match response_message {
1262            ControllerMessage::RemoteFunctionFailed {
1263                seq,
1264                error: WorkerError { backtrace: msg, .. },
1265            } => {
1266                assert_eq!(seq, 0.into());
1267                assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1268            }
1269            _ => panic!("unexpected response {:#?}", response_message),
1270        }
1271
1272        Ok(())
1273    }
1274
1275    #[async_timed_test(timeout_secs = 60)]
1276    async fn mutated_refs_are_updated_with_error() -> Result<()> {
1277        test_setup()?;
1278
1279        let proc = Proc::local();
1280        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1281
1282        let worker_handle = proc
1283            .spawn(
1284                "worker",
1285                WorkerActor::new(WorkerParams {
1286                    world_size: 1,
1287                    rank: 0,
1288                    device_index: None,
1289                    controller_actor: controller_ref,
1290                })
1291                .await
1292                .unwrap(),
1293            )
1294            .unwrap();
1295        worker_handle
1296            .command_group(
1297                &client,
1298                vec![
1299                    WorkerMessage::CreateStream {
1300                        id: 1.into(),
1301                        stream_creation: StreamCreationMode::UseDefaultStream,
1302                    },
1303                    WorkerMessage::SetRefUnitTestsOnly {
1304                        reference: 0.into(),
1305                        value: WireValue::Int(1),
1306                        stream: 1.into(),
1307                    },
1308                    WorkerMessage::CallFunction(CallFunctionParams {
1309                        seq: 0.into(),
1310                        results: vec![Some(Ref { id: 2 })],
1311                        mutates: vec![0.into()],
1312                        function: "i.dont.exist".into(),
1313                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1314                        stream: 1.into(),
1315                        remote_process_groups: vec![],
1316                    }),
1317                ],
1318            )
1319            .await
1320            .unwrap();
1321
1322        let result = worker_handle
1323            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1324            .await?;
1325
1326        // Stop/drain worker before asserts to avoid hangs.
1327        worker_handle.drain_and_stop().unwrap();
1328        worker_handle.await;
1329
1330        let mutated_ref = result
1331            .context("no such ref")?
1332            .err()
1333            .context("expected error")?;
1334        assert!(mutated_ref.contains("failed to resolve function"));
1335
1336        let responses = controller_rx.drain();
1337        assert_eq!(
1338            responses.len(),
1339            1,
1340            "Expected one response, got: {:#?}",
1341            responses
1342        );
1343        Ok(())
1344    }
1345
1346    #[async_timed_test(timeout_secs = 60)]
1347    async fn accessing_errored_dependency() -> Result<()> {
1348        test_setup()?;
1349
1350        let proc = Proc::local();
1351        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1352
1353        let worker_handle = proc
1354            .spawn(
1355                "worker",
1356                WorkerActor::new(WorkerParams {
1357                    world_size: 1,
1358                    rank: 0,
1359                    device_index: None,
1360                    controller_actor: controller_ref,
1361                })
1362                .await
1363                .unwrap(),
1364            )
1365            .unwrap();
1366        worker_handle
1367            .command_group(
1368                &client,
1369                vec![
1370                    WorkerMessage::CreateStream {
1371                        id: 1.into(),
1372                        stream_creation: StreamCreationMode::UseDefaultStream,
1373                    },
1374                    WorkerMessage::CallFunction(CallFunctionParams {
1375                        seq: 0.into(),
1376                        results: vec![Some(0.into())],
1377                        mutates: vec![],
1378                        function: "i.dont.exist".into(),
1379                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1380                        stream: 1.into(),
1381                        remote_process_groups: vec![],
1382                    }),
1383                    WorkerMessage::CallFunction(CallFunctionParams {
1384                        seq: 1.into(),
1385                        results: vec![Some(1.into())],
1386                        mutates: vec![],
1387                        function: "torch.ops.aten.sub_.Scalar".into(),
1388                        args_kwargs: ArgsKwargs::from_wire_values(
1389                            vec![WireValue::Ref(0.into())],
1390                            HashMap::new(),
1391                        )
1392                        .unwrap(),
1393                        stream: 1.into(),
1394                        remote_process_groups: vec![],
1395                    }),
1396                    WorkerMessage::Exit { error: None },
1397                ],
1398            )
1399            .await
1400            .unwrap();
1401
1402        worker_handle.drain_and_stop().unwrap();
1403        worker_handle.await;
1404
1405        let responses = controller_rx.drain();
1406        assert_eq!(
1407            responses.len(),
1408            1,
1409            "Expected one response, got: {:#?}",
1410            responses
1411        );
1412
1413        match &responses[0] {
1414            ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1415                assert_eq!(seq, &0.into())
1416            }
1417            _ => panic!("unexpected response {:#?}", responses[0]),
1418        };
1419        Ok(())
1420    }
1421
1422    #[async_timed_test(timeout_secs = 60)]
1423    async fn py_remote_function_calls() -> Result<()> {
1424        test_setup()?;
1425
1426        let proc = Proc::local();
1427        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1428
1429        let worker_handle = proc
1430            .spawn(
1431                "worker",
1432                WorkerActor::new(WorkerParams {
1433                    world_size: 1,
1434                    rank: 0,
1435                    device_index: None,
1436                    controller_actor: controller_ref,
1437                })
1438                .await
1439                .unwrap(),
1440            )
1441            .unwrap();
1442        let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1443            Python::with_gil(|py| {
1444                let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1445                    .into_any()
1446                    .try_into()?;
1447                let sort_list: PickledPyObject =
1448                    PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1449                let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1450                let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1451                let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1452                let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1453                let device: PickledPyObject = py
1454                    .import("torch")?
1455                    .getattr("device")?
1456                    .call1(("cuda:1",))?
1457                    .try_into()?;
1458                let memory_format: PickledPyObject = py
1459                    .import("torch")?
1460                    .getattr("contiguous_format")?
1461                    .try_into()?;
1462                PyResult::Ok((
1463                    split_arg,
1464                    sort_list,
1465                    dim,
1466                    layout,
1467                    none,
1468                    scalar,
1469                    device,
1470                    memory_format,
1471                ))
1472            })?;
1473
1474        worker_handle
1475            .command_group(
1476                &client,
1477                vec![
1478                    WorkerMessage::CreateStream {
1479                        id: 1.into(),
1480                        stream_creation: StreamCreationMode::UseDefaultStream,
1481                    },
1482                    WorkerMessage::CallFunction(CallFunctionParams {
1483                        seq: 0.into(),
1484                        results: vec![Some(0.into()), Some(Ref { id: 2 })],
1485                        mutates: vec![],
1486                        function: "os.path.split".into(),
1487                        args_kwargs: ArgsKwargs::from_wire_values(
1488                            vec![split_arg.into()],
1489                            HashMap::new(),
1490                        )
1491                        .unwrap(),
1492                        stream: 1.into(),
1493                        remote_process_groups: vec![],
1494                    }),
1495                    WorkerMessage::CallFunction(CallFunctionParams {
1496                        seq: 2.into(),
1497                        results: vec![Some(4.into()), None, None, None, None],
1498                        mutates: vec![],
1499                        function: "builtins.sorted".into(),
1500                        args_kwargs: ArgsKwargs::from_wire_values(
1501                            vec![sort_list.into()],
1502                            HashMap::new(),
1503                        )
1504                        .unwrap(),
1505                        stream: 1.into(),
1506                        remote_process_groups: vec![],
1507                    }),
1508                    WorkerMessage::CreateDeviceMesh {
1509                        result: 5.into(),
1510                        names: vec!["x".into()],
1511                        ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1512                    },
1513                    WorkerMessage::CallFunction(CallFunctionParams {
1514                        seq: 2.into(),
1515                        results: vec![Some(6.into())],
1516                        mutates: vec![],
1517                        function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1518                        args_kwargs: ArgsKwargs::from_wire_values(
1519                            vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1520                            HashMap::new(),
1521                        )
1522                        .unwrap(),
1523                        stream: 1.into(),
1524                        remote_process_groups: vec![],
1525                    }),
1526                    WorkerMessage::CallFunction(CallFunctionParams {
1527                        seq: 4.into(),
1528                        results: vec![Some(7.into())],
1529                        mutates: vec![],
1530                        function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1531                            .into(),
1532                        args_kwargs: ArgsKwargs::from_wire_values(
1533                            vec![scalar.into()],
1534                            HashMap::new(),
1535                        )
1536                        .unwrap(),
1537                        stream: 1.into(),
1538                        remote_process_groups: vec![],
1539                    }),
1540                    WorkerMessage::CallFunction(CallFunctionParams {
1541                        seq: 5.into(),
1542                        results: vec![Some(8.into())],
1543                        mutates: vec![],
1544                        function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1545                        args_kwargs: ArgsKwargs::from_wire_values(
1546                            vec![layout.into()],
1547                            HashMap::new(),
1548                        )
1549                        .unwrap(),
1550                        stream: 1.into(),
1551                        remote_process_groups: vec![],
1552                    }),
1553                    WorkerMessage::CallFunction(CallFunctionParams {
1554                        seq: 6.into(),
1555                        results: vec![Some(9.into())],
1556                        mutates: vec![],
1557                        function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1558                        args_kwargs: ArgsKwargs::from_wire_values(
1559                            vec![none.into()],
1560                            HashMap::new(),
1561                        )
1562                        .unwrap(),
1563                        stream: 1.into(),
1564                        remote_process_groups: vec![],
1565                    }),
1566                    // Verify that a function that returns `None` matches up with an
1567                    // empty result list.
1568                    WorkerMessage::CallFunction(CallFunctionParams {
1569                        seq: 7.into(),
1570                        results: vec![None],
1571                        mutates: vec![],
1572                        function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1573                        args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1574                        stream: 1.into(),
1575                        remote_process_groups: vec![],
1576                    }),
1577                    WorkerMessage::CallFunction(CallFunctionParams {
1578                        seq: 8.into(),
1579                        results: vec![Some(10.into())],
1580                        mutates: vec![],
1581                        function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1582                        args_kwargs: ArgsKwargs::from_wire_values(
1583                            vec![device.into()],
1584                            HashMap::new(),
1585                        )
1586                        .unwrap(),
1587                        stream: 1.into(),
1588                        remote_process_groups: vec![],
1589                    }),
1590                    WorkerMessage::CallFunction(CallFunctionParams {
1591                        seq: 9.into(),
1592                        results: vec![Some(11.into())],
1593                        mutates: vec![],
1594                        function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1595                            .into(),
1596                        args_kwargs: ArgsKwargs::from_wire_values(
1597                            vec![memory_format.into()],
1598                            HashMap::new(),
1599                        )
1600                        .unwrap(),
1601                        stream: 1.into(),
1602                        remote_process_groups: vec![],
1603                    }),
1604                    // Test that list of tests can be passes correctly
1605                    WorkerMessage::CallFunction(CallFunctionParams {
1606                        seq: 10.into(),
1607                        results: vec![Some(12.into())],
1608                        mutates: vec![],
1609                        function: "torch.ops.aten.ones.default".into(),
1610                        args_kwargs: ArgsKwargs::from_wire_values(
1611                            vec![WireValue::IntList(vec![2, 3])],
1612                            HashMap::new(),
1613                        )
1614                        .unwrap(),
1615                        stream: 1.into(),
1616                        remote_process_groups: vec![],
1617                    }),
1618                    WorkerMessage::CallFunction(CallFunctionParams {
1619                        seq: 11.into(),
1620                        results: vec![Some(13.into())],
1621                        mutates: vec![],
1622                        function: "torch.ops.aten.stack.default".into(),
1623                        args_kwargs: ArgsKwargs::from_wire_values(
1624                            vec![WireValue::RefList(vec![12.into(), 12.into()])],
1625                            HashMap::new(),
1626                        )
1627                        .unwrap(),
1628                        stream: 1.into(),
1629                        remote_process_groups: vec![],
1630                    }),
1631                ],
1632            )
1633            .await
1634            .unwrap();
1635
1636        let result1: String = worker_handle
1637            .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1638            .await
1639            .unwrap()
1640            .unwrap()
1641            .unwrap()
1642            .try_into()
1643            .unwrap();
1644        let result2: String = worker_handle
1645            .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1646            .await
1647            .unwrap()
1648            .unwrap()
1649            .unwrap()
1650            .try_into()
1651            .unwrap();
1652        let result3: i64 = worker_handle
1653            .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1654            .await
1655            .unwrap()
1656            .unwrap()
1657            .unwrap()
1658            .try_into()
1659            .unwrap();
1660        let result4: i64 = worker_handle
1661            .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1662            .await
1663            .unwrap()
1664            .unwrap()
1665            .unwrap()
1666            .try_into()
1667            .unwrap();
1668        worker_handle
1669            .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1670            .await
1671            .unwrap()
1672            .unwrap()
1673            .unwrap();
1674
1675        worker_handle
1676            .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1677            .await
1678            .unwrap()
1679            .unwrap()
1680            .unwrap();
1681
1682        assert_matches!(
1683            worker_handle
1684                .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1685                .await
1686                .unwrap()
1687                .unwrap()
1688                .unwrap(),
1689            WireValue::None(()),
1690        );
1691        worker_handle
1692            .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1693            .await
1694            .unwrap()
1695            .unwrap()
1696            .unwrap();
1697        worker_handle
1698            .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1699            .await
1700            .unwrap()
1701            .unwrap()
1702            .unwrap();
1703
1704        worker_handle.drain_and_stop().unwrap();
1705        worker_handle.await;
1706        let error_responses = controller_rx.drain();
1707        assert!(
1708            error_responses.is_empty(),
1709            "Expected no error responses, got: {:#?}",
1710            error_responses
1711        );
1712
1713        assert_eq!(result1, "/fbs/fbc/foo");
1714        assert_eq!(result2, "bar");
1715        assert_eq!(result3, 1);
1716        assert_eq!(result4, 0);
1717
1718        Ok(())
1719    }
1720
1721    #[async_timed_test(timeout_secs = 60)]
1722    async fn delete_refs() -> Result<()> {
1723        test_setup()?;
1724
1725        let proc = Proc::local();
1726        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1727
1728        let worker_handle = proc
1729            .spawn(
1730                "worker",
1731                WorkerActor::new(WorkerParams {
1732                    world_size: 1,
1733                    rank: 0,
1734                    device_index: None,
1735                    controller_actor: controller_ref,
1736                })
1737                .await
1738                .unwrap(),
1739            )
1740            .unwrap();
1741        worker_handle
1742            .command_group(
1743                &client,
1744                vec![
1745                    WorkerMessage::CreateStream {
1746                        id: 0.into(),
1747                        stream_creation: StreamCreationMode::CreateNewStream,
1748                    },
1749                    WorkerMessage::CreateStream {
1750                        id: 1.into(),
1751                        stream_creation: StreamCreationMode::CreateNewStream,
1752                    },
1753                    WorkerMessage::SetRefUnitTestsOnly {
1754                        reference: Ref { id: 2 },
1755                        value: WireValue::Bool(false),
1756                        stream: 0.into(),
1757                    },
1758                    WorkerMessage::SetRefUnitTestsOnly {
1759                        reference: Ref { id: 3 },
1760                        value: WireValue::Bool(true),
1761                        stream: 0.into(),
1762                    },
1763                    WorkerMessage::SetRefUnitTestsOnly {
1764                        reference: Ref { id: 4 },
1765                        value: WireValue::Int(0),
1766                        stream: 1.into(),
1767                    },
1768                    WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1769                ],
1770            )
1771            .await
1772            .unwrap();
1773
1774        let result: bool = worker_handle
1775            .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1776            .await
1777            .unwrap()
1778            .unwrap()
1779            .unwrap()
1780            .try_into()
1781            .unwrap();
1782        let fail_result = worker_handle
1783            .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1784            .await
1785            .unwrap();
1786
1787        worker_handle.drain_and_stop().unwrap();
1788        worker_handle.await;
1789
1790        assert!(result, "should be able to get a non-deleted ref");
1791        assert!(fail_result.is_none(), "should fail to get a deleted ref");
1792
1793        Ok(())
1794    }
1795
1796    #[async_timed_test(timeout_secs = 60)]
1797    async fn request_status() -> Result<()> {
1798        test_setup()?;
1799
1800        let proc = Proc::local();
1801        let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1802
1803        let worker_handle = proc
1804            .spawn(
1805                "worker",
1806                WorkerActor::new(WorkerParams {
1807                    world_size: 1,
1808                    rank: 0,
1809                    device_index: None,
1810                    controller_actor: controller_ref,
1811                })
1812                .await
1813                .unwrap(),
1814            )
1815            .unwrap();
1816        worker_handle
1817            .command_group(
1818                &client,
1819                vec![
1820                    WorkerMessage::CreateStream {
1821                        id: 0.into(),
1822                        stream_creation: StreamCreationMode::CreateNewStream,
1823                    },
1824                    WorkerMessage::CreateStream {
1825                        id: 1.into(),
1826                        stream_creation: StreamCreationMode::CreateNewStream,
1827                    },
1828                ],
1829            )
1830            .await
1831            .unwrap();
1832
1833        for i in 0..100 {
1834            // call alternating functions on this stream.
1835            worker_handle
1836                .call_function(
1837                    &client,
1838                    CallFunctionParams {
1839                        seq: i.into(),
1840                        results: vec![Some(Ref { id: i + 2 })],
1841                        mutates: vec![],
1842                        function: "torch.ops.aten.ones.default".into(),
1843                        args_kwargs: ArgsKwargs::from_wire_values(
1844                            vec![WireValue::IntList(vec![2, 3])],
1845                            HashMap::new(),
1846                        )
1847                        .unwrap(),
1848                        stream: (i % 2).into(),
1849                        remote_process_groups: vec![],
1850                    },
1851                )
1852                .await
1853                .unwrap();
1854        }
1855
1856        worker_handle
1857            .request_status(&client, 100.into(), false)
1858            .await
1859            .unwrap();
1860
1861        worker_handle.drain_and_stop().unwrap();
1862        worker_handle.await;
1863
1864        let mut responses = controller_rx.drain();
1865        assert_eq!(
1866            responses.len(),
1867            1,
1868            "Expected one response, got: {:#?}",
1869            responses
1870        );
1871
1872        let response = responses.pop().unwrap();
1873        match response {
1874            ControllerMessage::Status { seq, .. } => {
1875                assert_eq!(seq, 101.into())
1876            }
1877            _ => panic!("unexpected response {:#?}", response),
1878        };
1879
1880        Ok(())
1881    }
1882
1883    #[async_timed_test(timeout_secs = 60)]
1884    async fn backend_network_init() {
1885        test_setup().unwrap();
1886        let proc = Proc::local();
1887        let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1888
1889        let worker_handle1 = proc
1890            .spawn(
1891                "worker0",
1892                WorkerActor::new(WorkerParams {
1893                    world_size: 2,
1894                    rank: 0,
1895                    device_index: Some(0),
1896                    controller_actor: controller_ref.clone(),
1897                })
1898                .await
1899                .unwrap(),
1900            )
1901            .unwrap();
1902        let worker_handle2 = proc
1903            .spawn(
1904                "worker1",
1905                WorkerActor::new(WorkerParams {
1906                    world_size: 2,
1907                    rank: 1,
1908                    device_index: Some(1),
1909                    controller_actor: controller_ref,
1910                })
1911                .await
1912                .unwrap(),
1913            )
1914            .unwrap();
1915
1916        let unique_id = UniqueId::new().unwrap();
1917        worker_handle1
1918            .backend_network_init(&client, unique_id.clone())
1919            .await
1920            .unwrap();
1921        worker_handle2
1922            .backend_network_init(&client, unique_id)
1923            .await
1924            .unwrap();
1925
1926        worker_handle1.drain_and_stop().unwrap();
1927        worker_handle1.await;
1928        worker_handle2.drain_and_stop().unwrap();
1929        worker_handle2.await;
1930    }
1931
1932    #[allow(dead_code)]
1933    fn get_random_channel_addr() -> ChannelAddr {
1934        let random_string = rand::thread_rng()
1935            .sample_iter(&Alphanumeric)
1936            .take(24)
1937            .map(char::from)
1938            .collect::<String>();
1939        format!("unix!@{random_string}").parse().unwrap()
1940    }
1941}