1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::OncePortHandle;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::proc::Proc;
38use monarch_hyperactor::actor::PythonMessage;
39use monarch_hyperactor::actor::PythonMessageKind;
40use monarch_hyperactor::buffers::FrozenBuffer;
41use monarch_hyperactor::local_state_broker::BrokerId;
42use monarch_hyperactor::local_state_broker::LocalState;
43use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::CallFunctionError;
50use monarch_messages::worker::CallFunctionParams;
51use monarch_messages::worker::SeqError;
52use monarch_messages::worker::StreamRef;
53use monarch_types::PyTree;
54use monarch_types::SerializablePyErr;
55use monarch_types::TryIntoPyObjectUnsafe;
56use pyo3::prelude::*;
57use pyo3::types::PyTuple;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys::BorrowType;
62use torch_sys::CudaDevice;
63use torch_sys::MultiBorrow;
64use torch_sys::RValue;
65use torch_sys::TensorCell;
66use torch_sys::deep_clone;
67use torch_sys::factory_empty;
68use torch_sys::factory_zeros;
69use torch_sys_cuda::cuda::Event;
70use torch_sys_cuda::cuda::Stream;
71use tracing_subscriber::fmt::Subscriber;
72
73use crate::ControllerActor;
74use crate::DeviceMesh;
75use crate::Factory;
76use crate::Reduction;
77use crate::Ref;
78use crate::ResolvableFunction;
79use crate::StreamCreationMode;
80use crate::WireValue;
81use crate::comm::CommBackend;
82use crate::comm::CommMessage;
83use crate::comm::CommMessageClient;
84use crate::comm::NcclCommActor;
85
86pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
87
88thread_local! {
90 pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
91 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
92 pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
93}
94
95fn pickle_python_result(
96 py: Python<'_>,
97 result: Bound<'_, PyAny>,
98 worker_rank: usize,
99) -> Result<PythonMessage, anyhow::Error> {
100 let pickle = py
101 .import("monarch._src.actor.actor_mesh")
102 .unwrap()
103 .getattr("_pickle")
104 .unwrap();
105 let data: FrozenBuffer = pickle
106 .call1((result,))
107 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
108 .extract()
109 .unwrap();
110 Ok(PythonMessage::new_from_buf(
111 PythonMessageKind::Result {
112 rank: Some(worker_rank),
113 },
114 data.inner,
115 ))
116}
117
118#[derive(Debug)]
119struct Recording {
120 messages: Vec<StreamMessage>,
121}
122
123impl Recording {
124 fn new() -> Self {
125 Self {
126 messages: Vec::new(),
127 }
128 }
129}
130
131#[derive(Debug, PartialEq)]
132enum RecordingState {
133 Defining {
134 recording: Ref,
135 defined_borrows: HashSet<u64>,
138 },
139 Running,
140}
141
142#[derive(Handler, HandleClient, Debug, Named)]
145#[named(register = false)]
146pub enum StreamMessage {
147 CallFunction(
148 CallFunctionParams,
149 HashMap<Ref, DeviceMesh>,
150 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
151 ),
152
153 BorrowCreate {
154 borrow: u64,
156 tensor: Ref,
158 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
161 },
162
163 BorrowFirstUse {
164 borrow: u64,
166 result: Ref,
168 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
171 },
172
173 BorrowLastUse {
174 borrow: u64,
176 result: Ref,
178 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
180 },
181
182 BorrowDrop {
183 borrow: u64,
184 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
186 },
187
188 DeleteRefs(Vec<Ref>),
189
190 RequestStatus(#[reply] OncePortHandle<()>),
191
192 InitComm(ActorHandle<NcclCommActor>),
193
194 Reduce {
195 comm: Arc<ActorHandle<NcclCommActor>>,
196 dim_size: i64,
197 result: Ref,
198 local_tensor: Ref,
199 factory: Factory,
200 reduction: Reduction,
201 scatter: bool,
202 in_place: bool,
203 out: Option<Ref>,
204 },
205
206 SendTensor {
207 result: Ref,
208 from_rank: Option<usize>,
209 to_rank: Option<usize>,
210 tensor: Ref,
211 factory: Factory,
212 comm: Arc<ActorHandle<NcclCommActor>>,
213 },
214
215 SendValue {
216 seq: Seq,
217 worker_actor_id: ActorId,
218 mutates: Vec<Ref>,
219 function: Option<ResolvableFunction>,
220 args: Vec<WireValue>,
221 kwargs: HashMap<String, WireValue>,
222 device_meshes: HashMap<Ref, DeviceMesh>,
223 },
224
225 DefineRecording {
226 recording: Ref,
227 },
228
229 FinalizeRecording {
230 recording: Ref,
231 },
232
233 CallRecording {
234 seq: Seq,
235 recording: Ref,
236 results: Vec<Ref>,
237 actuals: Vec<Ref>,
238 },
239
240 RecordingFormal {
241 result: Ref,
242 argument_index: usize,
243 },
244
245 RecordingResult {
246 result: Ref,
247 output_index: usize,
248 },
249
250 SetRefUnitTestsOnly(Ref, WireValue),
251
252 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
253
254 GetRefUnitTestsOnly(
255 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
257 ),
258
259 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
260
261 SendResultOfActorCall(ActorCallParams),
262 CallActorMethod(ActorMethodParams),
263}
264
265impl StreamMessage {
266 fn clone_for_recording(&self) -> Self {
267 match self {
268 StreamMessage::RecordingFormal {
269 result,
270 argument_index,
271 } => StreamMessage::RecordingFormal {
272 result: *result,
273 argument_index: *argument_index,
274 },
275 StreamMessage::RecordingResult {
276 result,
277 output_index,
278 } => StreamMessage::RecordingResult {
279 result: *result,
280 output_index: *output_index,
281 },
282 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
283 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
284 StreamMessage::CallFunction(
285 params.clone(),
286 device_meshes.clone(),
287 remote_process_groups.clone(),
288 )
289 }
290 StreamMessage::BorrowCreate {
291 borrow,
292 tensor,
293 first_use_sender,
294 } => StreamMessage::BorrowCreate {
295 borrow: *borrow,
296 tensor: *tensor,
297 first_use_sender: first_use_sender.clone(),
298 },
299 StreamMessage::BorrowFirstUse {
300 borrow,
301 result,
302 first_use_receiver,
303 } => StreamMessage::BorrowFirstUse {
304 borrow: *borrow,
305 result: *result,
306 first_use_receiver: first_use_receiver.clone(),
307 },
308 StreamMessage::BorrowLastUse {
309 borrow,
310 result,
311 last_use_sender,
312 } => StreamMessage::BorrowLastUse {
313 borrow: *borrow,
314 result: *result,
315 last_use_sender: last_use_sender.clone(),
316 },
317 StreamMessage::BorrowDrop {
318 borrow,
319 last_use_receiver,
320 } => StreamMessage::BorrowDrop {
321 borrow: *borrow,
322 last_use_receiver: last_use_receiver.clone(),
323 },
324 StreamMessage::Reduce {
325 comm,
326 dim_size,
327 result,
328 local_tensor,
329 factory,
330 reduction,
331 scatter,
332 in_place,
333 out,
334 } => StreamMessage::Reduce {
335 comm: comm.clone(),
336 dim_size: *dim_size,
337 result: *result,
338 local_tensor: *local_tensor,
339 factory: factory.clone(),
340 reduction: reduction.clone(),
341 scatter: *scatter,
342 in_place: *in_place,
343 out: out.clone(),
344 },
345 StreamMessage::SendTensor {
346 result,
347 from_rank,
348 to_rank,
349 tensor,
350 factory,
351 comm,
352 } => StreamMessage::SendTensor {
353 result: *result,
354 from_rank: *from_rank,
355 to_rank: *to_rank,
356 tensor: *tensor,
357 factory: factory.clone(),
358 comm: comm.clone(),
359 },
360 other => panic!(
361 "StreamMessage variant not supported in recording: {:?}",
362 other
363 ),
364 }
365 }
366
367 fn get_defined_refs(&self) -> HashSet<Ref> {
369 match self {
370 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
371 StreamMessage::CallFunction(params, ..) => {
372 params.results.iter().filter_map(|&ref_| ref_).collect()
373 }
374 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
375 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
376 StreamMessage::SendTensor {
377 result, from_rank, ..
378 } => {
379 if from_rank.is_some() {
380 HashSet::from([*result])
381 } else {
382 HashSet::new()
383 }
384 }
385 _ => HashSet::new(),
387 }
388 }
389
390 fn get_mutated_refs(&self) -> HashSet<Ref> {
392 match self {
393 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
394 StreamMessage::Reduce {
395 out,
396 in_place,
397 local_tensor,
398 ..
399 } => {
400 if *in_place {
401 HashSet::from([*local_tensor])
402 } else if let Some(out) = out {
403 HashSet::from([*out])
404 } else {
405 HashSet::new()
406 }
407 }
408 _ => HashSet::new(),
410 }
411 }
412}
413
414#[derive(Debug)]
423pub struct StreamActor {
424 world_size: usize,
425 rank: usize,
426 env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
430 creation_mode: StreamCreationMode,
432 cuda_stream: OnceLock<Option<Stream>>,
438 device: Option<CudaDevice>,
440 comm: Option<ActorHandle<NcclCommActor>>,
442 controller_actor: ActorRef<ControllerActor>,
444 remote_process_groups: HashMap<Ref, PyObject>,
445 recordings: HashMap<Ref, Recording>,
446 active_recording: Option<RecordingState>,
447 respond_with_python_message: bool,
448 last_seq_error: Option<Arc<SeqError>>,
449}
450
451#[derive(Debug, Clone)]
453pub struct StreamParams {
454 pub world_size: usize,
455 pub rank: usize,
456 pub creation_mode: StreamCreationMode,
458 pub id: StreamRef,
460 pub device: Option<CudaDevice>,
463 pub controller_actor: ActorRef<ControllerActor>,
465 pub respond_with_python_message: bool,
466}
467
468#[async_trait]
469impl Actor for StreamActor {
470 type Params = StreamParams;
471 async fn new(
472 StreamParams {
473 world_size,
474 rank,
475 id: _,
476 device,
477 controller_actor,
478 creation_mode,
479 respond_with_python_message,
480 }: Self::Params,
481 ) -> Result<Self> {
482 Ok(Self {
483 world_size,
484 rank,
485 env: HashMap::new(),
486 creation_mode,
487 cuda_stream: OnceLock::new(),
488 device,
489 comm: None,
490 controller_actor,
491 remote_process_groups: HashMap::new(),
492 recordings: HashMap::new(),
493 active_recording: None,
494 respond_with_python_message,
495 last_seq_error: None,
496 })
497 }
498
499 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
500 CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
504 controller_actor_ref.set(self.controller_actor.clone()).ok()
505 });
506 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
507 ROOT_ACTOR_ID.with(|root_actor_id| {
508 root_actor_id
509 .set(ActorId::root(
510 cx.self_id().proc_id().clone(),
511 cx.self_id().name().to_string(),
512 ))
513 .ok()
514 });
515 if let Some(stream) = self.cuda_stream() {
517 Stream::set_current_stream(stream);
518 }
519 Ok(())
520 }
521
522 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
528 where
529 F: Future + Send + 'static,
530 F::Output: Send + 'static,
531 {
532 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
533 let builder = std::thread::Builder::new().name("worker-stream".to_string());
541 let _thread_handle = builder.spawn(move || {
542 let rt = tokio::runtime::Builder::new_multi_thread()
546 .worker_threads(1)
547 .enable_all()
548 .build()
549 .unwrap();
550 let result = rt.block_on(async {
551 tokio::task::block_in_place(|| {
552 Python::with_gil(|py| {
557 py.allow_threads(|| {
558 let result = Handle::current().block_on(future);
559 if join_tx.send(result).is_err() {
560 panic!("could not send join result")
561 }
562 })
563 })
564 })
565 });
566 rt.shutdown_timeout(Duration::from_weeks(1));
567 result
568 });
569
570 tokio::spawn(async move { join_rx.await.unwrap() })
573 }
574}
575
576#[derive(Debug)]
578enum PyArg<'a> {
579 RValue(RValue),
580 DeviceMesh(&'a DeviceMesh),
581 PyObject(PyObject),
582}
583
584impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
586 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
587 match self {
588 PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
591 PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
592 PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
593 }
594 }
595}
596
597impl StreamActor {
598 fn cuda_stream(&self) -> Option<&Stream> {
599 self.cuda_stream
600 .get_or_init(|| {
601 self.device.map(|device| match self.creation_mode {
602 StreamCreationMode::UseDefaultStream => {
603 Stream::get_current_stream_on_device(device)
604 }
605 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
606 })
607 })
608 .as_ref()
609 }
610
611 fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
612 let rvalue = self
613 .env
614 .get(ref_)
615 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
616 match rvalue {
617 Ok(val) => Ok(val.clone()),
618 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
619 }
620 }
621
622 fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
623 let ret = match value {
624 WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
625 WireValue::RefList(val) => {
627 let mut ret = Vec::with_capacity(val.len());
628 for v in val {
629 match self.ref_to_rvalue(&v) {
630 Ok(RValue::Tensor(t)) => ret.push(t),
631 Err(err) => {
632 return Err(err);
633 }
634 Ok(val) => {
635 return Err(CallFunctionError::UnsupportedArgType(
636 "wire_to_rvalue".into(),
637 format!("RefList([{:?}])", val),
638 ));
639 }
640 }
641 }
642 RValue::TensorList(ret)
643 }
644 WireValue::Int(val) => RValue::Int(val),
645 WireValue::IntList(val) => RValue::IntList(val),
646 WireValue::Double(val) => RValue::Double(val),
647 WireValue::Bool(val) => RValue::Bool(val),
648 WireValue::String(val) => RValue::String(val),
649 WireValue::Device(val) => RValue::Device(val),
650 WireValue::Layout(val) => RValue::Layout(val),
651 WireValue::ScalarType(val) => RValue::ScalarType(val),
652 WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
653 WireValue::PyObject(val) => RValue::PyObject(val),
654 WireValue::None(()) => RValue::None,
655 WireValue::IValue(val) => RValue::Opaque(val.into()),
656 };
657 Ok(ret)
658 }
659
660 async fn report_seq_error(
661 &mut self,
662 cx: &Context<'_, Self>,
663 seq: Seq,
664 error: CallFunctionError,
665 ) -> Result<Arc<SeqError>, anyhow::Error> {
666 match error {
667 CallFunctionError::DependentError(root) => Ok(root),
668 CallFunctionError::Error(e) => {
669 if self.active_recording.is_none() {
670 let worker_error = WorkerError {
671 backtrace: format!("{e}"),
672 worker_actor_id: cx.self_id().clone(),
673 };
674 tracing::info!("Propagating remote function error to client: {worker_error}");
675 self.controller_actor
676 .remote_function_failed(cx, seq, worker_error)
677 .await?
678 }
679 let err = Arc::new(SeqError { seq, error: e });
680 self.last_seq_error = Some(err.clone());
681 Ok(err)
682 }
683 }
684 }
685
686 async fn try_define<F>(
687 &mut self,
688 cx: &Context<'_, Self>,
689 seq: Seq,
690 result_refs: Vec<Option<Ref>>,
691 mutates: &Vec<Ref>,
692 f: F,
693 ) -> Result<()>
694 where
695 F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
696 {
697 let actual_results = f(self).await;
698 let op_results = actual_results.and_then(|actual_results| {
701 if result_refs.len() == actual_results.len() {
702 Ok(actual_results
703 .into_iter()
704 .zip(result_refs.iter())
705 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
706 .collect::<Vec<(Ref, RValue)>>())
707 } else {
708 Err(CallFunctionError::UnexpectedNumberOfReturns(
709 result_refs.len(),
710 actual_results.len(),
711 ))
712 }
713 });
714
715 match op_results {
718 Ok(op_results) => {
719 for (ref_, rvalue) in op_results.into_iter() {
720 let prev = self.env.insert(ref_, Ok(rvalue));
721 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
722 }
723 }
724 Err(err) => {
725 let err = self.report_seq_error(cx, seq, err).await?;
726 for ref_ in result_refs {
727 match ref_ {
728 Some(ref_) => {
729 let prev = self.env.insert(ref_, Err(err.clone()));
730 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
731 }
732 None => {}
733 }
734 }
735 for ref_ in mutates {
736 self.env.insert(*ref_, Err(err.clone()));
737 }
738 }
739 }
740 Ok(())
741 }
742
743 fn call_torch_op(
744 &self,
745 op: String,
746 overload: String,
747 args: Vec<WireValue>,
748 kwargs: HashMap<String, WireValue>,
749 ) -> Result<Vec<RValue>, CallFunctionError> {
750 let args = args
751 .into_iter()
752 .map(|arg| self.wire_to_rvalue(arg))
753 .collect::<Result<Vec<_>, _>>()?;
754 let kwargs = kwargs
755 .into_iter()
756 .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
757 .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
758
759 let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
760
761 Ok(if results.is_empty() {
765 vec![RValue::None]
766 } else {
767 results
768 })
769 }
770
771 fn call_python_fn<'py>(
772 &mut self,
773 py: Python<'py>,
774 cx: &Context<Self>,
775 function: Option<ResolvableFunction>,
776 args: Vec<WireValue>,
777 kwargs: HashMap<String, WireValue>,
778 mutates: &[Ref],
779 device_meshes: HashMap<Ref, DeviceMesh>,
780 remote_process_groups: HashMap<
781 Ref,
782 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
783 >,
784 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
785 let function = function
786 .map(|function| {
787 function.resolve(py).map_err(|e| {
788 CallFunctionError::InvalidRemoteFunction(format!(
789 "failed to resolve function {}: {}",
790 function,
791 SerializablePyErr::from(py, &e)
792 ))
793 })
794 })
795 .transpose()?;
796
797 let remote_process_groups = remote_process_groups
798 .into_iter()
799 .map(|(gref, (mesh, dims, comm))| {
800 let group = match self.remote_process_groups.entry(gref) {
801 Entry::Occupied(ent) => ent.get().clone_ref(py),
802 Entry::Vacant(ent) => {
803 torch_sys::backend::ensure_init_process_group(
806 py,
807 self.world_size,
808 self.rank,
809 )?;
810
811 let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
814 let group_size = ranks.len();
815 let (child_instance, _) = cx.child()?;
816 let backend = CommBackend::new(
817 child_instance,
818 comm,
819 self.rank,
820 group_size,
821 self.world_size,
822 );
823 ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
824 .clone_ref(py)
825 }
826 };
827 PyResult::Ok((gref, group))
828 })
829 .collect::<Result<HashMap<_, _>, _>>()
830 .map_err(SerializablePyErr::from_fn(py))?;
831
832 let mut multiborrow = MultiBorrow::new();
836
837 let resolve = |val: WireValue| {
838 val.into_py_object()
839 .map_err(|e| {
840 CallFunctionError::UnsupportedArgType(
841 format!("{:?}", function),
842 format!("{:?}", e),
843 )
844 })?
845 .unpickle(py)
846 .map_err(SerializablePyErr::from_fn(py))?
847 .extract::<PyTree<PyObject>>()
848 .map_err(SerializablePyErr::from_fn(py))?
849 .try_into_map(|obj| {
850 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
851 if let Some(mesh) = device_meshes.get(&ref_) {
852 PyArg::DeviceMesh(mesh)
853 } else if let Some(pg) = remote_process_groups.get(&ref_) {
854 PyArg::PyObject(pg.clone_ref(py))
855 } else {
856 let rval = self.ref_to_rvalue(&ref_)?;
857 PyArg::RValue(rval)
858 }
859 } else {
860 PyArg::PyObject(obj)
861 })
862 })
863 };
864
865 let py_args: Vec<PyTree<PyArg>> = args
867 .into_iter()
868 .map(resolve)
869 .collect::<Result<_, CallFunctionError>>()?;
870 let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
871 .into_iter()
872 .map(|(k, object)| Ok((k, resolve(object)?)))
873 .collect::<Result<_, CallFunctionError>>()?;
874
875 py_args
877 .iter()
878 .chain(py_kwargs.values())
879 .flat_map(|o| o.iter())
880 .for_each(|arg| {
881 if let PyArg::RValue(rval) = arg {
882 multiborrow.add(rval, BorrowType::Shared);
883 }
884 });
885
886 let mutates: Vec<_> = mutates
888 .iter()
889 .map(|r| self.ref_to_rvalue(r))
890 .collect::<Result<_, CallFunctionError>>()?;
891 mutates
892 .iter()
893 .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
894
895 let _borrow = multiborrow.borrow()?;
897
898 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
901 let result: Bound<'_, PyAny> =
902 tracing::subscriber::with_default(scoped_subscriber, || {
903 let args = unsafe { py_args.try_to_object_unsafe(py) }
912 .map_err(SerializablePyErr::from_fn(py))?;
913 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
915 .map_err(SerializablePyErr::from_fn(py))?;
916
917 if let Some(function) = function {
918 function
919 .call(args, Some(kwargs))
920 .map_err(SerializablePyErr::from_fn(py))
921 } else {
922 Ok(args.get_item(0).unwrap())
923 }
924 })?;
925 Ok(result)
926 }
927
928 fn call_python_fn_pytree(
929 &mut self,
930 cx: &hyperactor::Context<Self>,
931 function: ResolvableFunction,
932 args: Vec<WireValue>,
933 kwargs: HashMap<String, WireValue>,
934 mutates: &[Ref],
935 device_meshes: HashMap<Ref, DeviceMesh>,
936 remote_process_groups: HashMap<
937 Ref,
938 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
939 >,
940 ) -> Result<PyTree<RValue>, CallFunctionError> {
941 Python::with_gil(|py| {
942 let result = self.call_python_fn(
943 py,
944 cx,
945 Some(function),
946 args,
947 kwargs,
948 mutates,
949 device_meshes,
950 remote_process_groups,
951 )?;
952 Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
953 })
954 }
955 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
963 let rvalue = self
964 .env
965 .get(&ref_)
966 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
967
968 match rvalue {
969 Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
970 Err(_) => {
971 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
972 Ok(TensorCell::new(t))
973 }
974 }
975 }
976
977 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
978 self.active_recording
979 .as_mut()
980 .and_then(|state| match state {
981 RecordingState::Defining {
982 recording,
983 defined_borrows,
984 } => {
985 match self.recordings.get_mut(recording) {
986 Some(recording) => Some((recording, defined_borrows)),
987 None => panic!("recording not found: {:?}", recording),
989 }
990 }
991 RecordingState::Running => None,
992 })
993 }
994
995 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
996 for ref_ in refs {
997 let rvalue_or_err = self
998 .env
999 .get(ref_)
1000 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1001 if let Err(err) = rvalue_or_err {
1002 return Ok(Some(err.clone()));
1003 }
1004 }
1005 Ok(None)
1006 }
1007 async fn send_value_python_message(
1008 &mut self,
1009 cx: &hyperactor::Context<'_, Self>,
1010 seq: Seq,
1011 mutates: Vec<Ref>,
1012 function: Option<ResolvableFunction>,
1013 args: Vec<WireValue>,
1014 kwargs: HashMap<String, WireValue>,
1015 device_meshes: HashMap<Ref, DeviceMesh>,
1016 ) -> Result<()> {
1017 let rank = self.rank;
1018 self.try_define(cx, seq, vec![], &vec![], async |self_| {
1019 let python_message =
1020 Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1021 let python_result = tokio::task::block_in_place(|| {
1022 self_.call_python_fn(
1023 py,
1024 cx,
1025 function,
1026 args,
1027 kwargs,
1028 &mutates,
1029 device_meshes,
1030 HashMap::new(),
1031 )
1032 })?;
1033 pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
1034 })?;
1035 let ser = Serialized::serialize(&python_message).unwrap();
1036 self_
1037 .controller_actor
1038 .fetch_result(cx, seq, Ok(ser))
1039 .await?;
1040 Ok(vec![])
1041 })
1042 .await
1043 }
1044 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1045 let rvalue = self
1046 .env
1047 .get(&src)
1048 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1049 self.env.insert(dest, rvalue.clone());
1050 Ok(())
1051 }
1052 async fn call_actor(
1053 &mut self,
1054 cx: &Context<'_, Self>,
1055 params: ActorCallParams,
1056 ) -> Result<PyObject, CallFunctionError> {
1057 let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1058 params
1059 .local_state
1060 .into_iter()
1061 .map(|elem| {
1062 unsafe {
1064 let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1065 Ok(x)
1066 }
1067 })
1068 .collect()
1069 });
1070
1071 let (send, recv) = cx.open_once_port();
1072 let state = LocalState {
1073 response_port: send,
1074 state: local_state?,
1075 };
1076 let x: u64 = params.seq.into();
1077 let message = LocalStateBrokerMessage::Set(x as usize, state);
1078
1079 let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1080 broker
1081 .send(message)
1082 .map_err(|e| CallFunctionError::Error(e.into()))?;
1083 let result = recv
1084 .recv()
1085 .await
1086 .map_err(|e| CallFunctionError::Error(e.into()))?;
1087
1088 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1089 }
1090}
1091
1092#[async_trait]
1093#[forward(StreamMessage)]
1094impl StreamMessageHandler for StreamActor {
1095 async fn call_function(
1096 &mut self,
1097 cx: &Context<Self>,
1098 params: CallFunctionParams,
1099 device_meshes: HashMap<Ref, DeviceMesh>,
1100 remote_process_groups: HashMap<
1101 Ref,
1102 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1103 >,
1104 ) -> Result<()> {
1105 if let Some((recording, _)) = self.get_defining_recording() {
1106 recording.messages.push(StreamMessage::CallFunction(
1107 params,
1108 device_meshes,
1109 remote_process_groups,
1110 ));
1111 return Ok(());
1112 }
1113
1114 params.function.panic_if_requested();
1115 self.try_define(
1116 cx,
1117 params.seq,
1118 params.results,
1119 ¶ms.mutates,
1120 async |self| {
1121 tokio::task::block_in_place(|| match params.function.as_torch_op() {
1122 Some((op, overload)) => {
1123 self.call_torch_op(op, overload, params.args, params.kwargs)
1124 }
1125 _ => self
1126 .call_python_fn_pytree(
1127 cx,
1128 params.function,
1129 params.args,
1130 params.kwargs,
1131 ¶ms.mutates,
1132 device_meshes,
1133 remote_process_groups,
1134 )
1135 .map(|results| results.into_leaves()),
1136 })
1137 },
1138 )
1139 .await?;
1140 Ok(())
1141 }
1142
1143 async fn borrow_create(
1144 &mut self,
1145 _cx: &Context<Self>,
1146 borrow: u64,
1147 tensor: Ref,
1148 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1149 ) -> Result<()> {
1150 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1151 recording.messages.push(StreamMessage::BorrowCreate {
1152 borrow,
1153 tensor,
1154 first_use_sender,
1155 });
1156 ensure!(
1157 defined_borrows.insert(borrow),
1158 "duplicate borrow create in recording"
1159 );
1160 return Ok(());
1161 }
1162
1163 let rvalue_result = self
1164 .env
1165 .get(&tensor)
1166 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1167
1168 let result = match rvalue_result {
1169 Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1170 Err(e) => Err(e.clone()),
1171 };
1172
1173 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1174 first_use_sender.send((event, result)).map_err(|err| {
1175 anyhow!(
1176 "failed sending first use event for borrow {:?}: {:?}",
1177 borrow,
1178 err
1179 )
1180 })
1181 }
1182
1183 async fn borrow_first_use(
1184 &mut self,
1185 _cx: &Context<Self>,
1186 borrow: u64,
1187 result: Ref,
1188 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1189 ) -> Result<()> {
1190 if let Some((recording, _)) = self.get_defining_recording() {
1191 recording.messages.push(StreamMessage::BorrowFirstUse {
1192 borrow,
1193 result,
1194 first_use_receiver: first_use_receiver.clone(),
1195 });
1196 return Ok(());
1197 }
1198
1199 let (first_use_event, cell) =
1200 first_use_receiver
1201 .lock()
1202 .await
1203 .recv()
1204 .await
1205 .map_err(|err| {
1206 anyhow!(
1207 "failed receiving first use event for borrow {:?}: {:?}",
1208 borrow,
1209 err
1210 )
1211 })?;
1212
1213 if let Some(stream) = self.cuda_stream() {
1214 stream.wait_event(
1215 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1216 );
1217 }
1218 match cell {
1219 Ok(cell) => {
1220 self.env.insert(result, Ok(cell.into()));
1221 }
1222 Err(err) => {
1223 self.env.insert(result, Err(err.clone()));
1224 }
1225 }
1226 Ok(())
1227 }
1228
1229 async fn borrow_last_use(
1230 &mut self,
1231 _cx: &Context<Self>,
1232 borrow: u64,
1233 result: Ref,
1234 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1235 ) -> Result<()> {
1236 if let Some((recording, _)) = self.get_defining_recording() {
1237 recording.messages.push(StreamMessage::BorrowLastUse {
1238 borrow,
1239 result,
1240 last_use_sender,
1241 });
1242 return Ok(());
1243 }
1244
1245 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1246 let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1247 "Invalid reference for borrow_last_use: {result:#?}"
1248 ))?;
1249 let tensor = match rvalue_or_err {
1250 Ok(RValue::Tensor(t)) => Ok(t),
1251 Err(e) => Err(e),
1252 _ => bail!("invalid rvalue type for borrow_last_use"),
1253 };
1254
1255 last_use_sender.send((event, tensor)).map_err(|err| {
1256 anyhow!(
1257 "failed sending last use event for borrow {:?}: {:?}",
1258 borrow,
1259 err
1260 )
1261 })
1262 }
1263
1264 async fn borrow_drop(
1265 &mut self,
1266 _cx: &Context<Self>,
1267 borrow: u64,
1268 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1269 ) -> Result<()> {
1270 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1271 recording.messages.push(StreamMessage::BorrowDrop {
1272 borrow,
1273 last_use_receiver: last_use_receiver.clone(),
1274 });
1275 ensure!(
1276 defined_borrows.remove(&borrow),
1277 "borrow drop for borrow not defined in recording"
1278 );
1279 return Ok(());
1280 }
1281
1282 let (last_use_event, _cell) =
1286 last_use_receiver.lock().await.recv().await.map_err(|err| {
1287 anyhow!(
1288 "failed receiving last use event for borrow {:?}: {:?}",
1289 borrow,
1290 err
1291 )
1292 })?;
1293
1294 if let Some(stream) = self.cuda_stream() {
1295 stream.wait_event(
1296 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1297 );
1298 }
1299 Ok(())
1301 }
1302
1303 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1304 if let Some((recording, _)) = self.get_defining_recording() {
1305 recording.messages.push(StreamMessage::DeleteRefs(refs));
1306 return Ok(());
1307 }
1308
1309 for ref_ in refs.iter() {
1310 self.env.remove(ref_);
1311 }
1312 Ok(())
1313 }
1314
1315 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1316 if self.get_defining_recording().is_some() {
1317 bail!("request_status not allowed in recording");
1318 }
1319
1320 Ok(())
1321 }
1322
1323 async fn init_comm(
1324 &mut self,
1325 _cx: &Context<Self>,
1326 comm: ActorHandle<NcclCommActor>,
1327 ) -> Result<()> {
1328 if self.get_defining_recording().is_some() {
1329 bail!("init_comm not allowed in recording");
1330 }
1331
1332 self.comm = Some(comm);
1333 Ok(())
1334 }
1335
1336 async fn reduce(
1337 &mut self,
1338 cx: &Context<Self>,
1339 comm: Arc<ActorHandle<NcclCommActor>>,
1340 dim_size: i64,
1341 result: Ref,
1342 local_tensor: Ref,
1343 factory: Factory,
1344 reduction: Reduction,
1345 scatter: bool,
1346 in_place: bool,
1347 out: Option<Ref>,
1348 ) -> Result<()> {
1349 if let Some((recording, _)) = self.get_defining_recording() {
1350 recording.messages.push(StreamMessage::Reduce {
1351 comm,
1352 dim_size,
1353 result,
1354 local_tensor,
1355 factory,
1356 reduction,
1357 scatter,
1358 in_place,
1359 out,
1360 });
1361 return Ok(());
1362 }
1363
1364 let stream = self
1365 .cuda_stream()
1366 .expect("reductions not yet supported for non-CUDA workers")
1367 .clone();
1368 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1369 let out_cell = out
1370 .map(|out| self.get_or_fake_on_err(out, &factory))
1371 .transpose()?;
1372 let output_cell = match reduction {
1373 Reduction::Stack => {
1374 if scatter {
1375 let output_cell = if in_place {
1376 input_cell.clone()
1377 } else {
1378 out_cell.unwrap_or({
1379 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1380 let cloned = deep_clone(&borrow);
1381 TensorCell::new(cloned)
1382 })
1383 };
1384 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1385 .await?;
1386 output_cell
1387 } else {
1388 ensure!(
1389 !in_place,
1390 "in-place, non-scatter not supported for stack reduce"
1391 );
1392
1393 let output_cell = out_cell.unwrap_or({
1394 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1396 let output =
1397 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1398 TensorCell::new(output)
1399 });
1400
1401 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1402 .await?;
1403 output_cell
1404 }
1405 }
1406 Reduction::ReduceOp(op) => {
1407 if scatter {
1408 ensure!(!in_place, "in-place, scatter not supported for reduce");
1409
1410 let output_cell = out_cell.unwrap_or({
1411 let output = factory_empty(
1412 &factory.size[1..],
1413 factory.dtype,
1414 factory.layout,
1415 factory.device,
1416 );
1417 TensorCell::new(output)
1418 });
1419 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1420 .await?;
1421 output_cell
1422 } else {
1423 let output_cell = if in_place {
1424 input_cell.clone()
1425 } else {
1426 out_cell.map_or(
1427 {
1428 let borrow =
1429 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1430 let cloned = deep_clone(&borrow);
1431 Ok(TensorCell::new(cloned))
1432 },
1433 |out_cell| -> Result<_, anyhow::Error> {
1434 let mut out_borrow =
1435 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1436 let in_borrow =
1437 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1438 out_borrow.copy_(&in_borrow);
1439 drop(out_borrow);
1440 Ok(out_cell)
1441 },
1442 )?
1443 };
1444
1445 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1446 output_cell
1447 }
1448 }
1449 };
1450
1451 self.env.insert(result, Ok(output_cell.into()));
1452 Ok(())
1453 }
1454
1455 async fn send_tensor(
1456 &mut self,
1457 cx: &Context<Self>,
1458 result: Ref,
1459 from_rank: Option<usize>,
1460 to_rank: Option<usize>,
1461 tensor: Ref,
1462 factory: Factory,
1463 comm: Arc<ActorHandle<NcclCommActor>>,
1464 ) -> Result<()> {
1465 if let Some((recording, _)) = self.get_defining_recording() {
1466 recording.messages.push(StreamMessage::SendTensor {
1467 result,
1468 from_rank,
1469 to_rank,
1470 tensor,
1471 factory,
1472 comm,
1473 });
1474 return Ok(());
1475 }
1476
1477 if to_rank.is_none() && from_rank.is_none() {
1478 bail!("tried to send tensor without a to/from rank");
1479 }
1480
1481 if from_rank == to_rank {
1483 let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1484 .env
1485 .get(&tensor)
1486 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1487 let output_cell = match input_cell {
1488 Ok(RValue::Tensor(input_cell)) => {
1489 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1494 let cloned = deep_clone(&borrow);
1495 Ok(RValue::Tensor(TensorCell::new(cloned)))
1496 }
1497 Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1498 Err(err) => Err(err.clone()),
1499 };
1500 self.env.insert(result, output_cell);
1501 return Ok(());
1502 }
1503
1504 let mut messages = Vec::new();
1505
1506 if let Some(to_rank) = to_rank {
1507 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1508 messages.push(CommMessage::Send(
1509 input_cell,
1510 to_rank.try_into().unwrap(),
1511 self.cuda_stream()
1512 .expect("tried to send_tensor on non-cuda stream")
1513 .clone(),
1514 cx.open_once_port().0,
1515 ));
1516 }
1517
1518 if let Some(from_rank) = from_rank {
1519 let output_cell = TensorCell::new(factory_empty(
1520 &factory.size,
1521 factory.dtype,
1522 factory.layout,
1523 factory.device,
1524 ));
1525 messages.push(CommMessage::Recv(
1526 output_cell.clone(),
1527 from_rank.try_into().unwrap(),
1528 self.cuda_stream()
1529 .expect("tried to send_tensor on non-cuda stream")
1530 .clone(),
1531 cx.open_once_port().0,
1532 ));
1533 self.env.insert(result, Ok(output_cell.into()));
1534 }
1535
1536 comm.group(
1537 cx,
1538 messages,
1539 self.cuda_stream()
1540 .expect("tried to send_tensor on non-cuda stream")
1541 .clone(),
1542 )
1543 .await?;
1544 Ok(())
1545 }
1546
1547 async fn send_value(
1548 &mut self,
1549 cx: &Context<Self>,
1550 seq: Seq,
1551 worker_actor_id: ActorId,
1552 mutates: Vec<Ref>,
1553 function: Option<ResolvableFunction>,
1554 args: Vec<WireValue>,
1555 kwargs: HashMap<String, WireValue>,
1556 device_meshes: HashMap<Ref, DeviceMesh>,
1557 ) -> Result<()> {
1558 if self.respond_with_python_message {
1559 return self
1560 .send_value_python_message(cx, seq, mutates, function, args, kwargs, device_meshes)
1561 .await;
1562 }
1563 let result = if let Some(function) = function {
1564 match function.as_torch_op() {
1566 Some((op, overload)) => {
1567 self.call_torch_op(op, overload, args, kwargs)
1568 .map(|rvalues| {
1569 if rvalues.len() == 1 {
1570 Ok(rvalues[0].clone().into())
1571 } else {
1572 Python::with_gil(|py| {
1574 Ok((|| {
1575 let py_rvalues = rvalues
1576 .into_iter()
1577 .map(|rvalue| unsafe {
1579 rvalue.try_to_object_unsafe(py)
1580 })
1581 .collect::<Result<Vec<_>, _>>()?;
1582 PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1583 })()
1584 .map_err(SerializablePyErr::from_fn(py))?)
1585 })
1586 }
1587 })?
1588 }
1589 _ => tokio::task::block_in_place(|| {
1592 self.call_python_fn_pytree(
1593 cx,
1594 function,
1595 args,
1596 kwargs,
1597 &mutates,
1598 device_meshes,
1599 HashMap::new(),
1600 )
1601 }),
1602 }
1603 } else {
1604 match (args.len(), kwargs.len()) {
1607 (1, 0) => Python::with_gil(|py| {
1608 let arg = args[0]
1609 .as_py_object()
1610 .ok_or_else(|| {
1611 CallFunctionError::UnsupportedArgType(
1612 "send_value".to_string(),
1613 "expected a PyObject as the first arg".to_string(),
1614 )
1615 })?
1616 .unpickle(py)
1617 .map_err(SerializablePyErr::from_fn(py))?;
1618 arg.extract::<PyTree<PyObject>>()
1619 .map_err(SerializablePyErr::from_fn(py))?
1620 .try_into_map(|obj| {
1621 let bound_obj = obj.bind(py);
1622 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1623 self.ref_to_rvalue(&ref_)
1624 } else {
1625 Ok(bound_obj
1626 .extract::<RValue>()
1627 .map_err(SerializablePyErr::from_fn(py))?)
1628 }
1629 })
1630 }),
1631 _ => Err(CallFunctionError::TooManyArgsForValue(
1632 format!("{:?}", args),
1633 format!("{:?}", kwargs),
1634 )),
1635 }
1636 };
1637
1638 let value = match result {
1639 Ok(rvalue) => {
1640 Ok(rvalue.into_map(|rval| match rval {
1645 RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1646 RValue::TensorList(tensors) => RValue::TensorList(
1647 tensors
1648 .into_iter()
1649 .map(|tensor| tensor.try_cpu().unwrap())
1650 .collect(),
1651 ),
1652 rval => rval,
1653 }))
1654 }
1655 Err(err) => {
1656 let err = self.report_seq_error(cx, seq, err).await?;
1657 for ref_ in mutates {
1658 self.env.insert(ref_, Err(err.clone()));
1659 }
1660 Err(WorkerError {
1661 backtrace: format!("{:?}", err),
1662 worker_actor_id,
1663 })
1664 }
1665 };
1666
1667 let result = match value {
1669 Ok(value) => Ok(Serialized::serialize(&value).map_err(anyhow::Error::from)?),
1670 Err(e) => Err(e),
1671 };
1672 self.controller_actor.fetch_result(cx, seq, result).await?;
1673
1674 Ok(())
1675 }
1676
1677 async fn send_result_of_actor_call(
1678 &mut self,
1679 cx: &Context<Self>,
1680 params: ActorCallParams,
1681 ) -> anyhow::Result<()> {
1682 let seq = params.seq;
1683 let mutates = params.mutates.clone();
1684 self.try_define(cx, seq, vec![], &mutates, async |self| {
1685 let value = self.call_actor(cx, params).await?;
1686 let result =
1687 Python::with_gil(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1688 let result = Serialized::serialize(&result).unwrap();
1689 self.controller_actor
1690 .fetch_result(cx, seq, Ok(result))
1691 .await?;
1692 Ok(vec![])
1693 })
1694 .await
1695 }
1696
1697 async fn call_actor_method(
1698 &mut self,
1699 cx: &Context<Self>,
1700 params: ActorMethodParams,
1701 ) -> anyhow::Result<()> {
1702 let seq = params.call.seq;
1703 let mutates = params.call.mutates.clone();
1704 self.try_define(cx, seq, params.results, &mutates, async |self| {
1705 let result = self.call_actor(cx, params.call).await?;
1706 let result = Python::with_gil(|py| {
1707 PyTree::<RValue>::extract_bound(&result.into_bound(py))
1708 .map_err(SerializablePyErr::from_fn(py))
1709 })?;
1710 Ok(result.into_leaves())
1711 })
1712 .await
1713 }
1714
1715 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1716 if self.active_recording.is_some() {
1717 bail!("different recording already active");
1718 }
1719 match self.recordings.entry(recording) {
1720 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1721 Entry::Vacant(entry) => entry.insert(Recording::new()),
1722 };
1723 self.active_recording = Some(RecordingState::Defining {
1724 recording,
1725 defined_borrows: HashSet::new(),
1726 });
1727 Ok(())
1728 }
1729
1730 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1731 match self.active_recording {
1732 Some(RecordingState::Defining {
1733 recording: active_recording,
1734 ref defined_borrows,
1735 }) if active_recording == recording => {
1736 ensure!(
1737 defined_borrows.is_empty(),
1738 "all borrows created within recording must be dropped within recording"
1739 );
1740 self.active_recording = None;
1741 }
1742 _ => bail!("cannot finalize recording that isn't active"),
1743 }
1744 Ok(())
1745 }
1746
1747 async fn recording_formal(
1748 &mut self,
1749 _cx: &Context<Self>,
1750 result: Ref,
1751 argument_index: usize,
1752 ) -> Result<()> {
1753 match self.get_defining_recording() {
1754 Some((recording, _)) => {
1755 recording.messages.push(StreamMessage::RecordingFormal {
1756 result,
1757 argument_index,
1758 });
1759 }
1760 None => bail!("recording_formal called outside of recording"),
1761 };
1762 Ok(())
1763 }
1764
1765 async fn recording_result(
1766 &mut self,
1767 _cx: &Context<Self>,
1768 result: Ref,
1769 output_index: usize,
1770 ) -> Result<()> {
1771 match self.get_defining_recording() {
1772 Some((recording, _)) => {
1773 recording.messages.push(StreamMessage::RecordingResult {
1774 result,
1775 output_index,
1776 });
1777 }
1778 None => bail!("recording_result called outside of recording"),
1779 };
1780 Ok(())
1781 }
1782
1783 async fn call_recording(
1784 &mut self,
1785 cx: &Context<Self>,
1786 seq: Seq,
1787 recording: Ref,
1788 results: Vec<Ref>,
1789 actuals: Vec<Ref>,
1790 ) -> Result<()> {
1791 if self.active_recording.is_some() {
1792 bail!("cannot call recording while another recording is active");
1793 }
1794
1795 let messages = match self.recordings.get(&recording) {
1796 Some(recording) => recording
1797 .messages
1798 .iter()
1799 .map(|message| message.clone_for_recording())
1800 .collect::<Vec<_>>(),
1801 None => bail!("recording {:?} not found", recording),
1802 };
1803
1804 self.active_recording = Some(RecordingState::Running);
1805
1806 let mut error: Option<Arc<SeqError>> = None;
1811 let mut all_defined_refs = HashSet::new();
1814 let mut all_mutated_refs = HashSet::new();
1817 let mut formal_to_actual_refs = HashMap::new();
1823 self.last_seq_error = None;
1825 for message in messages.into_iter() {
1826 let defined_refs = message.get_defined_refs();
1827 all_defined_refs.extend(defined_refs.clone());
1828
1829 let mutated_refs_with_formals = message.get_mutated_refs();
1830 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1831 match formal_to_actual_refs.get(ref_) {
1832 Some(actual_ref) => Some(*actual_ref),
1833 None => {
1834 if all_defined_refs.contains(ref_) {
1835 None
1836 } else {
1837 Some(*ref_)
1838 }
1839 }
1840 }
1841 }));
1842
1843 match message {
1844 StreamMessage::RecordingFormal {
1845 result: formal_ref,
1846 argument_index,
1847 } => match actuals.get(argument_index) {
1848 None => bail!("recording_formal called with too few arguments"),
1849 Some(actual_ref) => {
1850 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1851 self.define_ref(formal_ref, *actual_ref)?;
1852 }
1853 },
1854 StreamMessage::RecordingResult {
1855 result: result_ref,
1856 output_index,
1857 } => match results.get(output_index) {
1858 None => bail!("recording_result called with too few results"),
1859 Some(actual_result_ref) => {
1860 self.define_ref(*actual_result_ref, result_ref)?;
1861 }
1862 },
1863 StreamMessage::DeleteRefs(ref refs) => {
1864 for ref_ in refs {
1865 all_defined_refs.remove(ref_);
1866 }
1867 StreamMessageHandler::handle(self, cx, message).await?;
1868 }
1869 StreamMessage::CallFunction { .. } if error.is_some() => {
1870 let error = error.clone().unwrap();
1876 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1877 self.env.insert(*ref_, Err(error.clone()));
1878 }
1879 }
1880 StreamMessage::BorrowLastUse { ref result, .. } => {
1881 all_defined_refs.remove(result);
1882 StreamMessageHandler::handle(self, cx, message).await?;
1883 }
1884 StreamMessage::Reduce {
1885 local_tensor,
1886 ref out,
1887 ..
1888 } => {
1889 if error.is_none() {
1893 let inputs_to_check = [Some(local_tensor), out.clone()]
1894 .iter()
1895 .filter_map(|r| *r)
1896 .collect::<Vec<_>>();
1897 error = self.get_first_error(inputs_to_check.as_slice())?;
1898 }
1899 StreamMessageHandler::handle(self, cx, message).await?;
1900 }
1901 StreamMessage::SendTensor {
1902 ref tensor,
1903 ref to_rank,
1904 ..
1905 } => {
1906 if to_rank.is_some() && error.is_none() {
1911 error = self.get_first_error(&[*tensor])?;
1912 }
1913 StreamMessageHandler::handle(self, cx, message).await?;
1914 }
1915 _ => {
1916 StreamMessageHandler::handle(self, cx, message).await?;
1917 }
1918 };
1919
1920 match (&error, self.last_seq_error.take()) {
1928 (None, Some(seq_err)) => {
1929 self.controller_actor
1931 .remote_function_failed(
1932 cx,
1933 seq,
1934 WorkerError {
1935 backtrace: format!("recording failed: {}", &seq_err),
1936 worker_actor_id: cx.self_id().clone(),
1937 },
1938 )
1939 .await?;
1940 error = Some(seq_err)
1941 }
1942 _ => {}
1943 }
1944 }
1949
1950 StreamMessageHandler::handle(
1954 self,
1955 cx,
1956 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1957 )
1958 .await?;
1959
1960 if error.is_some() {
1963 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1964 self.env.insert(*ref_, Err(error.clone().unwrap()));
1965 }
1966 }
1967
1968 self.active_recording = None;
1969 Ok(())
1970 }
1971
1972 async fn set_ref_unit_tests_only(
1973 &mut self,
1974 _cx: &Context<Self>,
1975 reference: Ref,
1976 value: WireValue,
1977 ) -> Result<()> {
1978 self.env
1979 .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
1980 Ok(())
1981 }
1982
1983 async fn set_tensor_ref_unit_tests_only(
1984 &mut self,
1985 _cx: &Context<Self>,
1986 reference: Ref,
1987 tensor_result: TensorCellResult,
1988 ) -> Result<()> {
1989 match tensor_result {
1990 Ok(tensor_cell) => {
1991 self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
1992 }
1993 Err(err) => {
1994 self.env.insert(reference, Err(err));
1995 }
1996 }
1997 Ok(())
1998 }
1999
2000 async fn get_ref_unit_tests_only(
2001 &mut self,
2002 _cx: &Context<Self>,
2003 reference: Ref,
2004 ) -> Result<Option<Result<WireValue, String>>> {
2005 fn rvalue_to_wire(
2007 value: Result<RValue, Arc<SeqError>>,
2008 ) -> Result<WireValue, Arc<SeqError>> {
2009 Ok(match value? {
2010 RValue::Int(val) => WireValue::Int(val),
2011 RValue::IntList(val) => WireValue::IntList(val),
2012 RValue::Double(val) => WireValue::Double(val),
2013 RValue::Bool(val) => WireValue::Bool(val),
2014 RValue::String(val) => WireValue::String(val),
2015 RValue::Layout(val) => WireValue::Layout(val),
2016 RValue::Device(val) => WireValue::Device(val),
2017 RValue::ScalarType(val) => WireValue::ScalarType(val),
2018 RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2019 RValue::None => WireValue::None(()),
2020 other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2021 })
2022 }
2023 Ok(self
2024 .env
2025 .get(&reference)
2026 .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2027 }
2028
2029 async fn get_tensor_ref_unit_tests_only(
2030 &mut self,
2031 _cx: &Context<Self>,
2032 reference: Ref,
2033 ) -> Result<Option<TensorCellResult>> {
2034 match self.env.get(&reference) {
2035 Some(Ok(rvalue)) => match rvalue {
2036 RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2037 other => bail!("expected tensor, got {:?}", other),
2038 },
2039 Some(Err(err)) => Ok(Some(Err(err.clone()))),
2040 None => Ok(None),
2041 }
2042 }
2043}
2044
2045#[cfg(test)]
2046mod tests {
2047 use hyperactor::actor::ActorStatus;
2048 use hyperactor::context;
2049 use hyperactor::supervision::ActorSupervisionEvent;
2050 use monarch_messages::controller::ControllerMessage;
2051 use monarch_messages::worker::StreamCreationMode;
2052 use monarch_types::PickledPyObject;
2053 use pyo3::IntoPyObjectExt;
2054 use timed_test::async_timed_test;
2055 use torch_sys::factory_float_tensor;
2056 use torch_sys::testing::allclose;
2057 use torch_sys_cuda::nccl::UniqueId;
2058
2059 use super::*;
2060 use crate::comm::CommParams;
2061 use crate::test_util;
2062
2063 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2064 Arc::new(SeqError {
2065 seq: 0.into(),
2066 error: err,
2067 })
2068 }
2069
2070 struct TestSetup {
2071 proc: Proc,
2072 stream_actor: ActorHandle<StreamActor>,
2073 client: Instance<()>,
2074 #[allow(dead_code)]
2077 supervision_rx: PortReceiver<ActorSupervisionEvent>,
2078 controller_rx: PortReceiver<ControllerMessage>,
2079 controller_actor: ActorRef<ControllerActor>,
2080 next_ref: Ref,
2081 }
2082
2083 impl TestSetup {
2084 async fn new() -> Result<Self> {
2085 Self::new_with_world_size(1).await
2086 }
2087
2088 async fn new_with_world_size(world_size: usize) -> Result<Self> {
2089 test_util::test_setup()?;
2090
2091 let proc = Proc::local();
2092 let (_, controller_actor, controller_rx) =
2093 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2094 let (client, _handle) = proc.instance("client")?;
2095 let (supervision_tx, supervision_rx) = client.open_port();
2096 proc.set_supervision_coordinator(supervision_tx)?;
2097 let stream_actor = proc
2098 .spawn::<StreamActor>(
2099 "stream",
2100 StreamParams {
2101 world_size,
2102 rank: 0,
2103 creation_mode: StreamCreationMode::UseDefaultStream,
2104 id: 0.into(),
2105 device: Some(CudaDevice::new(0.into())),
2106 controller_actor: controller_actor.clone(),
2107 respond_with_python_message: false,
2108 },
2109 )
2110 .await?;
2111
2112 Ok(Self {
2113 proc,
2114 stream_actor,
2115 client,
2116 supervision_rx,
2117 controller_rx,
2118 controller_actor,
2119 next_ref: 0.into(),
2120 })
2121 }
2122
2123 fn next_ref(&mut self) -> Ref {
2124 let ref_ = self.next_ref;
2125 self.next_ref = Ref {
2126 id: self.next_ref.id + 1,
2127 };
2128 ref_
2129 }
2130
2131 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2132 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2133 self.stream_actor
2134 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2135 .await
2136 }
2137
2138 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2139 let actual = self
2140 .stream_actor
2141 .get_tensor_ref_unit_tests_only(&self.client, reference)
2142 .await
2143 .unwrap()
2144 .unwrap()
2145 .unwrap();
2146
2147 let result = allclose(
2148 &factory_float_tensor(data, "cpu".try_into().unwrap()),
2149 &actual.borrow(),
2150 )
2151 .unwrap();
2152 result
2154 }
2155
2156 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2157 let result_error = self
2158 .stream_actor
2159 .get_tensor_ref_unit_tests_only(&self.client, reference)
2160 .await
2161 .unwrap()
2162 .unwrap()
2163 .unwrap_err();
2164
2165 assert!(Arc::ptr_eq(&result_error, &error));
2166 }
2167 }
2168
2169 async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2170 loop {
2171 let status = proc
2172 .ledger_snapshot()
2173 .roots
2174 .get(actor_id)
2175 .unwrap()
2176 .status
2177 .clone();
2178 if let ActorStatus::Failed(msg) = status {
2179 assert!(msg.to_string().contains(&expected_msg));
2180 break;
2181 } else {
2182 tokio::task::yield_now().await;
2183 }
2184 }
2185 }
2186
2187 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2188 for ref_ in refs {
2189 assert!(
2190 test_setup
2191 .stream_actor
2192 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2193 .await
2194 .unwrap()
2195 .is_none()
2196 );
2197 }
2198 }
2199
2200 async fn fetch_result(
2201 cx: &impl context::Actor,
2202 stream_actor: ActorHandle<StreamActor>,
2203 seq: Seq,
2204 reference: Ref,
2205 ) {
2206 let ref_to_send = Python::with_gil(|py| {
2207 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2208 });
2209
2210 stream_actor
2211 .send_value(
2212 cx,
2213 seq,
2214 stream_actor.actor_id().clone(),
2215 Vec::new(),
2216 None,
2217 vec![WireValue::PyObject(ref_to_send)],
2218 HashMap::new(),
2219 HashMap::new(),
2220 )
2221 .await
2222 .unwrap()
2223 }
2224
2225 async fn check_fetch_result_error(
2226 cx: &impl context::Actor,
2227 stream_actor: ActorHandle<StreamActor>,
2228 seq: Seq,
2229 reference: Ref,
2230 controller_rx: &mut PortReceiver<ControllerMessage>,
2231 expected_backtrace: &str,
2232 ) {
2233 fetch_result(cx, stream_actor, seq, reference).await;
2234
2235 let controller_msg = controller_rx.recv().await.unwrap();
2236 match controller_msg {
2237 ControllerMessage::FetchResult {
2238 seq: actual_seq,
2239 value: Err(err),
2240 } => {
2241 assert_eq!(actual_seq, seq);
2242 assert!(
2243 err.backtrace.contains(expected_backtrace),
2244 "backtrace did not contain {:?}: {:?}",
2245 expected_backtrace,
2246 err.backtrace
2247 );
2248 }
2249 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2250 };
2251 }
2252
2253 async fn check_fetch_result_value(
2254 cx: &impl context::Actor,
2255 stream_actor: ActorHandle<StreamActor>,
2256 seq: Seq,
2257 reference: Ref,
2258 controller_rx: &mut PortReceiver<ControllerMessage>,
2259 ) {
2260 fetch_result(cx, stream_actor, seq, reference).await;
2261
2262 let controller_msg = controller_rx.recv().await.unwrap();
2263 match controller_msg {
2264 ControllerMessage::FetchResult {
2265 value: Ok(_),
2266 seq: actual_seq,
2267 } => assert_eq!(seq, actual_seq),
2268 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2269 };
2270 }
2271
2272 #[async_timed_test(timeout_secs = 60)]
2273 async fn test_define_recording_other_recording_active() -> Result<()> {
2274 let test_setup = TestSetup::new().await?;
2275 test_setup
2276 .stream_actor
2277 .define_recording(&test_setup.client, 0.into())
2278 .await?;
2279 test_setup
2280 .stream_actor
2281 .define_recording(&test_setup.client, 1.into())
2282 .await?;
2283 assert_actor_failed_with_msg(
2284 &test_setup.proc,
2285 test_setup.stream_actor.actor_id(),
2286 "different recording already active".into(),
2287 )
2288 .await;
2289 Ok(())
2290 }
2291
2292 #[async_timed_test(timeout_secs = 60)]
2293 async fn test_define_recording_already_defined() -> Result<()> {
2294 let test_setup = TestSetup::new().await?;
2295 test_setup
2296 .stream_actor
2297 .define_recording(&test_setup.client, 0.into())
2298 .await?;
2299 test_setup
2300 .stream_actor
2301 .finalize_recording(&test_setup.client, 0.into())
2302 .await?;
2303 test_setup
2304 .stream_actor
2305 .define_recording(&test_setup.client, 0.into())
2306 .await?;
2307 assert_actor_failed_with_msg(
2308 &test_setup.proc,
2309 test_setup.stream_actor.actor_id(),
2310 "already defined".into(),
2311 )
2312 .await;
2313 Ok(())
2314 }
2315
2316 #[async_timed_test(timeout_secs = 60)]
2317 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2318 let test_setup = TestSetup::new().await?;
2319 test_setup
2320 .stream_actor
2321 .define_recording(&test_setup.client, 0.into())
2322 .await?;
2323 test_setup
2324 .stream_actor
2325 .finalize_recording(&test_setup.client, 1.into())
2326 .await?;
2327 assert_actor_failed_with_msg(
2328 &test_setup.proc,
2329 test_setup.stream_actor.actor_id(),
2330 "cannot finalize recording that isn't active".into(),
2331 )
2332 .await;
2333 Ok(())
2334 }
2335
2336 #[async_timed_test(timeout_secs = 60)]
2337 async fn test_recording_formal_outside_recording() -> Result<()> {
2338 let test_setup = TestSetup::new().await?;
2339 test_setup
2340 .stream_actor
2341 .recording_formal(&test_setup.client, 0.into(), 0)
2342 .await?;
2343 assert_actor_failed_with_msg(
2344 &test_setup.proc,
2345 test_setup.stream_actor.actor_id(),
2346 "recording_formal called outside of recording".into(),
2347 )
2348 .await;
2349 Ok(())
2350 }
2351
2352 #[async_timed_test(timeout_secs = 60)]
2353 async fn test_recording_result_outside_recording() -> Result<()> {
2354 let test_setup = TestSetup::new().await?;
2355 test_setup
2356 .stream_actor
2357 .recording_result(&test_setup.client, 0.into(), 0)
2358 .await?;
2359 assert_actor_failed_with_msg(
2360 &test_setup.proc,
2361 test_setup.stream_actor.actor_id(),
2362 "recording_result called outside of recording".into(),
2363 )
2364 .await;
2365 Ok(())
2366 }
2367
2368 #[async_timed_test(timeout_secs = 60)]
2369 async fn test_call_recording_other_recording_active() -> Result<()> {
2370 let test_setup = TestSetup::new().await?;
2371 test_setup
2372 .stream_actor
2373 .define_recording(&test_setup.client, 0.into())
2374 .await?;
2375 test_setup
2376 .stream_actor
2377 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2378 .await?;
2379 assert_actor_failed_with_msg(
2380 &test_setup.proc,
2381 test_setup.stream_actor.actor_id(),
2382 "cannot call recording while another recording is active".into(),
2383 )
2384 .await;
2385 Ok(())
2386 }
2387
2388 #[async_timed_test(timeout_secs = 60)]
2389 async fn test_call_recording_not_found() -> Result<()> {
2390 let test_setup = TestSetup::new().await?;
2391 test_setup
2392 .stream_actor
2393 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2394 .await?;
2395 assert_actor_failed_with_msg(
2396 &test_setup.proc,
2397 test_setup.stream_actor.actor_id(),
2398 "not found".into(),
2399 )
2400 .await;
2401 Ok(())
2402 }
2403
2404 #[async_timed_test(timeout_secs = 60)]
2405 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2406 let test_setup = TestSetup::new().await?;
2407
2408 test_setup
2409 .stream_actor
2410 .define_recording(&test_setup.client, 0.into())
2411 .await?;
2412
2413 test_setup
2414 .stream_actor
2415 .recording_formal(&test_setup.client, 1.into(), 0)
2416 .await?;
2417
2418 test_setup
2419 .stream_actor
2420 .finalize_recording(&test_setup.client, 0.into())
2421 .await?;
2422
2423 test_setup
2424 .stream_actor
2425 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2426 .await?;
2427
2428 assert_actor_failed_with_msg(
2429 &test_setup.proc,
2430 test_setup.stream_actor.actor_id(),
2431 "recording_formal called with too few arguments".into(),
2432 )
2433 .await;
2434 Ok(())
2435 }
2436
2437 #[async_timed_test(timeout_secs = 60)]
2438 async fn test_recording_result_too_few_results() -> Result<()> {
2439 let test_setup = TestSetup::new().await?;
2440
2441 test_setup
2442 .stream_actor
2443 .define_recording(&test_setup.client, 0.into())
2444 .await?;
2445
2446 test_setup
2447 .stream_actor
2448 .recording_result(&test_setup.client, 1.into(), 0)
2449 .await?;
2450
2451 test_setup
2452 .stream_actor
2453 .finalize_recording(&test_setup.client, 0.into())
2454 .await?;
2455
2456 test_setup
2457 .stream_actor
2458 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2459 .await?;
2460
2461 assert_actor_failed_with_msg(
2462 &test_setup.proc,
2463 test_setup.stream_actor.actor_id(),
2464 "recording_result called with too few results".into(),
2465 )
2466 .await;
2467 Ok(())
2468 }
2469
2470 #[async_timed_test(timeout_secs = 60)]
2471 async fn test_basic_call_recording() -> Result<()> {
2472 let mut test_setup = TestSetup::new().await?;
2473
2474 test_setup
2478 .stream_actor
2479 .define_recording(&test_setup.client, 0.into())
2480 .await?;
2481
2482 let formal0_ref = 1.into();
2483 let formal0_index = 1;
2484 test_setup
2485 .stream_actor
2486 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2487 .await?;
2488
2489 let formal1_ref = 2.into();
2490 let formal1_index = 0;
2491 test_setup
2492 .stream_actor
2493 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2494 .await?;
2495
2496 let result0_ref = formal0_ref;
2497 let result0_index = 0;
2498 test_setup
2499 .stream_actor
2500 .recording_result(&test_setup.client, result0_ref, result0_index)
2501 .await?;
2502
2503 let result1_ref = formal1_ref;
2504 let result1_index = 1;
2505 test_setup
2506 .stream_actor
2507 .recording_result(&test_setup.client, result1_ref, result1_index)
2508 .await?;
2509
2510 test_setup
2511 .stream_actor
2512 .finalize_recording(&test_setup.client, 0.into())
2513 .await?;
2514
2515 let actual0_ref = 3.into();
2516 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2517
2518 let actual1_ref = 4.into();
2519 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2520
2521 let actual_result0_ref = 5.into();
2524 let actual_result1_ref = 6.into();
2525 test_setup
2526 .stream_actor
2527 .call_recording(
2528 &test_setup.client,
2529 0.into(),
2530 0.into(),
2531 vec![actual_result0_ref, actual_result1_ref],
2532 vec![actual0_ref, actual1_ref],
2533 )
2534 .await?;
2535
2536 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2538 assert!(
2539 test_setup
2540 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2541 .await
2542 );
2543
2544 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2547 Ok(())
2548 }
2549
2550 #[async_timed_test(timeout_secs = 60)]
2551 async fn test_request_status_in_recording() -> Result<()> {
2552 let test_setup = TestSetup::new().await?;
2553 test_setup
2554 .stream_actor
2555 .define_recording(&test_setup.client, 0.into())
2556 .await?;
2557 test_setup
2558 .stream_actor
2559 .request_status(&test_setup.client)
2560 .await
2561 .expect_err("request_status should have failed");
2562 assert_actor_failed_with_msg(
2563 &test_setup.proc,
2564 test_setup.stream_actor.actor_id(),
2565 "request_status not allowed in recording".into(),
2566 )
2567 .await;
2568 Ok(())
2569 }
2570
2571 #[async_timed_test(timeout_secs = 60)]
2572 async fn test_init_comm_in_recording() -> Result<()> {
2573 let test_setup = TestSetup::new().await?;
2574 test_setup
2575 .stream_actor
2576 .define_recording(&test_setup.client, 0.into())
2577 .await?;
2578
2579 let dummy_comm = test_setup
2580 .proc
2581 .spawn::<NcclCommActor>(
2582 "comm",
2583 CommParams::New {
2584 device: CudaDevice::new(0.into()),
2585 unique_id: UniqueId::new()?,
2586 world_size: 1,
2587 rank: 0,
2588 },
2589 )
2590 .await?;
2591
2592 test_setup
2593 .stream_actor
2594 .init_comm(&test_setup.client, dummy_comm)
2595 .await?;
2596 assert_actor_failed_with_msg(
2597 &test_setup.proc,
2598 test_setup.stream_actor.actor_id(),
2599 "init_comm not allowed in recording".into(),
2600 )
2601 .await;
2602 Ok(())
2603 }
2604
2605 #[async_timed_test(timeout_secs = 60)]
2606 async fn test_call_function_in_recording() -> Result<()> {
2607 let mut test_setup = TestSetup::new().await?;
2608
2609 test_setup
2616 .stream_actor
2617 .define_recording(&test_setup.client, 0.into())
2618 .await?;
2619
2620 let formal0_ref = test_setup.next_ref();
2621 let formal0_index = 0;
2622 test_setup
2623 .stream_actor
2624 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2625 .await?;
2626
2627 let formal1_ref = test_setup.next_ref();
2628 let formal1_index = 1;
2629 test_setup
2630 .stream_actor
2631 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2632 .await?;
2633
2634 let captured_ref = test_setup.next_ref();
2635 let result_captured_ref = test_setup.next_ref();
2636 let add_one_function =
2637 ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2638 let add_tensors_function =
2639 ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2640
2641 let add_result_ref_0 = test_setup.next_ref();
2642 test_setup
2643 .stream_actor
2644 .call_function(
2645 &test_setup.client,
2646 CallFunctionParams {
2647 seq: 100.into(),
2648 function: add_tensors_function.clone(),
2649 args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2650 kwargs: HashMap::new(),
2651 results: vec![Some(add_result_ref_0)],
2652 mutates: vec![],
2653 stream: 0.into(),
2654 remote_process_groups: Vec::new(),
2655 },
2656 HashMap::new(),
2657 HashMap::new(),
2658 )
2659 .await?;
2660
2661 test_setup
2662 .stream_actor
2663 .call_function(
2664 &test_setup.client,
2665 CallFunctionParams {
2666 seq: 101.into(),
2667 function: add_one_function,
2668 args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2669 kwargs: HashMap::new(),
2670 results: vec![Some(result_captured_ref)],
2671 mutates: vec![captured_ref],
2672 stream: 0.into(),
2673 remote_process_groups: Vec::new(),
2674 },
2675 HashMap::new(),
2676 HashMap::new(),
2677 )
2678 .await?;
2679
2680 let add_result_ref_1 = test_setup.next_ref();
2681 test_setup
2682 .stream_actor
2683 .call_function(
2684 &test_setup.client,
2685 CallFunctionParams {
2686 seq: 102.into(),
2687 function: add_tensors_function,
2688 args: vec![
2689 WireValue::Ref(add_result_ref_0),
2690 WireValue::Ref(captured_ref),
2691 ],
2692 kwargs: HashMap::new(),
2693 results: vec![Some(add_result_ref_1)],
2694 mutates: vec![],
2695 stream: 0.into(),
2696 remote_process_groups: Vec::new(),
2697 },
2698 HashMap::new(),
2699 HashMap::new(),
2700 )
2701 .await?;
2702
2703 test_setup
2704 .stream_actor
2705 .recording_result(&test_setup.client, add_result_ref_1, 0)
2706 .await?;
2707
2708 test_setup
2709 .stream_actor
2710 .delete_refs(
2711 &test_setup.client,
2712 vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2713 )
2714 .await?;
2715
2716 test_setup
2717 .stream_actor
2718 .finalize_recording(&test_setup.client, 0.into())
2719 .await?;
2720
2721 let actual0_ref = test_setup.next_ref();
2722 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2723
2724 let actual1_ref = test_setup.next_ref();
2725 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2726
2727 test_setup
2728 .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2729 .await?;
2730
2731 let actual_result_ref = test_setup.next_ref();
2732 test_setup
2733 .stream_actor
2734 .call_recording(
2735 &test_setup.client,
2736 0.into(),
2737 0.into(),
2738 vec![actual_result_ref],
2739 vec![actual0_ref, actual1_ref],
2740 )
2741 .await?;
2742
2743 assert!(
2744 test_setup
2745 .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2746 .await
2747 );
2748
2749 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2751
2752 let actual_result_ref = test_setup.next_ref();
2753 test_setup
2754 .stream_actor
2755 .call_recording(
2756 &test_setup.client,
2757 1.into(),
2758 0.into(),
2759 vec![actual_result_ref],
2760 vec![actual0_ref, actual1_ref],
2761 )
2762 .await?;
2763
2764 for ref_ in [actual0_ref, actual1_ref] {
2766 let _ = test_setup
2767 .stream_actor
2768 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2769 .await?
2770 .unwrap()
2771 .unwrap();
2772 }
2773
2774 for ref_ in [captured_ref, actual_result_ref] {
2775 let result_error = test_setup
2776 .stream_actor
2777 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2778 .await?
2779 .unwrap()
2780 .unwrap_err();
2781 let error_str = result_error.to_string();
2783 assert!(
2784 error_str.contains("torch operator error"),
2785 "Error should contain 'torch operator failed': {}",
2786 error_str
2787 );
2788 }
2789
2790 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2791 match controller_msg {
2792 ControllerMessage::RemoteFunctionFailed { seq, error } => {
2793 assert_eq!(seq, 1.into());
2794 assert!(
2795 error.backtrace.contains("torch operator error"),
2796 "Unexpected WorkerError: {:?}",
2797 error
2798 );
2799 }
2800 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2801 };
2802
2803 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2805
2806 let actual_result_ref = test_setup.next_ref();
2810 test_setup
2811 .stream_actor
2812 .call_recording(
2813 &test_setup.client,
2814 2.into(),
2815 0.into(),
2816 vec![actual_result_ref],
2817 vec![actual0_ref, actual1_ref],
2818 )
2819 .await?;
2820
2821 for ref_ in [actual0_ref, actual1_ref] {
2823 let _ = test_setup
2824 .stream_actor
2825 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2826 .await?
2827 .unwrap()
2828 .unwrap();
2829 }
2830
2831 for ref_ in [captured_ref, actual_result_ref] {
2832 let result_error = test_setup
2833 .stream_actor
2834 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2835 .await?
2836 .unwrap()
2837 .unwrap_err();
2838 let error_str = result_error.to_string();
2840 assert!(
2841 error_str.contains("torch operator error"),
2842 "Error should contain input error: {}",
2843 error_str
2844 );
2845 }
2846
2847 check_fetch_result_error(
2851 &test_setup.client,
2852 test_setup.stream_actor.clone(),
2853 3.into(),
2854 captured_ref,
2855 &mut test_setup.controller_rx,
2856 "torch operator error",
2857 )
2858 .await;
2859
2860 Ok(())
2861 }
2862
2863 #[async_timed_test(timeout_secs = 60)]
2864 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2865 let mut test_setup = TestSetup::new().await?;
2866 test_setup
2867 .stream_actor
2868 .define_recording(&test_setup.client, 0.into())
2869 .await?;
2870
2871 let borrow_id = 1;
2872 let tensor_ref = test_setup.next_ref();
2873 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2874
2875 test_setup
2876 .stream_actor
2877 .borrow_create(
2878 &test_setup.client,
2879 borrow_id,
2880 tensor_ref,
2881 first_use_sender.clone(),
2882 )
2883 .await?;
2884
2885 test_setup
2886 .stream_actor
2887 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2888 .await?;
2889
2890 assert_actor_failed_with_msg(
2891 &test_setup.proc,
2892 test_setup.stream_actor.actor_id(),
2893 "duplicate borrow create in recording".into(),
2894 )
2895 .await;
2896
2897 Ok(())
2898 }
2899
2900 #[async_timed_test(timeout_secs = 60)]
2901 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2902 let test_setup = TestSetup::new().await?;
2903 test_setup
2904 .stream_actor
2905 .define_recording(&test_setup.client, 0.into())
2906 .await?;
2907
2908 let borrow_id = 1;
2909 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2910
2911 test_setup
2912 .stream_actor
2913 .borrow_drop(
2914 &test_setup.client,
2915 borrow_id,
2916 Arc::new(Mutex::new(last_use_receiver)),
2917 )
2918 .await?;
2919
2920 assert_actor_failed_with_msg(
2921 &test_setup.proc,
2922 test_setup.stream_actor.actor_id(),
2923 "borrow drop for borrow not defined in recording".into(),
2924 )
2925 .await;
2926
2927 Ok(())
2928 }
2929
2930 #[async_timed_test(timeout_secs = 60)]
2931 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2932 let mut test_setup = TestSetup::new().await?;
2933 test_setup
2934 .stream_actor
2935 .define_recording(&test_setup.client, 0.into())
2936 .await?;
2937
2938 let borrow_id = 1;
2939 let tensor_ref = test_setup.next_ref();
2940 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2941
2942 test_setup
2943 .stream_actor
2944 .borrow_create(
2945 &test_setup.client,
2946 borrow_id,
2947 tensor_ref,
2948 first_use_sender.clone(),
2949 )
2950 .await?;
2951
2952 test_setup
2954 .stream_actor
2955 .finalize_recording(&test_setup.client, 0.into())
2956 .await?;
2957
2958 assert_actor_failed_with_msg(
2959 &test_setup.proc,
2960 test_setup.stream_actor.actor_id(),
2961 "all borrows created within recording must be dropped within recording".into(),
2962 )
2963 .await;
2964
2965 Ok(())
2966 }
2967
2968 #[async_timed_test(timeout_secs = 60)]
2969 async fn test_borrow_in_recording() -> Result<()> {
2970 let mut test_setup = TestSetup::new().await?;
2971
2972 let borrower_stream = test_setup
2973 .proc
2974 .spawn::<StreamActor>(
2975 "stream1",
2976 StreamParams {
2977 world_size: 1,
2978 rank: 0,
2979 creation_mode: StreamCreationMode::CreateNewStream,
2980 id: 1.into(),
2981 device: Some(CudaDevice::new(0.into())),
2982 controller_actor: test_setup.controller_actor.clone(),
2983 respond_with_python_message: false,
2984 },
2985 )
2986 .await?;
2987
2988 let lender_stream = test_setup.stream_actor.clone();
2989
2990 let borrow_id = 1;
2991 let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
2992 let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
2993
2994 lender_stream
2996 .define_recording(&test_setup.client, 0.into())
2997 .await?;
2998
2999 let formal_ref = test_setup.next_ref();
3000 lender_stream
3001 .recording_formal(&test_setup.client, formal_ref, 0)
3002 .await?;
3003
3004 lender_stream
3005 .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3006 .await?;
3007
3008 lender_stream
3009 .borrow_drop(
3010 &test_setup.client,
3011 borrow_id,
3012 Arc::new(Mutex::new(last_use_receiver)),
3013 )
3014 .await?;
3015
3016 lender_stream
3017 .finalize_recording(&test_setup.client, 0.into())
3018 .await?;
3019
3020 let borrower_tensor_ref = test_setup.next_ref();
3021 let borrower_tensor = TensorCell::new(factory_float_tensor(
3022 &[1.0, 2.0, 3.0],
3023 "cuda".try_into().unwrap(),
3024 ));
3025
3026 borrower_stream
3027 .set_tensor_ref_unit_tests_only(
3028 &test_setup.client,
3029 borrower_tensor_ref,
3030 Ok(borrower_tensor.clone()),
3031 )
3032 .await?;
3033
3034 borrower_stream
3036 .define_recording(&test_setup.client, 0.into())
3037 .await?;
3038
3039 let borrowed_ref = test_setup.next_ref();
3040
3041 borrower_stream
3042 .borrow_first_use(
3043 &test_setup.client,
3044 borrow_id,
3045 borrowed_ref,
3046 Arc::new(Mutex::new(first_use_receiver)),
3047 )
3048 .await?;
3049
3050 let result_ref = test_setup.next_ref();
3051 borrower_stream
3052 .call_function(
3053 &test_setup.client,
3054 CallFunctionParams {
3055 seq: 100.into(),
3056 function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3057 args: vec![
3058 WireValue::Ref(borrowed_ref),
3059 WireValue::Ref(borrower_tensor_ref),
3060 ],
3061 kwargs: HashMap::new(),
3062 results: vec![Some(result_ref)],
3063 mutates: vec![],
3064 stream: 1.into(),
3065 remote_process_groups: Vec::new(),
3066 },
3067 HashMap::new(),
3068 HashMap::new(),
3069 )
3070 .await?;
3071
3072 borrower_stream
3073 .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3074 .await?;
3075
3076 borrower_stream
3077 .recording_result(&test_setup.client, result_ref, 0)
3078 .await?;
3079
3080 borrower_stream
3081 .finalize_recording(&test_setup.client, 0.into())
3082 .await?;
3083
3084 let input_tensor_ref = test_setup.next_ref();
3086 test_setup
3087 .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3088 .await?;
3089
3090 let result_tensor_ref = test_setup.next_ref();
3091
3092 let lender_future = lender_stream.call_recording(
3093 &test_setup.client,
3094 0.into(),
3095 0.into(),
3096 vec![],
3097 vec![input_tensor_ref],
3098 );
3099
3100 let borrower_future = borrower_stream.call_recording(
3101 &test_setup.client,
3102 0.into(),
3103 0.into(),
3104 vec![result_tensor_ref],
3105 vec![],
3106 );
3107
3108 tokio::try_join!(lender_future, borrower_future)?;
3109
3110 let result_tensor = borrower_stream
3111 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3112 .await?
3113 .unwrap()
3114 .unwrap();
3115
3116 let expected_tensor = TensorCell::new(factory_float_tensor(
3117 &[5.0, 7.0, 9.0],
3118 "cpu".try_into().unwrap(),
3119 ));
3120 assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3121
3122 let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3124 &[1.0, 2.0],
3125 "cuda".try_into().unwrap(),
3126 ));
3127 borrower_stream
3128 .set_tensor_ref_unit_tests_only(
3129 &test_setup.client,
3130 borrower_tensor_ref,
3131 Ok(invalid_borrower_tensor.clone()),
3132 )
3133 .await?;
3134
3135 let lender_future = lender_stream.call_recording(
3137 &test_setup.client,
3138 1.into(),
3139 0.into(),
3140 vec![],
3141 vec![input_tensor_ref],
3142 );
3143
3144 let borrower_future = borrower_stream.call_recording(
3145 &test_setup.client,
3146 1.into(),
3147 0.into(),
3148 vec![result_tensor_ref],
3149 vec![],
3150 );
3151
3152 tokio::try_join!(lender_future, borrower_future)?;
3153
3154 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3156 match controller_msg {
3157 ControllerMessage::RemoteFunctionFailed { seq, error } => {
3158 assert_eq!(seq, 1.into());
3159 assert!(
3160 error.backtrace.contains("recording failed"),
3161 "Unexpected WorkerError: {:?}",
3162 error
3163 );
3164 assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3165 }
3166 _ => panic!("Unexpected controller message: {:?}", controller_msg),
3167 };
3168
3169 check_fetch_result_value(
3171 &test_setup.client,
3172 lender_stream.clone(),
3173 2.into(),
3174 input_tensor_ref,
3175 &mut test_setup.controller_rx,
3176 )
3177 .await;
3178
3179 let input_error = fake_seq_error(anyhow!("input error"));
3181 lender_stream
3182 .set_tensor_ref_unit_tests_only(
3183 &test_setup.client,
3184 input_tensor_ref,
3185 Err(input_error.clone()),
3186 )
3187 .await?;
3188
3189 let lender_future = lender_stream.call_recording(
3190 &test_setup.client,
3191 3.into(),
3192 0.into(),
3193 vec![],
3194 vec![input_tensor_ref],
3195 );
3196
3197 let borrower_future = borrower_stream.call_recording(
3198 &test_setup.client,
3199 3.into(),
3200 0.into(),
3201 vec![result_tensor_ref],
3202 vec![],
3203 );
3204
3205 tokio::try_join!(lender_future, borrower_future)?;
3206
3207 let result_error = borrower_stream
3209 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3210 .await?
3211 .unwrap()
3212 .unwrap_err();
3213
3214 let error_str = result_error.to_string();
3216 assert!(
3217 error_str.contains("input error"),
3218 "Error should contain input error: {}",
3219 error_str
3220 );
3221
3222 let input_error_str = input_error.to_string();
3225 assert!(
3226 error_str.contains(&input_error_str),
3227 "Error should contain the original error: {}",
3228 error_str
3229 );
3230
3231 check_fetch_result_error(
3233 &test_setup.client,
3234 lender_stream,
3235 4.into(),
3236 input_tensor_ref,
3237 &mut test_setup.controller_rx,
3238 "input error",
3239 )
3240 .await;
3241
3242 check_fetch_result_error(
3244 &test_setup.client,
3245 borrower_stream,
3246 5.into(),
3247 result_tensor_ref,
3248 &mut test_setup.controller_rx,
3249 "input error",
3250 )
3251 .await;
3252
3253 Ok(())
3254 }
3255
3256 #[async_timed_test(timeout_secs = 60)]
3257 async fn test_reduce_in_recording() -> Result<()> {
3258 let mut test_setup = TestSetup::new().await?;
3259 let recording_ref = test_setup.next_ref();
3260
3261 let comm = Arc::new(
3262 test_setup
3263 .proc
3264 .spawn::<NcclCommActor>(
3265 "comm",
3266 CommParams::New {
3267 device: CudaDevice::new(0.into()),
3268 unique_id: UniqueId::new()?,
3269 world_size: 1,
3270 rank: 0,
3271 },
3272 )
3273 .await?,
3274 );
3275
3276 let factory = Factory {
3277 size: vec![3],
3278 dtype: torch_sys::ScalarType::Float,
3279 layout: torch_sys::Layout::Strided,
3280 device: "cuda".try_into().unwrap(),
3281 };
3282
3283 let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3284
3285 test_setup
3286 .stream_actor
3287 .define_recording(&test_setup.client, recording_ref)
3288 .await?;
3289
3290 let formal_tensor_ref_0 = test_setup.next_ref();
3291 let formal_tensor_ref_1 = test_setup.next_ref();
3292 let formal_tensor_ref_2 = test_setup.next_ref();
3293
3294 test_setup
3295 .stream_actor
3296 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3297 .await?;
3298 test_setup
3299 .stream_actor
3300 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3301 .await?;
3302 test_setup
3303 .stream_actor
3304 .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3305 .await?;
3306
3307 let intermediate_tensor_ref_0 = test_setup.next_ref();
3308
3309 test_setup
3311 .stream_actor
3312 .reduce(
3313 &test_setup.client,
3314 comm.clone(),
3315 1,
3316 intermediate_tensor_ref_0,
3317 formal_tensor_ref_0,
3318 factory.clone(),
3319 reduction.clone(),
3320 false,
3321 true,
3322 None,
3323 )
3324 .await?;
3325
3326 let intermediate_tensor_ref_1 = test_setup.next_ref();
3328 test_setup
3329 .stream_actor
3330 .reduce(
3331 &test_setup.client,
3332 comm.clone(),
3333 1,
3334 intermediate_tensor_ref_1,
3335 formal_tensor_ref_1,
3336 factory.clone(),
3337 reduction.clone(),
3338 false,
3339 false,
3340 None,
3341 )
3342 .await?;
3343
3344 let intermediate_tensor_ref_2 = test_setup.next_ref();
3345
3346 test_setup
3348 .stream_actor
3349 .reduce(
3350 &test_setup.client,
3351 comm.clone(),
3352 1,
3353 intermediate_tensor_ref_2,
3354 intermediate_tensor_ref_1,
3355 factory.clone(),
3356 reduction.clone(),
3357 false,
3358 false,
3359 Some(formal_tensor_ref_2),
3360 )
3361 .await?;
3362
3363 test_setup
3364 .stream_actor
3365 .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3366 .await?;
3367
3368 test_setup
3369 .stream_actor
3370 .finalize_recording(&test_setup.client, recording_ref)
3371 .await?;
3372
3373 let input_tensor_ref_0 = test_setup.next_ref();
3374 let input_tensor_ref_1 = test_setup.next_ref();
3375 let input_tensor_ref_2 = test_setup.next_ref();
3376
3377 test_setup
3378 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3379 .await?;
3380
3381 test_setup
3382 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3383 .await?;
3384
3385 test_setup
3386 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3387 .await?;
3388
3389 let output_ref = test_setup.next_ref();
3390
3391 test_setup
3392 .stream_actor
3393 .call_recording(
3394 &test_setup.client,
3395 0.into(),
3396 recording_ref,
3397 vec![output_ref],
3398 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3399 )
3400 .await?;
3401
3402 assert!(
3404 test_setup
3405 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3406 .await
3407 );
3408 for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3410 assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3411 }
3412
3413 let input_error = fake_seq_error(anyhow!("input error"));
3415 test_setup
3416 .stream_actor
3417 .set_tensor_ref_unit_tests_only(
3418 &test_setup.client,
3419 input_tensor_ref_0,
3420 Err(input_error.clone()),
3421 )
3422 .await?;
3423
3424 test_setup
3425 .stream_actor
3426 .call_recording(
3427 &test_setup.client,
3428 1.into(),
3429 recording_ref,
3430 vec![output_ref],
3431 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3432 )
3433 .await?;
3434
3435 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3437 test_setup
3438 .validate_dependent_error(ref_, input_error.clone())
3439 .await;
3440 }
3441
3442 assert!(
3444 test_setup
3445 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3446 .await
3447 );
3448
3449 check_fetch_result_value(
3451 &test_setup.client,
3452 test_setup.stream_actor.clone(),
3453 2.into(),
3454 input_tensor_ref_1,
3455 &mut test_setup.controller_rx,
3456 )
3457 .await;
3458
3459 test_setup
3461 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3462 .await?;
3463 test_setup
3464 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3465 .await?;
3466
3467 test_setup
3469 .stream_actor
3470 .set_tensor_ref_unit_tests_only(
3471 &test_setup.client,
3472 input_tensor_ref_1,
3473 Err(input_error.clone()),
3474 )
3475 .await?;
3476
3477 test_setup
3478 .stream_actor
3479 .call_recording(
3480 &test_setup.client,
3481 3.into(),
3482 recording_ref,
3483 vec![output_ref],
3484 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3485 )
3486 .await?;
3487
3488 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3491 test_setup
3492 .validate_dependent_error(ref_, input_error.clone())
3493 .await;
3494 }
3495
3496 check_fetch_result_error(
3498 &test_setup.client,
3499 test_setup.stream_actor.clone(),
3500 4.into(),
3501 input_tensor_ref_1,
3502 &mut test_setup.controller_rx,
3503 "input error",
3504 )
3505 .await;
3506
3507 test_setup
3509 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3510 .await?;
3511 test_setup
3512 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3513 .await?;
3514
3515 test_setup
3517 .stream_actor
3518 .set_tensor_ref_unit_tests_only(
3519 &test_setup.client,
3520 input_tensor_ref_2,
3521 Err(input_error.clone()),
3522 )
3523 .await?;
3524
3525 test_setup
3526 .stream_actor
3527 .call_recording(
3528 &test_setup.client,
3529 5.into(),
3530 recording_ref,
3531 vec![output_ref],
3532 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3533 )
3534 .await?;
3535
3536 assert!(
3538 test_setup
3539 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3540 .await
3541 );
3542
3543 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3546 test_setup
3547 .validate_dependent_error(ref_, input_error.clone())
3548 .await;
3549 }
3550
3551 check_fetch_result_value(
3553 &test_setup.client,
3554 test_setup.stream_actor.clone(),
3555 6.into(),
3556 input_tensor_ref_1,
3557 &mut test_setup.controller_rx,
3558 )
3559 .await;
3560
3561 Ok(())
3562 }
3563
3564 #[async_timed_test(timeout_secs = 60)]
3565 async fn test_send_tensor_in_recording() -> Result<()> {
3566 let mut test_setup = TestSetup::new_with_world_size(2).await?;
3567 let recording_ref = test_setup.next_ref();
3568
3569 let unique_id = UniqueId::new()?;
3570 let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3571 "comm0",
3572 CommParams::New {
3573 device: CudaDevice::new(0.into()),
3574 unique_id: unique_id.clone(),
3575 world_size: 2,
3576 rank: 0,
3577 },
3578 );
3579 let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3580 "comm1",
3581 CommParams::New {
3582 device: CudaDevice::new(1.into()),
3583 unique_id,
3584 world_size: 2,
3585 rank: 1,
3586 },
3587 );
3588 let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3589 let comm0 = Arc::new(comm0);
3590 let comm1 = Arc::new(comm1);
3591
3592 let factory = Factory {
3593 size: vec![3],
3594 dtype: torch_sys::ScalarType::Float,
3595 layout: torch_sys::Layout::Strided,
3596 device: "cuda".try_into().unwrap(),
3597 };
3598
3599 let send_stream = test_setup.stream_actor.clone();
3600 let recv_stream = test_setup
3601 .proc
3602 .spawn::<StreamActor>(
3603 "recv_stream",
3604 StreamParams {
3605 world_size: 2,
3606 rank: 1,
3607 creation_mode: StreamCreationMode::CreateNewStream,
3608 id: 1.into(),
3609 device: Some(CudaDevice::new(1.into())),
3610 controller_actor: test_setup.controller_actor.clone(),
3611 respond_with_python_message: false,
3612 },
3613 )
3614 .await?;
3615
3616 send_stream
3617 .define_recording(&test_setup.client, recording_ref)
3618 .await?;
3619 recv_stream
3620 .define_recording(&test_setup.client, recording_ref)
3621 .await?;
3622
3623 let formal_tensor_ref_0 = test_setup.next_ref();
3624 let formal_tensor_ref_1 = test_setup.next_ref();
3625
3626 send_stream
3627 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3628 .await?;
3629 send_stream
3630 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3631 .await?;
3632
3633 let _ref = test_setup.next_ref();
3634 send_stream
3635 .send_tensor(
3636 &test_setup.client,
3637 _ref,
3638 None,
3639 Some(1),
3640 formal_tensor_ref_0,
3641 factory.clone(),
3642 comm0.clone(),
3643 )
3644 .await?;
3645
3646 let result_ref_0 = test_setup.next_ref();
3647 let _ref = test_setup.next_ref();
3648 recv_stream
3649 .send_tensor(
3650 &test_setup.client,
3651 result_ref_0,
3652 Some(0),
3653 None,
3654 _ref,
3655 factory.clone(),
3656 comm1,
3657 )
3658 .await?;
3659
3660 let result_ref_1 = test_setup.next_ref();
3661 send_stream
3662 .send_tensor(
3663 &test_setup.client,
3664 result_ref_1,
3665 Some(0),
3666 Some(0),
3667 formal_tensor_ref_1,
3668 factory.clone(),
3669 comm0,
3670 )
3671 .await?;
3672
3673 send_stream
3674 .recording_result(&test_setup.client, result_ref_1, 0)
3675 .await?;
3676 recv_stream
3677 .recording_result(&test_setup.client, result_ref_0, 0)
3678 .await?;
3679
3680 send_stream
3681 .finalize_recording(&test_setup.client, recording_ref)
3682 .await?;
3683 recv_stream
3684 .finalize_recording(&test_setup.client, recording_ref)
3685 .await?;
3686
3687 let input_tensor_ref_0 = test_setup.next_ref();
3688 let input_tensor_ref_1 = test_setup.next_ref();
3689 test_setup
3690 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3691 .await?;
3692 test_setup
3693 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3694 .await?;
3695
3696 let actual_result_ref_0 = test_setup.next_ref();
3697 let actual_result_ref_1 = test_setup.next_ref();
3698 let send_fut = send_stream.call_recording(
3699 &test_setup.client,
3700 0.into(),
3701 recording_ref,
3702 vec![actual_result_ref_1],
3703 vec![input_tensor_ref_0, input_tensor_ref_1],
3704 );
3705 let recv_fut = recv_stream.call_recording(
3706 &test_setup.client,
3707 0.into(),
3708 recording_ref,
3709 vec![actual_result_ref_0],
3710 vec![],
3711 );
3712 tokio::try_join!(send_fut, recv_fut)?;
3713
3714 assert!(
3715 test_setup
3716 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3717 .await
3718 );
3719 assert!(
3720 test_setup
3721 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3722 .await
3723 );
3724 assert!(
3725 test_setup
3726 .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3727 .await
3728 );
3729
3730 let actual_result_0 = recv_stream
3731 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3732 .await
3733 .unwrap()
3734 .unwrap()
3735 .unwrap();
3736 assert!(allclose(
3737 &actual_result_0.borrow(),
3738 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3739 )?);
3740
3741 check_fetch_result_value(
3743 &test_setup.client,
3744 send_stream.clone(),
3745 1.into(),
3746 actual_result_ref_1,
3747 &mut test_setup.controller_rx,
3748 )
3749 .await;
3750 check_fetch_result_value(
3751 &test_setup.client,
3752 recv_stream.clone(),
3753 2.into(),
3754 actual_result_ref_0,
3755 &mut test_setup.controller_rx,
3756 )
3757 .await;
3758
3759 let input_error = fake_seq_error(anyhow!("input error"));
3760 send_stream
3761 .set_tensor_ref_unit_tests_only(
3762 &test_setup.client,
3763 input_tensor_ref_0,
3764 Err(input_error.clone()),
3765 )
3766 .await?;
3767
3768 let send_fut = send_stream.call_recording(
3769 &test_setup.client,
3770 3.into(),
3771 recording_ref,
3772 vec![actual_result_ref_1],
3773 vec![input_tensor_ref_0, input_tensor_ref_1],
3774 );
3775 let recv_fut = recv_stream.call_recording(
3776 &test_setup.client,
3777 3.into(),
3778 recording_ref,
3779 vec![actual_result_ref_0],
3780 vec![],
3781 );
3782 tokio::try_join!(send_fut, recv_fut)?;
3783
3784 let _ = recv_stream
3786 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3787 .await
3788 .unwrap()
3789 .unwrap()
3790 .unwrap();
3791
3792 test_setup
3793 .validate_dependent_error(actual_result_ref_1, input_error.clone())
3794 .await;
3795
3796 assert!(
3798 test_setup
3799 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3800 .await
3801 );
3802
3803 check_fetch_result_error(
3805 &test_setup.client,
3806 send_stream.clone(),
3807 4.into(),
3808 actual_result_ref_1,
3809 &mut test_setup.controller_rx,
3810 "input error",
3811 )
3812 .await;
3813 check_fetch_result_value(
3814 &test_setup.client,
3815 recv_stream.clone(),
3816 5.into(),
3817 actual_result_ref_0,
3818 &mut test_setup.controller_rx,
3819 )
3820 .await;
3821
3822 test_setup
3823 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3824 .await?;
3825 send_stream
3826 .set_tensor_ref_unit_tests_only(
3827 &test_setup.client,
3828 input_tensor_ref_1,
3829 Err(input_error.clone()),
3830 )
3831 .await?;
3832
3833 let send_fut = send_stream.call_recording(
3834 &test_setup.client,
3835 6.into(),
3836 recording_ref,
3837 vec![actual_result_ref_1],
3838 vec![input_tensor_ref_0, input_tensor_ref_1],
3839 );
3840 let recv_fut = recv_stream.call_recording(
3841 &test_setup.client,
3842 6.into(),
3843 recording_ref,
3844 vec![actual_result_ref_0],
3845 vec![],
3846 );
3847 tokio::try_join!(send_fut, recv_fut)?;
3848
3849 let actual_result_0 = recv_stream
3850 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3851 .await
3852 .unwrap()
3853 .unwrap()
3854 .unwrap();
3855 assert!(allclose(
3856 &actual_result_0.borrow(),
3857 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3858 )?);
3859
3860 assert!(
3861 test_setup
3862 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3863 .await
3864 );
3865
3866 test_setup
3867 .validate_dependent_error(actual_result_ref_1, input_error)
3868 .await;
3869
3870 check_fetch_result_error(
3872 &test_setup.client,
3873 send_stream.clone(),
3874 7.into(),
3875 actual_result_ref_1,
3876 &mut test_setup.controller_rx,
3877 "input error",
3878 )
3879 .await;
3880 check_fetch_result_value(
3881 &test_setup.client,
3882 recv_stream.clone(),
3883 8.into(),
3884 actual_result_ref_0,
3885 &mut test_setup.controller_rx,
3886 )
3887 .await;
3888
3889 Ok(())
3890 }
3891}