monarch_messages/
worker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
10// and unsafe-op-in-unsafe-fn warnings.
11#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::RefClient;
27use hyperactor::Unbind;
28use hyperactor::reference::ActorId;
29use monarch_types::SerializablePyErr;
30use monarch_types::py_global;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys_cuda::nccl::ReduceOp;
41use torch_sys_cuda::nccl::UniqueId;
42use torch_sys2::BorrowError;
43use torch_sys2::Device;
44use torch_sys2::Layout;
45use torch_sys2::ScalarType;
46use typeuri::Named;
47
48use crate::controller::ControllerActor;
49use crate::controller::Seq;
50use crate::wire_value::WireValue;
51
52#[derive(
53    Serialize,
54    Deserialize,
55    Debug,
56    Clone,
57    Hash,
58    PartialEq,
59    Eq,
60    Copy,
61    PartialOrd,
62    Ord,
63    From
64)]
65#[pyo3::pyclass(
66    frozen,
67    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
68)]
69pub struct StreamRef {
70    #[pyo3(get)]
71    pub id: u64,
72}
73
74#[pyo3::pymethods]
75impl StreamRef {
76    #[new]
77    #[pyo3(signature = (*, id))]
78    fn new(id: u64) -> Self {
79        Self { id }
80    }
81
82    fn __repr__(&self) -> String {
83        format!("StreamRef({})", self.id)
84    }
85
86    // TODO: Upgrade pyo3 to use eq, ord on pyclass
87    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
88        Ok(match op {
89            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
90            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
91            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
92            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
93            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
94            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
95        })
96    }
97
98    fn __hash__(&self) -> PyResult<u64> {
99        Ok(self.id)
100    }
101}
102
103// TODO: The Python implementation uses `Ref` to describe any worker value that
104// can be referenced by the controller, including: tensors, streams, pipes,
105// device meshes. We might be able to more explicitly type these, as they are
106// not generally interchangeable.
107#[derive(
108    Serialize,
109    Deserialize,
110    Debug,
111    Clone,
112    Hash,
113    PartialEq,
114    Eq,
115    Copy,
116    PartialOrd,
117    Ord,
118    From
119)]
120#[pyo3::pyclass(
121    frozen,
122    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
123)]
124pub struct Ref {
125    #[pyo3(get)]
126    pub id: u64,
127}
128
129#[pyo3::pymethods]
130impl Ref {
131    #[new]
132    fn new(id: u64) -> Self {
133        Self { id }
134    }
135
136    #[getter]
137    fn r#ref(&self) -> u64 {
138        self.id
139    }
140
141    fn __repr__(&self) -> String {
142        format!("Ref({})", self.id)
143    }
144
145    // TODO: Upgrade pyo3 to use eq, ord on pyclass
146    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
147        Ok(match op {
148            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
149            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
150            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
151            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
152            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
153            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
154        })
155    }
156
157    fn __hash__(&self) -> PyResult<u64> {
158        Ok(self.id)
159    }
160
161    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
162        let kwargs = PyDict::new(py);
163        kwargs.set_item("id", self.id).unwrap();
164
165        PyTuple::new(
166            py,
167            vec![
168                PyTuple::empty(py).unbind().into_any(),
169                kwargs.unbind().into_any(),
170            ],
171        )
172    }
173}
174
175impl Ref {
176    // This is a function on ref instead of impl FromPyObject due to a bug in pyo3
177    // https://github.com/PyO3/pyo3/issues/4337
178    pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
179        let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
180        if let Ok(ref_) = obj.extract::<Ref>() {
181            return Ok(ref_);
182        }
183        if let Ok(func) = obj.getattr(attr_name) {
184            if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
185                return Ok(val.into());
186            }
187        }
188        Err(PyValueError::new_err("Could not convert object to Ref"))
189    }
190}
191
192impl Display for Ref {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        write!(f, "r{}", self.id)
195    }
196}
197
198/// Identifies a CallFunction target. Can either be a torch op or a Python
199/// global reference.
200// TODO: do some validation on the namespace/opname/overload
201#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
202#[pyo3::pyclass(
203    frozen,
204    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
205)]
206pub struct FunctionPath {
207    #[pyo3(get)]
208    pub path: String,
209}
210
211impl fmt::Display for FunctionPath {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        write!(f, "<function \"{}\">", self.path)
214    }
215}
216
217impl<T: Into<String>> From<T> for FunctionPath {
218    fn from(val: T) -> Self {
219        Self { path: val.into() }
220    }
221}
222
223#[pyo3::pymethods]
224impl FunctionPath {
225    #[new]
226    #[pyo3(signature = (*, path))]
227    pub fn new(path: String) -> Self {
228        Self { path }
229    }
230
231    fn __repr__(&self) -> String {
232        self.path.clone()
233    }
234
235    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
236        let (start, rest) = self.path.split_once(".").with_context(|| {
237            format!(
238                "invalid function path {}: paths must be fully qualified",
239                self.path
240            )
241        })?;
242        if start == "torch" {
243            let mut cur = py.import("torch")?.into_any();
244            for p in rest.split(".") {
245                cur = cur.getattr(p)?;
246            }
247            Ok(cur)
248        } else {
249            let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
250                format!(
251                    "invalid function path {}: paths must be fully qualified",
252                    self.path
253                )
254            })?;
255            let module = PyModule::import(py, module_fqn)?;
256            let mut function = module.getattr(function_name)?;
257            if function.hasattr("_remote_impl")? {
258                function = function.getattr("_remote_impl")?;
259            }
260            Ok(function.downcast_into()?)
261        }
262    }
263}
264
265/// Identifies a CallFunction target. Can either be a torch op or a Python
266/// global reference.
267// TODO: do some validation on the namespace/opname/overload
268#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
269#[pyo3::pyclass(
270    frozen,
271    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
272)]
273pub struct Cloudpickle {
274    #[serde(with = "serde_bytes")]
275    bytes: Vec<u8>,
276}
277
278impl fmt::Display for Cloudpickle {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        write!(f, "<cloud-pickle>")
281    }
282}
283
284py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
285
286#[pyo3::pymethods]
287impl Cloudpickle {
288    #[new]
289    #[pyo3(signature = (*, bytes))]
290    pub fn new(bytes: Vec<u8>) -> Self {
291        Self { bytes }
292    }
293
294    fn __repr__(&self) -> String {
295        format!("Cloudpickle(bytes={:?})", self.bytes)
296    }
297
298    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
299        let module = PyModule::import(py, "cloudpickle")?;
300        let loads = module.getattr("loads")?;
301        loads.call1((PyBytes::new(py, &self.bytes),))
302    }
303}
304
305impl Cloudpickle {
306    pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
307        let py = obj.py();
308        let dumps = cloudpickle_dumps(py);
309        let bytes_obj = dumps.call1((obj,))?;
310        let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
311        Ok(Self { bytes })
312    }
313}
314
315#[derive(
316    PartialEq,
317    Serialize,
318    Deserialize,
319    Debug,
320    Clone,
321    TryInto,
322    From,
323    FromPyObject,
324    Display
325)]
326pub enum ResolvableFunction {
327    #[pyo3(transparent)]
328    Cloudpickle(Cloudpickle),
329    #[pyo3(transparent)]
330    FunctionPath(FunctionPath),
331}
332
333#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
334pub struct ArgsKwargs {
335    payload: Cloudpickle,
336}
337
338impl ArgsKwargs {
339    pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
340        // Create tuple (args, kwargs), then cloudpickle it
341        let py = args.py();
342        let tuple = PyTuple::new(py, vec![args, kwargs])?;
343        let payload = Cloudpickle::dumps(tuple.into_any())?;
344        Ok(Self { payload })
345    }
346
347    pub fn from_wire_values(
348        args: Vec<WireValue>,
349        kwargs: HashMap<String, WireValue>,
350    ) -> PyResult<Self> {
351        Python::with_gil(|py| {
352            // Convert WireValue args to Python objects
353            let py_args: Vec<Bound<'_, PyAny>> = args
354                .into_iter()
355                .map(|v| v.into_pyobject(py))
356                .collect::<PyResult<_>>()?;
357            let args_tuple = PyTuple::new(py, py_args)?;
358
359            // Convert WireValue kwargs to Python dict
360            let kwargs_dict = PyDict::new(py);
361            for (k, v) in kwargs {
362                kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
363            }
364
365            Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
366        })
367    }
368
369    pub fn to_python<'py>(
370        &self,
371        py: Python<'py>,
372    ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
373        let tuple = self.payload.resolve(py)?;
374        let tuple = tuple.downcast::<PyTuple>()?;
375
376        // Extract args (first element)
377        let args = tuple.get_item(0)?;
378        let args_tuple = args.downcast::<PyTuple>()?;
379
380        // Extract kwargs (second element)
381        let kwargs = tuple.get_item(1)?;
382        let kwargs_dict = kwargs.downcast::<PyDict>()?;
383
384        Ok((args_tuple.clone(), kwargs_dict.clone()))
385    }
386}
387
388impl<'py> IntoPyObject<'py> for ResolvableFunction {
389    type Target = PyAny;
390    type Output = Bound<'py, Self::Target>;
391    type Error = PyErr;
392
393    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
394        Ok(match self {
395            Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
396            Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
397        })
398    }
399}
400
401impl ResolvableFunction {
402    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
403        match self {
404            Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
405            Self::FunctionPath(func) => func.resolve(py),
406        }
407    }
408
409    /// For testing: this is a special remote function path that induces a panic
410    /// when called.
411    pub fn panic_if_requested(&self) {
412        match self {
413            Self::FunctionPath(func) => {
414                if func.path == "__test_panic" {
415                    panic!("__test_panic called");
416                }
417            }
418            _ => (),
419        }
420    }
421}
422
423impl<T: Into<String>> From<T> for ResolvableFunction {
424    fn from(val: T) -> Self {
425        FunctionPath::from(val).into()
426    }
427}
428
429#[derive(Serialize, Deserialize, Debug, Clone)]
430pub struct CallFunctionParams {
431    /// Sequence ID of the invocation.
432    pub seq: Seq,
433    /// The references of the results to set.
434    pub results: Vec<Option<Ref>>,
435    /// The references of the mutates to set.
436    pub mutates: Vec<Ref>,
437    /// The function to call.
438    pub function: ResolvableFunction,
439    /// The arguments and keyword arguments to the function.
440    pub args_kwargs: ArgsKwargs,
441    /// The stream to call the function on.
442    pub stream: StreamRef,
443    /// The process groups to execute the function on.
444    pub remote_process_groups: Vec<Ref>,
445}
446
447#[derive(Serialize, Deserialize, Debug, Clone)]
448pub struct ActorCallParams {
449    pub seq: Seq,
450    // The BrokerId but we do not depend on hyperactor in messages.
451    pub broker_id: (String, usize),
452    /// Referenceable objects to pass to the actor as LocalState,
453    /// these will be put into the PythonMessage
454    /// during its unpickling.
455    pub local_state: Vec<Ref>,
456    /// Tensors that will be mutated by the call.
457    pub mutates: Vec<Ref>,
458    pub stream: StreamRef,
459}
460
461#[derive(Serialize, Deserialize, Debug, Clone)]
462pub struct ActorMethodParams {
463    pub results: Vec<Option<Ref>>,
464    pub call: ActorCallParams,
465}
466
467/// Type of reduction for [`WorkerMessage::Reduce`].
468#[derive(Debug, Clone, Serialize, Deserialize)]
469pub enum Reduction {
470    /// A gather, concat'ing the values along the reduction dimension.
471    Stack,
472    /// A NCCL reduction type.
473    ReduceOp(ReduceOp),
474}
475
476#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
477#[pyo3::pyclass(
478    frozen,
479    name = "TensorFactory",
480    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
481)]
482pub struct Factory {
483    pub size: Vec<i64>,
484    #[serde(with = "torch_sys2::ScalarTypeDef")]
485    pub dtype: ScalarType,
486    #[serde(with = "torch_sys2::LayoutDef")]
487    pub layout: Layout,
488    pub device: Device,
489}
490
491#[pyo3::pymethods]
492impl Factory {
493    #[new]
494    #[pyo3(signature = (*, size, dtype, layout, device))]
495    pub fn new(
496        py: Python<'_>,
497        size: Vec<i64>,
498        dtype: PyObject,
499        layout: PyObject,
500        device: PyObject,
501    ) -> PyResult<Self> {
502        // TODO: Add some validation around dtype / layout. We should have pyre types on
503        // the python side to help in the short term.
504        Ok(Self {
505            size,
506            dtype: dtype.extract::<ScalarType>(py)?,
507            layout: layout.extract::<Layout>(py)?,
508            device: device.extract::<Device>(py)?,
509        })
510    }
511
512    #[staticmethod]
513    pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
514        Self::new(
515            obj.py(),
516            obj.getattr("size")?.extract()?,
517            obj.getattr("dtype")?.unbind(),
518            obj.getattr("layout")?.unbind(),
519            obj.getattr("device")?.unbind(),
520        )
521    }
522
523    #[getter]
524    fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
525        PyTuple::new(py, self.size.iter())
526    }
527
528    #[getter]
529    fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
530        self.dtype.into_pyobject(py)
531    }
532
533    #[getter]
534    fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
535        self.layout.into_pyobject(py)
536    }
537
538    #[getter]
539    fn device(&self) -> String {
540        self.device.to_string()
541    }
542}
543
544/// Controls what CUDA stream an actor will use.
545#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
546#[pyo3::pyclass(
547    module = "monarch._rust_bindings.monarch_extension.tensor_worker",
548    eq,
549    eq_int
550)]
551pub enum StreamCreationMode {
552    /// Use the default stream for the current device.
553    UseDefaultStream,
554    /// Create a new stream for this actor.
555    CreateNewStream,
556}
557
558/// An error associated with a seq number that failed to execute.
559/// Any defined value that has an error value will have an assocated
560/// SeqError that is the root cause of why that value has an error.
561/// A value may have this error because it was directly defined by the
562/// action associated with the sequence number, or if it was defined by
563/// another action that dependend on the failing one.
564#[derive(Debug, Named)]
565pub struct SeqError {
566    pub seq: Seq,
567    pub error: anyhow::Error,
568}
569
570impl Display for SeqError {
571    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
572        self.error.fmt(f)
573    }
574}
575
576/// When a worker runs any function, it may not succeed either because the function itself
577/// failed (Error) or because an input to the function already had an error value
578/// DependentError.
579#[derive(Error, Debug, Named)]
580pub enum CallFunctionError {
581    #[error("{0}")]
582    Error(#[from] anyhow::Error),
583    #[error("Computation depended on an input that failed with error: {0}")]
584    DependentError(Arc<SeqError>),
585}
586
587impl CallFunctionError {
588    // Static functions for backward compatibility with existing enum cases
589    #[allow(non_snake_case)]
590    pub fn RefNotFound(r: Ref) -> Self {
591        Self::Error(anyhow::anyhow!("ref not found: {}", r))
592    }
593
594    #[allow(non_snake_case)]
595    pub fn InvalidRemoteFunction(msg: String) -> Self {
596        Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
597    }
598
599    #[allow(non_snake_case)]
600    pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
601        Self::Error(anyhow::anyhow!(
602            "unsupported arg type for {} function: {}",
603            function_type,
604            arg_type
605        ))
606    }
607
608    #[allow(non_snake_case)]
609    pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
610        Self::Error(anyhow::anyhow!("remote function failed: {}", err))
611    }
612
613    #[allow(non_snake_case)]
614    pub fn BorrowError(err: BorrowError) -> Self {
615        Self::Error(anyhow::anyhow!("borrow failed: {}", err))
616    }
617
618    #[allow(non_snake_case)]
619    pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
620        Self::Error(anyhow::anyhow!(
621            "unexpected number of returns from op, expected {}, got {}",
622            expected,
623            actual
624        ))
625    }
626
627    #[allow(non_snake_case)]
628    pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
629        Self::Error(anyhow::anyhow!(
630            "expected only a single arg (and no kwargs) when no function is given: {}, {}",
631            args,
632            kwargs
633        ))
634    }
635
636    #[allow(non_snake_case)]
637    pub fn Anyhow(err: anyhow::Error) -> Self {
638        Self::Error(err)
639    }
640}
641
642impl From<SerializablePyErr> for CallFunctionError {
643    fn from(v: SerializablePyErr) -> CallFunctionError {
644        CallFunctionError::Error(v.into())
645    }
646}
647
648impl From<BorrowError> for CallFunctionError {
649    fn from(v: BorrowError) -> CallFunctionError {
650        CallFunctionError::Error(v.into())
651    }
652}
653
654/// Worker messages. These define the observable behavior of the worker, so the
655/// documentations here
656#[derive(
657    Handler,
658    HandleClient,
659    RefClient,
660    Clone,
661    Serialize,
662    Deserialize,
663    Debug,
664    Named,
665    EnumAsInner,
666    Bind,
667    Unbind
668)]
669pub enum WorkerMessage {
670    /// Initialize backend network state.
671    BackendNetworkInit(UniqueId),
672
673    /// Initialize backend network state for point-to-point communication.
674    BackendNetworkPointToPointInit {
675        from_stream: StreamRef,
676        to_stream: StreamRef,
677    },
678
679    /// Call a function, either a torch op or a Python `remote_function`.
680    CallFunction(CallFunctionParams),
681
682    /// Groups commands together; these commands will be executed in order by
683    /// the worker.
684    CommandGroup(Vec<WorkerMessage>),
685
686    /// Create a [`Stream`] on the worker wih the provided id. Commands will be
687    /// generally be scheduled onto streams to run; different streams can
688    /// execute concurrently with one another.
689    CreateStream {
690        /// Id of the stream to create.
691        id: StreamRef,
692        /// Whether to use the default device stream or create a new one.
693        stream_creation: StreamCreationMode,
694    },
695
696    /// Create a [`DeviceMesh`] on the worker, which can be used to schedule
697    /// efficient inter-worker communication.
698    CreateDeviceMesh {
699        result: Ref,
700        names: Vec<String>,
701        ranks: Slice,
702    },
703
704    /// Create a PyTorch distributed process group on the worker, which can be
705    /// used to schedule collectives in UDFs using monarch communicators.
706    CreateRemoteProcessGroup {
707        result: Ref,
708        device_mesh: Ref,
709        dims: Vec<String>,
710    },
711
712    /// Create a borrow of a tensor from one stream to another.
713    ///
714    /// Borrows allows streams to access tensors on another stream. The runtime
715    /// will insert appropriate synchronization to ensure that cross-stream
716    /// usage is safe.
717    BorrowCreate {
718        /// Ref of the resulting borrowed tensor
719        result: Ref,
720        /// Id for the borrow
721        borrow: u64,
722        /// Tensor to borrow
723        tensor: Ref,
724        /// Stream to borrow from
725        from_stream: StreamRef,
726        /// Stream to borrow to
727        to_stream: StreamRef,
728    },
729
730    /// First use of the borrow on the receiving stream. This is a marker for
731    /// synchronization.
732    BorrowFirstUse {
733        borrow: u64,
734    },
735
736    /// Last use of the borrow on the receiving stream. This is a marker for
737    /// synchronization.
738    BorrowLastUse {
739        borrow: u64,
740    },
741
742    /// Drop the borrow and free the resources associated with it.
743    BorrowDrop {
744        borrow: u64,
745    },
746
747    /// Delete these refs from the worker state.
748    DeleteRefs(Vec<Ref>),
749
750    /// A [`ControllerMessage::Status`] will be send to the controller
751    /// when all streams have processed all the message sent before this one.
752    RequestStatus {
753        seq: Seq,
754        controller: bool,
755    },
756
757    /// Perform a reduction operation, using an efficient communication backend.
758    /// Only NCCL is supported for now.
759    Reduce {
760        /// Where to store the result of the reduction.
761        result: Ref,
762        /// The tensor to reduce.
763        tensor: Ref,
764        /// Tensor metadata for `tensor` that can be used to construct a
765        /// fresh tensor of appropriate size/shape. We use this if
766        /// `tensor` isn't accessible for some reason (like a previous
767        /// error on the worker).
768        factory: Factory,
769        /// The device mesh on which to perform the reduction.
770        mesh: Ref,
771        /// The stream to call the reduction on.
772        stream: StreamRef,
773        /// The dimensions of the device mesh to reduce over. The members of
774        /// these dimension will form the members of the reduction collective.
775        dims: Vec<String>,
776        /// What kind of reduction to perform.
777        reduction: Reduction,
778        /// If `true`, the reduced result will be evenly split across the tensors
779        /// of `dim`.
780        scatter: bool,
781        /// If `true`, the reduction will be performed in-place on `tensor`.
782        in_place: bool,
783        /// Pre-existing tensor that should be used as the output for the reduction.
784        out: Option<Ref>,
785    },
786
787    /// Create a new communicator on each rank in `ranks`, capable of
788    /// communicating with its peers along the specified dimensions.
789    SplitComm {
790        /// The device mesh dimensions along which the constructed communicator
791        /// should be able to exchange data.
792        dims: Vec<String>,
793        /// The device mesh associated with the new communicator. One communicator
794        /// will be created for every member of the mesh.
795        device_mesh: Ref,
796        /// The stream associated with the communicator. Communicator operations
797        /// will be ordered with respect to other operations scheduled on this
798        /// stream.
799        stream: StreamRef,
800    },
801
802    /// Create a new communicator on each rank in `ranks`, capable of
803    /// communicating with its peers along the specified dimensions.
804    SplitCommForProcessGroup {
805        /// The device mesh associated with the new communicator. One communicator
806        /// will be created for every member of the mesh.
807        remote_process_group: Ref,
808        /// The stream associated with the communicator. Communicator operations
809        /// will be ordered with respect to other operations scheduled on this
810        /// stream.
811        stream: StreamRef,
812    },
813
814    SendTensor {
815        result: Ref,
816        from_ranks: Slice,
817        to_ranks: Slice,
818        tensor: Ref,
819        factory: Factory,
820        from_stream: StreamRef,
821        to_stream: StreamRef,
822    },
823
824    SendValue {
825        seq: Seq,
826        /// Pipe to send value to.  If `None`, value is sent to controller.
827        destination: Option<Ref>,
828        mutates: Vec<Ref>,
829        /// Function to resolve the value to retrieve.  If `None`, then `args_kwargs`
830        /// must contain the value as the only element in args with no kwargs.
831        function: Option<ResolvableFunction>,
832        args_kwargs: ArgsKwargs,
833        /// The stream to retrieve from.
834        stream: StreamRef,
835    },
836
837    SendResultOfActorCall(ActorCallParams),
838    CallActorMethod(ActorMethodParams),
839    PipeRecv {
840        seq: Seq,
841        /// Result refs.
842        results: Vec<Option<Ref>>,
843        /// Pipe to receive value from.
844        pipe: Ref,
845        /// The stream to retrieve from.
846        stream: StreamRef,
847    },
848
849    /// Finish processing all messages previously sent to this worker and stop
850    /// the actor loop. Any streams will also be drained.
851    Exit {
852        /// Optional error reason if the exit is the result of an error, including
853        /// - optional actor id to indicate the source of the error
854        /// - error message or stacktrace
855        /// The worker process will be stopped if the error is provided.
856        error: Option<(Option<ActorId>, String)>,
857    },
858
859    /// Defines (part of) a new recording on the worker. This is a list of commands
860    /// representing the execution of a function that was defined using
861    /// monarch.compile. If there are too many commands to send in a single
862    /// DefineRecording message, the commands may be chunked into `ntotal_messages`,
863    /// with the `index` field indicating how to order the DefineRecording messages
864    /// for a single recording.
865    DefineRecording {
866        /// The ref associated with this recording that will be used to
867        /// call it in the future.
868        result: Ref,
869        /// The number of output tensors.
870        nresults: usize,
871        /// The number of input tensors.
872        nformals: usize,
873        /// The list of commands to run.
874        commands: Vec<WorkerMessage>,
875        /// How many total DefineRecording messages make up this recording.
876        ntotal_messages: usize,
877        /// This DefineRecording message's index in the set of messages
878        /// that make up this recording.
879        index: usize,
880    },
881
882    /// Defines an input tensor for a recording.
883    RecordingFormal {
884        /// The ref that will be used to pass the input tensor to the
885        /// recording.
886        result: Ref,
887        /// The index of the input tensor in the list of input tensors.
888        argument_index: usize,
889        /// The stream that this input tensor will be used on.
890        stream: StreamRef,
891    },
892
893    /// Defines an output tensor for a recording.
894    RecordingResult {
895        /// The ref that will be used to store the output tensor.
896        result: Ref,
897        /// The index of the output tensor in the list of output tensors.
898        output_index: usize,
899        /// The stream that this output tensor will come from.
900        stream: StreamRef,
901    },
902
903    /// Calls a recording that was previously defined using
904    /// DefineRecording.
905    CallRecording {
906        /// The sequence number of the invocation.
907        seq: Seq,
908        /// The ref of the recording to call.
909        recording: Ref,
910        /// The list of refs where the result tensors from the recording
911        /// will be stored.
912        results: Vec<Ref>,
913        /// The list of refs of input tensors to the recording.
914        actuals: Vec<Ref>,
915    },
916
917    SetRefUnitTestsOnly {
918        /// The reference to set.
919        reference: Ref,
920        /// The value to set it with.
921        value: WireValue,
922        /// The stream to set it on.
923        stream: StreamRef,
924    },
925
926    GetRefUnitTestsOnly {
927        /// The value to retrieve, expected to be a bool.
928        value: Ref,
929        /// The stream to retrieve from.
930        stream: StreamRef,
931        #[reply]
932        response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
933    },
934}
935
936/// The parameters to spawn a worker actor.
937#[derive(Debug, Clone, Serialize, Deserialize, Named)]
938pub struct WorkerParams {
939    // Global world size for this job
940    pub world_size: usize,
941
942    // Rank of the worker within the global world
943    pub rank: usize,
944
945    // Local cuda device that this worker represents. If None, we won't do CUDA
946    // synchronization.
947    pub device_index: Option<i8>,
948
949    // Actor Ref for the controller that the worker is associated with.
950    pub controller_actor: ActorRef<ControllerActor>,
951}
952wirevalue::register_type!(WorkerParams);
953
954hyperactor::behavior!(
955    WorkerActor,
956    WorkerMessage { cast = true },
957);