1#![allow(unsafe_op_in_unsafe_fn)]
12#![allow(unused_assignments)]
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22use anyhow::Context;
23use derive_more::Display;
24use derive_more::From;
25use derive_more::TryInto;
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32use hyperactor::reference;
33use monarch_types::SerializablePyErr;
34use monarch_types::py_global;
35use ndslice::Slice;
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38use pyo3::types::PyBytes;
39use pyo3::types::PyDict;
40use pyo3::types::PyTuple;
41use serde::Deserialize;
42use serde::Serialize;
43use thiserror::Error;
44use torch_sys_cuda::nccl::ReduceOp;
45use torch_sys_cuda::nccl::UniqueId;
46use torch_sys2::BorrowError;
47use torch_sys2::Device;
48use torch_sys2::Layout;
49use torch_sys2::ScalarType;
50use typeuri::Named;
51
52use crate::controller::ControllerActor;
53use crate::controller::Seq;
54use crate::wire_value::WireValue;
55
56#[derive(
57 Serialize,
58 Deserialize,
59 Debug,
60 Clone,
61 Hash,
62 PartialEq,
63 Eq,
64 Copy,
65 PartialOrd,
66 Ord,
67 From
68)]
69#[pyo3::pyclass(
70 frozen,
71 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
72)]
73pub struct StreamRef {
74 #[pyo3(get)]
75 pub id: u64,
76}
77
78#[pyo3::pymethods]
79impl StreamRef {
80 #[new]
81 #[pyo3(signature = (*, id))]
82 fn new(id: u64) -> Self {
83 Self { id }
84 }
85
86 fn __repr__(&self) -> String {
87 format!("StreamRef({})", self.id)
88 }
89
90 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
92 Ok(match op {
93 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
94 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
95 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
96 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
97 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
98 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
99 })
100 }
101
102 fn __hash__(&self) -> PyResult<u64> {
103 Ok(self.id)
104 }
105}
106
107#[derive(
112 Serialize,
113 Deserialize,
114 Debug,
115 Clone,
116 Hash,
117 PartialEq,
118 Eq,
119 Copy,
120 PartialOrd,
121 Ord,
122 From
123)]
124#[pyo3::pyclass(
125 frozen,
126 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
127)]
128pub struct Ref {
129 #[pyo3(get)]
130 pub id: u64,
131}
132
133#[pyo3::pymethods]
134impl Ref {
135 #[new]
136 fn new(id: u64) -> Self {
137 Self { id }
138 }
139
140 #[getter]
141 fn r#ref(&self) -> u64 {
142 self.id
143 }
144
145 fn __repr__(&self) -> String {
146 format!("Ref({})", self.id)
147 }
148
149 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
151 Ok(match op {
152 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
153 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
154 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
155 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
156 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
157 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
158 })
159 }
160
161 fn __hash__(&self) -> PyResult<u64> {
162 Ok(self.id)
163 }
164
165 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
166 let kwargs = PyDict::new(py);
167 kwargs.set_item("id", self.id).unwrap();
168
169 PyTuple::new(
170 py,
171 vec![
172 PyTuple::empty(py).unbind().into_any(),
173 kwargs.unbind().into_any(),
174 ],
175 )
176 }
177}
178
179impl Ref {
180 pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
183 let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
184 if let Ok(ref_) = obj.extract::<Ref>() {
185 return Ok(ref_);
186 }
187 if let Ok(func) = obj.getattr(attr_name) {
188 if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
189 return Ok(val.into());
190 }
191 }
192 Err(PyValueError::new_err("Could not convert object to Ref"))
193 }
194}
195
196impl Display for Ref {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(f, "r{}", self.id)
199 }
200}
201
202#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
206#[pyo3::pyclass(
207 frozen,
208 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
209)]
210pub struct FunctionPath {
211 #[pyo3(get)]
212 pub path: String,
213}
214
215impl fmt::Display for FunctionPath {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 write!(f, "<function \"{}\">", self.path)
218 }
219}
220
221impl<T: Into<String>> From<T> for FunctionPath {
222 fn from(val: T) -> Self {
223 Self { path: val.into() }
224 }
225}
226
227#[pyo3::pymethods]
228impl FunctionPath {
229 #[new]
230 #[pyo3(signature = (*, path))]
231 pub fn new(path: String) -> Self {
232 Self { path }
233 }
234
235 fn __repr__(&self) -> String {
236 self.path.clone()
237 }
238
239 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
240 let (start, rest) = self.path.split_once(".").with_context(|| {
241 format!(
242 "invalid function path {}: paths must be fully qualified",
243 self.path
244 )
245 })?;
246 if start == "torch" {
247 let mut cur = py.import("torch")?.into_any();
248 for p in rest.split(".") {
249 cur = cur.getattr(p)?;
250 }
251 Ok(cur)
252 } else {
253 let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
254 format!(
255 "invalid function path {}: paths must be fully qualified",
256 self.path
257 )
258 })?;
259 let module = PyModule::import(py, module_fqn)?;
260 let mut function = module.getattr(function_name)?;
261 if function.hasattr("_remote_impl")? {
262 function = function.getattr("_remote_impl")?;
263 }
264 Ok(function.downcast_into()?)
265 }
266 }
267}
268
269#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
273#[pyo3::pyclass(
274 frozen,
275 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
276)]
277pub struct Cloudpickle {
278 #[serde(with = "serde_bytes")]
279 bytes: Vec<u8>,
280}
281
282impl fmt::Display for Cloudpickle {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 write!(f, "<cloud-pickle>")
285 }
286}
287
288py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
289
290#[pyo3::pymethods]
291impl Cloudpickle {
292 #[new]
293 #[pyo3(signature = (*, bytes))]
294 pub fn new(bytes: Vec<u8>) -> Self {
295 Self { bytes }
296 }
297
298 fn __repr__(&self) -> String {
299 format!("Cloudpickle(bytes={:?})", self.bytes)
300 }
301
302 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303 let module = PyModule::import(py, "cloudpickle")?;
304 let loads = module.getattr("loads")?;
305 loads.call1((PyBytes::new(py, &self.bytes),))
306 }
307}
308
309impl Cloudpickle {
310 pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
311 let py = obj.py();
312 let dumps = cloudpickle_dumps(py);
313 let bytes_obj = dumps.call1((obj,))?;
314 let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
315 Ok(Self { bytes })
316 }
317}
318
319#[derive(
320 PartialEq,
321 Serialize,
322 Deserialize,
323 Debug,
324 Clone,
325 TryInto,
326 From,
327 FromPyObject,
328 Display
329)]
330pub enum ResolvableFunction {
331 #[pyo3(transparent)]
332 Cloudpickle(Cloudpickle),
333 #[pyo3(transparent)]
334 FunctionPath(FunctionPath),
335}
336
337#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
338pub struct ArgsKwargs {
339 payload: Cloudpickle,
340}
341
342impl ArgsKwargs {
343 pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
344 let py = args.py();
346 let tuple = PyTuple::new(py, vec![args, kwargs])?;
347 let payload = Cloudpickle::dumps(tuple.into_any())?;
348 Ok(Self { payload })
349 }
350
351 pub fn from_wire_values(
352 args: Vec<WireValue>,
353 kwargs: HashMap<String, WireValue>,
354 ) -> PyResult<Self> {
355 Python::attach(|py| {
356 let py_args: Vec<Bound<'_, PyAny>> = args
358 .into_iter()
359 .map(|v| v.into_pyobject(py))
360 .collect::<PyResult<_>>()?;
361 let args_tuple = PyTuple::new(py, py_args)?;
362
363 let kwargs_dict = PyDict::new(py);
365 for (k, v) in kwargs {
366 kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
367 }
368
369 Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
370 })
371 }
372
373 pub fn to_python<'py>(
374 &self,
375 py: Python<'py>,
376 ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
377 let tuple = self.payload.resolve(py)?;
378 let tuple = tuple.downcast::<PyTuple>()?;
379
380 let args = tuple.get_item(0)?;
382 let args_tuple = args.downcast::<PyTuple>()?;
383
384 let kwargs = tuple.get_item(1)?;
386 let kwargs_dict = kwargs.downcast::<PyDict>()?;
387
388 Ok((args_tuple.clone(), kwargs_dict.clone()))
389 }
390}
391
392impl<'py> IntoPyObject<'py> for ResolvableFunction {
393 type Target = PyAny;
394 type Output = Bound<'py, Self::Target>;
395 type Error = PyErr;
396
397 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
398 Ok(match self {
399 Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
400 Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
401 })
402 }
403}
404
405impl ResolvableFunction {
406 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
407 match self {
408 Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
409 Self::FunctionPath(func) => func.resolve(py),
410 }
411 }
412
413 pub fn panic_if_requested(&self) {
416 match self {
417 Self::FunctionPath(func) => {
418 if func.path == "__test_panic" {
419 panic!("__test_panic called");
420 }
421 }
422 _ => (),
423 }
424 }
425}
426
427impl<T: Into<String>> From<T> for ResolvableFunction {
428 fn from(val: T) -> Self {
429 FunctionPath::from(val).into()
430 }
431}
432
433#[derive(Serialize, Deserialize, Debug, Clone)]
434pub struct CallFunctionParams {
435 pub seq: Seq,
437 pub results: Vec<Option<Ref>>,
439 pub mutates: Vec<Ref>,
441 pub function: ResolvableFunction,
443 pub args_kwargs: ArgsKwargs,
445 pub stream: StreamRef,
447 pub remote_process_groups: Vec<Ref>,
449}
450
451#[derive(Serialize, Deserialize, Debug, Clone)]
452pub struct ActorCallParams {
453 pub seq: Seq,
454 pub broker_id: (String, usize),
456 pub local_state: Vec<Ref>,
460 pub mutates: Vec<Ref>,
462 pub stream: StreamRef,
463}
464
465#[derive(Serialize, Deserialize, Debug, Clone)]
466pub struct ActorMethodParams {
467 pub results: Vec<Option<Ref>>,
468 pub call: ActorCallParams,
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
473pub enum Reduction {
474 Stack,
476 ReduceOp(ReduceOp),
478}
479
480#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
481#[pyo3::pyclass(
482 frozen,
483 name = "TensorFactory",
484 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
485)]
486pub struct Factory {
487 pub size: Vec<i64>,
488 #[serde(with = "torch_sys2::ScalarTypeDef")]
489 pub dtype: ScalarType,
490 #[serde(with = "torch_sys2::LayoutDef")]
491 pub layout: Layout,
492 pub device: Device,
493}
494
495#[pyo3::pymethods]
496impl Factory {
497 #[new]
498 #[pyo3(signature = (*, size, dtype, layout, device))]
499 pub fn new(
500 py: Python<'_>,
501 size: Vec<i64>,
502 dtype: Py<PyAny>,
503 layout: Py<PyAny>,
504 device: Py<PyAny>,
505 ) -> PyResult<Self> {
506 Ok(Self {
509 size,
510 dtype: dtype.extract::<ScalarType>(py)?,
511 layout: layout.extract::<Layout>(py)?,
512 device: device.extract::<Device>(py)?,
513 })
514 }
515
516 #[staticmethod]
517 pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
518 Self::new(
519 obj.py(),
520 obj.getattr("size")?.extract()?,
521 obj.getattr("dtype")?.unbind(),
522 obj.getattr("layout")?.unbind(),
523 obj.getattr("device")?.unbind(),
524 )
525 }
526
527 #[getter]
528 fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
529 PyTuple::new(py, self.size.iter())
530 }
531
532 #[getter]
533 fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
534 self.dtype.into_pyobject(py)
535 }
536
537 #[getter]
538 fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
539 self.layout.into_pyobject(py)
540 }
541
542 #[getter]
543 fn device(&self) -> String {
544 self.device.to_string()
545 }
546}
547
548#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
550#[pyo3::pyclass(
551 module = "monarch._rust_bindings.monarch_extension.tensor_worker",
552 eq,
553 eq_int
554)]
555pub enum StreamCreationMode {
556 UseDefaultStream,
558 CreateNewStream,
560}
561
562#[derive(Debug, Named)]
569pub struct SeqError {
570 pub seq: Seq,
571 pub error: anyhow::Error,
572}
573
574impl Display for SeqError {
575 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576 self.error.fmt(f)
577 }
578}
579
580#[derive(Error, Debug, Named)]
584pub enum CallFunctionError {
585 #[error("{0}")]
586 Error(#[from] anyhow::Error),
587 #[error("Computation depended on an input that failed with error: {0}")]
588 DependentError(Arc<SeqError>),
589}
590
591impl CallFunctionError {
592 #[allow(non_snake_case)]
594 pub fn RefNotFound(r: Ref) -> Self {
595 Self::Error(anyhow::anyhow!("ref not found: {}", r))
596 }
597
598 #[allow(non_snake_case)]
599 pub fn InvalidRemoteFunction(msg: String) -> Self {
600 Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
601 }
602
603 #[allow(non_snake_case)]
604 pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
605 Self::Error(anyhow::anyhow!(
606 "unsupported arg type for {} function: {}",
607 function_type,
608 arg_type
609 ))
610 }
611
612 #[allow(non_snake_case)]
613 pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
614 Self::Error(anyhow::anyhow!("remote function failed: {}", err))
615 }
616
617 #[allow(non_snake_case)]
618 pub fn BorrowError(err: BorrowError) -> Self {
619 Self::Error(anyhow::anyhow!("borrow failed: {}", err))
620 }
621
622 #[allow(non_snake_case)]
623 pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
624 Self::Error(anyhow::anyhow!(
625 "unexpected number of returns from op, expected {}, got {}",
626 expected,
627 actual
628 ))
629 }
630
631 #[allow(non_snake_case)]
632 pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
633 Self::Error(anyhow::anyhow!(
634 "expected only a single arg (and no kwargs) when no function is given: {}, {}",
635 args,
636 kwargs
637 ))
638 }
639
640 #[allow(non_snake_case)]
641 pub fn Anyhow(err: anyhow::Error) -> Self {
642 Self::Error(err)
643 }
644}
645
646impl From<SerializablePyErr> for CallFunctionError {
647 fn from(v: SerializablePyErr) -> CallFunctionError {
648 CallFunctionError::Error(v.into())
649 }
650}
651
652impl From<BorrowError> for CallFunctionError {
653 fn from(v: BorrowError) -> CallFunctionError {
654 CallFunctionError::Error(v.into())
655 }
656}
657
658#[derive(
661 Handler,
662 HandleClient,
663 RefClient,
664 Clone,
665 Serialize,
666 Deserialize,
667 Debug,
668 Named,
669 EnumAsInner,
670 Bind,
671 Unbind
672)]
673pub enum WorkerMessage {
674 BackendNetworkInit(UniqueId),
676
677 BackendNetworkPointToPointInit {
679 from_stream: StreamRef,
680 to_stream: StreamRef,
681 },
682
683 CallFunction(CallFunctionParams),
685
686 CommandGroup(Vec<WorkerMessage>),
689
690 CreateStream {
694 id: StreamRef,
696 stream_creation: StreamCreationMode,
698 },
699
700 CreateDeviceMesh {
703 result: Ref,
704 names: Vec<String>,
705 ranks: Slice,
706 },
707
708 CreateRemoteProcessGroup {
711 result: Ref,
712 device_mesh: Ref,
713 dims: Vec<String>,
714 },
715
716 BorrowCreate {
722 result: Ref,
724 borrow: u64,
726 tensor: Ref,
728 from_stream: StreamRef,
730 to_stream: StreamRef,
732 },
733
734 BorrowFirstUse {
737 borrow: u64,
738 },
739
740 BorrowLastUse {
743 borrow: u64,
744 },
745
746 BorrowDrop {
748 borrow: u64,
749 },
750
751 DeleteRefs(Vec<Ref>),
753
754 RequestStatus {
757 seq: Seq,
758 controller: bool,
759 },
760
761 Reduce {
764 result: Ref,
766 tensor: Ref,
768 factory: Factory,
773 mesh: Ref,
775 stream: StreamRef,
777 dims: Vec<String>,
780 reduction: Reduction,
782 scatter: bool,
785 in_place: bool,
787 out: Option<Ref>,
789 },
790
791 SplitComm {
794 dims: Vec<String>,
797 device_mesh: Ref,
800 stream: StreamRef,
804 },
805
806 SplitCommForProcessGroup {
809 remote_process_group: Ref,
812 stream: StreamRef,
816 },
817
818 SendTensor {
819 result: Ref,
820 from_ranks: Slice,
821 to_ranks: Slice,
822 tensor: Ref,
823 factory: Factory,
824 from_stream: StreamRef,
825 to_stream: StreamRef,
826 },
827
828 SendValue {
829 seq: Seq,
830 destination: Option<Ref>,
832 mutates: Vec<Ref>,
833 function: Option<ResolvableFunction>,
836 args_kwargs: ArgsKwargs,
837 stream: StreamRef,
839 },
840
841 SendResultOfActorCall(ActorCallParams),
842 CallActorMethod(ActorMethodParams),
843 PipeRecv {
844 seq: Seq,
845 results: Vec<Option<Ref>>,
847 pipe: Ref,
849 stream: StreamRef,
851 },
852
853 Exit {
856 error: Option<(Option<reference::ActorId>, String)>,
861 },
862
863 DefineRecording {
870 result: Ref,
873 nresults: usize,
875 nformals: usize,
877 commands: Vec<WorkerMessage>,
879 ntotal_messages: usize,
881 index: usize,
884 },
885
886 RecordingFormal {
888 result: Ref,
891 argument_index: usize,
893 stream: StreamRef,
895 },
896
897 RecordingResult {
899 result: Ref,
901 output_index: usize,
903 stream: StreamRef,
905 },
906
907 CallRecording {
910 seq: Seq,
912 recording: Ref,
914 results: Vec<Ref>,
917 actuals: Vec<Ref>,
919 },
920
921 SetRefUnitTestsOnly {
922 reference: Ref,
924 value: WireValue,
926 stream: StreamRef,
928 },
929
930 GetRefUnitTestsOnly {
931 value: Ref,
933 stream: StreamRef,
935 #[reply]
936 response_port: reference::OncePortRef<Option<Result<WireValue, String>>>,
937 },
938}
939
940#[derive(Debug, Clone, Serialize, Deserialize, Named)]
942pub struct WorkerParams {
943 pub world_size: usize,
945
946 pub rank: usize,
948
949 pub device_index: Option<i8>,
952
953 pub controller_actor: reference::ActorRef<ControllerActor>,
955}
956wirevalue::register_type!(WorkerParams);
957
958hyperactor::behavior!(
959 WorkerActor,
960 WorkerMessage { cast = true },
961);