1#![allow(unsafe_op_in_unsafe_fn)]
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use anyhow::Context;
18use derive_more::Display;
19use derive_more::From;
20use derive_more::TryInto;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::Bind;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Named;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29use hyperactor::reference::ActorId;
30use monarch_types::SerializablePyErr;
31use ndslice::Slice;
32use pyo3::exceptions::PyValueError;
33use pyo3::prelude::*;
34use pyo3::types::PyBytes;
35use pyo3::types::PyDict;
36use pyo3::types::PyTuple;
37use serde::Deserialize;
38use serde::Serialize;
39use thiserror::Error;
40use torch_sys::BorrowError;
41use torch_sys::Device;
42use torch_sys::Layout;
43use torch_sys::ScalarType;
44use torch_sys::call_op::CallOpError;
45use torch_sys_cuda::nccl::NcclConfig;
46use torch_sys_cuda::nccl::ReduceOp;
47use torch_sys_cuda::nccl::UniqueId;
48
49use crate::controller::ControllerActor;
50use crate::controller::Seq;
51use crate::wire_value::WireValue;
52
53#[derive(
54 Serialize,
55 Deserialize,
56 Debug,
57 Clone,
58 Hash,
59 PartialEq,
60 Eq,
61 Copy,
62 PartialOrd,
63 Ord,
64 From
65)]
66#[pyo3::pyclass(
67 frozen,
68 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
69)]
70pub struct StreamRef {
71 #[pyo3(get)]
72 pub id: u64,
73}
74
75#[pyo3::pymethods]
76impl StreamRef {
77 #[new]
78 #[pyo3(signature = (*, id))]
79 fn new(id: u64) -> Self {
80 Self { id }
81 }
82
83 fn __repr__(&self) -> String {
84 format!("StreamRef({})", self.id)
85 }
86
87 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
89 Ok(match op {
90 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
91 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
92 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
93 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
94 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
95 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
96 })
97 }
98
99 fn __hash__(&self) -> PyResult<u64> {
100 Ok(self.id)
101 }
102}
103
104#[derive(
109 Serialize,
110 Deserialize,
111 Debug,
112 Clone,
113 Hash,
114 PartialEq,
115 Eq,
116 Copy,
117 PartialOrd,
118 Ord,
119 From
120)]
121#[pyo3::pyclass(
122 frozen,
123 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
124)]
125pub struct Ref {
126 #[pyo3(get)]
127 pub id: u64,
128}
129
130#[pyo3::pymethods]
131impl Ref {
132 #[new]
133 fn new(id: u64) -> Self {
134 Self { id }
135 }
136
137 #[getter]
138 fn r#ref(&self) -> u64 {
139 self.id
140 }
141
142 fn __repr__(&self) -> String {
143 format!("Ref({})", self.id)
144 }
145
146 fn __richcmp__(&self, other: PyRef<Self>, op: pyo3::class::basic::CompareOp) -> PyResult<bool> {
148 Ok(match op {
149 pyo3::class::basic::CompareOp::Eq => self.id == other.id,
150 pyo3::class::basic::CompareOp::Ne => self.id != other.id,
151 pyo3::class::basic::CompareOp::Lt => self.id < other.id,
152 pyo3::class::basic::CompareOp::Le => self.id <= other.id,
153 pyo3::class::basic::CompareOp::Gt => self.id > other.id,
154 pyo3::class::basic::CompareOp::Ge => self.id >= other.id,
155 })
156 }
157
158 fn __hash__(&self) -> PyResult<u64> {
159 Ok(self.id)
160 }
161
162 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
163 let kwargs = PyDict::new(py);
164 kwargs.set_item("id", self.id).unwrap();
165
166 PyTuple::new(
167 py,
168 vec![
169 PyTuple::empty(py).unbind().into_any(),
170 kwargs.unbind().into_any(),
171 ],
172 )
173 }
174}
175
176impl Ref {
177 pub fn from_py_object(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
180 let attr_name = pyo3::intern!(obj.py(), "__monarch_ref__");
181 if let Ok(ref_) = obj.extract::<Ref>() {
182 return Ok(ref_);
183 }
184 if let Ok(func) = obj.getattr(attr_name) {
185 if let Ok(Ok(val)) = func.call0().map(|val| val.extract::<u64>()) {
186 return Ok(val.into());
187 }
188 }
189 Err(PyValueError::new_err("Could not convert object to Ref"))
190 }
191}
192
193impl Display for Ref {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "r{}", self.id)
196 }
197}
198
199#[derive(PartialEq, Serialize, Deserialize, Debug, Clone)]
203#[pyo3::pyclass(
204 frozen,
205 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
206)]
207pub struct FunctionPath {
208 #[pyo3(get)]
209 pub path: String,
210}
211
212impl fmt::Display for FunctionPath {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 write!(f, "<function \"{}\">", self.path)
215 }
216}
217
218impl<T: Into<String>> From<T> for FunctionPath {
219 fn from(val: T) -> Self {
220 Self { path: val.into() }
221 }
222}
223
224#[pyo3::pymethods]
225impl FunctionPath {
226 #[new]
227 #[pyo3(signature = (*, path))]
228 pub fn new(path: String) -> Self {
229 Self { path }
230 }
231
232 fn __repr__(&self) -> String {
233 self.path.clone()
234 }
235
236 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
237 let (start, rest) = self.path.split_once(".").with_context(|| {
238 format!(
239 "invalid function path {}: paths must be fully qualified",
240 self.path
241 )
242 })?;
243 if start == "torch" {
244 let mut cur = py.import("torch")?.into_any();
245 for p in rest.split(".") {
246 cur = cur.getattr(p)?;
247 }
248 Ok(cur)
249 } else {
250 let (module_fqn, function_name) = self.path.rsplit_once(".").with_context(|| {
251 format!(
252 "invalid function path {}: paths must be fully qualified",
253 self.path
254 )
255 })?;
256 let module = PyModule::import(py, module_fqn)?;
257 let mut function = module.getattr(function_name)?;
258 if function.hasattr("_remote_impl")? {
259 function = function.getattr("_remote_impl")?;
260 }
261 Ok(function.downcast_into()?)
262 }
263 }
264}
265
266#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)]
270#[pyo3::pyclass(
271 frozen,
272 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
273)]
274pub struct Cloudpickle {
275 #[serde(with = "serde_bytes")]
276 bytes: Vec<u8>,
277}
278
279impl fmt::Display for Cloudpickle {
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 write!(f, "<cloud-pickle>")
282 }
283}
284
285#[pyo3::pymethods]
286impl Cloudpickle {
287 #[new]
288 #[pyo3(signature = (*, bytes))]
289 pub fn new(bytes: Vec<u8>) -> Self {
290 Self { bytes }
291 }
292
293 fn __repr__(&self) -> String {
294 format!("Cloudpickle(bytes={:?})", self.bytes)
295 }
296
297 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
298 let module = PyModule::import(py, "cloudpickle")?;
299 let loads = module.getattr("loads")?;
300 loads.call1((PyBytes::new(py, &self.bytes),))
301 }
302}
303
304#[derive(
305 PartialEq,
306 Serialize,
307 Deserialize,
308 Debug,
309 Clone,
310 TryInto,
311 From,
312 FromPyObject,
313 Display
314)]
315pub enum ResolvableFunction {
316 #[pyo3(transparent)]
317 Cloudpickle(Cloudpickle),
318 #[pyo3(transparent)]
319 FunctionPath(FunctionPath),
320}
321
322impl<'py> IntoPyObject<'py> for ResolvableFunction {
323 type Target = PyAny;
324 type Output = Bound<'py, Self::Target>;
325 type Error = PyErr;
326
327 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
328 Ok(match self {
329 Self::Cloudpickle(func) => func.into_pyobject(py)?.into_any(),
330 Self::FunctionPath(func) => func.into_pyobject(py)?.into_any(),
331 })
332 }
333}
334
335impl ResolvableFunction {
336 pub fn resolve<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
337 match self {
338 Self::Cloudpickle(func) => Ok(func.resolve(py)?.into_any()),
339 Self::FunctionPath(func) => func.resolve(py),
340 }
341 }
342
343 pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
344 match self {
345 Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
346 ["torch", "ops", namespace, op_name, "default"] => {
347 Some((format!("{}::{}", namespace, op_name), String::new()))
348 }
349 ["torch", "ops", namespace, op_name, overload] => {
350 Some((format!("{}::{}", namespace, op_name), overload.to_string()))
351 }
352 _ => None,
353 },
354 _ => None,
355 }
356 }
357
358 pub fn panic_if_requested(&self) {
361 match self {
362 Self::FunctionPath(func) => {
363 if func.path == "__test_panic" {
364 panic!("__test_panic called");
365 }
366 }
367 _ => (),
368 }
369 }
370
371 pub fn supports_pytree_args(&self) -> bool {
372 match self {
373 Self::Cloudpickle(_) => true,
374 Self::FunctionPath(_) => self.as_torch_op().is_none(),
375 }
376 }
377}
378
379impl<T: Into<String>> From<T> for ResolvableFunction {
380 fn from(val: T) -> Self {
381 FunctionPath::from(val).into()
382 }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone)]
386pub struct CallFunctionParams {
387 pub seq: Seq,
389 pub results: Vec<Option<Ref>>,
391 pub mutates: Vec<Ref>,
393 pub function: ResolvableFunction,
395 pub args: Vec<WireValue>,
397 pub kwargs: HashMap<String, WireValue>,
399 pub stream: StreamRef,
401 pub remote_process_groups: Vec<Ref>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct ActorCallParams {
407 pub seq: Seq,
408 pub broker_id: (String, usize),
410 pub local_state: Vec<Ref>,
414 pub mutates: Vec<Ref>,
416 pub stream: StreamRef,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct ActorMethodParams {
421 pub results: Vec<Option<Ref>>,
422 pub call: ActorCallParams,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub enum Reduction {
428 Stack,
430 ReduceOp(ReduceOp),
432}
433
434#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
435#[pyo3::pyclass(
436 frozen,
437 name = "TensorFactory",
438 module = "monarch._rust_bindings.monarch_extension.tensor_worker"
439)]
440pub struct Factory {
441 pub size: Vec<i64>,
442 #[serde(with = "torch_sys::ScalarTypeDef")]
443 pub dtype: ScalarType,
444 #[serde(with = "torch_sys::LayoutDef")]
445 pub layout: Layout,
446 pub device: Device,
447}
448
449#[pyo3::pymethods]
450impl Factory {
451 #[new]
452 #[pyo3(signature = (*, size, dtype, layout, device))]
453 pub fn new(
454 py: Python<'_>,
455 size: Vec<i64>,
456 dtype: PyObject,
457 layout: PyObject,
458 device: PyObject,
459 ) -> PyResult<Self> {
460 Ok(Self {
463 size,
464 dtype: dtype.extract::<ScalarType>(py)?,
465 layout: layout.extract::<Layout>(py)?,
466 device: device.extract::<Device>(py)?,
467 })
468 }
469
470 #[staticmethod]
471 pub fn from_py(obj: Bound<'_, PyAny>) -> PyResult<Self> {
472 Self::new(
473 obj.py(),
474 obj.getattr("size")?.extract()?,
475 obj.getattr("dtype")?.unbind(),
476 obj.getattr("layout")?.unbind(),
477 obj.getattr("device")?.unbind(),
478 )
479 }
480
481 #[getter]
482 fn size<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
483 PyTuple::new(py, self.size.iter())
484 }
485
486 #[getter]
487 fn dtype<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
488 self.dtype.into_pyobject(py)
489 }
490
491 #[getter]
492 fn layout<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
493 self.layout.into_pyobject(py)
494 }
495
496 #[getter]
497 fn device(&self) -> String {
498 self.device.to_string()
499 }
500}
501
502#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
504#[pyo3::pyclass(
505 module = "monarch._rust_bindings.monarch_extension.tensor_worker",
506 eq,
507 eq_int
508)]
509pub enum StreamCreationMode {
510 UseDefaultStream,
512 CreateNewStream,
514}
515
516#[derive(Debug, Named)]
523#[named(register = false)]
524pub struct SeqError {
525 pub seq: Seq,
526 pub error: anyhow::Error,
527}
528
529impl Display for SeqError {
530 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531 self.error.fmt(f)
532 }
533}
534
535#[derive(Error, Debug, Named)]
539#[named(register = false)]
540pub enum CallFunctionError {
541 #[error("{0}")]
542 Error(#[from] anyhow::Error),
543 #[error("Computation depended on an input that failed with error: {0}")]
544 DependentError(Arc<SeqError>),
545}
546
547impl CallFunctionError {
548 #[allow(non_snake_case)]
550 pub fn RefNotFound(r: Ref) -> Self {
551 Self::Error(anyhow::anyhow!("ref not found: {}", r))
552 }
553
554 #[allow(non_snake_case)]
555 pub fn InvalidRemoteFunction(msg: String) -> Self {
556 Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
557 }
558
559 #[allow(non_snake_case)]
560 pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
561 Self::Error(anyhow::anyhow!(
562 "unsupported arg type for {} function: {}",
563 function_type,
564 arg_type
565 ))
566 }
567
568 #[allow(non_snake_case)]
569 pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
570 Self::Error(anyhow::anyhow!("remote function failed: {}", err))
571 }
572
573 #[allow(non_snake_case)]
574 pub fn BorrowError(err: BorrowError) -> Self {
575 Self::Error(anyhow::anyhow!("borrow failed: {}", err))
576 }
577
578 #[allow(non_snake_case)]
579 pub fn OperatorFailed(err: CallOpError) -> Self {
580 Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
581 }
582
583 #[allow(non_snake_case)]
584 pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
585 Self::Error(anyhow::anyhow!(
586 "unexpected number of returns from op, expected {}, got {}",
587 expected,
588 actual
589 ))
590 }
591
592 #[allow(non_snake_case)]
593 pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
594 Self::Error(anyhow::anyhow!(
595 "expected only a single arg (and no kwargs) when no function is given: {}, {}",
596 args,
597 kwargs
598 ))
599 }
600
601 #[allow(non_snake_case)]
602 pub fn Anyhow(err: anyhow::Error) -> Self {
603 Self::Error(err)
604 }
605}
606
607impl From<SerializablePyErr> for CallFunctionError {
608 fn from(v: SerializablePyErr) -> CallFunctionError {
609 CallFunctionError::Error(v.into())
610 }
611}
612
613impl From<BorrowError> for CallFunctionError {
614 fn from(v: BorrowError) -> CallFunctionError {
615 CallFunctionError::Error(v.into())
616 }
617}
618
619impl From<CallOpError> for CallFunctionError {
620 fn from(v: CallOpError) -> CallFunctionError {
621 CallFunctionError::Error(v.into())
622 }
623}
624
625#[derive(
628 Handler,
629 HandleClient,
630 RefClient,
631 Clone,
632 Serialize,
633 Deserialize,
634 Debug,
635 Named,
636 EnumAsInner,
637 Bind,
638 Unbind
639)]
640pub enum WorkerMessage {
641 BackendNetworkInit(UniqueId),
643
644 BackendNetworkPointToPointInit {
646 from_stream: StreamRef,
647 to_stream: StreamRef,
648 },
649
650 CallFunction(CallFunctionParams),
652
653 CommandGroup(Vec<WorkerMessage>),
656
657 CreateStream {
661 id: StreamRef,
663 stream_creation: StreamCreationMode,
665 },
666
667 CreateDeviceMesh {
670 result: Ref,
671 names: Vec<String>,
672 ranks: Slice,
673 },
674
675 CreateRemoteProcessGroup {
678 result: Ref,
679 device_mesh: Ref,
680 dims: Vec<String>,
681 },
682
683 BorrowCreate {
689 result: Ref,
691 borrow: u64,
693 tensor: Ref,
695 from_stream: StreamRef,
697 to_stream: StreamRef,
699 },
700
701 BorrowFirstUse {
704 borrow: u64,
705 },
706
707 BorrowLastUse {
710 borrow: u64,
711 },
712
713 BorrowDrop {
715 borrow: u64,
716 },
717
718 DeleteRefs(Vec<Ref>),
720
721 RequestStatus {
724 seq: Seq,
725 controller: bool,
726 },
727
728 Reduce {
731 result: Ref,
733 tensor: Ref,
735 factory: Factory,
740 mesh: Ref,
742 stream: StreamRef,
744 dims: Vec<String>,
747 reduction: Reduction,
749 scatter: bool,
752 in_place: bool,
754 out: Option<Ref>,
756 },
757
758 SplitComm {
761 dims: Vec<String>,
764 device_mesh: Ref,
767 stream: StreamRef,
771 config: Option<NcclConfig>,
775 },
776
777 SplitCommForProcessGroup {
780 remote_process_group: Ref,
783 stream: StreamRef,
787 config: Option<NcclConfig>,
791 },
792
793 SendTensor {
794 result: Ref,
795 from_ranks: Slice,
796 to_ranks: Slice,
797 tensor: Ref,
798 factory: Factory,
799 from_stream: StreamRef,
800 to_stream: StreamRef,
801 },
802
803 CreatePipe {
804 result: Ref,
805 key: String,
806 function: ResolvableFunction,
807 max_messages: i64,
808 mesh: Ref,
809 args: Vec<WireValue>,
810 kwargs: HashMap<String, WireValue>,
811 },
812
813 SendValue {
814 seq: Seq,
815 destination: Option<Ref>,
817 mutates: Vec<Ref>,
818 function: Option<ResolvableFunction>,
822 args: Vec<WireValue>,
823 kwargs: HashMap<String, WireValue>,
824 stream: StreamRef,
826 },
827
828 SendResultOfActorCall(ActorCallParams),
829 CallActorMethod(ActorMethodParams),
830 PipeRecv {
831 seq: Seq,
832 results: Vec<Option<Ref>>,
834 pipe: Ref,
836 stream: StreamRef,
838 },
839
840 Exit {
843 error: Option<(Option<ActorId>, String)>,
848 },
849
850 DefineRecording {
857 result: Ref,
860 nresults: usize,
862 nformals: usize,
864 commands: Vec<WorkerMessage>,
866 ntotal_messages: usize,
868 index: usize,
871 },
872
873 RecordingFormal {
875 result: Ref,
878 argument_index: usize,
880 stream: StreamRef,
882 },
883
884 RecordingResult {
886 result: Ref,
888 output_index: usize,
890 stream: StreamRef,
892 },
893
894 CallRecording {
897 seq: Seq,
899 recording: Ref,
901 results: Vec<Ref>,
904 actuals: Vec<Ref>,
906 },
907
908 SetRefUnitTestsOnly {
909 reference: Ref,
911 value: WireValue,
913 stream: StreamRef,
915 },
916
917 GetRefUnitTestsOnly {
918 value: Ref,
920 stream: StreamRef,
922 #[reply]
923 response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
924 },
925}
926
927#[derive(Debug, Clone, Serialize, Deserialize, Named)]
929pub struct WorkerParams {
930 pub world_size: usize,
932
933 pub rank: usize,
935
936 pub device_index: Option<i8>,
939
940 pub controller_actor: ActorRef<ControllerActor>,
942}
943
944hyperactor::behavior!(
945 WorkerActor,
946 WorkerMessage { cast = true },
947);