1#![allow(unsafe_op_in_unsafe_fn)]
12#![allow(unused_assignments)]
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22use anyhow::Context;
23use derive_more::Display;
24use derive_more::From;
25use derive_more::TryInto;
26use enum_as_inner::EnumAsInner;
27use hyperactor as reference;
28use hyperactor::Bind;
29use hyperactor::HandleClient;
30use hyperactor::Handler;
31use hyperactor::RefClient;
32use hyperactor::Unbind;
33use monarch_types::ReduceOp;
34use monarch_types::SerializablePyErr;
35use monarch_types::UniqueId;
36use monarch_types::py_global;
37use ndslice::Slice;
38use pyo3::exceptions::PyValueError;
39use pyo3::prelude::*;
40use pyo3::types::PyBytes;
41use pyo3::types::PyDict;
42use pyo3::types::PyTuple;
43use serde::Deserialize;
44use serde::Serialize;
45use thiserror::Error;
46use torch_sys2::BorrowError;
47use torch_sys2::Device;
48use torch_sys2::Layout;
49use torch_sys2::ScalarType;
50use typeuri::Named;
51
52use crate::controller::ControllerActor;
53use crate::controller::Seq;
54use crate::wire_value::WireValue;
55
56#[derive(
57 Serialize,
58 Deserialize,
59 Debug,
60 Clone,
61 Hash,
62 PartialEq,
63 Eq,
64 Copy,
65 PartialOrd,
66 Ord,
67 From
68)]
69#[pyo3::pyclass(
70 frozen,
71 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
72)]
73pub struct StreamRef {
74 #[pyo3(get)]
75 pub id: u64,
76}
77
78#[pyo3::pymethods]
79impl StreamRef {
80 #[new]
81 #[pyo3(signature = (*, id))]
82 fn new(id: u64) -> Self {
83 Self { id }
84 }
85
86 fn __repr__(&self) -> String {
87 format!("StreamRef({})", self.id)
88 }
89
90 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
92 Ok(match op {
93 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
94 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
95 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
96 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
97 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
98 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
99 })
100 }
101
102 fn __hash__(&self) -> PyResult<u64> {
103 Ok(self.id)
104 }
105}
106
107#[derive(
112 Serialize,
113 Deserialize,
114 Debug,
115 Clone,
116 Hash,
117 PartialEq,
118 Eq,
119 Copy,
120 PartialOrd,
121 Ord,
122 From
123)]
124#[pyo3::pyclass(
125 frozen,
126 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
127)]
128pub struct Ref {
129 #[pyo3(get)]
130 pub id: u64,
131}
132
133#[pyo3::pymethods]
134impl Ref {
135 #[new]
136 fn new(id: u64) -> Self {
137 Self { id }
138 }
139
140 #[getter]
141 fn r#ref(&self) -> u64 {
142 self.id
143 }
144
145 fn __repr__(&self) -> String {
146 format!("Ref({})", self.id)
147 }
148
149 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
151 Ok(match op {
152 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
153 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
154 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
155 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
156 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
157 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
158 })
159 }
160
161 fn __hash__(&self) -> PyResult<u64> {
162 Ok(self.id)
163 }
164
165 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
166 let kwargs = PyDict::new(py);
167 kwargs.set_item("id", self.id).unwrap();
168
169 PyTuple::new(
170 py,
171 vec![
172 PyTuple::empty(py).unbind().into_any(),
173 kwargs.unbind().into_any(),
174 ],
175 )
176 }
177}
178
179impl Ref {
180 pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
183 let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
184 if let Ok(ref_) = obj.extract::<Ref>() {
185 return Ok(ref_);
186 }
187 if let Ok(func) = obj.getattr(attr_name) {
188 if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
189 return Ok(val.into());
190 }
191 }
192 Err(PyValueError::new_err("Could not convert object to Ref"))
193 }
194}
195
196impl Display for Ref {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(f, "r{}", self.id)
199 }
200}
201
202#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
206#[pyo3::pyclass(
207 frozen,
208 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
209)]
210pub struct FunctionPath {
211 #[pyo3(get)]
212 pub path: String,
213}
214
215impl fmt::Display for FunctionPath {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 write!(f, "<function \"{}\">", self.path)
218 }
219}
220
221impl<T: Into<String>> From<T> for FunctionPath {
222 fn from(val: T) -> Self {
223 Self { path: val.into() }
224 }
225}
226
227#[pyo3::pymethods]
228impl FunctionPath {
229 #[new]
230 #[pyo3(signature = (*, path))]
231 pub fn new(path: String) -> Self {
232 Self { path }
233 }
234
235 fn __repr__(&self) -> String {
236 self.path.clone()
237 }
238
239 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
240 let (start, rest) = self.path.split_once(".").with_context(|| {
241 format!(
242 "invalid function path {}: paths must be fully qualified",
243 self.path
244 )
245 })?;
246 if start == "torch" {
247 let mut cur = py.import("torch")?.into_any();
248 for p in rest.split(".") {
249 cur = cur.getattr(p)?;
250 }
251 Ok(cur)
252 } else {
253 let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
254 format!(
255 "invalid function path {}: paths must be fully qualified",
256 self.path
257 )
258 })?;
259 let module = PyModule::import(py, module_fqn)?;
260 let mut function = module.getattr(function_name)?;
261 if function.hasattr("_remote_impl")? {
262 function = function.getattr("_remote_impl")?;
263 }
264 Ok(function.downcast_into()?)
265 }
266 }
267}
268
269#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
273#[pyo3::pyclass(
274 frozen,
275 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
276)]
277pub struct Cloudpickle {
278 #[serde(with = "serde_bytes")]
279 bytes: Vec<u8>,
280}
281
282impl fmt::Display for Cloudpickle {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 write!(f, "<cloud-pickle>")
285 }
286}
287
288py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
289
290#[pyo3::pymethods]
291impl Cloudpickle {
292 #[new]
293 #[pyo3(signature = (*, bytes))]
294 pub fn new(bytes: Vec<u8>) -> Self {
295 Self { bytes }
296 }
297
298 fn __repr__(&self) -> String {
299 format!("Cloudpickle(bytes={:?})", self.bytes)
300 }
301
302 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303 let module = PyModule::import(py, "cloudpickle")?;
304 let loads = module.getattr("loads")?;
305 loads.call1((PyBytes::new(py, &self.bytes),))
306 }
307}
308
309impl Cloudpickle {
310 pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
311 let py = obj.py();
312 let dumps = cloudpickle_dumps(py);
313 let bytes_obj = dumps.call1((obj,))?;
314 let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
315 Ok(Self { bytes })
316 }
317}
318
319#[derive(
320 PartialEq,
321 Serialize,
322 Deserialize,
323 Debug,
324 Clone,
325 TryInto,
326 From,
327 FromPyObject,
328 Display
329)]
330pub enum ResolvableFunction {
331 #[pyo3(transparent)]
332 Cloudpickle(Cloudpickle),
333 #[pyo3(transparent)]
334 FunctionPath(FunctionPath),
335}
336
337#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
338pub struct ArgsKwargs {
339 payload: Cloudpickle,
340}
341
342impl ArgsKwargs {
343 pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
344 let py = args.py();
346 let tuple = PyTuple::new(py, vec![args, kwargs])?;
347 let payload = Cloudpickle::dumps(tuple.into_any())?;
348 Ok(Self { payload })
349 }
350
351 pub fn from_wire_values(
352 args: Vec<WireValue>,
353 kwargs: HashMap<String, WireValue>,
354 ) -> PyResult<Self> {
355 Python::attach(|py| {
356 let py_args: Vec<Bound<'_, PyAny>> = args
358 .into_iter()
359 .map(|v| v.into_pyobject(py))
360 .collect::<PyResult<_>>()?;
361 let args_tuple = PyTuple::new(py, py_args)?;
362
363 let kwargs_dict = PyDict::new(py);
365 for (k, v) in kwargs {
366 kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
367 }
368
369 Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
370 })
371 }
372
373 pub fn to_python<'py>(
374 &self,
375 py: Python<'py>,
376 ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
377 let tuple = self.payload.resolve(py)?;
378 let tuple = tuple.downcast::<PyTuple>()?;
379
380 let args = tuple.get_item(0)?;
382 let args_tuple = args.downcast::<PyTuple>()?;
383
384 let kwargs = tuple.get_item(1)?;
386 let kwargs_dict = kwargs.downcast::<PyDict>()?;
387
388 Ok((args_tuple.clone(), kwargs_dict.clone()))
389 }
390}
391
392impl<'py> IntoPyObject<'py> for ResolvableFunction {
393 type Target = PyAny;
394 type Output = Bound<'py, Self::Target>;
395 type Error = PyErr;
396
397 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
398 Ok(match self {
399 Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
400 Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
401 })
402 }
403}
404
405impl ResolvableFunction {
406 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
407 match self {
408 Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
409 Self::FunctionPath(func) => func.resolve(py),
410 }
411 }
412
413 pub fn panic_if_requested(&self) {
416 match self {
417 Self::FunctionPath(func) if func.path == "__test_panic" => {
418 panic!("__test_panic called");
419 }
420 _ => (),
421 }
422 }
423}
424
425impl<T: Into<String>> From<T> for ResolvableFunction {
426 fn from(val: T) -> Self {
427 FunctionPath::from(val).into()
428 }
429}
430
431#[derive(Serialize, Deserialize, Debug, Clone)]
432pub struct CallFunctionParams {
433 pub seq: Seq,
435 pub results: Vec<Option<Ref>>,
437 pub mutates: Vec<Ref>,
439 pub function: ResolvableFunction,
441 pub args_kwargs: ArgsKwargs,
443 pub stream: StreamRef,
445 pub remote_process_groups: Vec<Ref>,
447}
448
449#[derive(Serialize, Deserialize, Debug, Clone)]
450pub struct ActorCallParams {
451 pub seq: Seq,
452 pub broker_id: (String, usize),
454 pub local_state: Vec<Ref>,
458 pub mutates: Vec<Ref>,
460 pub stream: StreamRef,
461}
462
463#[derive(Serialize, Deserialize, Debug, Clone)]
464pub struct ActorMethodParams {
465 pub results: Vec<Option<Ref>>,
466 pub call: ActorCallParams,
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum Reduction {
472 Stack,
474 ReduceOp(ReduceOp),
476}
477
478#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
479#[pyo3::pyclass(
480 frozen,
481 name = "TensorFactory",
482 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
483)]
484pub struct Factory {
485 pub size: Vec<i64>,
486 #[serde(with = "torch_sys2::ScalarTypeDef")]
487 pub dtype: ScalarType,
488 #[serde(with = "torch_sys2::LayoutDef")]
489 pub layout: Layout,
490 pub device: Device,
491}
492
493#[pyo3::pymethods]
494impl Factory {
495 #[new]
496 #[pyo3(signature = (*, size, dtype, layout, device))]
497 pub fn new(
498 py: Python<'_>,
499 size: Vec<i64>,
500 dtype: Py<PyAny>,
501 layout: Py<PyAny>,
502 device: Py<PyAny>,
503 ) -> PyResult<Self> {
504 Ok(Self {
507 size,
508 dtype: dtype.extract::<ScalarType>(py)?,
509 layout: layout.extract::<Layout>(py)?,
510 device: device.extract::<Device>(py)?,
511 })
512 }
513
514 #[staticmethod]
515 pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
516 Self::new(
517 obj.py(),
518 obj.getattr("size")?.extract()?,
519 obj.getattr("dtype")?.unbind(),
520 obj.getattr("layout")?.unbind(),
521 obj.getattr("device")?.unbind(),
522 )
523 }
524
525 #[getter]
526 fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
527 PyTuple::new(py, self.size.iter())
528 }
529
530 #[getter]
531 fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
532 self.dtype.into_pyobject(py)
533 }
534
535 #[getter]
536 fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
537 self.layout.into_pyobject(py)
538 }
539
540 #[getter]
541 fn device(&self) -> String {
542 self.device.to_string()
543 }
544}
545
546#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
548#[pyo3::pyclass(
549 module = "monarch._rust_bindings.monarch_extension.tensor_worker",
550 eq,
551 eq_int
552)]
553pub enum StreamCreationMode {
554 UseDefaultStream,
556 CreateNewStream,
558}
559
560#[derive(Debug, Named)]
567pub struct SeqError {
568 pub seq: Seq,
569 pub error: anyhow::Error,
570}
571
572impl Display for SeqError {
573 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
574 self.error.fmt(f)
575 }
576}
577
578#[derive(Error, Debug, Named)]
582pub enum CallFunctionError {
583 #[error("{0}")]
584 Error(#[from] anyhow::Error),
585 #[error("Computation depended on an input that failed with error: {0}")]
586 DependentError(Arc<SeqError>),
587}
588
589impl CallFunctionError {
590 #[allow(non_snake_case)]
592 pub fn RefNotFound(r: Ref) -> Self {
593 Self::Error(anyhow::anyhow!("ref not found: {}", r))
594 }
595
596 #[allow(non_snake_case)]
597 pub fn InvalidRemoteFunction(msg: String) -> Self {
598 Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
599 }
600
601 #[allow(non_snake_case)]
602 pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
603 Self::Error(anyhow::anyhow!(
604 "unsupported arg type for {} function: {}",
605 function_type,
606 arg_type
607 ))
608 }
609
610 #[allow(non_snake_case)]
611 pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
612 Self::Error(anyhow::anyhow!("remote function failed: {}", err))
613 }
614
615 #[allow(non_snake_case)]
616 pub fn BorrowError(err: BorrowError) -> Self {
617 Self::Error(anyhow::anyhow!("borrow failed: {}", err))
618 }
619
620 #[allow(non_snake_case)]
621 pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
622 Self::Error(anyhow::anyhow!(
623 "unexpected number of returns from op, expected {}, got {}",
624 expected,
625 actual
626 ))
627 }
628
629 #[allow(non_snake_case)]
630 pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
631 Self::Error(anyhow::anyhow!(
632 "expected only a single arg (and no kwargs) when no function is given: {}, {}",
633 args,
634 kwargs
635 ))
636 }
637
638 #[allow(non_snake_case)]
639 pub fn Anyhow(err: anyhow::Error) -> Self {
640 Self::Error(err)
641 }
642}
643
644impl From<SerializablePyErr> for CallFunctionError {
645 fn from(v: SerializablePyErr) -> CallFunctionError {
646 CallFunctionError::Error(v.into())
647 }
648}
649
650impl From<BorrowError> for CallFunctionError {
651 fn from(v: BorrowError) -> CallFunctionError {
652 CallFunctionError::Error(v.into())
653 }
654}
655
656#[derive(
659 Handler,
660 HandleClient,
661 RefClient,
662 Clone,
663 Serialize,
664 Deserialize,
665 Debug,
666 Named,
667 EnumAsInner,
668 Bind,
669 Unbind
670)]
671pub enum WorkerMessage {
672 BackendNetworkInit(UniqueId),
674
675 BackendNetworkPointToPointInit {
677 from_stream: StreamRef,
678 to_stream: StreamRef,
679 },
680
681 CallFunction(CallFunctionParams),
683
684 CommandGroup(Vec<WorkerMessage>),
687
688 CreateStream {
692 id: StreamRef,
694 stream_creation: StreamCreationMode,
696 },
697
698 CreateDeviceMesh {
701 result: Ref,
702 names: Vec<String>,
703 ranks: Slice,
704 },
705
706 CreateRemoteProcessGroup {
709 result: Ref,
710 device_mesh: Ref,
711 dims: Vec<String>,
712 },
713
714 BorrowCreate {
720 result: Ref,
722 borrow: u64,
724 tensor: Ref,
726 from_stream: StreamRef,
728 to_stream: StreamRef,
730 },
731
732 BorrowFirstUse {
735 borrow: u64,
736 },
737
738 BorrowLastUse {
741 borrow: u64,
742 },
743
744 BorrowDrop {
746 borrow: u64,
747 },
748
749 DeleteRefs(Vec<Ref>),
751
752 RequestStatus {
755 seq: Seq,
756 controller: bool,
757 },
758
759 Reduce {
762 result: Ref,
764 tensor: Ref,
766 factory: Factory,
771 mesh: Ref,
773 stream: StreamRef,
775 dims: Vec<String>,
778 reduction: Reduction,
780 scatter: bool,
783 in_place: bool,
785 out: Option<Ref>,
787 },
788
789 SplitComm {
792 dims: Vec<String>,
795 device_mesh: Ref,
798 stream: StreamRef,
802 },
803
804 SplitCommForProcessGroup {
807 remote_process_group: Ref,
810 stream: StreamRef,
814 },
815
816 SendTensor {
817 result: Ref,
818 from_ranks: Slice,
819 to_ranks: Slice,
820 tensor: Ref,
821 factory: Factory,
822 from_stream: StreamRef,
823 to_stream: StreamRef,
824 },
825
826 SendValue {
827 seq: Seq,
828 destination: Option<Ref>,
830 mutates: Vec<Ref>,
831 function: Option<ResolvableFunction>,
834 args_kwargs: ArgsKwargs,
835 stream: StreamRef,
837 },
838
839 SendResultOfActorCall(ActorCallParams),
840 CallActorMethod(ActorMethodParams),
841 PipeRecv {
842 seq: Seq,
843 results: Vec<Option<Ref>>,
845 pipe: Ref,
847 stream: StreamRef,
849 },
850
851 Exit {
854 error: Option<(Option<reference::ActorAddr>, String)>,
860 },
861
862 DefineRecording {
869 result: Ref,
872 nresults: usize,
874 nformals: usize,
876 commands: Vec<WorkerMessage>,
878 ntotal_messages: usize,
880 index: usize,
883 },
884
885 RecordingFormal {
887 result: Ref,
890 argument_index: usize,
892 stream: StreamRef,
894 },
895
896 RecordingResult {
898 result: Ref,
900 output_index: usize,
902 stream: StreamRef,
904 },
905
906 CallRecording {
909 seq: Seq,
911 recording: Ref,
913 results: Vec<Ref>,
916 actuals: Vec<Ref>,
918 },
919
920 SetRefUnitTestsOnly {
921 reference: Ref,
923 value: WireValue,
925 stream: StreamRef,
927 },
928
929 GetRefUnitTestsOnly {
930 value: Ref,
932 stream: StreamRef,
934 #[reply]
935 response_port: reference::OncePortRef<Option<Result<WireValue, String>>>,
936 },
937}
938
939#[derive(Debug, Clone, Serialize, Deserialize, Named)]
941pub struct WorkerParams {
942 pub world_size: usize,
944
945 pub rank: usize,
947
948 pub device_index: Option<i8>,
951
952 pub controller_actor: reference::ActorRef<ControllerActor>,
954}
955wirevalue::register_type!(WorkerParams);
956
957hyperactor::behavior!(
958 WorkerActor,
959 WorkerMessage { cast = true },
960);