monarch_messages/
worker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
10// and unsafe-op-in-unsafe-fn warnings.
11#![allow(unsafe_op_in_unsafe_fn)]
12// EnumAsInner generates code that triggers a false positive
13// unused_assignments lint on struct variant fields. #[allow] on the
14// enum itself doesn't propagate into derive-macro-generated code, so
15// the suppression must be at module scope.
16#![allow(unused_assignments)]
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22use anyhow::Context;
23use derive_more::Display;
24use derive_more::From;
25use derive_more::TryInto;
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32use hyperactor::reference;
33use monarch_types::SerializablePyErr;
34use monarch_types::py_global;
35use ndslice::Slice;
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38use pyo3::types::PyBytes;
39use pyo3::types::PyDict;
40use pyo3::types::PyTuple;
41use serde::Deserialize;
42use serde::Serialize;
43use thiserror::Error;
44use torch_sys_cuda::nccl::ReduceOp;
45use torch_sys_cuda::nccl::UniqueId;
46use torch_sys2::BorrowError;
47use torch_sys2::Device;
48use torch_sys2::Layout;
49use torch_sys2::ScalarType;
50use typeuri::Named;
51
52use crate::controller::ControllerActor;
53use crate::controller::Seq;
54use crate::wire_value::WireValue;
55
56#[derive(
57    Serialize,
58    Deserialize,
59    Debug,
60    Clone,
61    Hash,
62    PartialEq,
63    Eq,
64    Copy,
65    PartialOrd,
66    Ord,
67    From
68)]
69#[pyo3::pyclass(
70    frozen,
71    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
72)]
73pub struct StreamRef {
74    #[pyo3(get)]
75    pub id: u64,
76}
77
78#[pyo3::pymethods]
79impl StreamRef {
80    #[new]
81    #[pyo3(signature = (*, id))]
82    fn new(id: u64) -> Self {
83        Self { id }
84    }
85
86    fn __repr__(&self) -> String {
87        format!("StreamRef({})", self.id)
88    }
89
90    // TODO: Upgrade pyo3 to use eq, ord on pyclass
91    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
92        Ok(match op {
93            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
94            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
95            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
96            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
97            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
98            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
99        })
100    }
101
102    fn __hash__(&self) -> PyResult<u64> {
103        Ok(self.id)
104    }
105}
106
107// TODO: The Python implementation uses `Ref` to describe any worker value that
108// can be referenced by the controller, including: tensors, streams, pipes,
109// device meshes. We might be able to more explicitly type these, as they are
110// not generally interchangeable.
111#[derive(
112    Serialize,
113    Deserialize,
114    Debug,
115    Clone,
116    Hash,
117    PartialEq,
118    Eq,
119    Copy,
120    PartialOrd,
121    Ord,
122    From
123)]
124#[pyo3::pyclass(
125    frozen,
126    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
127)]
128pub struct Ref {
129    #[pyo3(get)]
130    pub id: u64,
131}
132
133#[pyo3::pymethods]
134impl Ref {
135    #[new]
136    fn new(id: u64) -> Self {
137        Self { id }
138    }
139
140    #[getter]
141    fn r#ref(&self) -> u64 {
142        self.id
143    }
144
145    fn __repr__(&self) -> String {
146        format!("Ref({})", self.id)
147    }
148
149    // TODO: Upgrade pyo3 to use eq, ord on pyclass
150    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
151        Ok(match op {
152            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
153            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
154            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
155            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
156            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
157            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
158        })
159    }
160
161    fn __hash__(&self) -> PyResult<u64> {
162        Ok(self.id)
163    }
164
165    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
166        let kwargs = PyDict::new(py);
167        kwargs.set_item("id", self.id).unwrap();
168
169        PyTuple::new(
170            py,
171            vec![
172                PyTuple::empty(py).unbind().into_any(),
173                kwargs.unbind().into_any(),
174            ],
175        )
176    }
177}
178
179impl Ref {
180    // This is a function on ref instead of impl FromPyObject due to a bug in pyo3
181    // https://github.com/PyO3/pyo3/issues/4337
182    pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
183        let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
184        if let Ok(ref_) = obj.extract::<Ref>() {
185            return Ok(ref_);
186        }
187        if let Ok(func) = obj.getattr(attr_name) {
188            if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
189                return Ok(val.into());
190            }
191        }
192        Err(PyValueError::new_err("Could not convert object to Ref"))
193    }
194}
195
196impl Display for Ref {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "r{}", self.id)
199    }
200}
201
202/// Identifies a CallFunction target. Can either be a torch op or a Python
203/// global reference.
204// TODO: do some validation on the namespace/opname/overload
205#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
206#[pyo3::pyclass(
207    frozen,
208    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
209)]
210pub struct FunctionPath {
211    #[pyo3(get)]
212    pub path: String,
213}
214
215impl fmt::Display for FunctionPath {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        write!(f, "<function \"{}\">", self.path)
218    }
219}
220
221impl<T: Into<String>> From<T> for FunctionPath {
222    fn from(val: T) -> Self {
223        Self { path: val.into() }
224    }
225}
226
227#[pyo3::pymethods]
228impl FunctionPath {
229    #[new]
230    #[pyo3(signature = (*, path))]
231    pub fn new(path: String) -> Self {
232        Self { path }
233    }
234
235    fn __repr__(&self) -> String {
236        self.path.clone()
237    }
238
239    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
240        let (start, rest) = self.path.split_once(".").with_context(|| {
241            format!(
242                "invalid function path {}: paths must be fully qualified",
243                self.path
244            )
245        })?;
246        if start == "torch" {
247            let mut cur = py.import("torch")?.into_any();
248            for p in rest.split(".") {
249                cur = cur.getattr(p)?;
250            }
251            Ok(cur)
252        } else {
253            let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
254                format!(
255                    "invalid function path {}: paths must be fully qualified",
256                    self.path
257                )
258            })?;
259            let module = PyModule::import(py, module_fqn)?;
260            let mut function = module.getattr(function_name)?;
261            if function.hasattr("_remote_impl")? {
262                function = function.getattr("_remote_impl")?;
263            }
264            Ok(function.downcast_into()?)
265        }
266    }
267}
268
269/// Identifies a CallFunction target. Can either be a torch op or a Python
270/// global reference.
271// TODO: do some validation on the namespace/opname/overload
272#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
273#[pyo3::pyclass(
274    frozen,
275    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
276)]
277pub struct Cloudpickle {
278    #[serde(with = "serde_bytes")]
279    bytes: Vec<u8>,
280}
281
282impl fmt::Display for Cloudpickle {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        write!(f, "<cloud-pickle>")
285    }
286}
287
288py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
289
290#[pyo3::pymethods]
291impl Cloudpickle {
292    #[new]
293    #[pyo3(signature = (*, bytes))]
294    pub fn new(bytes: Vec<u8>) -> Self {
295        Self { bytes }
296    }
297
298    fn __repr__(&self) -> String {
299        format!("Cloudpickle(bytes={:?})", self.bytes)
300    }
301
302    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303        let module = PyModule::import(py, "cloudpickle")?;
304        let loads = module.getattr("loads")?;
305        loads.call1((PyBytes::new(py, &self.bytes),))
306    }
307}
308
309impl Cloudpickle {
310    pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
311        let py = obj.py();
312        let dumps = cloudpickle_dumps(py);
313        let bytes_obj = dumps.call1((obj,))?;
314        let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
315        Ok(Self { bytes })
316    }
317}
318
319#[derive(
320    PartialEq,
321    Serialize,
322    Deserialize,
323    Debug,
324    Clone,
325    TryInto,
326    From,
327    FromPyObject,
328    Display
329)]
330pub enum ResolvableFunction {
331    #[pyo3(transparent)]
332    Cloudpickle(Cloudpickle),
333    #[pyo3(transparent)]
334    FunctionPath(FunctionPath),
335}
336
337#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
338pub struct ArgsKwargs {
339    payload: Cloudpickle,
340}
341
342impl ArgsKwargs {
343    pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
344        // Create tuple (args, kwargs), then cloudpickle it
345        let py = args.py();
346        let tuple = PyTuple::new(py, vec![args, kwargs])?;
347        let payload = Cloudpickle::dumps(tuple.into_any())?;
348        Ok(Self { payload })
349    }
350
351    pub fn from_wire_values(
352        args: Vec<WireValue>,
353        kwargs: HashMap<String, WireValue>,
354    ) -> PyResult<Self> {
355        Python::attach(|py| {
356            // Convert WireValue args to Python objects
357            let py_args: Vec<Bound<'_, PyAny>> = args
358                .into_iter()
359                .map(|v| v.into_pyobject(py))
360                .collect::<PyResult<_>>()?;
361            let args_tuple = PyTuple::new(py, py_args)?;
362
363            // Convert WireValue kwargs to Python dict
364            let kwargs_dict = PyDict::new(py);
365            for (k, v) in kwargs {
366                kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
367            }
368
369            Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
370        })
371    }
372
373    pub fn to_python<'py>(
374        &self,
375        py: Python<'py>,
376    ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
377        let tuple = self.payload.resolve(py)?;
378        let tuple = tuple.downcast::<PyTuple>()?;
379
380        // Extract args (first element)
381        let args = tuple.get_item(0)?;
382        let args_tuple = args.downcast::<PyTuple>()?;
383
384        // Extract kwargs (second element)
385        let kwargs = tuple.get_item(1)?;
386        let kwargs_dict = kwargs.downcast::<PyDict>()?;
387
388        Ok((args_tuple.clone(), kwargs_dict.clone()))
389    }
390}
391
392impl<'py> IntoPyObject<'py> for ResolvableFunction {
393    type Target = PyAny;
394    type Output = Bound<'py, Self::Target>;
395    type Error = PyErr;
396
397    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
398        Ok(match self {
399            Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
400            Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
401        })
402    }
403}
404
405impl ResolvableFunction {
406    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
407        match self {
408            Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
409            Self::FunctionPath(func) => func.resolve(py),
410        }
411    }
412
413    /// For testing: this is a special remote function path that induces a panic
414    /// when called.
415    pub fn panic_if_requested(&self) {
416        match self {
417            Self::FunctionPath(func) => {
418                if func.path == "__test_panic" {
419                    panic!("__test_panic called");
420                }
421            }
422            _ => (),
423        }
424    }
425}
426
427impl<T: Into<String>> From<T> for ResolvableFunction {
428    fn from(val: T) -> Self {
429        FunctionPath::from(val).into()
430    }
431}
432
433#[derive(Serialize, Deserialize, Debug, Clone)]
434pub struct CallFunctionParams {
435    /// Sequence ID of the invocation.
436    pub seq: Seq,
437    /// The references of the results to set.
438    pub results: Vec<Option<Ref>>,
439    /// The references of the mutates to set.
440    pub mutates: Vec<Ref>,
441    /// The function to call.
442    pub function: ResolvableFunction,
443    /// The arguments and keyword arguments to the function.
444    pub args_kwargs: ArgsKwargs,
445    /// The stream to call the function on.
446    pub stream: StreamRef,
447    /// The process groups to execute the function on.
448    pub remote_process_groups: Vec<Ref>,
449}
450
451#[derive(Serialize, Deserialize, Debug, Clone)]
452pub struct ActorCallParams {
453    pub seq: Seq,
454    // The BrokerId but we do not depend on hyperactor in messages.
455    pub broker_id: (String, usize),
456    /// Referenceable objects to pass to the actor as LocalState,
457    /// these will be put into the PythonMessage
458    /// during its unpickling.
459    pub local_state: Vec<Ref>,
460    /// Tensors that will be mutated by the call.
461    pub mutates: Vec<Ref>,
462    pub stream: StreamRef,
463}
464
465#[derive(Serialize, Deserialize, Debug, Clone)]
466pub struct ActorMethodParams {
467    pub results: Vec<Option<Ref>>,
468    pub call: ActorCallParams,
469}
470
471/// Type of reduction for [`WorkerMessage::Reduce`].
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub enum Reduction {
474    /// A gather, concat'ing the values along the reduction dimension.
475    Stack,
476    /// A NCCL reduction type.
477    ReduceOp(ReduceOp),
478}
479
480#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
481#[pyo3::pyclass(
482    frozen,
483    name = "TensorFactory",
484    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
485)]
486pub struct Factory {
487    pub size: Vec<i64>,
488    #[serde(with = "torch_sys2::ScalarTypeDef")]
489    pub dtype: ScalarType,
490    #[serde(with = "torch_sys2::LayoutDef")]
491    pub layout: Layout,
492    pub device: Device,
493}
494
495#[pyo3::pymethods]
496impl Factory {
497    #[new]
498    #[pyo3(signature = (*, size, dtype, layout, device))]
499    pub fn new(
500        py: Python<'_>,
501        size: Vec<i64>,
502        dtype: Py<PyAny>,
503        layout: Py<PyAny>,
504        device: Py<PyAny>,
505    ) -> PyResult<Self> {
506        // TODO: Add some validation around dtype / layout. We should have pyre types on
507        // the python side to help in the short term.
508        Ok(Self {
509            size,
510            dtype: dtype.extract::<ScalarType>(py)?,
511            layout: layout.extract::<Layout>(py)?,
512            device: device.extract::<Device>(py)?,
513        })
514    }
515
516    #[staticmethod]
517    pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
518        Self::new(
519            obj.py(),
520            obj.getattr("size")?.extract()?,
521            obj.getattr("dtype")?.unbind(),
522            obj.getattr("layout")?.unbind(),
523            obj.getattr("device")?.unbind(),
524        )
525    }
526
527    #[getter]
528    fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
529        PyTuple::new(py, self.size.iter())
530    }
531
532    #[getter]
533    fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
534        self.dtype.into_pyobject(py)
535    }
536
537    #[getter]
538    fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
539        self.layout.into_pyobject(py)
540    }
541
542    #[getter]
543    fn device(&self) -> String {
544        self.device.to_string()
545    }
546}
547
548/// Controls what CUDA stream an actor will use.
549#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
550#[pyo3::pyclass(
551    module = "monarch._rust_bindings.monarch_extension.tensor_worker",
552    eq,
553    eq_int
554)]
555pub enum StreamCreationMode {
556    /// Use the default stream for the current device.
557    UseDefaultStream,
558    /// Create a new stream for this actor.
559    CreateNewStream,
560}
561
562/// An error associated with a seq number that failed to execute.
563/// Any defined value that has an error value will have an assocated
564/// SeqError that is the root cause of why that value has an error.
565/// A value may have this error because it was directly defined by the
566/// action associated with the sequence number, or if it was defined by
567/// another action that dependend on the failing one.
568#[derive(Debug, Named)]
569pub struct SeqError {
570    pub seq: Seq,
571    pub error: anyhow::Error,
572}
573
574impl Display for SeqError {
575    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576        self.error.fmt(f)
577    }
578}
579
580/// When a worker runs any function, it may not succeed either because the function itself
581/// failed (Error) or because an input to the function already had an error value
582/// DependentError.
583#[derive(Error, Debug, Named)]
584pub enum CallFunctionError {
585    #[error("{0}")]
586    Error(#[from] anyhow::Error),
587    #[error("Computation depended on an input that failed with error: {0}")]
588    DependentError(Arc<SeqError>),
589}
590
591impl CallFunctionError {
592    // Static functions for backward compatibility with existing enum cases
593    #[allow(non_snake_case)]
594    pub fn RefNotFound(r: Ref) -> Self {
595        Self::Error(anyhow::anyhow!("ref not found: {}", r))
596    }
597
598    #[allow(non_snake_case)]
599    pub fn InvalidRemoteFunction(msg: String) -> Self {
600        Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
601    }
602
603    #[allow(non_snake_case)]
604    pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
605        Self::Error(anyhow::anyhow!(
606            "unsupported arg type for {} function: {}",
607            function_type,
608            arg_type
609        ))
610    }
611
612    #[allow(non_snake_case)]
613    pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
614        Self::Error(anyhow::anyhow!("remote function failed: {}", err))
615    }
616
617    #[allow(non_snake_case)]
618    pub fn BorrowError(err: BorrowError) -> Self {
619        Self::Error(anyhow::anyhow!("borrow failed: {}", err))
620    }
621
622    #[allow(non_snake_case)]
623    pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
624        Self::Error(anyhow::anyhow!(
625            "unexpected number of returns from op, expected {}, got {}",
626            expected,
627            actual
628        ))
629    }
630
631    #[allow(non_snake_case)]
632    pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
633        Self::Error(anyhow::anyhow!(
634            "expected only a single arg (and no kwargs) when no function is given: {}, {}",
635            args,
636            kwargs
637        ))
638    }
639
640    #[allow(non_snake_case)]
641    pub fn Anyhow(err: anyhow::Error) -> Self {
642        Self::Error(err)
643    }
644}
645
646impl From<SerializablePyErr> for CallFunctionError {
647    fn from(v: SerializablePyErr) -> CallFunctionError {
648        CallFunctionError::Error(v.into())
649    }
650}
651
652impl From<BorrowError> for CallFunctionError {
653    fn from(v: BorrowError) -> CallFunctionError {
654        CallFunctionError::Error(v.into())
655    }
656}
657
658/// Worker messages. These define the observable behavior of the worker, so the
659/// documentations here
660#[derive(
661    Handler,
662    HandleClient,
663    RefClient,
664    Clone,
665    Serialize,
666    Deserialize,
667    Debug,
668    Named,
669    EnumAsInner,
670    Bind,
671    Unbind
672)]
673pub enum WorkerMessage {
674    /// Initialize backend network state.
675    BackendNetworkInit(UniqueId),
676
677    /// Initialize backend network state for point-to-point communication.
678    BackendNetworkPointToPointInit {
679        from_stream: StreamRef,
680        to_stream: StreamRef,
681    },
682
683    /// Call a function, either a torch op or a Python `remote_function`.
684    CallFunction(CallFunctionParams),
685
686    /// Groups commands together; these commands will be executed in order by
687    /// the worker.
688    CommandGroup(Vec<WorkerMessage>),
689
690    /// Create a [`Stream`] on the worker wih the provided id. Commands will be
691    /// generally be scheduled onto streams to run; different streams can
692    /// execute concurrently with one another.
693    CreateStream {
694        /// Id of the stream to create.
695        id: StreamRef,
696        /// Whether to use the default device stream or create a new one.
697        stream_creation: StreamCreationMode,
698    },
699
700    /// Create a [`DeviceMesh`] on the worker, which can be used to schedule
701    /// efficient inter-worker communication.
702    CreateDeviceMesh {
703        result: Ref,
704        names: Vec<String>,
705        ranks: Slice,
706    },
707
708    /// Create a PyTorch distributed process group on the worker, which can be
709    /// used to schedule collectives in UDFs using monarch communicators.
710    CreateRemoteProcessGroup {
711        result: Ref,
712        device_mesh: Ref,
713        dims: Vec<String>,
714    },
715
716    /// Create a borrow of a tensor from one stream to another.
717    ///
718    /// Borrows allows streams to access tensors on another stream. The runtime
719    /// will insert appropriate synchronization to ensure that cross-stream
720    /// usage is safe.
721    BorrowCreate {
722        /// Ref of the resulting borrowed tensor
723        result: Ref,
724        /// Id for the borrow
725        borrow: u64,
726        /// Tensor to borrow
727        tensor: Ref,
728        /// Stream to borrow from
729        from_stream: StreamRef,
730        /// Stream to borrow to
731        to_stream: StreamRef,
732    },
733
734    /// First use of the borrow on the receiving stream. This is a marker for
735    /// synchronization.
736    BorrowFirstUse {
737        borrow: u64,
738    },
739
740    /// Last use of the borrow on the receiving stream. This is a marker for
741    /// synchronization.
742    BorrowLastUse {
743        borrow: u64,
744    },
745
746    /// Drop the borrow and free the resources associated with it.
747    BorrowDrop {
748        borrow: u64,
749    },
750
751    /// Delete these refs from the worker state.
752    DeleteRefs(Vec<Ref>),
753
754    /// A [`ControllerMessage::Status`] will be send to the controller
755    /// when all streams have processed all the message sent before this one.
756    RequestStatus {
757        seq: Seq,
758        controller: bool,
759    },
760
761    /// Perform a reduction operation, using an efficient communication backend.
762    /// Only NCCL is supported for now.
763    Reduce {
764        /// Where to store the result of the reduction.
765        result: Ref,
766        /// The tensor to reduce.
767        tensor: Ref,
768        /// Tensor metadata for `tensor` that can be used to construct a
769        /// fresh tensor of appropriate size/shape. We use this if
770        /// `tensor` isn't accessible for some reason (like a previous
771        /// error on the worker).
772        factory: Factory,
773        /// The device mesh on which to perform the reduction.
774        mesh: Ref,
775        /// The stream to call the reduction on.
776        stream: StreamRef,
777        /// The dimensions of the device mesh to reduce over. The members of
778        /// these dimension will form the members of the reduction collective.
779        dims: Vec<String>,
780        /// What kind of reduction to perform.
781        reduction: Reduction,
782        /// If `true`, the reduced result will be evenly split across the tensors
783        /// of `dim`.
784        scatter: bool,
785        /// If `true`, the reduction will be performed in-place on `tensor`.
786        in_place: bool,
787        /// Pre-existing tensor that should be used as the output for the reduction.
788        out: Option<Ref>,
789    },
790
791    /// Create a new communicator on each rank in `ranks`, capable of
792    /// communicating with its peers along the specified dimensions.
793    SplitComm {
794        /// The device mesh dimensions along which the constructed communicator
795        /// should be able to exchange data.
796        dims: Vec<String>,
797        /// The device mesh associated with the new communicator. One communicator
798        /// will be created for every member of the mesh.
799        device_mesh: Ref,
800        /// The stream associated with the communicator. Communicator operations
801        /// will be ordered with respect to other operations scheduled on this
802        /// stream.
803        stream: StreamRef,
804    },
805
806    /// Create a new communicator on each rank in `ranks`, capable of
807    /// communicating with its peers along the specified dimensions.
808    SplitCommForProcessGroup {
809        /// The device mesh associated with the new communicator. One communicator
810        /// will be created for every member of the mesh.
811        remote_process_group: Ref,
812        /// The stream associated with the communicator. Communicator operations
813        /// will be ordered with respect to other operations scheduled on this
814        /// stream.
815        stream: StreamRef,
816    },
817
818    SendTensor {
819        result: Ref,
820        from_ranks: Slice,
821        to_ranks: Slice,
822        tensor: Ref,
823        factory: Factory,
824        from_stream: StreamRef,
825        to_stream: StreamRef,
826    },
827
828    SendValue {
829        seq: Seq,
830        /// Pipe to send value to.  If `None`, value is sent to controller.
831        destination: Option<Ref>,
832        mutates: Vec<Ref>,
833        /// Function to resolve the value to retrieve.  If `None`, then `args_kwargs`
834        /// must contain the value as the only element in args with no kwargs.
835        function: Option<ResolvableFunction>,
836        args_kwargs: ArgsKwargs,
837        /// The stream to retrieve from.
838        stream: StreamRef,
839    },
840
841    SendResultOfActorCall(ActorCallParams),
842    CallActorMethod(ActorMethodParams),
843    PipeRecv {
844        seq: Seq,
845        /// Result refs.
846        results: Vec<Option<Ref>>,
847        /// Pipe to receive value from.
848        pipe: Ref,
849        /// The stream to retrieve from.
850        stream: StreamRef,
851    },
852
853    /// Finish processing all messages previously sent to this worker and stop
854    /// the actor loop. Any streams will also be drained.
855    Exit {
856        /// Optional error reason if the exit is the result of an error, including
857        /// - optional actor id to indicate the source of the error
858        /// - error message or stacktrace
859        /// The worker process will be stopped if the error is provided.
860        error: Option<(Option<reference::ActorId>, String)>,
861    },
862
863    /// Defines (part of) a new recording on the worker. This is a list of commands
864    /// representing the execution of a function that was defined using
865    /// monarch.compile. If there are too many commands to send in a single
866    /// DefineRecording message, the commands may be chunked into `ntotal_messages`,
867    /// with the `index` field indicating how to order the DefineRecording messages
868    /// for a single recording.
869    DefineRecording {
870        /// The ref associated with this recording that will be used to
871        /// call it in the future.
872        result: Ref,
873        /// The number of output tensors.
874        nresults: usize,
875        /// The number of input tensors.
876        nformals: usize,
877        /// The list of commands to run.
878        commands: Vec<WorkerMessage>,
879        /// How many total DefineRecording messages make up this recording.
880        ntotal_messages: usize,
881        /// This DefineRecording message's index in the set of messages
882        /// that make up this recording.
883        index: usize,
884    },
885
886    /// Defines an input tensor for a recording.
887    RecordingFormal {
888        /// The ref that will be used to pass the input tensor to the
889        /// recording.
890        result: Ref,
891        /// The index of the input tensor in the list of input tensors.
892        argument_index: usize,
893        /// The stream that this input tensor will be used on.
894        stream: StreamRef,
895    },
896
897    /// Defines an output tensor for a recording.
898    RecordingResult {
899        /// The ref that will be used to store the output tensor.
900        result: Ref,
901        /// The index of the output tensor in the list of output tensors.
902        output_index: usize,
903        /// The stream that this output tensor will come from.
904        stream: StreamRef,
905    },
906
907    /// Calls a recording that was previously defined using
908    /// DefineRecording.
909    CallRecording {
910        /// The sequence number of the invocation.
911        seq: Seq,
912        /// The ref of the recording to call.
913        recording: Ref,
914        /// The list of refs where the result tensors from the recording
915        /// will be stored.
916        results: Vec<Ref>,
917        /// The list of refs of input tensors to the recording.
918        actuals: Vec<Ref>,
919    },
920
921    SetRefUnitTestsOnly {
922        /// The reference to set.
923        reference: Ref,
924        /// The value to set it with.
925        value: WireValue,
926        /// The stream to set it on.
927        stream: StreamRef,
928    },
929
930    GetRefUnitTestsOnly {
931        /// The value to retrieve, expected to be a bool.
932        value: Ref,
933        /// The stream to retrieve from.
934        stream: StreamRef,
935        #[reply]
936        response_port: reference::OncePortRef<Option<Result<WireValue, String>>>,
937    },
938}
939
940/// The parameters to spawn a worker actor.
941#[derive(Debug, Clone, Serialize, Deserialize, Named)]
942pub struct WorkerParams {
943    // Global world size for this job
944    pub world_size: usize,
945
946    // Rank of the worker within the global world
947    pub rank: usize,
948
949    // Local cuda device that this worker represents. If None, we won't do CUDA
950    // synchronization.
951    pub device_index: Option<i8>,
952
953    // Actor Ref for the controller that the worker is associated with.
954    pub controller_actor: reference::ActorRef<ControllerActor>,
955}
956wirevalue::register_type!(WorkerParams);
957
958hyperactor::behavior!(
959    WorkerActor,
960    WorkerMessage { cast = true },
961);