1#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Named;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29use hyperactor::reference::ActorId;
30use monarch_types::SerializablePyErr;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys::BorrowError;
41use torch_sys::Device;
42use torch_sys::Layout;
43use torch_sys::ScalarType;
44use torch_sys::call_op::CallOpError;
45use torch_sys_cuda::nccl::NcclConfig;
46use torch_sys_cuda::nccl::ReduceOp;
47use torch_sys_cuda::nccl::UniqueId;
48
49use crate::controller::ControllerActor;
50use crate::controller::Seq;
51use crate::wire_value::WireValue;
52
53#[derive(
54 Serialize,
55 Deserialize,
56 Debug,
57 Clone,
58 Hash,
59 PartialEq,
60 Eq,
61 Copy,
62 PartialOrd,
63 Ord,
64 From
65)]
66#[pyo3::pyclass(
67 frozen,
68 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
69)]
70pub struct StreamRef {
71 #[pyo3(get)]
72 pub id: u64,
73}
74
75#[pyo3::pymethods]
76impl StreamRef {
77 #[new]
78 #[pyo3(signature = (*, id))]
79 fn new(id: u64) -> Self {
80 Self { id }
81 }
82
83 fn __repr__(&self) -> String {
84 format!("StreamRef({})", self.id)
85 }
86
87 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
89 Ok(match op {
90 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
91 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
92 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
93 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
94 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
95 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
96 })
97 }
98
99 fn __hash__(&self) -> PyResult<u64> {
100 Ok(self.id)
101 }
102}
103
104#[derive(
109 Serialize,
110 Deserialize,
111 Debug,
112 Clone,
113 Hash,
114 PartialEq,
115 Eq,
116 Copy,
117 PartialOrd,
118 Ord,
119 From
120)]
121#[pyo3::pyclass(
122 frozen,
123 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
124)]
125pub struct Ref {
126 #[pyo3(get)]
127 pub id: u64,
128}
129
130#[pyo3::pymethods]
131impl Ref {
132 #[new]
133 fn new(id: u64) -> Self {
134 Self { id }
135 }
136
137 #[getter]
138 fn r#ref(&self) -> u64 {
139 self.id
140 }
141
142 fn __repr__(&self) -> String {
143 format!("Ref({})", self.id)
144 }
145
146 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
148 Ok(match op {
149 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
150 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
151 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
152 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
153 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
154 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
155 })
156 }
157
158 fn __hash__(&self) -> PyResult<u64> {
159 Ok(self.id)
160 }
161
162 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
163 let kwargs = PyDict::new(py);
164 kwargs.set_item("id", self.id).unwrap();
165
166 PyTuple::new(
167 py,
168 vec![
169 PyTuple::empty(py).unbind().into_any(),
170 kwargs.unbind().into_any(),
171 ],
172 )
173 }
174}
175
176impl Ref {
177 pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
180 let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
181 if let Ok(ref_) = obj.extract::<Ref>() {
182 return Ok(ref_);
183 }
184 if let Ok(func) = obj.getattr(attr_name) {
185 if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
186 return Ok(val.into());
187 }
188 }
189 Err(PyValueError::new_err("Could not convert object to Ref"))
190 }
191}
192
193impl Display for Ref {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "r{}", self.id)
196 }
197}
198
199#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
203#[pyo3::pyclass(
204 frozen,
205 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
206)]
207pub struct FunctionPath {
208 #[pyo3(get)]
209 pub path: String,
210}
211
212impl fmt::Display for FunctionPath {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 write!(f, "<function \"{}\">", self.path)
215 }
216}
217
218impl<T: Into<String>> From<T> for FunctionPath {
219 fn from(val: T) -> Self {
220 Self { path: val.into() }
221 }
222}
223
224#[pyo3::pymethods]
225impl FunctionPath {
226 #[new]
227 #[pyo3(signature = (*, path))]
228 pub fn new(path: String) -> Self {
229 Self { path }
230 }
231
232 fn __repr__(&self) -> String {
233 self.path.clone()
234 }
235
236 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
237 let (start, rest) = self.path.split_once(".").with_context(|| {
238 format!(
239 "invalid function path {}: paths must be fully qualified",
240 self.path
241 )
242 })?;
243 if start == "torch" {
244 let mut cur = py.import("torch")?.into_any();
245 for p in rest.split(".") {
246 cur = cur.getattr(p)?;
247 }
248 Ok(cur)
249 } else {
250 let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
251 format!(
252 "invalid function path {}: paths must be fully qualified",
253 self.path
254 )
255 })?;
256 let module = PyModule::import(py, module_fqn)?;
257 let mut function = module.getattr(function_name)?;
258 if function.hasattr("_remote_impl")? {
259 function = function.getattr("_remote_impl")?;
260 }
261 Ok(function.downcast_into()?)
262 }
263 }
264}
265
266#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
270#[pyo3::pyclass(
271 frozen,
272 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
273)]
274pub struct Cloudpickle {
275 #[serde(with = "serde_bytes")]
276 bytes: Vec<u8>,
277}
278
279impl fmt::Display for Cloudpickle {
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 write!(f, "<cloud-pickle>")
282 }
283}
284
285#[pyo3::pymethods]
286impl Cloudpickle {
287 #[new]
288 #[pyo3(signature = (*, bytes))]
289 pub fn new(bytes: Vec<u8>) -> Self {
290 Self { bytes }
291 }
292
293 fn __repr__(&self) -> String {
294 format!("Cloudpickle(bytes={:?})", self.bytes)
295 }
296
297 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
298 let module = PyModule::import(py, "cloudpickle")?;
299 let loads = module.getattr("loads")?;
300 loads.call1((PyBytes::new(py, &self.bytes),))
301 }
302}
303
304#[derive(
305 PartialEq,
306 Serialize,
307 Deserialize,
308 Debug,
309 Clone,
310 TryInto,
311 From,
312 FromPyObject,
313 Display
314)]
315pub enum ResolvableFunction {
316 #[pyo3(transparent)]
317 Cloudpickle(Cloudpickle),
318 #[pyo3(transparent)]
319 FunctionPath(FunctionPath),
320}
321
322impl<'py> IntoPyObject<'py> for ResolvableFunction {
323 type Target = PyAny;
324 type Output = Bound<'py, Self::Target>;
325 type Error = PyErr;
326
327 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
328 Ok(match self {
329 Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
330 Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
331 })
332 }
333}
334
335impl ResolvableFunction {
336 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
337 match self {
338 Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
339 Self::FunctionPath(func) => func.resolve(py),
340 }
341 }
342
343 pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
344 match self {
345 Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
346 ["torch", "ops", namespace, op_name, "default"] => {
347 Some((format!("{}::{}", namespace, op_name), String::new()))
348 }
349 ["torch", "ops", namespace, op_name, overload] => {
350 Some((format!("{}::{}", namespace, op_name), overload.to_string()))
351 }
352 _ => None,
353 },
354 _ => None,
355 }
356 }
357
358 pub fn panic_if_requested(&self) {
361 match self {
362 Self::FunctionPath(func) => {
363 if func.path == "__test_panic" {
364 panic!("__test_panic called");
365 }
366 }
367 _ => (),
368 }
369 }
370
371 pub fn supports_pytree_args(&self) -> bool {
372 match self {
373 Self::Cloudpickle(_) => true,
374 Self::FunctionPath(_) => self.as_torch_op().is_none(),
375 }
376 }
377}
378
379impl<T: Into<String>> From<T> for ResolvableFunction {
380 fn from(val: T) -> Self {
381 FunctionPath::from(val).into()
382 }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone)]
386pub struct CallFunctionParams {
387 pub seq: Seq,
389 pub results: Vec<Option<Ref>>,
391 pub mutates: Vec<Ref>,
393 pub function: ResolvableFunction,
395 pub args: Vec<WireValue>,
397 pub kwargs: HashMap<String, WireValue>,
399 pub stream: StreamRef,
401 pub remote_process_groups: Vec<Ref>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct ActorCallParams {
407 pub seq: Seq,
408 pub broker_id: (String, usize),
410 pub local_state: Vec<Ref>,
414 pub mutates: Vec<Ref>,
416 pub stream: StreamRef,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct ActorMethodParams {
421 pub results: Vec<Option<Ref>>,
422 pub call: ActorCallParams,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub enum Reduction {
428 Stack,
430 ReduceOp(ReduceOp),
432}
433
434#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
435#[pyo3::pyclass(
436 frozen,
437 name = "TensorFactory",
438 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
439)]
440pub struct Factory {
441 pub size: Vec<i64>,
442 #[serde(with = "torch_sys::ScalarTypeDef")]
443 pub dtype: ScalarType,
444 #[serde(with = "torch_sys::LayoutDef")]
445 pub layout: Layout,
446 pub device: Device,
447}
448
449#[pyo3::pymethods]
450impl Factory {
451 #[new]
452 #[pyo3(signature = (*, size, dtype, layout, device))]
453 pub fn new(
454 py: Python<'_>,
455 size: Vec<i64>,
456 dtype: PyObject,
457 layout: PyObject,
458 device: PyObject,
459 ) -> PyResult<Self> {
460 Ok(Self {
463 size,
464 dtype: dtype.extract::<ScalarType>(py)?,
465 layout: layout.extract::<Layout>(py)?,
466 device: device.extract::<Device>(py)?,
467 })
468 }
469
470 #[staticmethod]
471 pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
472 Self::new(
473 obj.py(),
474 obj.getattr("size")?.extract()?,
475 obj.getattr("dtype")?.unbind(),
476 obj.getattr("layout")?.unbind(),
477 obj.getattr("device")?.unbind(),
478 )
479 }
480
481 #[getter]
482 fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
483 PyTuple::new(py, self.size.iter())
484 }
485
486 #[getter]
487 fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
488 self.dtype.into_pyobject(py)
489 }
490
491 #[getter]
492 fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
493 self.layout.into_pyobject(py)
494 }
495
496 #[getter]
497 fn device(&self) -> String {
498 self.device.to_string()
499 }
500}
501
502#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
504#[pyo3::pyclass(
505 module = "monarch._rust_bindings.monarch_extension.tensor_worker",
506 eq,
507 eq_int
508)]
509pub enum StreamCreationMode {
510 UseDefaultStream,
512 CreateNewStream,
514}
515
516#[derive(Debug, Named)]
523pub struct SeqError {
524 pub seq: Seq,
525 pub error: anyhow::Error,
526}
527
528impl Display for SeqError {
529 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 self.error.fmt(f)
531 }
532}
533
534#[derive(Error, Debug, Named)]
538pub enum CallFunctionError {
539 #[error("{0}")]
540 Error(#[from] anyhow::Error),
541 #[error("Computation depended on an input that failed with error: {0}")]
542 DependentError(Arc<SeqError>),
543}
544
545impl CallFunctionError {
546 #[allow(non_snake_case)]
548 pub fn RefNotFound(r: Ref) -> Self {
549 Self::Error(anyhow::anyhow!("ref not found: {}", r))
550 }
551
552 #[allow(non_snake_case)]
553 pub fn InvalidRemoteFunction(msg: String) -> Self {
554 Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
555 }
556
557 #[allow(non_snake_case)]
558 pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
559 Self::Error(anyhow::anyhow!(
560 "unsupported arg type for {} function: {}",
561 function_type,
562 arg_type
563 ))
564 }
565
566 #[allow(non_snake_case)]
567 pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
568 Self::Error(anyhow::anyhow!("remote function failed: {}", err))
569 }
570
571 #[allow(non_snake_case)]
572 pub fn BorrowError(err: BorrowError) -> Self {
573 Self::Error(anyhow::anyhow!("borrow failed: {}", err))
574 }
575
576 #[allow(non_snake_case)]
577 pub fn OperatorFailed(err: CallOpError) -> Self {
578 Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
579 }
580
581 #[allow(non_snake_case)]
582 pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
583 Self::Error(anyhow::anyhow!(
584 "unexpected number of returns from op, expected {}, got {}",
585 expected,
586 actual
587 ))
588 }
589
590 #[allow(non_snake_case)]
591 pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
592 Self::Error(anyhow::anyhow!(
593 "expected only a single arg (and no kwargs) when no function is given: {}, {}",
594 args,
595 kwargs
596 ))
597 }
598
599 #[allow(non_snake_case)]
600 pub fn Anyhow(err: anyhow::Error) -> Self {
601 Self::Error(err)
602 }
603}
604
605impl From<SerializablePyErr> for CallFunctionError {
606 fn from(v: SerializablePyErr) -> CallFunctionError {
607 CallFunctionError::Error(v.into())
608 }
609}
610
611impl From<BorrowError> for CallFunctionError {
612 fn from(v: BorrowError) -> CallFunctionError {
613 CallFunctionError::Error(v.into())
614 }
615}
616
617impl From<CallOpError> for CallFunctionError {
618 fn from(v: CallOpError) -> CallFunctionError {
619 CallFunctionError::Error(v.into())
620 }
621}
622
623#[derive(
626 Handler,
627 HandleClient,
628 RefClient,
629 Clone,
630 Serialize,
631 Deserialize,
632 Debug,
633 Named,
634 EnumAsInner,
635 Bind,
636 Unbind
637)]
638pub enum WorkerMessage {
639 BackendNetworkInit(UniqueId),
641
642 BackendNetworkPointToPointInit {
644 from_stream: StreamRef,
645 to_stream: StreamRef,
646 },
647
648 CallFunction(CallFunctionParams),
650
651 CommandGroup(Vec<WorkerMessage>),
654
655 CreateStream {
659 id: StreamRef,
661 stream_creation: StreamCreationMode,
663 },
664
665 CreateDeviceMesh {
668 result: Ref,
669 names: Vec<String>,
670 ranks: Slice,
671 },
672
673 CreateRemoteProcessGroup {
676 result: Ref,
677 device_mesh: Ref,
678 dims: Vec<String>,
679 },
680
681 BorrowCreate {
687 result: Ref,
689 borrow: u64,
691 tensor: Ref,
693 from_stream: StreamRef,
695 to_stream: StreamRef,
697 },
698
699 BorrowFirstUse {
702 borrow: u64,
703 },
704
705 BorrowLastUse {
708 borrow: u64,
709 },
710
711 BorrowDrop {
713 borrow: u64,
714 },
715
716 DeleteRefs(Vec<Ref>),
718
719 RequestStatus {
722 seq: Seq,
723 controller: bool,
724 },
725
726 Reduce {
729 result: Ref,
731 tensor: Ref,
733 factory: Factory,
738 mesh: Ref,
740 stream: StreamRef,
742 dims: Vec<String>,
745 reduction: Reduction,
747 scatter: bool,
750 in_place: bool,
752 out: Option<Ref>,
754 },
755
756 SplitComm {
759 dims: Vec<String>,
762 device_mesh: Ref,
765 stream: StreamRef,
769 config: Option<NcclConfig>,
773 },
774
775 SplitCommForProcessGroup {
778 remote_process_group: Ref,
781 stream: StreamRef,
785 config: Option<NcclConfig>,
789 },
790
791 SendTensor {
792 result: Ref,
793 from_ranks: Slice,
794 to_ranks: Slice,
795 tensor: Ref,
796 factory: Factory,
797 from_stream: StreamRef,
798 to_stream: StreamRef,
799 },
800
801 CreatePipe {
802 result: Ref,
803 key: String,
804 function: ResolvableFunction,
805 max_messages: i64,
806 mesh: Ref,
807 args: Vec<WireValue>,
808 kwargs: HashMap<String, WireValue>,
809 },
810
811 SendValue {
812 seq: Seq,
813 destination: Option<Ref>,
815 mutates: Vec<Ref>,
816 function: Option<ResolvableFunction>,
820 args: Vec<WireValue>,
821 kwargs: HashMap<String, WireValue>,
822 stream: StreamRef,
824 },
825
826 SendResultOfActorCall(ActorCallParams),
827 CallActorMethod(ActorMethodParams),
828 PipeRecv {
829 seq: Seq,
830 results: Vec<Option<Ref>>,
832 pipe: Ref,
834 stream: StreamRef,
836 },
837
838 Exit {
841 error: Option<(Option<ActorId>, String)>,
846 },
847
848 DefineRecording {
855 result: Ref,
858 nresults: usize,
860 nformals: usize,
862 commands: Vec<WorkerMessage>,
864 ntotal_messages: usize,
866 index: usize,
869 },
870
871 RecordingFormal {
873 result: Ref,
876 argument_index: usize,
878 stream: StreamRef,
880 },
881
882 RecordingResult {
884 result: Ref,
886 output_index: usize,
888 stream: StreamRef,
890 },
891
892 CallRecording {
895 seq: Seq,
897 recording: Ref,
899 results: Vec<Ref>,
902 actuals: Vec<Ref>,
904 },
905
906 SetRefUnitTestsOnly {
907 reference: Ref,
909 value: WireValue,
911 stream: StreamRef,
913 },
914
915 GetRefUnitTestsOnly {
916 value: Ref,
918 stream: StreamRef,
920 #[reply]
921 response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
922 },
923}
924
925#[derive(Debug, Clone, Serialize, Deserialize, Named)]
927pub struct WorkerParams {
928 pub world_size: usize,
930
931 pub rank: usize,
933
934 pub device_index: Option<i8>,
937
938 pub controller_actor: ActorRef<ControllerActor>,
940}
941
942hyperactor::alias!(
943 WorkerActor,
944 WorkerMessage { cast = true },
945);