Skip to main content

monarch_messages/
worker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
10// and unsafe-op-in-unsafe-fn warnings.
11#![allow(unsafe_op_in_unsafe_fn)]
12// EnumAsInner generates code that triggers a false positive
13// unused_assignments lint on struct variant fields. #[allow] on the
14// enum itself doesn't propagate into derive-macro-generated code, so
15// the suppression must be at module scope.
16#![allow(unused_assignments)]
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22use anyhow::Context;
23use derive_more::Display;
24use derive_more::From;
25use derive_more::TryInto;
26use enum_as_inner::EnumAsInner;
27use hyperactor as reference;
28use hyperactor::Bind;
29use hyperactor::HandleClient;
30use hyperactor::Handler;
31use hyperactor::RefClient;
32use hyperactor::Unbind;
33use monarch_types::ReduceOp;
34use monarch_types::SerializablePyErr;
35use monarch_types::UniqueId;
36use monarch_types::py_global;
37use ndslice::Slice;
38use pyo3::exceptions::PyValueError;
39use pyo3::prelude::*;
40use pyo3::types::PyBytes;
41use pyo3::types::PyDict;
42use pyo3::types::PyTuple;
43use serde::Deserialize;
44use serde::Serialize;
45use thiserror::Error;
46use torch_sys2::BorrowError;
47use torch_sys2::Device;
48use torch_sys2::Layout;
49use torch_sys2::ScalarType;
50use typeuri::Named;
51
52use crate::controller::ControllerActor;
53use crate::controller::Seq;
54use crate::wire_value::WireValue;
55
56#[derive(
57    Serialize,
58    Deserialize,
59    Debug,
60    Clone,
61    Hash,
62    PartialEq,
63    Eq,
64    Copy,
65    PartialOrd,
66    Ord,
67    From
68)]
69#[pyo3::pyclass(
70    frozen,
71    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
72)]
73pub struct StreamRef {
74    #[pyo3(get)]
75    pub id: u64,
76}
77
78#[pyo3::pymethods]
79impl StreamRef {
80    #[new]
81    #[pyo3(signature = (*, id))]
82    fn new(id: u64) -> Self {
83        Self { id }
84    }
85
86    fn __repr__(&self) -> String {
87        format!("StreamRef({})", self.id)
88    }
89
90    // TODO: Upgrade pyo3 to use eq, ord on pyclass
91    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
92        Ok(match op {
93            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
94            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
95            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
96            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
97            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
98            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
99        })
100    }
101
102    fn __hash__(&self) -> PyResult<u64> {
103        Ok(self.id)
104    }
105}
106
107// TODO: The Python implementation uses `Ref` to describe any worker value that
108// can be referenced by the controller, including: tensors, streams, pipes,
109// device meshes. We might be able to more explicitly type these, as they are
110// not generally interchangeable.
111#[derive(
112    Serialize,
113    Deserialize,
114    Debug,
115    Clone,
116    Hash,
117    PartialEq,
118    Eq,
119    Copy,
120    PartialOrd,
121    Ord,
122    From
123)]
124#[pyo3::pyclass(
125    frozen,
126    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
127)]
128pub struct Ref {
129    #[pyo3(get)]
130    pub id: u64,
131}
132
133#[pyo3::pymethods]
134impl Ref {
135    #[new]
136    fn new(id: u64) -> Self {
137        Self { id }
138    }
139
140    #[getter]
141    fn r#ref(&self) -> u64 {
142        self.id
143    }
144
145    fn __repr__(&self) -> String {
146        format!("Ref({})", self.id)
147    }
148
149    // TODO: Upgrade pyo3 to use eq, ord on pyclass
150    fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
151        Ok(match op {
152            pyo3::class::basic::CompareOp::Eq => self.id == other.id,
153            pyo3::class::basic::CompareOp::Ne => self.id != other.id,
154            pyo3::class::basic::CompareOp::Lt => self.id < other.id,
155            pyo3::class::basic::CompareOp::Le => self.id <= other.id,
156            pyo3::class::basic::CompareOp::Gt => self.id > other.id,
157            pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
158        })
159    }
160
161    fn __hash__(&self) -> PyResult<u64> {
162        Ok(self.id)
163    }
164
165    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
166        let kwargs = PyDict::new(py);
167        kwargs.set_item("id", self.id).unwrap();
168
169        PyTuple::new(
170            py,
171            vec![
172                PyTuple::empty(py).unbind().into_any(),
173                kwargs.unbind().into_any(),
174            ],
175        )
176    }
177}
178
179impl Ref {
180    // This is a function on ref instead of impl FromPyObject due to a bug in pyo3
181    // https://github.com/PyO3/pyo3/issues/4337
182    pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
183        let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
184        if let Ok(ref_) = obj.extract::<Ref>() {
185            return Ok(ref_);
186        }
187        if let Ok(func) = obj.getattr(attr_name) {
188            if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
189                return Ok(val.into());
190            }
191        }
192        Err(PyValueError::new_err("Could not convert object to Ref"))
193    }
194}
195
196impl Display for Ref {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "r{}", self.id)
199    }
200}
201
202/// Identifies a CallFunction target. Can either be a torch op or a Python
203/// global reference.
204// TODO: do some validation on the namespace/opname/overload
205#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
206#[pyo3::pyclass(
207    frozen,
208    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
209)]
210pub struct FunctionPath {
211    #[pyo3(get)]
212    pub path: String,
213}
214
215impl fmt::Display for FunctionPath {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        write!(f, "<function \"{}\">", self.path)
218    }
219}
220
221impl<T: Into<String>> From<T> for FunctionPath {
222    fn from(val: T) -> Self {
223        Self { path: val.into() }
224    }
225}
226
227#[pyo3::pymethods]
228impl FunctionPath {
229    #[new]
230    #[pyo3(signature = (*, path))]
231    pub fn new(path: String) -> Self {
232        Self { path }
233    }
234
235    fn __repr__(&self) -> String {
236        self.path.clone()
237    }
238
239    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
240        let (start, rest) = self.path.split_once(".").with_context(|| {
241            format!(
242                "invalid function path {}: paths must be fully qualified",
243                self.path
244            )
245        })?;
246        if start == "torch" {
247            let mut cur = py.import("torch")?.into_any();
248            for p in rest.split(".") {
249                cur = cur.getattr(p)?;
250            }
251            Ok(cur)
252        } else {
253            let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
254                format!(
255                    "invalid function path {}: paths must be fully qualified",
256                    self.path
257                )
258            })?;
259            let module = PyModule::import(py, module_fqn)?;
260            let mut function = module.getattr(function_name)?;
261            if function.hasattr("_remote_impl")? {
262                function = function.getattr("_remote_impl")?;
263            }
264            Ok(function.downcast_into()?)
265        }
266    }
267}
268
269/// Identifies a CallFunction target. Can either be a torch op or a Python
270/// global reference.
271// TODO: do some validation on the namespace/opname/overload
272#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
273#[pyo3::pyclass(
274    frozen,
275    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
276)]
277pub struct Cloudpickle {
278    #[serde(with = "serde_bytes")]
279    bytes: Vec<u8>,
280}
281
282impl fmt::Display for Cloudpickle {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        write!(f, "<cloud-pickle>")
285    }
286}
287
288py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
289
290#[pyo3::pymethods]
291impl Cloudpickle {
292    #[new]
293    #[pyo3(signature = (*, bytes))]
294    pub fn new(bytes: Vec<u8>) -> Self {
295        Self { bytes }
296    }
297
298    fn __repr__(&self) -> String {
299        format!("Cloudpickle(bytes={:?})", self.bytes)
300    }
301
302    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303        let module = PyModule::import(py, "cloudpickle")?;
304        let loads = module.getattr("loads")?;
305        loads.call1((PyBytes::new(py, &self.bytes),))
306    }
307}
308
309impl Cloudpickle {
310    pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
311        let py = obj.py();
312        let dumps = cloudpickle_dumps(py);
313        let bytes_obj = dumps.call1((obj,))?;
314        let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
315        Ok(Self { bytes })
316    }
317}
318
319#[derive(
320    PartialEq,
321    Serialize,
322    Deserialize,
323    Debug,
324    Clone,
325    TryInto,
326    From,
327    FromPyObject,
328    Display
329)]
330pub enum ResolvableFunction {
331    #[pyo3(transparent)]
332    Cloudpickle(Cloudpickle),
333    #[pyo3(transparent)]
334    FunctionPath(FunctionPath),
335}
336
337#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
338pub struct ArgsKwargs {
339    payload: Cloudpickle,
340}
341
342impl ArgsKwargs {
343    pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
344        // Create tuple (args, kwargs), then cloudpickle it
345        let py = args.py();
346        let tuple = PyTuple::new(py, vec![args, kwargs])?;
347        let payload = Cloudpickle::dumps(tuple.into_any())?;
348        Ok(Self { payload })
349    }
350
351    pub fn from_wire_values(
352        args: Vec<WireValue>,
353        kwargs: HashMap<String, WireValue>,
354    ) -> PyResult<Self> {
355        Python::attach(|py| {
356            // Convert WireValue args to Python objects
357            let py_args: Vec<Bound<'_, PyAny>> = args
358                .into_iter()
359                .map(|v| v.into_pyobject(py))
360                .collect::<PyResult<_>>()?;
361            let args_tuple = PyTuple::new(py, py_args)?;
362
363            // Convert WireValue kwargs to Python dict
364            let kwargs_dict = PyDict::new(py);
365            for (k, v) in kwargs {
366                kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
367            }
368
369            Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
370        })
371    }
372
373    pub fn to_python<'py>(
374        &self,
375        py: Python<'py>,
376    ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
377        let tuple = self.payload.resolve(py)?;
378        let tuple = tuple.downcast::<PyTuple>()?;
379
380        // Extract args (first element)
381        let args = tuple.get_item(0)?;
382        let args_tuple = args.downcast::<PyTuple>()?;
383
384        // Extract kwargs (second element)
385        let kwargs = tuple.get_item(1)?;
386        let kwargs_dict = kwargs.downcast::<PyDict>()?;
387
388        Ok((args_tuple.clone(), kwargs_dict.clone()))
389    }
390}
391
392impl<'py> IntoPyObject<'py> for ResolvableFunction {
393    type Target = PyAny;
394    type Output = Bound<'py, Self::Target>;
395    type Error = PyErr;
396
397    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
398        Ok(match self {
399            Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
400            Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
401        })
402    }
403}
404
405impl ResolvableFunction {
406    pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
407        match self {
408            Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
409            Self::FunctionPath(func) => func.resolve(py),
410        }
411    }
412
413    /// For testing: this is a special remote function path that induces a panic
414    /// when called.
415    pub fn panic_if_requested(&self) {
416        match self {
417            Self::FunctionPath(func) if func.path == "__test_panic" => {
418                panic!("__test_panic called");
419            }
420            _ => (),
421        }
422    }
423}
424
425impl<T: Into<String>> From<T> for ResolvableFunction {
426    fn from(val: T) -> Self {
427        FunctionPath::from(val).into()
428    }
429}
430
431#[derive(Serialize, Deserialize, Debug, Clone)]
432pub struct CallFunctionParams {
433    /// Sequence ID of the invocation.
434    pub seq: Seq,
435    /// The references of the results to set.
436    pub results: Vec<Option<Ref>>,
437    /// The references of the mutates to set.
438    pub mutates: Vec<Ref>,
439    /// The function to call.
440    pub function: ResolvableFunction,
441    /// The arguments and keyword arguments to the function.
442    pub args_kwargs: ArgsKwargs,
443    /// The stream to call the function on.
444    pub stream: StreamRef,
445    /// The process groups to execute the function on.
446    pub remote_process_groups: Vec<Ref>,
447}
448
449#[derive(Serialize, Deserialize, Debug, Clone)]
450pub struct ActorCallParams {
451    pub seq: Seq,
452    // The BrokerId but we do not depend on hyperactor in messages.
453    pub broker_id: (String, usize),
454    /// Referenceable objects to pass to the actor as LocalState,
455    /// these will be put into the PythonMessage
456    /// during its unpickling.
457    pub local_state: Vec<Ref>,
458    /// Tensors that will be mutated by the call.
459    pub mutates: Vec<Ref>,
460    pub stream: StreamRef,
461}
462
463#[derive(Serialize, Deserialize, Debug, Clone)]
464pub struct ActorMethodParams {
465    pub results: Vec<Option<Ref>>,
466    pub call: ActorCallParams,
467}
468
469/// Type of reduction for [`WorkerMessage::Reduce`].
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum Reduction {
472    /// A gather, concat'ing the values along the reduction dimension.
473    Stack,
474    /// A NCCL reduction type.
475    ReduceOp(ReduceOp),
476}
477
478#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
479#[pyo3::pyclass(
480    frozen,
481    name = "TensorFactory",
482    module = "monarch._rust_bindings.monarch_extension.tensor_worker"
483)]
484pub struct Factory {
485    pub size: Vec<i64>,
486    #[serde(with = "torch_sys2::ScalarTypeDef")]
487    pub dtype: ScalarType,
488    #[serde(with = "torch_sys2::LayoutDef")]
489    pub layout: Layout,
490    pub device: Device,
491}
492
493#[pyo3::pymethods]
494impl Factory {
495    #[new]
496    #[pyo3(signature = (*, size, dtype, layout, device))]
497    pub fn new(
498        py: Python<'_>,
499        size: Vec<i64>,
500        dtype: Py<PyAny>,
501        layout: Py<PyAny>,
502        device: Py<PyAny>,
503    ) -> PyResult<Self> {
504        // TODO: Add some validation around dtype / layout. We should have pyre types on
505        // the python side to help in the short term.
506        Ok(Self {
507            size,
508            dtype: dtype.extract::<ScalarType>(py)?,
509            layout: layout.extract::<Layout>(py)?,
510            device: device.extract::<Device>(py)?,
511        })
512    }
513
514    #[staticmethod]
515    pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
516        Self::new(
517            obj.py(),
518            obj.getattr("size")?.extract()?,
519            obj.getattr("dtype")?.unbind(),
520            obj.getattr("layout")?.unbind(),
521            obj.getattr("device")?.unbind(),
522        )
523    }
524
525    #[getter]
526    fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
527        PyTuple::new(py, self.size.iter())
528    }
529
530    #[getter]
531    fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
532        self.dtype.into_pyobject(py)
533    }
534
535    #[getter]
536    fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
537        self.layout.into_pyobject(py)
538    }
539
540    #[getter]
541    fn device(&self) -> String {
542        self.device.to_string()
543    }
544}
545
546/// Controls what CUDA stream an actor will use.
547#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
548#[pyo3::pyclass(
549    module = "monarch._rust_bindings.monarch_extension.tensor_worker",
550    eq,
551    eq_int
552)]
553pub enum StreamCreationMode {
554    /// Use the default stream for the current device.
555    UseDefaultStream,
556    /// Create a new stream for this actor.
557    CreateNewStream,
558}
559
560/// An error associated with a seq number that failed to execute.
561/// Any defined value that has an error value will have an assocated
562/// SeqError that is the root cause of why that value has an error.
563/// A value may have this error because it was directly defined by the
564/// action associated with the sequence number, or if it was defined by
565/// another action that dependend on the failing one.
566#[derive(Debug, Named)]
567pub struct SeqError {
568    pub seq: Seq,
569    pub error: anyhow::Error,
570}
571
572impl Display for SeqError {
573    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
574        self.error.fmt(f)
575    }
576}
577
578/// When a worker runs any function, it may not succeed either because the function itself
579/// failed (Error) or because an input to the function already had an error value
580/// DependentError.
581#[derive(Error, Debug, Named)]
582pub enum CallFunctionError {
583    #[error("{0}")]
584    Error(#[from] anyhow::Error),
585    #[error("Computation depended on an input that failed with error: {0}")]
586    DependentError(Arc<SeqError>),
587}
588
589impl CallFunctionError {
590    // Static functions for backward compatibility with existing enum cases
591    #[allow(non_snake_case)]
592    pub fn RefNotFound(r: Ref) -> Self {
593        Self::Error(anyhow::anyhow!("ref not found: {}", r))
594    }
595
596    #[allow(non_snake_case)]
597    pub fn InvalidRemoteFunction(msg: String) -> Self {
598        Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
599    }
600
601    #[allow(non_snake_case)]
602    pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
603        Self::Error(anyhow::anyhow!(
604            "unsupported arg type for {} function: {}",
605            function_type,
606            arg_type
607        ))
608    }
609
610    #[allow(non_snake_case)]
611    pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
612        Self::Error(anyhow::anyhow!("remote function failed: {}", err))
613    }
614
615    #[allow(non_snake_case)]
616    pub fn BorrowError(err: BorrowError) -> Self {
617        Self::Error(anyhow::anyhow!("borrow failed: {}", err))
618    }
619
620    #[allow(non_snake_case)]
621    pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
622        Self::Error(anyhow::anyhow!(
623            "unexpected number of returns from op, expected {}, got {}",
624            expected,
625            actual
626        ))
627    }
628
629    #[allow(non_snake_case)]
630    pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
631        Self::Error(anyhow::anyhow!(
632            "expected only a single arg (and no kwargs) when no function is given: {}, {}",
633            args,
634            kwargs
635        ))
636    }
637
638    #[allow(non_snake_case)]
639    pub fn Anyhow(err: anyhow::Error) -> Self {
640        Self::Error(err)
641    }
642}
643
644impl From<SerializablePyErr> for CallFunctionError {
645    fn from(v: SerializablePyErr) -> CallFunctionError {
646        CallFunctionError::Error(v.into())
647    }
648}
649
650impl From<BorrowError> for CallFunctionError {
651    fn from(v: BorrowError) -> CallFunctionError {
652        CallFunctionError::Error(v.into())
653    }
654}
655
656/// Worker messages. These define the observable behavior of the worker, so the
657/// documentations here
658#[derive(
659    Handler,
660    HandleClient,
661    RefClient,
662    Clone,
663    Serialize,
664    Deserialize,
665    Debug,
666    Named,
667    EnumAsInner,
668    Bind,
669    Unbind
670)]
671pub enum WorkerMessage {
672    /// Initialize backend network state.
673    BackendNetworkInit(UniqueId),
674
675    /// Initialize backend network state for point-to-point communication.
676    BackendNetworkPointToPointInit {
677        from_stream: StreamRef,
678        to_stream: StreamRef,
679    },
680
681    /// Call a function, either a torch op or a Python `remote_function`.
682    CallFunction(CallFunctionParams),
683
684    /// Groups commands together; these commands will be executed in order by
685    /// the worker.
686    CommandGroup(Vec<WorkerMessage>),
687
688    /// Create a [`Stream`] on the worker wih the provided id. Commands will be
689    /// generally be scheduled onto streams to run; different streams can
690    /// execute concurrently with one another.
691    CreateStream {
692        /// Id of the stream to create.
693        id: StreamRef,
694        /// Whether to use the default device stream or create a new one.
695        stream_creation: StreamCreationMode,
696    },
697
698    /// Create a [`DeviceMesh`] on the worker, which can be used to schedule
699    /// efficient inter-worker communication.
700    CreateDeviceMesh {
701        result: Ref,
702        names: Vec<String>,
703        ranks: Slice,
704    },
705
706    /// Create a PyTorch distributed process group on the worker, which can be
707    /// used to schedule collectives in UDFs using monarch communicators.
708    CreateRemoteProcessGroup {
709        result: Ref,
710        device_mesh: Ref,
711        dims: Vec<String>,
712    },
713
714    /// Create a borrow of a tensor from one stream to another.
715    ///
716    /// Borrows allows streams to access tensors on another stream. The runtime
717    /// will insert appropriate synchronization to ensure that cross-stream
718    /// usage is safe.
719    BorrowCreate {
720        /// Ref of the resulting borrowed tensor
721        result: Ref,
722        /// Id for the borrow
723        borrow: u64,
724        /// Tensor to borrow
725        tensor: Ref,
726        /// Stream to borrow from
727        from_stream: StreamRef,
728        /// Stream to borrow to
729        to_stream: StreamRef,
730    },
731
732    /// First use of the borrow on the receiving stream. This is a marker for
733    /// synchronization.
734    BorrowFirstUse {
735        borrow: u64,
736    },
737
738    /// Last use of the borrow on the receiving stream. This is a marker for
739    /// synchronization.
740    BorrowLastUse {
741        borrow: u64,
742    },
743
744    /// Drop the borrow and free the resources associated with it.
745    BorrowDrop {
746        borrow: u64,
747    },
748
749    /// Delete these refs from the worker state.
750    DeleteRefs(Vec<Ref>),
751
752    /// A [`ControllerMessage::Status`] will be send to the controller
753    /// when all streams have processed all the message sent before this one.
754    RequestStatus {
755        seq: Seq,
756        controller: bool,
757    },
758
759    /// Perform a reduction operation, using an efficient communication backend.
760    /// Only NCCL is supported for now.
761    Reduce {
762        /// Where to store the result of the reduction.
763        result: Ref,
764        /// The tensor to reduce.
765        tensor: Ref,
766        /// Tensor metadata for `tensor` that can be used to construct a
767        /// fresh tensor of appropriate size/shape. We use this if
768        /// `tensor` isn't accessible for some reason (like a previous
769        /// error on the worker).
770        factory: Factory,
771        /// The device mesh on which to perform the reduction.
772        mesh: Ref,
773        /// The stream to call the reduction on.
774        stream: StreamRef,
775        /// The dimensions of the device mesh to reduce over. The members of
776        /// these dimension will form the members of the reduction collective.
777        dims: Vec<String>,
778        /// What kind of reduction to perform.
779        reduction: Reduction,
780        /// If `true`, the reduced result will be evenly split across the tensors
781        /// of `dim`.
782        scatter: bool,
783        /// If `true`, the reduction will be performed in-place on `tensor`.
784        in_place: bool,
785        /// Pre-existing tensor that should be used as the output for the reduction.
786        out: Option<Ref>,
787    },
788
789    /// Create a new communicator on each rank in `ranks`, capable of
790    /// communicating with its peers along the specified dimensions.
791    SplitComm {
792        /// The device mesh dimensions along which the constructed communicator
793        /// should be able to exchange data.
794        dims: Vec<String>,
795        /// The device mesh associated with the new communicator. One communicator
796        /// will be created for every member of the mesh.
797        device_mesh: Ref,
798        /// The stream associated with the communicator. Communicator operations
799        /// will be ordered with respect to other operations scheduled on this
800        /// stream.
801        stream: StreamRef,
802    },
803
804    /// Create a new communicator on each rank in `ranks`, capable of
805    /// communicating with its peers along the specified dimensions.
806    SplitCommForProcessGroup {
807        /// The device mesh associated with the new communicator. One communicator
808        /// will be created for every member of the mesh.
809        remote_process_group: Ref,
810        /// The stream associated with the communicator. Communicator operations
811        /// will be ordered with respect to other operations scheduled on this
812        /// stream.
813        stream: StreamRef,
814    },
815
816    SendTensor {
817        result: Ref,
818        from_ranks: Slice,
819        to_ranks: Slice,
820        tensor: Ref,
821        factory: Factory,
822        from_stream: StreamRef,
823        to_stream: StreamRef,
824    },
825
826    SendValue {
827        seq: Seq,
828        /// Pipe to send value to.  If `None`, value is sent to controller.
829        destination: Option<Ref>,
830        mutates: Vec<Ref>,
831        /// Function to resolve the value to retrieve.  If `None`, then `args_kwargs`
832        /// must contain the value as the only element in args with no kwargs.
833        function: Option<ResolvableFunction>,
834        args_kwargs: ArgsKwargs,
835        /// The stream to retrieve from.
836        stream: StreamRef,
837    },
838
839    SendResultOfActorCall(ActorCallParams),
840    CallActorMethod(ActorMethodParams),
841    PipeRecv {
842        seq: Seq,
843        /// Result refs.
844        results: Vec<Option<Ref>>,
845        /// Pipe to receive value from.
846        pipe: Ref,
847        /// The stream to retrieve from.
848        stream: StreamRef,
849    },
850
851    /// Finish processing all messages previously sent to this worker and stop
852    /// the actor loop. Any streams will also be drained.
853    Exit {
854        /// Optional error reason if the exit is the result of an error, including
855        /// - optional actor id to indicate the source of the error
856        /// - error message or stacktrace
857        ///
858        /// The worker process will be stopped if the error is provided.
859        error: Option<(Option<reference::ActorAddr>, String)>,
860    },
861
862    /// Defines (part of) a new recording on the worker. This is a list of commands
863    /// representing the execution of a function that was defined using
864    /// monarch.compile. If there are too many commands to send in a single
865    /// DefineRecording message, the commands may be chunked into `ntotal_messages`,
866    /// with the `index` field indicating how to order the DefineRecording messages
867    /// for a single recording.
868    DefineRecording {
869        /// The ref associated with this recording that will be used to
870        /// call it in the future.
871        result: Ref,
872        /// The number of output tensors.
873        nresults: usize,
874        /// The number of input tensors.
875        nformals: usize,
876        /// The list of commands to run.
877        commands: Vec<WorkerMessage>,
878        /// How many total DefineRecording messages make up this recording.
879        ntotal_messages: usize,
880        /// This DefineRecording message's index in the set of messages
881        /// that make up this recording.
882        index: usize,
883    },
884
885    /// Defines an input tensor for a recording.
886    RecordingFormal {
887        /// The ref that will be used to pass the input tensor to the
888        /// recording.
889        result: Ref,
890        /// The index of the input tensor in the list of input tensors.
891        argument_index: usize,
892        /// The stream that this input tensor will be used on.
893        stream: StreamRef,
894    },
895
896    /// Defines an output tensor for a recording.
897    RecordingResult {
898        /// The ref that will be used to store the output tensor.
899        result: Ref,
900        /// The index of the output tensor in the list of output tensors.
901        output_index: usize,
902        /// The stream that this output tensor will come from.
903        stream: StreamRef,
904    },
905
906    /// Calls a recording that was previously defined using
907    /// DefineRecording.
908    CallRecording {
909        /// The sequence number of the invocation.
910        seq: Seq,
911        /// The ref of the recording to call.
912        recording: Ref,
913        /// The list of refs where the result tensors from the recording
914        /// will be stored.
915        results: Vec<Ref>,
916        /// The list of refs of input tensors to the recording.
917        actuals: Vec<Ref>,
918    },
919
920    SetRefUnitTestsOnly {
921        /// The reference to set.
922        reference: Ref,
923        /// The value to set it with.
924        value: WireValue,
925        /// The stream to set it on.
926        stream: StreamRef,
927    },
928
929    GetRefUnitTestsOnly {
930        /// The value to retrieve, expected to be a bool.
931        value: Ref,
932        /// The stream to retrieve from.
933        stream: StreamRef,
934        #[reply]
935        response_port: reference::OncePortRef<Option<Result<WireValue, String>>>,
936    },
937}
938
939/// The parameters to spawn a worker actor.
940#[derive(Debug, Clone, Serialize, Deserialize, Named)]
941pub struct WorkerParams {
942    // Global world size for this job
943    pub world_size: usize,
944
945    // Rank of the worker within the global world
946    pub rank: usize,
947
948    // Local cuda device that this worker represents. If None, we won't do CUDA
949    // synchronization.
950    pub device_index: Option<i8>,
951
952    // Actor Ref for the controller that the worker is associated with.
953    pub controller_actor: reference::ActorRef<ControllerActor>,
954}
955wirevalue::register_type!(WorkerParams);
956
957hyperactor::behavior!(
958    WorkerActor,
959    WorkerMessage { cast = true },
960);