1#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::RefClient;
27use hyperactor::Unbind;
28use hyperactor::reference::ActorId;
29use monarch_types::SerializablePyErr;
30use monarch_types::py_global;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys_cuda::nccl::ReduceOp;
41use torch_sys_cuda::nccl::UniqueId;
42use torch_sys2::BorrowError;
43use torch_sys2::Device;
44use torch_sys2::Layout;
45use torch_sys2::ScalarType;
46use typeuri::Named;
47
48use crate::controller::ControllerActor;
49use crate::controller::Seq;
50use crate::wire_value::WireValue;
51
52#[derive(
53 Serialize,
54 Deserialize,
55 Debug,
56 Clone,
57 Hash,
58 PartialEq,
59 Eq,
60 Copy,
61 PartialOrd,
62 Ord,
63 From
64)]
65#[pyo3::pyclass(
66 frozen,
67 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
68)]
69pub struct StreamRef {
70 #[pyo3(get)]
71 pub id: u64,
72}
73
74#[pyo3::pymethods]
75impl StreamRef {
76 #[new]
77 #[pyo3(signature = (*, id))]
78 fn new(id: u64) -> Self {
79 Self { id }
80 }
81
82 fn __repr__(&self) -> String {
83 format!("StreamRef({})", self.id)
84 }
85
86 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
88 Ok(match op {
89 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
90 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
91 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
92 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
93 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
94 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
95 })
96 }
97
98 fn __hash__(&self) -> PyResult<u64> {
99 Ok(self.id)
100 }
101}
102
103#[derive(
108 Serialize,
109 Deserialize,
110 Debug,
111 Clone,
112 Hash,
113 PartialEq,
114 Eq,
115 Copy,
116 PartialOrd,
117 Ord,
118 From
119)]
120#[pyo3::pyclass(
121 frozen,
122 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
123)]
124pub struct Ref {
125 #[pyo3(get)]
126 pub id: u64,
127}
128
129#[pyo3::pymethods]
130impl Ref {
131 #[new]
132 fn new(id: u64) -> Self {
133 Self { id }
134 }
135
136 #[getter]
137 fn r#ref(&self) -> u64 {
138 self.id
139 }
140
141 fn __repr__(&self) -> String {
142 format!("Ref({})", self.id)
143 }
144
145 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
147 Ok(match op {
148 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
149 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
150 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
151 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
152 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
153 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
154 })
155 }
156
157 fn __hash__(&self) -> PyResult<u64> {
158 Ok(self.id)
159 }
160
161 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
162 let kwargs = PyDict::new(py);
163 kwargs.set_item("id", self.id).unwrap();
164
165 PyTuple::new(
166 py,
167 vec![
168 PyTuple::empty(py).unbind().into_any(),
169 kwargs.unbind().into_any(),
170 ],
171 )
172 }
173}
174
175impl Ref {
176 pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
179 let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
180 if let Ok(ref_) = obj.extract::<Ref>() {
181 return Ok(ref_);
182 }
183 if let Ok(func) = obj.getattr(attr_name) {
184 if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
185 return Ok(val.into());
186 }
187 }
188 Err(PyValueError::new_err("Could not convert object to Ref"))
189 }
190}
191
192impl Display for Ref {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 write!(f, "r{}", self.id)
195 }
196}
197
198#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
202#[pyo3::pyclass(
203 frozen,
204 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
205)]
206pub struct FunctionPath {
207 #[pyo3(get)]
208 pub path: String,
209}
210
211impl fmt::Display for FunctionPath {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 write!(f, "<function \"{}\">", self.path)
214 }
215}
216
217impl<T: Into<String>> From<T> for FunctionPath {
218 fn from(val: T) -> Self {
219 Self { path: val.into() }
220 }
221}
222
223#[pyo3::pymethods]
224impl FunctionPath {
225 #[new]
226 #[pyo3(signature = (*, path))]
227 pub fn new(path: String) -> Self {
228 Self { path }
229 }
230
231 fn __repr__(&self) -> String {
232 self.path.clone()
233 }
234
235 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
236 let (start, rest) = self.path.split_once(".").with_context(|| {
237 format!(
238 "invalid function path {}: paths must be fully qualified",
239 self.path
240 )
241 })?;
242 if start == "torch" {
243 let mut cur = py.import("torch")?.into_any();
244 for p in rest.split(".") {
245 cur = cur.getattr(p)?;
246 }
247 Ok(cur)
248 } else {
249 let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
250 format!(
251 "invalid function path {}: paths must be fully qualified",
252 self.path
253 )
254 })?;
255 let module = PyModule::import(py, module_fqn)?;
256 let mut function = module.getattr(function_name)?;
257 if function.hasattr("_remote_impl")? {
258 function = function.getattr("_remote_impl")?;
259 }
260 Ok(function.downcast_into()?)
261 }
262 }
263}
264
265#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
269#[pyo3::pyclass(
270 frozen,
271 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
272)]
273pub struct Cloudpickle {
274 #[serde(with = "serde_bytes")]
275 bytes: Vec<u8>,
276}
277
278impl fmt::Display for Cloudpickle {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 write!(f, "<cloud-pickle>")
281 }
282}
283
284py_global!(cloudpickle_dumps, "cloudpickle", "dumps");
285
286#[pyo3::pymethods]
287impl Cloudpickle {
288 #[new]
289 #[pyo3(signature = (*, bytes))]
290 pub fn new(bytes: Vec<u8>) -> Self {
291 Self { bytes }
292 }
293
294 fn __repr__(&self) -> String {
295 format!("Cloudpickle(bytes={:?})", self.bytes)
296 }
297
298 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
299 let module = PyModule::import(py, "cloudpickle")?;
300 let loads = module.getattr("loads")?;
301 loads.call1((PyBytes::new(py, &self.bytes),))
302 }
303}
304
305impl Cloudpickle {
306 pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult<Self> {
307 let py = obj.py();
308 let dumps = cloudpickle_dumps(py);
309 let bytes_obj = dumps.call1((obj,))?;
310 let bytes = bytes_obj.downcast::<PyBytes>()?.as_bytes().to_vec();
311 Ok(Self { bytes })
312 }
313}
314
315#[derive(
316 PartialEq,
317 Serialize,
318 Deserialize,
319 Debug,
320 Clone,
321 TryInto,
322 From,
323 FromPyObject,
324 Display
325)]
326pub enum ResolvableFunction {
327 #[pyo3(transparent)]
328 Cloudpickle(Cloudpickle),
329 #[pyo3(transparent)]
330 FunctionPath(FunctionPath),
331}
332
333#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
334pub struct ArgsKwargs {
335 payload: Cloudpickle,
336}
337
338impl ArgsKwargs {
339 pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult<Self> {
340 let py = args.py();
342 let tuple = PyTuple::new(py, vec![args, kwargs])?;
343 let payload = Cloudpickle::dumps(tuple.into_any())?;
344 Ok(Self { payload })
345 }
346
347 pub fn from_wire_values(
348 args: Vec<WireValue>,
349 kwargs: HashMap<String, WireValue>,
350 ) -> PyResult<Self> {
351 Python::with_gil(|py| {
352 let py_args: Vec<Bound<'_, PyAny>> = args
354 .into_iter()
355 .map(|v| v.into_pyobject(py))
356 .collect::<PyResult<_>>()?;
357 let args_tuple = PyTuple::new(py, py_args)?;
358
359 let kwargs_dict = PyDict::new(py);
361 for (k, v) in kwargs {
362 kwargs_dict.set_item(k, v.into_pyobject(py)?)?;
363 }
364
365 Self::from_python(args_tuple.into_any(), kwargs_dict.into_any())
366 })
367 }
368
369 pub fn to_python<'py>(
370 &self,
371 py: Python<'py>,
372 ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> {
373 let tuple = self.payload.resolve(py)?;
374 let tuple = tuple.downcast::<PyTuple>()?;
375
376 let args = tuple.get_item(0)?;
378 let args_tuple = args.downcast::<PyTuple>()?;
379
380 let kwargs = tuple.get_item(1)?;
382 let kwargs_dict = kwargs.downcast::<PyDict>()?;
383
384 Ok((args_tuple.clone(), kwargs_dict.clone()))
385 }
386}
387
388impl<'py> IntoPyObject<'py> for ResolvableFunction {
389 type Target = PyAny;
390 type Output = Bound<'py, Self::Target>;
391 type Error = PyErr;
392
393 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
394 Ok(match self {
395 Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
396 Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
397 })
398 }
399}
400
401impl ResolvableFunction {
402 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
403 match self {
404 Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
405 Self::FunctionPath(func) => func.resolve(py),
406 }
407 }
408
409 pub fn panic_if_requested(&self) {
412 match self {
413 Self::FunctionPath(func) => {
414 if func.path == "__test_panic" {
415 panic!("__test_panic called");
416 }
417 }
418 _ => (),
419 }
420 }
421}
422
423impl<T: Into<String>> From<T> for ResolvableFunction {
424 fn from(val: T) -> Self {
425 FunctionPath::from(val).into()
426 }
427}
428
429#[derive(Serialize, Deserialize, Debug, Clone)]
430pub struct CallFunctionParams {
431 pub seq: Seq,
433 pub results: Vec<Option<Ref>>,
435 pub mutates: Vec<Ref>,
437 pub function: ResolvableFunction,
439 pub args_kwargs: ArgsKwargs,
441 pub stream: StreamRef,
443 pub remote_process_groups: Vec<Ref>,
445}
446
447#[derive(Serialize, Deserialize, Debug, Clone)]
448pub struct ActorCallParams {
449 pub seq: Seq,
450 pub broker_id: (String, usize),
452 pub local_state: Vec<Ref>,
456 pub mutates: Vec<Ref>,
458 pub stream: StreamRef,
459}
460
461#[derive(Serialize, Deserialize, Debug, Clone)]
462pub struct ActorMethodParams {
463 pub results: Vec<Option<Ref>>,
464 pub call: ActorCallParams,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
469pub enum Reduction {
470 Stack,
472 ReduceOp(ReduceOp),
474}
475
476#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
477#[pyo3::pyclass(
478 frozen,
479 name = "TensorFactory",
480 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
481)]
482pub struct Factory {
483 pub size: Vec<i64>,
484 #[serde(with = "torch_sys2::ScalarTypeDef")]
485 pub dtype: ScalarType,
486 #[serde(with = "torch_sys2::LayoutDef")]
487 pub layout: Layout,
488 pub device: Device,
489}
490
491#[pyo3::pymethods]
492impl Factory {
493 #[new]
494 #[pyo3(signature = (*, size, dtype, layout, device))]
495 pub fn new(
496 py: Python<'_>,
497 size: Vec<i64>,
498 dtype: PyObject,
499 layout: PyObject,
500 device: PyObject,
501 ) -> PyResult<Self> {
502 Ok(Self {
505 size,
506 dtype: dtype.extract::<ScalarType>(py)?,
507 layout: layout.extract::<Layout>(py)?,
508 device: device.extract::<Device>(py)?,
509 })
510 }
511
512 #[staticmethod]
513 pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
514 Self::new(
515 obj.py(),
516 obj.getattr("size")?.extract()?,
517 obj.getattr("dtype")?.unbind(),
518 obj.getattr("layout")?.unbind(),
519 obj.getattr("device")?.unbind(),
520 )
521 }
522
523 #[getter]
524 fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
525 PyTuple::new(py, self.size.iter())
526 }
527
528 #[getter]
529 fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
530 self.dtype.into_pyobject(py)
531 }
532
533 #[getter]
534 fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
535 self.layout.into_pyobject(py)
536 }
537
538 #[getter]
539 fn device(&self) -> String {
540 self.device.to_string()
541 }
542}
543
544#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
546#[pyo3::pyclass(
547 module = "monarch._rust_bindings.monarch_extension.tensor_worker",
548 eq,
549 eq_int
550)]
551pub enum StreamCreationMode {
552 UseDefaultStream,
554 CreateNewStream,
556}
557
558#[derive(Debug, Named)]
565pub struct SeqError {
566 pub seq: Seq,
567 pub error: anyhow::Error,
568}
569
570impl Display for SeqError {
571 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
572 self.error.fmt(f)
573 }
574}
575
576#[derive(Error, Debug, Named)]
580pub enum CallFunctionError {
581 #[error("{0}")]
582 Error(#[from] anyhow::Error),
583 #[error("Computation depended on an input that failed with error: {0}")]
584 DependentError(Arc<SeqError>),
585}
586
587impl CallFunctionError {
588 #[allow(non_snake_case)]
590 pub fn RefNotFound(r: Ref) -> Self {
591 Self::Error(anyhow::anyhow!("ref not found: {}", r))
592 }
593
594 #[allow(non_snake_case)]
595 pub fn InvalidRemoteFunction(msg: String) -> Self {
596 Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
597 }
598
599 #[allow(non_snake_case)]
600 pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
601 Self::Error(anyhow::anyhow!(
602 "unsupported arg type for {} function: {}",
603 function_type,
604 arg_type
605 ))
606 }
607
608 #[allow(non_snake_case)]
609 pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
610 Self::Error(anyhow::anyhow!("remote function failed: {}", err))
611 }
612
613 #[allow(non_snake_case)]
614 pub fn BorrowError(err: BorrowError) -> Self {
615 Self::Error(anyhow::anyhow!("borrow failed: {}", err))
616 }
617
618 #[allow(non_snake_case)]
619 pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
620 Self::Error(anyhow::anyhow!(
621 "unexpected number of returns from op, expected {}, got {}",
622 expected,
623 actual
624 ))
625 }
626
627 #[allow(non_snake_case)]
628 pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
629 Self::Error(anyhow::anyhow!(
630 "expected only a single arg (and no kwargs) when no function is given: {}, {}",
631 args,
632 kwargs
633 ))
634 }
635
636 #[allow(non_snake_case)]
637 pub fn Anyhow(err: anyhow::Error) -> Self {
638 Self::Error(err)
639 }
640}
641
642impl From<SerializablePyErr> for CallFunctionError {
643 fn from(v: SerializablePyErr) -> CallFunctionError {
644 CallFunctionError::Error(v.into())
645 }
646}
647
648impl From<BorrowError> for CallFunctionError {
649 fn from(v: BorrowError) -> CallFunctionError {
650 CallFunctionError::Error(v.into())
651 }
652}
653
654#[derive(
657 Handler,
658 HandleClient,
659 RefClient,
660 Clone,
661 Serialize,
662 Deserialize,
663 Debug,
664 Named,
665 EnumAsInner,
666 Bind,
667 Unbind
668)]
669pub enum WorkerMessage {
670 BackendNetworkInit(UniqueId),
672
673 BackendNetworkPointToPointInit {
675 from_stream: StreamRef,
676 to_stream: StreamRef,
677 },
678
679 CallFunction(CallFunctionParams),
681
682 CommandGroup(Vec<WorkerMessage>),
685
686 CreateStream {
690 id: StreamRef,
692 stream_creation: StreamCreationMode,
694 },
695
696 CreateDeviceMesh {
699 result: Ref,
700 names: Vec<String>,
701 ranks: Slice,
702 },
703
704 CreateRemoteProcessGroup {
707 result: Ref,
708 device_mesh: Ref,
709 dims: Vec<String>,
710 },
711
712 BorrowCreate {
718 result: Ref,
720 borrow: u64,
722 tensor: Ref,
724 from_stream: StreamRef,
726 to_stream: StreamRef,
728 },
729
730 BorrowFirstUse {
733 borrow: u64,
734 },
735
736 BorrowLastUse {
739 borrow: u64,
740 },
741
742 BorrowDrop {
744 borrow: u64,
745 },
746
747 DeleteRefs(Vec<Ref>),
749
750 RequestStatus {
753 seq: Seq,
754 controller: bool,
755 },
756
757 Reduce {
760 result: Ref,
762 tensor: Ref,
764 factory: Factory,
769 mesh: Ref,
771 stream: StreamRef,
773 dims: Vec<String>,
776 reduction: Reduction,
778 scatter: bool,
781 in_place: bool,
783 out: Option<Ref>,
785 },
786
787 SplitComm {
790 dims: Vec<String>,
793 device_mesh: Ref,
796 stream: StreamRef,
800 },
801
802 SplitCommForProcessGroup {
805 remote_process_group: Ref,
808 stream: StreamRef,
812 },
813
814 SendTensor {
815 result: Ref,
816 from_ranks: Slice,
817 to_ranks: Slice,
818 tensor: Ref,
819 factory: Factory,
820 from_stream: StreamRef,
821 to_stream: StreamRef,
822 },
823
824 SendValue {
825 seq: Seq,
826 destination: Option<Ref>,
828 mutates: Vec<Ref>,
829 function: Option<ResolvableFunction>,
832 args_kwargs: ArgsKwargs,
833 stream: StreamRef,
835 },
836
837 SendResultOfActorCall(ActorCallParams),
838 CallActorMethod(ActorMethodParams),
839 PipeRecv {
840 seq: Seq,
841 results: Vec<Option<Ref>>,
843 pipe: Ref,
845 stream: StreamRef,
847 },
848
849 Exit {
852 error: Option<(Option<ActorId>, String)>,
857 },
858
859 DefineRecording {
866 result: Ref,
869 nresults: usize,
871 nformals: usize,
873 commands: Vec<WorkerMessage>,
875 ntotal_messages: usize,
877 index: usize,
880 },
881
882 RecordingFormal {
884 result: Ref,
887 argument_index: usize,
889 stream: StreamRef,
891 },
892
893 RecordingResult {
895 result: Ref,
897 output_index: usize,
899 stream: StreamRef,
901 },
902
903 CallRecording {
906 seq: Seq,
908 recording: Ref,
910 results: Vec<Ref>,
913 actuals: Vec<Ref>,
915 },
916
917 SetRefUnitTestsOnly {
918 reference: Ref,
920 value: WireValue,
922 stream: StreamRef,
924 },
925
926 GetRefUnitTestsOnly {
927 value: Ref,
929 stream: StreamRef,
931 #[reply]
932 response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
933 },
934}
935
936#[derive(Debug, Clone, Serialize, Deserialize, Named)]
938pub struct WorkerParams {
939 pub world_size: usize,
941
942 pub rank: usize,
944
945 pub device_index: Option<i8>,
948
949 pub controller_actor: ActorRef<ControllerActor>,
951}
952wirevalue::register_type!(WorkerParams);
953
954hyperactor::behavior!(
955 WorkerActor,
956 WorkerMessage { cast = true },
957);