monarch_messages/
worker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
10// and unsafe-op-in-unsafe-fn warnings.
11#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Named;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29use hyperactor::reference::ActorId;
30use monarch_types::SerializablePyErr;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys::BorrowError;
41use torch_sys::Device;
42use torch_sys::Layout;
43use torch_sys::ScalarType;
44use torch_sys::call_op::CallOpError;
45use torch_sys_cuda::nccl::NcclConfig;
46use torch_sys_cuda::nccl::ReduceOp;
47use torch_sys_cuda::nccl::UniqueId;
48
49use crate::controller::ControllerActor;
50use crate::controller::Seq;
51use crate::wire_value::WireValue;
52
53#[derive(
54    Serialize,
55    Deserialize,
56    Debug,
57    Clone,
58    Hash,
59    PartialEq,
60    Eq,
61    Copy,
62    PartialOrd,
63    Ord,
64    From
65)]
66#[pyo3::pyclass(
67    frozen,
68    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
69)]
70pub struct StreamRef {
71    #[pyo3(get)]
72    pub id: u64,
73}
74
75#[pyo3::pymethods]
76impl StreamRef {
77    #[new]
78    #[pyo3(signature = (*, id))]
79    fn new(id: u64) -> Self {
80        Self { id }
81    }
82
83    fn __repr__(&self) -> String {
84        format!("StreamRef({})", self.id)
85    }
86
87    // TODO: Upgrade pyo3 to use eq, ord on pyclass
88    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
89        Ok(match op {
90            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
91            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
92            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
93            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
94            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
95            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
96        })
97    }
98
99    fn __hash__(&self) -> PyResult<u64> {
100        Ok(self.id)
101    }
102}
103
104// TODO: The Python implementation uses `Ref` to describe any worker value that
105// can be referenced by the controller, including: tensors, streams, pipes,
106// device meshes. We might be able to more explicitly type these, as they are
107// not generally interchangeable.
108#[derive(
109    Serialize,
110    Deserialize,
111    Debug,
112    Clone,
113    Hash,
114    PartialEq,
115    Eq,
116    Copy,
117    PartialOrd,
118    Ord,
119    From
120)]
121#[pyo3::pyclass(
122    frozen,
123    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
124)]
125pub struct Ref {
126    #[pyo3(get)]
127    pub id: u64,
128}
129
130#[pyo3::pymethods]
131impl Ref {
132    #[new]
133    fn new(id: u64) -> Self {
134        Self { id }
135    }
136
137    #[getter]
138    fn r#ref(&self) -> u64 {
139        self.id
140    }
141
142    fn __repr__(&self) -> String {
143        format!("Ref({})", self.id)
144    }
145
146    // TODO: Upgrade pyo3 to use eq, ord on pyclass
147    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
148        Ok(match op {
149            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
150            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
151            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
152            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
153            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
154            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
155        })
156    }
157
158    fn __hash__(&self) -> PyResult<u64> {
159        Ok(self.id)
160    }
161
162    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
163        let kwargs = PyDict::new(py);
164        kwargs.set_item("id", self.id).unwrap();
165
166        PyTuple::new(
167            py,
168            vec![
169                PyTuple::empty(py).unbind().into_any(),
170                kwargs.unbind().into_any(),
171            ],
172        )
173    }
174}
175
176impl Ref {
177    // This is a function on ref instead of impl FromPyObject due to a bug in pyo3
178    // https://github.com/PyO3/pyo3/issues/4337
179    pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
180        let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
181        if let Ok(ref_) = obj.extract::<Ref>() {
182            return Ok(ref_);
183        }
184        if let Ok(func) = obj.getattr(attr_name) {
185            if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
186                return Ok(val.into());
187            }
188        }
189        Err(PyValueError::new_err("Could not convert object to Ref"))
190    }
191}
192
193impl Display for Ref {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        write!(f, "r{}", self.id)
196    }
197}
198
199/// Identifies a CallFunction target. Can either be a torch op or a Python
200/// global reference.
201// TODO: do some validation on the namespace/opname/overload
202#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
203#[pyo3::pyclass(
204    frozen,
205    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
206)]
207pub struct FunctionPath {
208    #[pyo3(get)]
209    pub path: String,
210}
211
212impl fmt::Display for FunctionPath {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        write!(f, "<function \"{}\">", self.path)
215    }
216}
217
218impl<T: Into<String>> From<T> for FunctionPath {
219    fn from(val: T) -> Self {
220        Self { path: val.into() }
221    }
222}
223
224#[pyo3::pymethods]
225impl FunctionPath {
226    #[new]
227    #[pyo3(signature = (*, path))]
228    pub fn new(path: String) -> Self {
229        Self { path }
230    }
231
232    fn __repr__(&self) -> String {
233        self.path.clone()
234    }
235
236    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
237        let (start, rest) = self.path.split_once(".").with_context(|| {
238            format!(
239                "invalid function path {}: paths must be fully qualified",
240                self.path
241            )
242        })?;
243        if start == "torch" {
244            let mut cur = py.import("torch")?.into_any();
245            for p in rest.split(".") {
246                cur = cur.getattr(p)?;
247            }
248            Ok(cur)
249        } else {
250            let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
251                format!(
252                    "invalid function path {}: paths must be fully qualified",
253                    self.path
254                )
255            })?;
256            let module = PyModule::import(py, module_fqn)?;
257            let mut function = module.getattr(function_name)?;
258            if function.hasattr("_remote_impl")? {
259                function = function.getattr("_remote_impl")?;
260            }
261            Ok(function.downcast_into()?)
262        }
263    }
264}
265
266/// Identifies a CallFunction target. Can either be a torch op or a Python
267/// global reference.
268// TODO: do some validation on the namespace/opname/overload
269#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
270#[pyo3::pyclass(
271    frozen,
272    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
273)]
274pub struct Cloudpickle {
275    #[serde(with = "serde_bytes")]
276    bytes: Vec<u8>,
277}
278
279impl fmt::Display for Cloudpickle {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        write!(f, "<cloud-pickle>")
282    }
283}
284
285#[pyo3::pymethods]
286impl Cloudpickle {
287    #[new]
288    #[pyo3(signature = (*, bytes))]
289    pub fn new(bytes: Vec<u8>) -> Self {
290        Self { bytes }
291    }
292
293    fn __repr__(&self) -> String {
294        format!("Cloudpickle(bytes={:?})", self.bytes)
295    }
296
297    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
298        let module = PyModule::import(py, "cloudpickle")?;
299        let loads = module.getattr("loads")?;
300        loads.call1((PyBytes::new(py, &self.bytes),))
301    }
302}
303
304#[derive(
305    PartialEq,
306    Serialize,
307    Deserialize,
308    Debug,
309    Clone,
310    TryInto,
311    From,
312    FromPyObject,
313    Display
314)]
315pub enum ResolvableFunction {
316    #[pyo3(transparent)]
317    Cloudpickle(Cloudpickle),
318    #[pyo3(transparent)]
319    FunctionPath(FunctionPath),
320}
321
322impl<'py> IntoPyObject<'py> for ResolvableFunction {
323    type Target = PyAny;
324    type Output = Bound<'py, Self::Target>;
325    type Error = PyErr;
326
327    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
328        Ok(match self {
329            Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
330            Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
331        })
332    }
333}
334
335impl ResolvableFunction {
336    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
337        match self {
338            Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
339            Self::FunctionPath(func) => func.resolve(py),
340        }
341    }
342
343    pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
344        match self {
345            Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
346                ["torch", "ops", namespace, op_name, "default"] => {
347                    Some((format!("{}::{}", namespace, op_name), String::new()))
348                }
349                ["torch", "ops", namespace, op_name, overload] => {
350                    Some((format!("{}::{}", namespace, op_name), overload.to_string()))
351                }
352                _ => None,
353            },
354            _ => None,
355        }
356    }
357
358    /// For testing: this is a special remote function path that induces a panic
359    /// when called.
360    pub fn panic_if_requested(&self) {
361        match self {
362            Self::FunctionPath(func) => {
363                if func.path == "__test_panic" {
364                    panic!("__test_panic called");
365                }
366            }
367            _ => (),
368        }
369    }
370
371    pub fn supports_pytree_args(&self) -> bool {
372        match self {
373            Self::Cloudpickle(_) => true,
374            Self::FunctionPath(_) => self.as_torch_op().is_none(),
375        }
376    }
377}
378
379impl<T: Into<String>> From<T> for ResolvableFunction {
380    fn from(val: T) -> Self {
381        FunctionPath::from(val).into()
382    }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone)]
386pub struct CallFunctionParams {
387    /// Sequence ID of the invocation.
388    pub seq: Seq,
389    /// The references of the results to set.
390    pub results: Vec<Option<Ref>>,
391    /// The references of the mutates to set.
392    pub mutates: Vec<Ref>,
393    /// The function to call.
394    pub function: ResolvableFunction,
395    /// The arguments to the function.
396    pub args: Vec<WireValue>,
397    /// The keyword arguments to the function.
398    pub kwargs: HashMap<String, WireValue>,
399    /// The stream to call the function on.
400    pub stream: StreamRef,
401    /// The process groups to execute the function on.
402    pub remote_process_groups: Vec<Ref>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct ActorCallParams {
407    pub seq: Seq,
408    // The BrokerId but we do not depend on hyperactor in messages.
409    pub broker_id: (String, usize),
410    /// Referenceable objects to pass to the actor as LocalState,
411    /// these will be put into the PythonMessage
412    /// during its unpickling.
413    pub local_state: Vec<Ref>,
414    /// Tensors that will be mutated by the call.
415    pub mutates: Vec<Ref>,
416    pub stream: StreamRef,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct ActorMethodParams {
421    pub results: Vec<Option<Ref>>,
422    pub call: ActorCallParams,
423}
424
425/// Type of reduction for [`WorkerMessage::Reduce`].
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub enum Reduction {
428    /// A gather, concat'ing the values along the reduction dimension.
429    Stack,
430    /// A NCCL reduction type.
431    ReduceOp(ReduceOp),
432}
433
434#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
435#[pyo3::pyclass(
436    frozen,
437    name = "TensorFactory",
438    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
439)]
440pub struct Factory {
441    pub size: Vec<i64>,
442    #[serde(with = "torch_sys::ScalarTypeDef")]
443    pub dtype: ScalarType,
444    #[serde(with = "torch_sys::LayoutDef")]
445    pub layout: Layout,
446    pub device: Device,
447}
448
449#[pyo3::pymethods]
450impl Factory {
451    #[new]
452    #[pyo3(signature = (*, size, dtype, layout, device))]
453    pub fn new(
454        py: Python<'_>,
455        size: Vec<i64>,
456        dtype: PyObject,
457        layout: PyObject,
458        device: PyObject,
459    ) -> PyResult<Self> {
460        // TODO: Add some validation around dtype / layout. We should have pyre types on
461        // the python side to help in the short term.
462        Ok(Self {
463            size,
464            dtype: dtype.extract::<ScalarType>(py)?,
465            layout: layout.extract::<Layout>(py)?,
466            device: device.extract::<Device>(py)?,
467        })
468    }
469
470    #[staticmethod]
471    pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
472        Self::new(
473            obj.py(),
474            obj.getattr("size")?.extract()?,
475            obj.getattr("dtype")?.unbind(),
476            obj.getattr("layout")?.unbind(),
477            obj.getattr("device")?.unbind(),
478        )
479    }
480
481    #[getter]
482    fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
483        PyTuple::new(py, self.size.iter())
484    }
485
486    #[getter]
487    fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
488        self.dtype.into_pyobject(py)
489    }
490
491    #[getter]
492    fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
493        self.layout.into_pyobject(py)
494    }
495
496    #[getter]
497    fn device(&self) -> String {
498        self.device.to_string()
499    }
500}
501
502/// Controls what CUDA stream an actor will use.
503#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
504#[pyo3::pyclass(
505    module = "monarch._rust_bindings.monarch_extension.tensor_worker",
506    eq,
507    eq_int
508)]
509pub enum StreamCreationMode {
510    /// Use the default stream for the current device.
511    UseDefaultStream,
512    /// Create a new stream for this actor.
513    CreateNewStream,
514}
515
516/// An error associated with a seq number that failed to execute.
517/// Any defined value that has an error value will have an assocated
518/// SeqError that is the root cause of why that value has an error.
519/// A value may have this error because it was directly defined by the
520/// action associated with the sequence number, or if it was defined by
521/// another action that dependend on the failing one.
522#[derive(Debug, Named)]
523#[named(register = false)]
524pub struct SeqError {
525    pub seq: Seq,
526    pub error: anyhow::Error,
527}
528
529impl Display for SeqError {
530    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531        self.error.fmt(f)
532    }
533}
534
535/// When a worker runs any function, it may not succeed either because the function itself
536/// failed (Error) or because an input to the function already had an error value
537/// DependentError.
538#[derive(Error, Debug, Named)]
539#[named(register = false)]
540pub enum CallFunctionError {
541    #[error("{0}")]
542    Error(#[from] anyhow::Error),
543    #[error("Computation depended on an input that failed with error: {0}")]
544    DependentError(Arc<SeqError>),
545}
546
547impl CallFunctionError {
548    // Static functions for backward compatibility with existing enum cases
549    #[allow(non_snake_case)]
550    pub fn RefNotFound(r: Ref) -> Self {
551        Self::Error(anyhow::anyhow!("ref not found: {}", r))
552    }
553
554    #[allow(non_snake_case)]
555    pub fn InvalidRemoteFunction(msg: String) -> Self {
556        Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
557    }
558
559    #[allow(non_snake_case)]
560    pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
561        Self::Error(anyhow::anyhow!(
562            "unsupported arg type for {} function: {}",
563            function_type,
564            arg_type
565        ))
566    }
567
568    #[allow(non_snake_case)]
569    pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
570        Self::Error(anyhow::anyhow!("remote function failed: {}", err))
571    }
572
573    #[allow(non_snake_case)]
574    pub fn BorrowError(err: BorrowError) -> Self {
575        Self::Error(anyhow::anyhow!("borrow failed: {}", err))
576    }
577
578    #[allow(non_snake_case)]
579    pub fn OperatorFailed(err: CallOpError) -> Self {
580        Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
581    }
582
583    #[allow(non_snake_case)]
584    pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
585        Self::Error(anyhow::anyhow!(
586            "unexpected number of returns from op, expected {}, got {}",
587            expected,
588            actual
589        ))
590    }
591
592    #[allow(non_snake_case)]
593    pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
594        Self::Error(anyhow::anyhow!(
595            "expected only a single arg (and no kwargs) when no function is given: {}, {}",
596            args,
597            kwargs
598        ))
599    }
600
601    #[allow(non_snake_case)]
602    pub fn Anyhow(err: anyhow::Error) -> Self {
603        Self::Error(err)
604    }
605}
606
607impl From<SerializablePyErr> for CallFunctionError {
608    fn from(v: SerializablePyErr) -> CallFunctionError {
609        CallFunctionError::Error(v.into())
610    }
611}
612
613impl From<BorrowError> for CallFunctionError {
614    fn from(v: BorrowError) -> CallFunctionError {
615        CallFunctionError::Error(v.into())
616    }
617}
618
619impl From<CallOpError> for CallFunctionError {
620    fn from(v: CallOpError) -> CallFunctionError {
621        CallFunctionError::Error(v.into())
622    }
623}
624
625/// Worker messages. These define the observable behavior of the worker, so the
626/// documentations here
627#[derive(
628    Handler,
629    HandleClient,
630    RefClient,
631    Clone,
632    Serialize,
633    Deserialize,
634    Debug,
635    Named,
636    EnumAsInner,
637    Bind,
638    Unbind
639)]
640pub enum WorkerMessage {
641    /// Initialize backend network state.
642    BackendNetworkInit(UniqueId),
643
644    /// Initialize backend network state for point-to-point communication.
645    BackendNetworkPointToPointInit {
646        from_stream: StreamRef,
647        to_stream: StreamRef,
648    },
649
650    /// Call a function, either a torch op or a Python `remote_function`.
651    CallFunction(CallFunctionParams),
652
653    /// Groups commands together; these commands will be executed in order by
654    /// the worker.
655    CommandGroup(Vec<WorkerMessage>),
656
657    /// Create a [`Stream`] on the worker wih the provided id. Commands will be
658    /// generally be scheduled onto streams to run; different streams can
659    /// execute concurrently with one another.
660    CreateStream {
661        /// Id of the stream to create.
662        id: StreamRef,
663        /// Whether to use the default device stream or create a new one.
664        stream_creation: StreamCreationMode,
665    },
666
667    /// Create a [`DeviceMesh`] on the worker, which can be used to schedule
668    /// efficient inter-worker communication.
669    CreateDeviceMesh {
670        result: Ref,
671        names: Vec<String>,
672        ranks: Slice,
673    },
674
675    /// Create a PyTorch distributed process group on the worker, which can be
676    /// used to schedule collectives in UDFs using monarch communicators.
677    CreateRemoteProcessGroup {
678        result: Ref,
679        device_mesh: Ref,
680        dims: Vec<String>,
681    },
682
683    /// Create a borrow of a tensor from one stream to another.
684    ///
685    /// Borrows allows streams to access tensors on another stream. The runtime
686    /// will insert appropriate synchronization to ensure that cross-stream
687    /// usage is safe.
688    BorrowCreate {
689        /// Ref of the resulting borrowed tensor
690        result: Ref,
691        /// Id for the borrow
692        borrow: u64,
693        /// Tensor to borrow
694        tensor: Ref,
695        /// Stream to borrow from
696        from_stream: StreamRef,
697        /// Stream to borrow to
698        to_stream: StreamRef,
699    },
700
701    /// First use of the borrow on the receiving stream. This is a marker for
702    /// synchronization.
703    BorrowFirstUse {
704        borrow: u64,
705    },
706
707    /// Last use of the borrow on the receiving stream. This is a marker for
708    /// synchronization.
709    BorrowLastUse {
710        borrow: u64,
711    },
712
713    /// Drop the borrow and free the resources associated with it.
714    BorrowDrop {
715        borrow: u64,
716    },
717
718    /// Delete these refs from the worker state.
719    DeleteRefs(Vec<Ref>),
720
721    /// A [`ControllerMessage::Status`] will be send to the controller
722    /// when all streams have processed all the message sent before this one.
723    RequestStatus {
724        seq: Seq,
725        controller: bool,
726    },
727
728    /// Perform a reduction operation, using an efficient communication backend.
729    /// Only NCCL is supported for now.
730    Reduce {
731        /// Where to store the result of the reduction.
732        result: Ref,
733        /// The tensor to reduce.
734        tensor: Ref,
735        /// Tensor metadata for `tensor` that can be used to construct a
736        /// fresh tensor of appropriate size/shape. We use this if
737        /// `tensor` isn't accessible for some reason (like a previous
738        /// error on the worker).
739        factory: Factory,
740        /// The device mesh on which to perform the reduction.
741        mesh: Ref,
742        /// The stream to call the reduction on.
743        stream: StreamRef,
744        /// The dimensions of the device mesh to reduce over. The members of
745        /// these dimension will form the members of the reduction collective.
746        dims: Vec<String>,
747        /// What kind of reduction to perform.
748        reduction: Reduction,
749        /// If `true`, the reduced result will be evenly split across the tensors
750        /// of `dim`.
751        scatter: bool,
752        /// If `true`, the reduction will be performed in-place on `tensor`.
753        in_place: bool,
754        /// Pre-existing tensor that should be used as the output for the reduction.
755        out: Option<Ref>,
756    },
757
758    /// Create a new communicator on each rank in `ranks`, capable of
759    /// communicating with its peers along the specified dimensions.
760    SplitComm {
761        /// The device mesh dimensions along which the constructed communicator
762        /// should be able to exchange data.
763        dims: Vec<String>,
764        /// The device mesh associated with the new communicator. One communicator
765        /// will be created for every member of the mesh.
766        device_mesh: Ref,
767        /// The stream associated with the communicator. Communicator operations
768        /// will be ordered with respect to other operations scheduled on this
769        /// stream.
770        stream: StreamRef,
771        /// Configuration for the new communicator. If None, we will not pass a
772        /// config object to nccl, which means that the created communicator
773        /// will inherit its parent's config.
774        config: Option<NcclConfig>,
775    },
776
777    /// Create a new communicator on each rank in `ranks`, capable of
778    /// communicating with its peers along the specified dimensions.
779    SplitCommForProcessGroup {
780        /// The device mesh associated with the new communicator. One communicator
781        /// will be created for every member of the mesh.
782        remote_process_group: Ref,
783        /// The stream associated with the communicator. Communicator operations
784        /// will be ordered with respect to other operations scheduled on this
785        /// stream.
786        stream: StreamRef,
787        /// Configuration for the new communicator. If None, we will not pass a
788        /// config object to nccl, which means that the created communicator
789        /// will inherit its parent's config.
790        config: Option<NcclConfig>,
791    },
792
793    SendTensor {
794        result: Ref,
795        from_ranks: Slice,
796        to_ranks: Slice,
797        tensor: Ref,
798        factory: Factory,
799        from_stream: StreamRef,
800        to_stream: StreamRef,
801    },
802
803    CreatePipe {
804        result: Ref,
805        key: String,
806        function: ResolvableFunction,
807        max_messages: i64,
808        mesh: Ref,
809        args: Vec<WireValue>,
810        kwargs: HashMap<String, WireValue>,
811    },
812
813    SendValue {
814        seq: Seq,
815        /// Pipe to send value to.  If `None`, value is sent to controller.
816        destination: Option<Ref>,
817        mutates: Vec<Ref>,
818        /// Function to resolve the value to retrieve.  If `None`, then `args`
819        /// must contain the value as its only element and `kwargs` must be
820        /// empty.
821        function: Option<ResolvableFunction>,
822        args: Vec<WireValue>,
823        kwargs: HashMap<String, WireValue>,
824        /// The stream to retrieve from.
825        stream: StreamRef,
826    },
827
828    SendResultOfActorCall(ActorCallParams),
829    CallActorMethod(ActorMethodParams),
830    PipeRecv {
831        seq: Seq,
832        /// Result refs.
833        results: Vec<Option<Ref>>,
834        /// Pipe to receive value from.
835        pipe: Ref,
836        /// The stream to retrieve from.
837        stream: StreamRef,
838    },
839
840    /// Finish processing all messages previously sent to this worker and stop
841    /// the actor loop. Any streams will also be drained.
842    Exit {
843        /// Optional error reason if the exit is the result of an error, including
844        /// - optional actor id to indicate the source of the error
845        /// - error message or stacktrace
846        /// The worker process will be stopped if the error is provided.
847        error: Option<(Option<ActorId>, String)>,
848    },
849
850    /// Defines (part of) a new recording on the worker. This is a list of commands
851    /// representing the execution of a function that was defined using
852    /// monarch.compile. If there are too many commands to send in a single
853    /// DefineRecording message, the commands may be chunked into `ntotal_messages`,
854    /// with the `index` field indicating how to order the DefineRecording messages
855    /// for a single recording.
856    DefineRecording {
857        /// The ref associated with this recording that will be used to
858        /// call it in the future.
859        result: Ref,
860        /// The number of output tensors.
861        nresults: usize,
862        /// The number of input tensors.
863        nformals: usize,
864        /// The list of commands to run.
865        commands: Vec<WorkerMessage>,
866        /// How many total DefineRecording messages make up this recording.
867        ntotal_messages: usize,
868        /// This DefineRecording message's index in the set of messages
869        /// that make up this recording.
870        index: usize,
871    },
872
873    /// Defines an input tensor for a recording.
874    RecordingFormal {
875        /// The ref that will be used to pass the input tensor to the
876        /// recording.
877        result: Ref,
878        /// The index of the input tensor in the list of input tensors.
879        argument_index: usize,
880        /// The stream that this input tensor will be used on.
881        stream: StreamRef,
882    },
883
884    /// Defines an output tensor for a recording.
885    RecordingResult {
886        /// The ref that will be used to store the output tensor.
887        result: Ref,
888        /// The index of the output tensor in the list of output tensors.
889        output_index: usize,
890        /// The stream that this output tensor will come from.
891        stream: StreamRef,
892    },
893
894    /// Calls a recording that was previously defined using
895    /// DefineRecording.
896    CallRecording {
897        /// The sequence number of the invocation.
898        seq: Seq,
899        /// The ref of the recording to call.
900        recording: Ref,
901        /// The list of refs where the result tensors from the recording
902        /// will be stored.
903        results: Vec<Ref>,
904        /// The list of refs of input tensors to the recording.
905        actuals: Vec<Ref>,
906    },
907
908    SetRefUnitTestsOnly {
909        /// The reference to set.
910        reference: Ref,
911        /// The value to set it with.
912        value: WireValue,
913        /// The stream to set it on.
914        stream: StreamRef,
915    },
916
917    GetRefUnitTestsOnly {
918        /// The value to retrieve, expected to be a bool.
919        value: Ref,
920        /// The stream to retrieve from.
921        stream: StreamRef,
922        #[reply]
923        response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
924    },
925}
926
927/// The parameters to spawn a worker actor.
928#[derive(Debug, Clone, Serialize, Deserialize, Named)]
929pub struct WorkerParams {
930    // Global world size for this job
931    pub world_size: usize,
932
933    // Rank of the worker within the global world
934    pub rank: usize,
935
936    // Local cuda device that this worker represents. If None, we won't do CUDA
937    // synchronization.
938    pub device_index: Option<i8>,
939
940    // Actor Ref for the controller that the worker is associated with.
941    pub controller_actor: ActorRef<ControllerActor>,
942}
943
944hyperactor::behavior!(
945    WorkerActor,
946    WorkerMessage { cast = true },
947);