monarch_messages/
worker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
10// and unsafe-op-in-unsafe-fn warnings.
11#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Named;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29use hyperactor::reference::ActorId;
30use monarch_types::SerializablePyErr;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys::BorrowError;
41use torch_sys::Device;
42use torch_sys::Layout;
43use torch_sys::ScalarType;
44use torch_sys::call_op::CallOpError;
45use torch_sys_cuda::nccl::NcclConfig;
46use torch_sys_cuda::nccl::ReduceOp;
47use torch_sys_cuda::nccl::UniqueId;
48
49use crate::controller::ControllerActor;
50use crate::controller::Seq;
51use crate::wire_value::WireValue;
52
53#[derive(
54    Serialize,
55    Deserialize,
56    Debug,
57    Clone,
58    Hash,
59    PartialEq,
60    Eq,
61    Copy,
62    PartialOrd,
63    Ord,
64    From
65)]
66#[pyo3::pyclass(
67    frozen,
68    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
69)]
70pub struct StreamRef {
71    #[pyo3(get)]
72    pub id: u64,
73}
74
75#[pyo3::pymethods]
76impl StreamRef {
77    #[new]
78    #[pyo3(signature = (*, id))]
79    fn new(id: u64) -> Self {
80        Self { id }
81    }
82
83    fn __repr__(&self) -> String {
84        format!("StreamRef({})", self.id)
85    }
86
87    // TODO: Upgrade pyo3 to use eq, ord on pyclass
88    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
89        Ok(match op {
90            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
91            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
92            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
93            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
94            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
95            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
96        })
97    }
98
99    fn __hash__(&self) -> PyResult<u64> {
100        Ok(self.id)
101    }
102}
103
104// TODO: The Python implementation uses `Ref` to describe any worker value that
105// can be referenced by the controller, including: tensors, streams, pipes,
106// device meshes. We might be able to more explicitly type these, as they are
107// not generally interchangeable.
108#[derive(
109    Serialize,
110    Deserialize,
111    Debug,
112    Clone,
113    Hash,
114    PartialEq,
115    Eq,
116    Copy,
117    PartialOrd,
118    Ord,
119    From
120)]
121#[pyo3::pyclass(
122    frozen,
123    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
124)]
125pub struct Ref {
126    #[pyo3(get)]
127    pub id: u64,
128}
129
130#[pyo3::pymethods]
131impl Ref {
132    #[new]
133    fn new(id: u64) -> Self {
134        Self { id }
135    }
136
137    #[getter]
138    fn r#ref(&self) -> u64 {
139        self.id
140    }
141
142    fn __repr__(&self) -> String {
143        format!("Ref({})", self.id)
144    }
145
146    // TODO: Upgrade pyo3 to use eq, ord on pyclass
147    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
148        Ok(match op {
149            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
150            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
151            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
152            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
153            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
154            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
155        })
156    }
157
158    fn __hash__(&self) -> PyResult<u64> {
159        Ok(self.id)
160    }
161
162    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
163        let kwargs = PyDict::new(py);
164        kwargs.set_item("id", self.id).unwrap();
165
166        PyTuple::new(
167            py,
168            vec![
169                PyTuple::empty(py).unbind().into_any(),
170                kwargs.unbind().into_any(),
171            ],
172        )
173    }
174}
175
176impl Ref {
177    // This is a function on ref instead of impl FromPyObject due to a bug in pyo3
178    // https://github.com/PyO3/pyo3/issues/4337
179    pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
180        let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
181        if let Ok(ref_) = obj.extract::<Ref>() {
182            return Ok(ref_);
183        }
184        if let Ok(func) = obj.getattr(attr_name) {
185            if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
186                return Ok(val.into());
187            }
188        }
189        Err(PyValueError::new_err("Could not convert object to Ref"))
190    }
191}
192
193impl Display for Ref {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        write!(f, "r{}", self.id)
196    }
197}
198
199/// Identifies a CallFunction target. Can either be a torch op or a Python
200/// global reference.
201// TODO: do some validation on the namespace/opname/overload
202#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
203#[pyo3::pyclass(
204    frozen,
205    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
206)]
207pub struct FunctionPath {
208    #[pyo3(get)]
209    pub path: String,
210}
211
212impl fmt::Display for FunctionPath {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        write!(f, "<function \"{}\">", self.path)
215    }
216}
217
218impl<T: Into<String>> From<T> for FunctionPath {
219    fn from(val: T) -> Self {
220        Self { path: val.into() }
221    }
222}
223
224#[pyo3::pymethods]
225impl FunctionPath {
226    #[new]
227    #[pyo3(signature = (*, path))]
228    pub fn new(path: String) -> Self {
229        Self { path }
230    }
231
232    fn __repr__(&self) -> String {
233        self.path.clone()
234    }
235
236    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
237        let (start, rest) = self.path.split_once(".").with_context(|| {
238            format!(
239                "invalid function path {}: paths must be fully qualified",
240                self.path
241            )
242        })?;
243        if start == "torch" {
244            let mut cur = py.import("torch")?.into_any();
245            for p in rest.split(".") {
246                cur = cur.getattr(p)?;
247            }
248            Ok(cur)
249        } else {
250            let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
251                format!(
252                    "invalid function path {}: paths must be fully qualified",
253                    self.path
254                )
255            })?;
256            let module = PyModule::import(py, module_fqn)?;
257            let mut function = module.getattr(function_name)?;
258            if function.hasattr("_remote_impl")? {
259                function = function.getattr("_remote_impl")?;
260            }
261            Ok(function.downcast_into()?)
262        }
263    }
264}
265
266/// Identifies a CallFunction target. Can either be a torch op or a Python
267/// global reference.
268// TODO: do some validation on the namespace/opname/overload
269#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
270#[pyo3::pyclass(
271    frozen,
272    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
273)]
274pub struct Cloudpickle {
275    #[serde(with = "serde_bytes")]
276    bytes: Vec<u8>,
277}
278
279impl fmt::Display for Cloudpickle {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        write!(f, "<cloud-pickle>")
282    }
283}
284
285#[pyo3::pymethods]
286impl Cloudpickle {
287    #[new]
288    #[pyo3(signature = (*, bytes))]
289    pub fn new(bytes: Vec<u8>) -> Self {
290        Self { bytes }
291    }
292
293    fn __repr__(&self) -> String {
294        format!("Cloudpickle(bytes={:?})", self.bytes)
295    }
296
297    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
298        let module = PyModule::import(py, "cloudpickle")?;
299        let loads = module.getattr("loads")?;
300        loads.call1((PyBytes::new(py, &self.bytes),))
301    }
302}
303
304#[derive(
305    PartialEq,
306    Serialize,
307    Deserialize,
308    Debug,
309    Clone,
310    TryInto,
311    From,
312    FromPyObject,
313    Display
314)]
315pub enum ResolvableFunction {
316    #[pyo3(transparent)]
317    Cloudpickle(Cloudpickle),
318    #[pyo3(transparent)]
319    FunctionPath(FunctionPath),
320}
321
322impl<'py> IntoPyObject<'py> for ResolvableFunction {
323    type Target = PyAny;
324    type Output = Bound<'py, Self::Target>;
325    type Error = PyErr;
326
327    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
328        Ok(match self {
329            Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
330            Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
331        })
332    }
333}
334
335impl ResolvableFunction {
336    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
337        match self {
338            Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
339            Self::FunctionPath(func) => func.resolve(py),
340        }
341    }
342
343    pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
344        match self {
345            Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
346                ["torch", "ops", namespace, op_name, "default"] => {
347                    Some((format!("{}::{}", namespace, op_name), String::new()))
348                }
349                ["torch", "ops", namespace, op_name, overload] => {
350                    Some((format!("{}::{}", namespace, op_name), overload.to_string()))
351                }
352                _ => None,
353            },
354            _ => None,
355        }
356    }
357
358    /// For testing: this is a special remote function path that induces a panic
359    /// when called.
360    pub fn panic_if_requested(&self) {
361        match self {
362            Self::FunctionPath(func) => {
363                if func.path == "__test_panic" {
364                    panic!("__test_panic called");
365                }
366            }
367            _ => (),
368        }
369    }
370
371    pub fn supports_pytree_args(&self) -> bool {
372        match self {
373            Self::Cloudpickle(_) => true,
374            Self::FunctionPath(_) => self.as_torch_op().is_none(),
375        }
376    }
377}
378
379impl<T: Into<String>> From<T> for ResolvableFunction {
380    fn from(val: T) -> Self {
381        FunctionPath::from(val).into()
382    }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone)]
386pub struct CallFunctionParams {
387    /// Sequence ID of the invocation.
388    pub seq: Seq,
389    /// The references of the results to set.
390    pub results: Vec<Option<Ref>>,
391    /// The references of the mutates to set.
392    pub mutates: Vec<Ref>,
393    /// The function to call.
394    pub function: ResolvableFunction,
395    /// The arguments to the function.
396    pub args: Vec<WireValue>,
397    /// The keyword arguments to the function.
398    pub kwargs: HashMap<String, WireValue>,
399    /// The stream to call the function on.
400    pub stream: StreamRef,
401    /// The process groups to execute the function on.
402    pub remote_process_groups: Vec<Ref>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct ActorCallParams {
407    pub seq: Seq,
408    // The BrokerId but we do not depend on hyperactor in messages.
409    pub broker_id: (String, usize),
410    /// Referenceable objects to pass to the actor as LocalState,
411    /// these will be put into the PythonMessage
412    /// during its unpickling.
413    pub local_state: Vec<Ref>,
414    /// Tensors that will be mutated by the call.
415    pub mutates: Vec<Ref>,
416    pub stream: StreamRef,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct ActorMethodParams {
421    pub results: Vec<Option<Ref>>,
422    pub call: ActorCallParams,
423}
424
425/// Type of reduction for [`WorkerMessage::Reduce`].
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub enum Reduction {
428    /// A gather, concat'ing the values along the reduction dimension.
429    Stack,
430    /// A NCCL reduction type.
431    ReduceOp(ReduceOp),
432}
433
434#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
435#[pyo3::pyclass(
436    frozen,
437    name = "TensorFactory",
438    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
439)]
440pub struct Factory {
441    pub size: Vec<i64>,
442    #[serde(with = "torch_sys::ScalarTypeDef")]
443    pub dtype: ScalarType,
444    #[serde(with = "torch_sys::LayoutDef")]
445    pub layout: Layout,
446    pub device: Device,
447}
448
449#[pyo3::pymethods]
450impl Factory {
451    #[new]
452    #[pyo3(signature = (*, size, dtype, layout, device))]
453    pub fn new(
454        py: Python<'_>,
455        size: Vec<i64>,
456        dtype: PyObject,
457        layout: PyObject,
458        device: PyObject,
459    ) -> PyResult<Self> {
460        // TODO: Add some validation around dtype / layout. We should have pyre types on
461        // the python side to help in the short term.
462        Ok(Self {
463            size,
464            dtype: dtype.extract::<ScalarType>(py)?,
465            layout: layout.extract::<Layout>(py)?,
466            device: device.extract::<Device>(py)?,
467        })
468    }
469
470    #[staticmethod]
471    pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
472        Self::new(
473            obj.py(),
474            obj.getattr("size")?.extract()?,
475            obj.getattr("dtype")?.unbind(),
476            obj.getattr("layout")?.unbind(),
477            obj.getattr("device")?.unbind(),
478        )
479    }
480
481    #[getter]
482    fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
483        PyTuple::new(py, self.size.iter())
484    }
485
486    #[getter]
487    fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
488        self.dtype.into_pyobject(py)
489    }
490
491    #[getter]
492    fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
493        self.layout.into_pyobject(py)
494    }
495
496    #[getter]
497    fn device(&self) -> String {
498        self.device.to_string()
499    }
500}
501
502/// Controls what CUDA stream an actor will use.
503#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
504#[pyo3::pyclass(
505    module = "monarch._rust_bindings.monarch_extension.tensor_worker",
506    eq,
507    eq_int
508)]
509pub enum StreamCreationMode {
510    /// Use the default stream for the current device.
511    UseDefaultStream,
512    /// Create a new stream for this actor.
513    CreateNewStream,
514}
515
516/// An error associated with a seq number that failed to execute.
517/// Any defined value that has an error value will have an assocated
518/// SeqError that is the root cause of why that value has an error.
519/// A value may have this error because it was directly defined by the
520/// action associated with the sequence number, or if it was defined by
521/// another action that dependend on the failing one.
522#[derive(Debug, Named)]
523pub struct SeqError {
524    pub seq: Seq,
525    pub error: anyhow::Error,
526}
527
528impl Display for SeqError {
529    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530        self.error.fmt(f)
531    }
532}
533
534/// When a worker runs any function, it may not succeed either because the function itself
535/// failed (Error) or because an input to the function already had an error value
536/// DependentError.
537#[derive(Error, Debug, Named)]
538pub enum CallFunctionError {
539    #[error("{0}")]
540    Error(#[from] anyhow::Error),
541    #[error("Computation depended on an input that failed with error: {0}")]
542    DependentError(Arc<SeqError>),
543}
544
545impl CallFunctionError {
546    // Static functions for backward compatibility with existing enum cases
547    #[allow(non_snake_case)]
548    pub fn RefNotFound(r: Ref) -> Self {
549        Self::Error(anyhow::anyhow!("ref not found: {}", r))
550    }
551
552    #[allow(non_snake_case)]
553    pub fn InvalidRemoteFunction(msg: String) -> Self {
554        Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
555    }
556
557    #[allow(non_snake_case)]
558    pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
559        Self::Error(anyhow::anyhow!(
560            "unsupported arg type for {} function: {}",
561            function_type,
562            arg_type
563        ))
564    }
565
566    #[allow(non_snake_case)]
567    pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
568        Self::Error(anyhow::anyhow!("remote function failed: {}", err))
569    }
570
571    #[allow(non_snake_case)]
572    pub fn BorrowError(err: BorrowError) -> Self {
573        Self::Error(anyhow::anyhow!("borrow failed: {}", err))
574    }
575
576    #[allow(non_snake_case)]
577    pub fn OperatorFailed(err: CallOpError) -> Self {
578        Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
579    }
580
581    #[allow(non_snake_case)]
582    pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
583        Self::Error(anyhow::anyhow!(
584            "unexpected number of returns from op, expected {}, got {}",
585            expected,
586            actual
587        ))
588    }
589
590    #[allow(non_snake_case)]
591    pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
592        Self::Error(anyhow::anyhow!(
593            "expected only a single arg (and no kwargs) when no function is given: {}, {}",
594            args,
595            kwargs
596        ))
597    }
598
599    #[allow(non_snake_case)]
600    pub fn Anyhow(err: anyhow::Error) -> Self {
601        Self::Error(err)
602    }
603}
604
605impl From<SerializablePyErr> for CallFunctionError {
606    fn from(v: SerializablePyErr) -> CallFunctionError {
607        CallFunctionError::Error(v.into())
608    }
609}
610
611impl From<BorrowError> for CallFunctionError {
612    fn from(v: BorrowError) -> CallFunctionError {
613        CallFunctionError::Error(v.into())
614    }
615}
616
617impl From<CallOpError> for CallFunctionError {
618    fn from(v: CallOpError) -> CallFunctionError {
619        CallFunctionError::Error(v.into())
620    }
621}
622
623/// Worker messages. These define the observable behavior of the worker, so the
624/// documentations here
625#[derive(
626    Handler,
627    HandleClient,
628    RefClient,
629    Clone,
630    Serialize,
631    Deserialize,
632    Debug,
633    Named,
634    EnumAsInner,
635    Bind,
636    Unbind
637)]
638pub enum WorkerMessage {
639    /// Initialize backend network state.
640    BackendNetworkInit(UniqueId),
641
642    /// Initialize backend network state for point-to-point communication.
643    BackendNetworkPointToPointInit {
644        from_stream: StreamRef,
645        to_stream: StreamRef,
646    },
647
648    /// Call a function, either a torch op or a Python `remote_function`.
649    CallFunction(CallFunctionParams),
650
651    /// Groups commands together; these commands will be executed in order by
652    /// the worker.
653    CommandGroup(Vec<WorkerMessage>),
654
655    /// Create a [`Stream`] on the worker wih the provided id. Commands will be
656    /// generally be scheduled onto streams to run; different streams can
657    /// execute concurrently with one another.
658    CreateStream {
659        /// Id of the stream to create.
660        id: StreamRef,
661        /// Whether to use the default device stream or create a new one.
662        stream_creation: StreamCreationMode,
663    },
664
665    /// Create a [`DeviceMesh`] on the worker, which can be used to schedule
666    /// efficient inter-worker communication.
667    CreateDeviceMesh {
668        result: Ref,
669        names: Vec<String>,
670        ranks: Slice,
671    },
672
673    /// Create a PyTorch distributed process group on the worker, which can be
674    /// used to schedule collectives in UDFs using monarch communicators.
675    CreateRemoteProcessGroup {
676        result: Ref,
677        device_mesh: Ref,
678        dims: Vec<String>,
679    },
680
681    /// Create a borrow of a tensor from one stream to another.
682    ///
683    /// Borrows allows streams to access tensors on another stream. The runtime
684    /// will insert appropriate synchronization to ensure that cross-stream
685    /// usage is safe.
686    BorrowCreate {
687        /// Ref of the resulting borrowed tensor
688        result: Ref,
689        /// Id for the borrow
690        borrow: u64,
691        /// Tensor to borrow
692        tensor: Ref,
693        /// Stream to borrow from
694        from_stream: StreamRef,
695        /// Stream to borrow to
696        to_stream: StreamRef,
697    },
698
699    /// First use of the borrow on the receiving stream. This is a marker for
700    /// synchronization.
701    BorrowFirstUse {
702        borrow: u64,
703    },
704
705    /// Last use of the borrow on the receiving stream. This is a marker for
706    /// synchronization.
707    BorrowLastUse {
708        borrow: u64,
709    },
710
711    /// Drop the borrow and free the resources associated with it.
712    BorrowDrop {
713        borrow: u64,
714    },
715
716    /// Delete these refs from the worker state.
717    DeleteRefs(Vec<Ref>),
718
719    /// A [`ControllerMessage::Status`] will be send to the controller
720    /// when all streams have processed all the message sent before this one.
721    RequestStatus {
722        seq: Seq,
723        controller: bool,
724    },
725
726    /// Perform a reduction operation, using an efficient communication backend.
727    /// Only NCCL is supported for now.
728    Reduce {
729        /// Where to store the result of the reduction.
730        result: Ref,
731        /// The tensor to reduce.
732        tensor: Ref,
733        /// Tensor metadata for `tensor` that can be used to construct a
734        /// fresh tensor of appropriate size/shape. We use this if
735        /// `tensor` isn't accessible for some reason (like a previous
736        /// error on the worker).
737        factory: Factory,
738        /// The device mesh on which to perform the reduction.
739        mesh: Ref,
740        /// The stream to call the reduction on.
741        stream: StreamRef,
742        /// The dimensions of the device mesh to reduce over. The members of
743        /// these dimension will form the members of the reduction collective.
744        dims: Vec<String>,
745        /// What kind of reduction to perform.
746        reduction: Reduction,
747        /// If `true`, the reduced result will be evenly split across the tensors
748        /// of `dim`.
749        scatter: bool,
750        /// If `true`, the reduction will be performed in-place on `tensor`.
751        in_place: bool,
752        /// Pre-existing tensor that should be used as the output for the reduction.
753        out: Option<Ref>,
754    },
755
756    /// Create a new communicator on each rank in `ranks`, capable of
757    /// communicating with its peers along the specified dimensions.
758    SplitComm {
759        /// The device mesh dimensions along which the constructed communicator
760        /// should be able to exchange data.
761        dims: Vec<String>,
762        /// The device mesh associated with the new communicator. One communicator
763        /// will be created for every member of the mesh.
764        device_mesh: Ref,
765        /// The stream associated with the communicator. Communicator operations
766        /// will be ordered with respect to other operations scheduled on this
767        /// stream.
768        stream: StreamRef,
769        /// Configuration for the new communicator. If None, we will not pass a
770        /// config object to nccl, which means that the created communicator
771        /// will inherit its parent's config.
772        config: Option<NcclConfig>,
773    },
774
775    /// Create a new communicator on each rank in `ranks`, capable of
776    /// communicating with its peers along the specified dimensions.
777    SplitCommForProcessGroup {
778        /// The device mesh associated with the new communicator. One communicator
779        /// will be created for every member of the mesh.
780        remote_process_group: Ref,
781        /// The stream associated with the communicator. Communicator operations
782        /// will be ordered with respect to other operations scheduled on this
783        /// stream.
784        stream: StreamRef,
785        /// Configuration for the new communicator. If None, we will not pass a
786        /// config object to nccl, which means that the created communicator
787        /// will inherit its parent's config.
788        config: Option<NcclConfig>,
789    },
790
791    SendTensor {
792        result: Ref,
793        from_ranks: Slice,
794        to_ranks: Slice,
795        tensor: Ref,
796        factory: Factory,
797        from_stream: StreamRef,
798        to_stream: StreamRef,
799    },
800
801    CreatePipe {
802        result: Ref,
803        key: String,
804        function: ResolvableFunction,
805        max_messages: i64,
806        mesh: Ref,
807        args: Vec<WireValue>,
808        kwargs: HashMap<String, WireValue>,
809    },
810
811    SendValue {
812        seq: Seq,
813        /// Pipe to send value to.  If `None`, value is sent to controller.
814        destination: Option<Ref>,
815        mutates: Vec<Ref>,
816        /// Function to resolve the value to retrieve.  If `None`, then `args`
817        /// must contain the value as its only element and `kwargs` must be
818        /// empty.
819        function: Option<ResolvableFunction>,
820        args: Vec<WireValue>,
821        kwargs: HashMap<String, WireValue>,
822        /// The stream to retrieve from.
823        stream: StreamRef,
824    },
825
826    SendResultOfActorCall(ActorCallParams),
827    CallActorMethod(ActorMethodParams),
828    PipeRecv {
829        seq: Seq,
830        /// Result refs.
831        results: Vec<Option<Ref>>,
832        /// Pipe to receive value from.
833        pipe: Ref,
834        /// The stream to retrieve from.
835        stream: StreamRef,
836    },
837
838    /// Finish processing all messages previously sent to this worker and stop
839    /// the actor loop. Any streams will also be drained.
840    Exit {
841        /// Optional error reason if the exit is the result of an error, including
842        /// - optional actor id to indicate the source of the error
843        /// - error message or stacktrace
844        /// The worker process will be stopped if the error is provided.
845        error: Option<(Option<ActorId>, String)>,
846    },
847
848    /// Defines (part of) a new recording on the worker. This is a list of commands
849    /// representing the execution of a function that was defined using
850    /// monarch.compile. If there are too many commands to send in a single
851    /// DefineRecording message, the commands may be chunked into `ntotal_messages`,
852    /// with the `index` field indicating how to order the DefineRecording messages
853    /// for a single recording.
854    DefineRecording {
855        /// The ref associated with this recording that will be used to
856        /// call it in the future.
857        result: Ref,
858        /// The number of output tensors.
859        nresults: usize,
860        /// The number of input tensors.
861        nformals: usize,
862        /// The list of commands to run.
863        commands: Vec<WorkerMessage>,
864        /// How many total DefineRecording messages make up this recording.
865        ntotal_messages: usize,
866        /// This DefineRecording message's index in the set of messages
867        /// that make up this recording.
868        index: usize,
869    },
870
871    /// Defines an input tensor for a recording.
872    RecordingFormal {
873        /// The ref that will be used to pass the input tensor to the
874        /// recording.
875        result: Ref,
876        /// The index of the input tensor in the list of input tensors.
877        argument_index: usize,
878        /// The stream that this input tensor will be used on.
879        stream: StreamRef,
880    },
881
882    /// Defines an output tensor for a recording.
883    RecordingResult {
884        /// The ref that will be used to store the output tensor.
885        result: Ref,
886        /// The index of the output tensor in the list of output tensors.
887        output_index: usize,
888        /// The stream that this output tensor will come from.
889        stream: StreamRef,
890    },
891
892    /// Calls a recording that was previously defined using
893    /// DefineRecording.
894    CallRecording {
895        /// The sequence number of the invocation.
896        seq: Seq,
897        /// The ref of the recording to call.
898        recording: Ref,
899        /// The list of refs where the result tensors from the recording
900        /// will be stored.
901        results: Vec<Ref>,
902        /// The list of refs of input tensors to the recording.
903        actuals: Vec<Ref>,
904    },
905
906    SetRefUnitTestsOnly {
907        /// The reference to set.
908        reference: Ref,
909        /// The value to set it with.
910        value: WireValue,
911        /// The stream to set it on.
912        stream: StreamRef,
913    },
914
915    GetRefUnitTestsOnly {
916        /// The value to retrieve, expected to be a bool.
917        value: Ref,
918        /// The stream to retrieve from.
919        stream: StreamRef,
920        #[reply]
921        response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
922    },
923}
924
925/// The parameters to spawn a worker actor.
926#[derive(Debug, Clone, Serialize, Deserialize, Named)]
927pub struct WorkerParams {
928    // Global world size for this job
929    pub world_size: usize,
930
931    // Rank of the worker within the global world
932    pub rank: usize,
933
934    // Local cuda device that this worker represents. If None, we won't do CUDA
935    // synchronization.
936    pub device_index: Option<i8>,
937
938    // Actor Ref for the controller that the worker is associated with.
939    pub controller_actor: ActorRef<ControllerActor>,
940}
941
942hyperactor::alias!(
943    WorkerActor,
944    WorkerMessage { cast = true },
945);