1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::Mailbox;
36use hyperactor::mailbox::OncePortHandle;
37use hyperactor::mailbox::PortReceiver;
38use hyperactor::proc::Proc;
39use monarch_hyperactor::actor::PythonMessage;
40use monarch_hyperactor::actor::PythonMessageKind;
41use monarch_hyperactor::buffers::FrozenBuffer;
42use monarch_hyperactor::local_state_broker::BrokerId;
43use monarch_hyperactor::local_state_broker::LocalState;
44use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
45use monarch_messages::controller::ControllerMessageClient;
46use monarch_messages::controller::Seq;
47use monarch_messages::controller::WorkerError;
48use monarch_messages::worker::ActorCallParams;
49use monarch_messages::worker::ActorMethodParams;
50use monarch_messages::worker::CallFunctionError;
51use monarch_messages::worker::CallFunctionParams;
52use monarch_messages::worker::SeqError;
53use monarch_messages::worker::StreamRef;
54use monarch_types::PyTree;
55use monarch_types::SerializablePyErr;
56use monarch_types::TryIntoPyObjectUnsafe;
57use pyo3::prelude::*;
58use pyo3::types::PyTuple;
59use tokio::runtime::Handle;
60use tokio::sync::Mutex;
61use tokio::task::JoinHandle;
62use torch_sys::BorrowType;
63use torch_sys::CudaDevice;
64use torch_sys::MultiBorrow;
65use torch_sys::RValue;
66use torch_sys::TensorCell;
67use torch_sys::deep_clone;
68use torch_sys::factory_empty;
69use torch_sys::factory_zeros;
70use torch_sys_cuda::cuda::Event;
71use torch_sys_cuda::cuda::Stream;
72use tracing_subscriber::fmt::Subscriber;
73
74use crate::ControllerActor;
75use crate::DeviceMesh;
76use crate::Factory;
77use crate::Reduction;
78use crate::Ref;
79use crate::ResolvableFunction;
80use crate::StreamCreationMode;
81use crate::WireValue;
82use crate::comm::CommBackend;
83use crate::comm::CommMessage;
84use crate::comm::CommMessageClient;
85use crate::comm::NcclCommActor;
86use crate::pipe::PipeMessage;
87
88pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
89
90thread_local! {
92 pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
93 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
94 pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
95}
96
97fn pickle_python_result(
98 py: Python<'_>,
99 result: Bound<'_, PyAny>,
100 worker_actor_id: ActorId,
101) -> Result<PythonMessage, anyhow::Error> {
102 let pickle = py
103 .import("monarch._src.actor.actor_mesh")
104 .unwrap()
105 .getattr("_pickle")
106 .unwrap();
107 let data: FrozenBuffer = pickle
108 .call1((result,))
109 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
110 .extract()
111 .unwrap();
112 Ok(PythonMessage::new_from_buf(
113 PythonMessageKind::Result {
114 rank: Some(worker_actor_id.rank()),
115 },
116 data.inner,
117 ))
118}
119
120#[derive(Debug)]
121struct Recording {
122 messages: Vec<StreamMessage>,
123}
124
125impl Recording {
126 fn new() -> Self {
127 Self {
128 messages: Vec::new(),
129 }
130 }
131}
132
133#[derive(Debug, PartialEq)]
134enum RecordingState {
135 Defining {
136 recording: Ref,
137 defined_borrows: HashSet<u64>,
140 },
141 Running,
142}
143
144#[derive(Handler, HandleClient, Debug, Named)]
147#[named(register = false)]
148pub enum StreamMessage {
149 CallFunction(
150 CallFunctionParams,
151 HashMap<Ref, DeviceMesh>,
152 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
153 ),
154
155 BorrowCreate {
156 borrow: u64,
158 tensor: Ref,
160 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
163 },
164
165 BorrowFirstUse {
166 borrow: u64,
168 result: Ref,
170 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
173 },
174
175 BorrowLastUse {
176 borrow: u64,
178 result: Ref,
180 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
182 },
183
184 BorrowDrop {
185 borrow: u64,
186 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
188 },
189
190 DeleteRefs(Vec<Ref>),
191
192 RequestStatus(#[reply] OncePortHandle<()>),
193
194 InitComm(ActorHandle<NcclCommActor>),
195
196 Reduce {
197 comm: Arc<ActorHandle<NcclCommActor>>,
198 dim_size: i64,
199 result: Ref,
200 local_tensor: Ref,
201 factory: Factory,
202 reduction: Reduction,
203 scatter: bool,
204 in_place: bool,
205 out: Option<Ref>,
206 },
207
208 SendTensor {
209 result: Ref,
210 from_rank: Option<usize>,
211 to_rank: Option<usize>,
212 tensor: Ref,
213 factory: Factory,
214 comm: Arc<ActorHandle<NcclCommActor>>,
215 },
216
217 SendValue {
218 seq: Seq,
219 worker_actor_id: ActorId,
220 mutates: Vec<Ref>,
221 function: Option<ResolvableFunction>,
222 args: Vec<WireValue>,
223 kwargs: HashMap<String, WireValue>,
224 device_meshes: HashMap<Ref, DeviceMesh>,
225 pipe: Option<PortHandle<PipeMessage>>,
226 },
227
228 SetValue {
229 seq: Seq,
230 results: Vec<Option<Ref>>,
231 pipe: PortHandle<PipeMessage>,
232 },
233
234 DefineRecording {
235 recording: Ref,
236 },
237
238 FinalizeRecording {
239 recording: Ref,
240 },
241
242 CallRecording {
243 seq: Seq,
244 recording: Ref,
245 results: Vec<Ref>,
246 actuals: Vec<Ref>,
247 },
248
249 RecordingFormal {
250 result: Ref,
251 argument_index: usize,
252 },
253
254 RecordingResult {
255 result: Ref,
256 output_index: usize,
257 },
258
259 SetRefUnitTestsOnly(Ref, WireValue),
260
261 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
262
263 GetRefUnitTestsOnly(
264 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
266 ),
267
268 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
269
270 SendResultOfActorCall(ActorId, ActorCallParams),
271 CallActorMethod(ActorMethodParams),
272}
273
274impl StreamMessage {
275 fn clone_for_recording(&self) -> Self {
276 match self {
277 StreamMessage::RecordingFormal {
278 result,
279 argument_index,
280 } => StreamMessage::RecordingFormal {
281 result: *result,
282 argument_index: *argument_index,
283 },
284 StreamMessage::RecordingResult {
285 result,
286 output_index,
287 } => StreamMessage::RecordingResult {
288 result: *result,
289 output_index: *output_index,
290 },
291 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
292 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
293 StreamMessage::CallFunction(
294 params.clone(),
295 device_meshes.clone(),
296 remote_process_groups.clone(),
297 )
298 }
299 StreamMessage::BorrowCreate {
300 borrow,
301 tensor,
302 first_use_sender,
303 } => StreamMessage::BorrowCreate {
304 borrow: *borrow,
305 tensor: *tensor,
306 first_use_sender: first_use_sender.clone(),
307 },
308 StreamMessage::BorrowFirstUse {
309 borrow,
310 result,
311 first_use_receiver,
312 } => StreamMessage::BorrowFirstUse {
313 borrow: *borrow,
314 result: *result,
315 first_use_receiver: first_use_receiver.clone(),
316 },
317 StreamMessage::BorrowLastUse {
318 borrow,
319 result,
320 last_use_sender,
321 } => StreamMessage::BorrowLastUse {
322 borrow: *borrow,
323 result: *result,
324 last_use_sender: last_use_sender.clone(),
325 },
326 StreamMessage::BorrowDrop {
327 borrow,
328 last_use_receiver,
329 } => StreamMessage::BorrowDrop {
330 borrow: *borrow,
331 last_use_receiver: last_use_receiver.clone(),
332 },
333 StreamMessage::Reduce {
334 comm,
335 dim_size,
336 result,
337 local_tensor,
338 factory,
339 reduction,
340 scatter,
341 in_place,
342 out,
343 } => StreamMessage::Reduce {
344 comm: comm.clone(),
345 dim_size: *dim_size,
346 result: *result,
347 local_tensor: *local_tensor,
348 factory: factory.clone(),
349 reduction: reduction.clone(),
350 scatter: *scatter,
351 in_place: *in_place,
352 out: out.clone(),
353 },
354 StreamMessage::SendTensor {
355 result,
356 from_rank,
357 to_rank,
358 tensor,
359 factory,
360 comm,
361 } => StreamMessage::SendTensor {
362 result: *result,
363 from_rank: *from_rank,
364 to_rank: *to_rank,
365 tensor: *tensor,
366 factory: factory.clone(),
367 comm: comm.clone(),
368 },
369 StreamMessage::SetValue { seq, results, pipe } => StreamMessage::SetValue {
370 seq: seq.clone(),
371 results: results.clone(),
372 pipe: pipe.clone(),
373 },
374 other => panic!(
375 "StreamMessage variant not supported in recording: {:?}",
376 other
377 ),
378 }
379 }
380
381 fn get_defined_refs(&self) -> HashSet<Ref> {
383 match self {
384 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
385 StreamMessage::CallFunction(params, ..) => {
386 params.results.iter().filter_map(|&ref_| ref_).collect()
387 }
388 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
389 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
390 StreamMessage::SendTensor {
391 result, from_rank, ..
392 } => {
393 if from_rank.is_some() {
394 HashSet::from([*result])
395 } else {
396 HashSet::new()
397 }
398 }
399 StreamMessage::SetValue { results, .. } => {
400 results.iter().filter_map(|&ref_| ref_).collect()
401 }
402 _ => HashSet::new(),
404 }
405 }
406
407 fn get_mutated_refs(&self) -> HashSet<Ref> {
409 match self {
410 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
411 StreamMessage::Reduce {
412 out,
413 in_place,
414 local_tensor,
415 ..
416 } => {
417 if *in_place {
418 HashSet::from([*local_tensor])
419 } else if let Some(out) = out {
420 HashSet::from([*out])
421 } else {
422 HashSet::new()
423 }
424 }
425 _ => HashSet::new(),
427 }
428 }
429}
430
431#[derive(Debug)]
440pub struct StreamActor {
441 world_size: usize,
442 rank: usize,
443 env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
447 creation_mode: StreamCreationMode,
449 cuda_stream: OnceLock<Option<Stream>>,
455 device: Option<CudaDevice>,
457 comm: Option<ActorHandle<NcclCommActor>>,
459 controller_actor: ActorRef<ControllerActor>,
461 remote_process_groups: HashMap<Ref, PyObject>,
462 recordings: HashMap<Ref, Recording>,
463 active_recording: Option<RecordingState>,
464 respond_with_python_message: bool,
465 last_seq_error: Option<Arc<SeqError>>,
466}
467
468#[derive(Debug, Clone)]
470pub struct StreamParams {
471 pub world_size: usize,
472 pub rank: usize,
473 pub creation_mode: StreamCreationMode,
475 pub id: StreamRef,
477 pub device: Option<CudaDevice>,
480 pub controller_actor: ActorRef<ControllerActor>,
482 pub respond_with_python_message: bool,
483}
484
485#[async_trait]
486impl Actor for StreamActor {
487 type Params = StreamParams;
488 async fn new(
489 StreamParams {
490 world_size,
491 rank,
492 id: _,
493 device,
494 controller_actor,
495 creation_mode,
496 respond_with_python_message,
497 }: Self::Params,
498 ) -> Result<Self> {
499 Ok(Self {
500 world_size,
501 rank,
502 env: HashMap::new(),
503 creation_mode,
504 cuda_stream: OnceLock::new(),
505 device,
506 comm: None,
507 controller_actor,
508 remote_process_groups: HashMap::new(),
509 recordings: HashMap::new(),
510 active_recording: None,
511 respond_with_python_message,
512 last_seq_error: None,
513 })
514 }
515
516 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
517 CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
521 controller_actor_ref.set(self.controller_actor.clone()).ok()
522 });
523 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
524 ROOT_ACTOR_ID.with(|root_actor_id| {
525 root_actor_id
526 .set(ActorId::root(
527 cx.self_id().proc_id().clone(),
528 cx.self_id().name().to_string(),
529 ))
530 .ok()
531 });
532 if let Some(stream) = self.cuda_stream() {
534 Stream::set_current_stream(stream);
535 }
536 Ok(())
537 }
538
539 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
545 where
546 F: Future + Send + 'static,
547 F::Output: Send + 'static,
548 {
549 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
550 let builder = std::thread::Builder::new().name("worker-stream".to_string());
558 let _thread_handle = builder.spawn(move || {
559 let rt = tokio::runtime::Builder::new_multi_thread()
563 .worker_threads(1)
564 .enable_io()
565 .build()
566 .unwrap();
567 let result = rt.block_on(async {
568 tokio::task::block_in_place(|| {
569 Python::with_gil(|py| {
574 py.allow_threads(|| {
575 let result = Handle::current().block_on(future);
576 if join_tx.send(result).is_err() {
577 panic!("could not send join result")
578 }
579 })
580 })
581 })
582 });
583 rt.shutdown_timeout(Duration::from_weeks(1));
584 result
585 });
586
587 tokio::spawn(async move { join_rx.await.unwrap() })
590 }
591}
592
593#[derive(Debug)]
595enum PyArg<'a> {
596 RValue(RValue),
597 DeviceMesh(&'a DeviceMesh),
598 PyObject(PyObject),
599}
600
601impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
603 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
604 match self {
605 PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
608 PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
609 PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
610 }
611 }
612}
613
614impl StreamActor {
615 fn cuda_stream(&self) -> Option<&Stream> {
616 self.cuda_stream
617 .get_or_init(|| {
618 self.device.map(|device| match self.creation_mode {
619 StreamCreationMode::UseDefaultStream => {
620 Stream::get_current_stream_on_device(device)
621 }
622 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
623 })
624 })
625 .as_ref()
626 }
627
628 fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
629 let rvalue = self
630 .env
631 .get(ref_)
632 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
633 match rvalue {
634 Ok(val) => Ok(val.clone()),
635 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
636 }
637 }
638
639 fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
640 let ret = match value {
641 WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
642 WireValue::RefList(val) => {
644 let mut ret = Vec::with_capacity(val.len());
645 for v in val {
646 match self.ref_to_rvalue(&v) {
647 Ok(RValue::Tensor(t)) => ret.push(t),
648 Err(err) => {
649 return Err(err);
650 }
651 Ok(val) => {
652 return Err(CallFunctionError::UnsupportedArgType(
653 "wire_to_rvalue".into(),
654 format!("RefList([{:?}])", val),
655 ));
656 }
657 }
658 }
659 RValue::TensorList(ret)
660 }
661 WireValue::Int(val) => RValue::Int(val),
662 WireValue::IntList(val) => RValue::IntList(val),
663 WireValue::Double(val) => RValue::Double(val),
664 WireValue::Bool(val) => RValue::Bool(val),
665 WireValue::String(val) => RValue::String(val),
666 WireValue::Device(val) => RValue::Device(val),
667 WireValue::Layout(val) => RValue::Layout(val),
668 WireValue::ScalarType(val) => RValue::ScalarType(val),
669 WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
670 WireValue::PyObject(val) => RValue::PyObject(val),
671 WireValue::None(()) => RValue::None,
672 WireValue::IValue(val) => RValue::Opaque(val.into()),
673 };
674 Ok(ret)
675 }
676
677 async fn report_seq_error(
678 &mut self,
679 cx: &Context<'_, Self>,
680 seq: Seq,
681 error: CallFunctionError,
682 ) -> Result<Arc<SeqError>, anyhow::Error> {
683 match error {
684 CallFunctionError::DependentError(root) => Ok(root),
685 CallFunctionError::Error(e) => {
686 if self.active_recording.is_none() {
687 let worker_error = WorkerError {
688 backtrace: format!("{e}"),
689 worker_actor_id: cx.self_id().clone(),
690 };
691 tracing::info!("Propagating remote function error to client: {worker_error}");
692 self.controller_actor
693 .remote_function_failed(cx, seq, worker_error)
694 .await?
695 }
696 let err = Arc::new(SeqError { seq, error: e });
697 self.last_seq_error = Some(err.clone());
698 Ok(err)
699 }
700 }
701 }
702
703 async fn try_define<F>(
704 &mut self,
705 cx: &Context<'_, Self>,
706 seq: Seq,
707 result_refs: Vec<Option<Ref>>,
708 mutates: &Vec<Ref>,
709 f: F,
710 ) -> Result<()>
711 where
712 F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
713 {
714 let actual_results = f(self).await;
715 let op_results = actual_results.and_then(|actual_results| {
718 if result_refs.len() == actual_results.len() {
719 Ok(actual_results
720 .into_iter()
721 .zip(result_refs.iter())
722 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
723 .collect::<Vec<(Ref, RValue)>>())
724 } else {
725 Err(CallFunctionError::UnexpectedNumberOfReturns(
726 result_refs.len(),
727 actual_results.len(),
728 ))
729 }
730 });
731
732 match op_results {
735 Ok(op_results) => {
736 for (ref_, rvalue) in op_results.into_iter() {
737 let prev = self.env.insert(ref_, Ok(rvalue));
738 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
739 }
740 }
741 Err(err) => {
742 let err = self.report_seq_error(cx, seq, err).await?;
743 for ref_ in result_refs {
744 match ref_ {
745 Some(ref_) => {
746 let prev = self.env.insert(ref_, Err(err.clone()));
747 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
748 }
749 None => {}
750 }
751 }
752 for ref_ in mutates {
753 self.env.insert(*ref_, Err(err.clone()));
754 }
755 }
756 }
757 Ok(())
758 }
759
760 fn call_torch_op(
761 &self,
762 op: String,
763 overload: String,
764 args: Vec<WireValue>,
765 kwargs: HashMap<String, WireValue>,
766 ) -> Result<Vec<RValue>, CallFunctionError> {
767 let args = args
768 .into_iter()
769 .map(|arg| self.wire_to_rvalue(arg))
770 .collect::<Result<Vec<_>, _>>()?;
771 let kwargs = kwargs
772 .into_iter()
773 .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
774 .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
775
776 let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
777
778 Ok(if results.is_empty() {
782 vec![RValue::None]
783 } else {
784 results
785 })
786 }
787
788 fn call_python_fn<'py>(
789 &mut self,
790 py: Python<'py>,
791 cx: &Context<Self>,
792 function: Option<ResolvableFunction>,
793 args: Vec<WireValue>,
794 kwargs: HashMap<String, WireValue>,
795 mutates: &[Ref],
796 device_meshes: HashMap<Ref, DeviceMesh>,
797 remote_process_groups: HashMap<
798 Ref,
799 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
800 >,
801 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
802 let function = function
803 .map(|function| {
804 function.resolve(py).map_err(|e| {
805 CallFunctionError::InvalidRemoteFunction(format!(
806 "failed to resolve function {}: {}",
807 function,
808 SerializablePyErr::from(py, &e)
809 ))
810 })
811 })
812 .transpose()?;
813
814 let remote_process_groups = remote_process_groups
815 .into_iter()
816 .map(|(gref, (mesh, dims, comm))| {
817 let group = match self.remote_process_groups.entry(gref) {
818 Entry::Occupied(ent) => ent.get().clone_ref(py),
819 Entry::Vacant(ent) => {
820 torch_sys::backend::ensure_init_process_group(
823 py,
824 self.world_size,
825 self.rank,
826 )?;
827
828 let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
831 let group_size = ranks.len();
832 let backend = CommBackend::new(
833 comm,
834 Mailbox::new_detached(cx.self_id().clone()),
835 self.rank,
836 group_size,
837 self.world_size,
838 );
839 ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
840 .clone_ref(py)
841 }
842 };
843 PyResult::Ok((gref, group))
844 })
845 .collect::<Result<HashMap<_, _>, _>>()
846 .map_err(SerializablePyErr::from_fn(py))?;
847
848 let mut multiborrow = MultiBorrow::new();
852
853 let resolve = |val: WireValue| {
854 val.into_py_object()
855 .map_err(|e| {
856 CallFunctionError::UnsupportedArgType(
857 format!("{:?}", function),
858 format!("{:?}", e),
859 )
860 })?
861 .unpickle(py)
862 .map_err(SerializablePyErr::from_fn(py))?
863 .extract::<PyTree<PyObject>>()
864 .map_err(SerializablePyErr::from_fn(py))?
865 .try_into_map(|obj| {
866 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
867 if let Some(mesh) = device_meshes.get(&ref_) {
868 PyArg::DeviceMesh(mesh)
869 } else if let Some(pg) = remote_process_groups.get(&ref_) {
870 PyArg::PyObject(pg.clone_ref(py))
871 } else {
872 let rval = self.ref_to_rvalue(&ref_)?;
873 PyArg::RValue(rval)
874 }
875 } else {
876 PyArg::PyObject(obj)
877 })
878 })
879 };
880
881 let py_args: Vec<PyTree<PyArg>> = args
883 .into_iter()
884 .map(resolve)
885 .collect::<Result<_, CallFunctionError>>()?;
886 let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
887 .into_iter()
888 .map(|(k, object)| Ok((k, resolve(object)?)))
889 .collect::<Result<_, CallFunctionError>>()?;
890
891 py_args
893 .iter()
894 .chain(py_kwargs.values())
895 .flat_map(|o| o.iter())
896 .for_each(|arg| {
897 if let PyArg::RValue(rval) = arg {
898 multiborrow.add(rval, BorrowType::Shared);
899 }
900 });
901
902 let mutates: Vec<_> = mutates
904 .iter()
905 .map(|r| self.ref_to_rvalue(r))
906 .collect::<Result<_, CallFunctionError>>()?;
907 mutates
908 .iter()
909 .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
910
911 let _borrow = multiborrow.borrow()?;
913
914 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
917 let result: Bound<'_, PyAny> =
918 tracing::subscriber::with_default(scoped_subscriber, || {
919 let args = unsafe { py_args.try_to_object_unsafe(py) }
928 .map_err(SerializablePyErr::from_fn(py))?;
929 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
931 .map_err(SerializablePyErr::from_fn(py))?;
932
933 if let Some(function) = function {
934 function
935 .call(args, Some(kwargs))
936 .map_err(SerializablePyErr::from_fn(py))
937 } else {
938 Ok(args.get_item(0).unwrap())
939 }
940 })?;
941 Ok(result)
942 }
943
944 fn call_python_fn_pytree(
945 &mut self,
946 cx: &hyperactor::Context<Self>,
947 function: ResolvableFunction,
948 args: Vec<WireValue>,
949 kwargs: HashMap<String, WireValue>,
950 mutates: &[Ref],
951 device_meshes: HashMap<Ref, DeviceMesh>,
952 remote_process_groups: HashMap<
953 Ref,
954 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
955 >,
956 ) -> Result<PyTree<RValue>, CallFunctionError> {
957 Python::with_gil(|py| {
958 let result = self.call_python_fn(
959 py,
960 cx,
961 Some(function),
962 args,
963 kwargs,
964 mutates,
965 device_meshes,
966 remote_process_groups,
967 )?;
968 Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
969 })
970 }
971 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
979 let rvalue = self
980 .env
981 .get(&ref_)
982 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
983
984 match rvalue {
985 Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
986 Err(_) => {
987 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
988 Ok(TensorCell::new(t))
989 }
990 }
991 }
992
993 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
994 self.active_recording
995 .as_mut()
996 .and_then(|state| match state {
997 RecordingState::Defining {
998 recording,
999 defined_borrows,
1000 } => {
1001 match self.recordings.get_mut(recording) {
1002 Some(recording) => Some((recording, defined_borrows)),
1003 None => panic!("recording not found: {:?}", recording),
1005 }
1006 }
1007 RecordingState::Running => None,
1008 })
1009 }
1010
1011 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
1012 for ref_ in refs {
1013 let rvalue_or_err = self
1014 .env
1015 .get(ref_)
1016 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1017 if let Err(err) = rvalue_or_err {
1018 return Ok(Some(err.clone()));
1019 }
1020 }
1021 Ok(None)
1022 }
1023 async fn send_value_python_message(
1024 &mut self,
1025 cx: &hyperactor::Context<'_, Self>,
1026 seq: Seq,
1027 worker_actor_id: ActorId,
1028 mutates: Vec<Ref>,
1029 function: Option<ResolvableFunction>,
1030 args: Vec<WireValue>,
1031 kwargs: HashMap<String, WireValue>,
1032 device_meshes: HashMap<Ref, DeviceMesh>,
1033 ) -> Result<()> {
1034 self.try_define(cx, seq, vec![], &vec![], async |self_| {
1035 let python_message =
1036 Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1037 let python_result = tokio::task::block_in_place(|| {
1038 self_.call_python_fn(
1039 py,
1040 cx,
1041 function,
1042 args,
1043 kwargs,
1044 &mutates,
1045 device_meshes,
1046 HashMap::new(),
1047 )
1048 })?;
1049 pickle_python_result(py, python_result, worker_actor_id)
1050 .map_err(CallFunctionError::Error)
1051 })?;
1052 let ser = Serialized::serialize(&python_message).unwrap();
1053 self_
1054 .controller_actor
1055 .fetch_result(cx, seq, Ok(ser))
1056 .await?;
1057 Ok(vec![])
1058 })
1059 .await
1060 }
1061 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1062 let rvalue = self
1063 .env
1064 .get(&src)
1065 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1066 self.env.insert(dest, rvalue.clone());
1067 Ok(())
1068 }
1069 async fn call_actor(
1070 &mut self,
1071 cx: &Context<'_, Self>,
1072 params: ActorCallParams,
1073 ) -> Result<PyObject, CallFunctionError> {
1074 let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1075 params
1076 .local_state
1077 .into_iter()
1078 .map(|elem| {
1079 unsafe {
1081 let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1082 Ok(x)
1083 }
1084 })
1085 .collect()
1086 });
1087
1088 let (send, recv) = cx.open_once_port();
1089 let state = LocalState {
1090 response_port: send,
1091 state: local_state?,
1092 };
1093 let x: u64 = params.seq.into();
1094 let message = LocalStateBrokerMessage::Set(x as usize, state);
1095
1096 let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1097 broker
1098 .send(message)
1099 .map_err(|e| CallFunctionError::Error(e.into()))?;
1100 let result = recv
1101 .recv()
1102 .await
1103 .map_err(|e| CallFunctionError::Error(e.into()))?;
1104
1105 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1106 }
1107}
1108
1109#[async_trait]
1110#[forward(StreamMessage)]
1111impl StreamMessageHandler for StreamActor {
1112 async fn call_function(
1113 &mut self,
1114 cx: &Context<Self>,
1115 params: CallFunctionParams,
1116 device_meshes: HashMap<Ref, DeviceMesh>,
1117 remote_process_groups: HashMap<
1118 Ref,
1119 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1120 >,
1121 ) -> Result<()> {
1122 if let Some((recording, _)) = self.get_defining_recording() {
1123 recording.messages.push(StreamMessage::CallFunction(
1124 params,
1125 device_meshes,
1126 remote_process_groups,
1127 ));
1128 return Ok(());
1129 }
1130
1131 params.function.panic_if_requested();
1132 self.try_define(
1133 cx,
1134 params.seq,
1135 params.results,
1136 ¶ms.mutates,
1137 async |self| {
1138 tokio::task::block_in_place(|| match params.function.as_torch_op() {
1139 Some((op, overload)) => {
1140 self.call_torch_op(op, overload, params.args, params.kwargs)
1141 }
1142 _ => self
1143 .call_python_fn_pytree(
1144 cx,
1145 params.function,
1146 params.args,
1147 params.kwargs,
1148 ¶ms.mutates,
1149 device_meshes,
1150 remote_process_groups,
1151 )
1152 .map(|results| results.into_leaves()),
1153 })
1154 },
1155 )
1156 .await?;
1157 Ok(())
1158 }
1159
1160 async fn borrow_create(
1161 &mut self,
1162 _cx: &Context<Self>,
1163 borrow: u64,
1164 tensor: Ref,
1165 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1166 ) -> Result<()> {
1167 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1168 recording.messages.push(StreamMessage::BorrowCreate {
1169 borrow,
1170 tensor,
1171 first_use_sender,
1172 });
1173 ensure!(
1174 defined_borrows.insert(borrow),
1175 "duplicate borrow create in recording"
1176 );
1177 return Ok(());
1178 }
1179
1180 let rvalue_result = self
1181 .env
1182 .get(&tensor)
1183 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1184
1185 let result = match rvalue_result {
1186 Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1187 Err(e) => Err(e.clone()),
1188 };
1189
1190 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1191 first_use_sender.send((event, result)).map_err(|err| {
1192 anyhow!(
1193 "failed sending first use event for borrow {:?}: {:?}",
1194 borrow,
1195 err
1196 )
1197 })
1198 }
1199
1200 async fn borrow_first_use(
1201 &mut self,
1202 _cx: &Context<Self>,
1203 borrow: u64,
1204 result: Ref,
1205 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1206 ) -> Result<()> {
1207 if let Some((recording, _)) = self.get_defining_recording() {
1208 recording.messages.push(StreamMessage::BorrowFirstUse {
1209 borrow,
1210 result,
1211 first_use_receiver: first_use_receiver.clone(),
1212 });
1213 return Ok(());
1214 }
1215
1216 let (first_use_event, cell) =
1217 first_use_receiver
1218 .lock()
1219 .await
1220 .recv()
1221 .await
1222 .map_err(|err| {
1223 anyhow!(
1224 "failed receiving first use event for borrow {:?}: {:?}",
1225 borrow,
1226 err
1227 )
1228 })?;
1229
1230 if let Some(stream) = self.cuda_stream() {
1231 stream.wait_event(
1232 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1233 );
1234 }
1235 match cell {
1236 Ok(cell) => {
1237 self.env.insert(result, Ok(cell.into()));
1238 }
1239 Err(err) => {
1240 self.env.insert(result, Err(err.clone()));
1241 }
1242 }
1243 Ok(())
1244 }
1245
1246 async fn borrow_last_use(
1247 &mut self,
1248 _cx: &Context<Self>,
1249 borrow: u64,
1250 result: Ref,
1251 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1252 ) -> Result<()> {
1253 if let Some((recording, _)) = self.get_defining_recording() {
1254 recording.messages.push(StreamMessage::BorrowLastUse {
1255 borrow,
1256 result,
1257 last_use_sender,
1258 });
1259 return Ok(());
1260 }
1261
1262 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1263 let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1264 "Invalid reference for borrow_last_use: {result:#?}"
1265 ))?;
1266 let tensor = match rvalue_or_err {
1267 Ok(RValue::Tensor(t)) => Ok(t),
1268 Err(e) => Err(e),
1269 _ => bail!("invalid rvalue type for borrow_last_use"),
1270 };
1271
1272 last_use_sender.send((event, tensor)).map_err(|err| {
1273 anyhow!(
1274 "failed sending last use event for borrow {:?}: {:?}",
1275 borrow,
1276 err
1277 )
1278 })
1279 }
1280
1281 async fn borrow_drop(
1282 &mut self,
1283 _cx: &Context<Self>,
1284 borrow: u64,
1285 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1286 ) -> Result<()> {
1287 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1288 recording.messages.push(StreamMessage::BorrowDrop {
1289 borrow,
1290 last_use_receiver: last_use_receiver.clone(),
1291 });
1292 ensure!(
1293 defined_borrows.remove(&borrow),
1294 "borrow drop for borrow not defined in recording"
1295 );
1296 return Ok(());
1297 }
1298
1299 let (last_use_event, _cell) =
1303 last_use_receiver.lock().await.recv().await.map_err(|err| {
1304 anyhow!(
1305 "failed receiving last use event for borrow {:?}: {:?}",
1306 borrow,
1307 err
1308 )
1309 })?;
1310
1311 if let Some(stream) = self.cuda_stream() {
1312 stream.wait_event(
1313 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1314 );
1315 }
1316 Ok(())
1318 }
1319
1320 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1321 if let Some((recording, _)) = self.get_defining_recording() {
1322 recording.messages.push(StreamMessage::DeleteRefs(refs));
1323 return Ok(());
1324 }
1325
1326 for ref_ in refs.iter() {
1327 self.env.remove(ref_);
1328 }
1329 Ok(())
1330 }
1331
1332 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1333 if self.get_defining_recording().is_some() {
1334 bail!("request_status not allowed in recording");
1335 }
1336
1337 Ok(())
1338 }
1339
1340 async fn init_comm(
1341 &mut self,
1342 _cx: &Context<Self>,
1343 comm: ActorHandle<NcclCommActor>,
1344 ) -> Result<()> {
1345 if self.get_defining_recording().is_some() {
1346 bail!("init_comm not allowed in recording");
1347 }
1348
1349 self.comm = Some(comm);
1350 Ok(())
1351 }
1352
1353 async fn reduce(
1354 &mut self,
1355 cx: &Context<Self>,
1356 comm: Arc<ActorHandle<NcclCommActor>>,
1357 dim_size: i64,
1358 result: Ref,
1359 local_tensor: Ref,
1360 factory: Factory,
1361 reduction: Reduction,
1362 scatter: bool,
1363 in_place: bool,
1364 out: Option<Ref>,
1365 ) -> Result<()> {
1366 if let Some((recording, _)) = self.get_defining_recording() {
1367 recording.messages.push(StreamMessage::Reduce {
1368 comm,
1369 dim_size,
1370 result,
1371 local_tensor,
1372 factory,
1373 reduction,
1374 scatter,
1375 in_place,
1376 out,
1377 });
1378 return Ok(());
1379 }
1380
1381 let stream = self
1382 .cuda_stream()
1383 .expect("reductions not yet supported for non-CUDA workers")
1384 .clone();
1385 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1386 let out_cell = out
1387 .map(|out| self.get_or_fake_on_err(out, &factory))
1388 .transpose()?;
1389 let output_cell = match reduction {
1390 Reduction::Stack => {
1391 if scatter {
1392 let output_cell = if in_place {
1393 input_cell.clone()
1394 } else {
1395 out_cell.unwrap_or({
1396 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1397 let cloned = deep_clone(&borrow);
1398 TensorCell::new(cloned)
1399 })
1400 };
1401 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1402 .await?;
1403 output_cell
1404 } else {
1405 ensure!(
1406 !in_place,
1407 "in-place, non-scatter not supported for stack reduce"
1408 );
1409
1410 let output_cell = out_cell.unwrap_or({
1411 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1413 let output =
1414 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1415 TensorCell::new(output)
1416 });
1417
1418 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1419 .await?;
1420 output_cell
1421 }
1422 }
1423 Reduction::ReduceOp(op) => {
1424 if scatter {
1425 ensure!(!in_place, "in-place, scatter not supported for reduce");
1426
1427 let output_cell = out_cell.unwrap_or({
1428 let output = factory_empty(
1429 &factory.size[1..],
1430 factory.dtype,
1431 factory.layout,
1432 factory.device,
1433 );
1434 TensorCell::new(output)
1435 });
1436 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1437 .await?;
1438 output_cell
1439 } else {
1440 let output_cell = if in_place {
1441 input_cell.clone()
1442 } else {
1443 out_cell.map_or(
1444 {
1445 let borrow =
1446 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1447 let cloned = deep_clone(&borrow);
1448 Ok(TensorCell::new(cloned))
1449 },
1450 |out_cell| -> Result<_, anyhow::Error> {
1451 let mut out_borrow =
1452 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1453 let in_borrow =
1454 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1455 out_borrow.copy_(&in_borrow);
1456 drop(out_borrow);
1457 Ok(out_cell)
1458 },
1459 )?
1460 };
1461
1462 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1463 output_cell
1464 }
1465 }
1466 };
1467
1468 self.env.insert(result, Ok(output_cell.into()));
1469 Ok(())
1470 }
1471
1472 async fn send_tensor(
1473 &mut self,
1474 cx: &Context<Self>,
1475 result: Ref,
1476 from_rank: Option<usize>,
1477 to_rank: Option<usize>,
1478 tensor: Ref,
1479 factory: Factory,
1480 comm: Arc<ActorHandle<NcclCommActor>>,
1481 ) -> Result<()> {
1482 if let Some((recording, _)) = self.get_defining_recording() {
1483 recording.messages.push(StreamMessage::SendTensor {
1484 result,
1485 from_rank,
1486 to_rank,
1487 tensor,
1488 factory,
1489 comm,
1490 });
1491 return Ok(());
1492 }
1493
1494 if to_rank.is_none() && from_rank.is_none() {
1495 bail!("tried to send tensor without a to/from rank");
1496 }
1497
1498 if from_rank == to_rank {
1500 let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1501 .env
1502 .get(&tensor)
1503 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1504 let output_cell = match input_cell {
1505 Ok(RValue::Tensor(input_cell)) => {
1506 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1511 let cloned = deep_clone(&borrow);
1512 Ok(RValue::Tensor(TensorCell::new(cloned)))
1513 }
1514 Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1515 Err(err) => Err(err.clone()),
1516 };
1517 self.env.insert(result, output_cell);
1518 return Ok(());
1519 }
1520
1521 let mut messages = Vec::new();
1522
1523 if let Some(to_rank) = to_rank {
1524 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1525 messages.push(CommMessage::Send(
1526 input_cell,
1527 to_rank.try_into().unwrap(),
1528 self.cuda_stream()
1529 .expect("tried to send_tensor on non-cuda stream")
1530 .clone(),
1531 cx.open_once_port().0,
1532 ));
1533 }
1534
1535 if let Some(from_rank) = from_rank {
1536 let output_cell = TensorCell::new(factory_empty(
1537 &factory.size,
1538 factory.dtype,
1539 factory.layout,
1540 factory.device,
1541 ));
1542 messages.push(CommMessage::Recv(
1543 output_cell.clone(),
1544 from_rank.try_into().unwrap(),
1545 self.cuda_stream()
1546 .expect("tried to send_tensor on non-cuda stream")
1547 .clone(),
1548 cx.open_once_port().0,
1549 ));
1550 self.env.insert(result, Ok(output_cell.into()));
1551 }
1552
1553 comm.group(
1554 cx,
1555 messages,
1556 self.cuda_stream()
1557 .expect("tried to send_tensor on non-cuda stream")
1558 .clone(),
1559 )
1560 .await?;
1561 Ok(())
1562 }
1563
1564 async fn send_value(
1565 &mut self,
1566 cx: &Context<Self>,
1567 seq: Seq,
1568 worker_actor_id: ActorId,
1569 mutates: Vec<Ref>,
1570 function: Option<ResolvableFunction>,
1571 args: Vec<WireValue>,
1572 kwargs: HashMap<String, WireValue>,
1573 device_meshes: HashMap<Ref, DeviceMesh>,
1574 pipe: Option<PortHandle<PipeMessage>>,
1575 ) -> Result<()> {
1576 if self.respond_with_python_message && pipe.is_none() {
1577 return self
1578 .send_value_python_message(
1579 cx,
1580 seq,
1581 worker_actor_id,
1582 mutates,
1583 function,
1584 args,
1585 kwargs,
1586 device_meshes,
1587 )
1588 .await;
1589 }
1590 let result = if let Some(function) = function {
1591 match function.as_torch_op() {
1593 Some((op, overload)) => {
1594 self.call_torch_op(op, overload, args, kwargs)
1595 .map(|rvalues| {
1596 if rvalues.len() == 1 {
1597 Ok(rvalues[0].clone().into())
1598 } else {
1599 Python::with_gil(|py| {
1601 Ok((|| {
1602 let py_rvalues = rvalues
1603 .into_iter()
1604 .map(|rvalue| unsafe {
1606 rvalue.try_to_object_unsafe(py)
1607 })
1608 .collect::<Result<Vec<_>, _>>()?;
1609 PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1610 })()
1611 .map_err(SerializablePyErr::from_fn(py))?)
1612 })
1613 }
1614 })?
1615 }
1616 _ => tokio::task::block_in_place(|| {
1619 self.call_python_fn_pytree(
1620 cx,
1621 function,
1622 args,
1623 kwargs,
1624 &mutates,
1625 device_meshes,
1626 HashMap::new(),
1627 )
1628 }),
1629 }
1630 } else {
1631 match (args.len(), kwargs.len()) {
1634 (1, 0) => Python::with_gil(|py| {
1635 let arg = args[0]
1636 .as_py_object()
1637 .ok_or_else(|| {
1638 CallFunctionError::UnsupportedArgType(
1639 "send_value".to_string(),
1640 "expected a PyObject as the first arg".to_string(),
1641 )
1642 })?
1643 .unpickle(py)
1644 .map_err(SerializablePyErr::from_fn(py))?;
1645 arg.extract::<PyTree<PyObject>>()
1646 .map_err(SerializablePyErr::from_fn(py))?
1647 .try_into_map(|obj| {
1648 let bound_obj = obj.bind(py);
1649 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1650 self.ref_to_rvalue(&ref_)
1651 } else {
1652 Ok(bound_obj
1653 .extract::<RValue>()
1654 .map_err(SerializablePyErr::from_fn(py))?)
1655 }
1656 })
1657 }),
1658 _ => Err(CallFunctionError::TooManyArgsForValue(
1659 format!("{:?}", args),
1660 format!("{:?}", kwargs),
1661 )),
1662 }
1663 };
1664
1665 let value = match result {
1666 Ok(rvalue) => {
1667 Ok(rvalue.into_map(|rval| match rval {
1672 RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1673 RValue::TensorList(tensors) => RValue::TensorList(
1674 tensors
1675 .into_iter()
1676 .map(|tensor| tensor.try_cpu().unwrap())
1677 .collect(),
1678 ),
1679 rval => rval,
1680 }))
1681 }
1682 Err(err) => {
1683 let err = self.report_seq_error(cx, seq, err).await?;
1684 for ref_ in mutates {
1685 self.env.insert(ref_, Err(err.clone()));
1686 }
1687 Err(WorkerError {
1688 backtrace: format!("{:?}", err),
1689 worker_actor_id,
1690 })
1691 }
1692 };
1693
1694 if let Some(pipe) = pipe {
1696 pipe.send(PipeMessage::SendValue(value))?;
1697 } else {
1698 let result = match value {
1699 Ok(value) => Ok(Serialized::serialize(&value).map_err(anyhow::Error::from)?),
1700 Err(e) => Err(e),
1701 };
1702 self.controller_actor.fetch_result(cx, seq, result).await?;
1703 }
1704
1705 Ok(())
1706 }
1707
1708 async fn send_result_of_actor_call(
1709 &mut self,
1710 cx: &Context<Self>,
1711 worker_actor_id: ActorId,
1712 params: ActorCallParams,
1713 ) -> anyhow::Result<()> {
1714 let seq = params.seq;
1715 let mutates = params.mutates.clone();
1716 self.try_define(cx, seq, vec![], &mutates, async |self| {
1717 let value = self.call_actor(cx, params).await?;
1718 let result = Python::with_gil(|py| {
1719 pickle_python_result(py, value.into_bound(py), worker_actor_id)
1720 })?;
1721 let result = Serialized::serialize(&result).unwrap();
1722 self.controller_actor
1723 .fetch_result(cx, seq, Ok(result))
1724 .await?;
1725 Ok(vec![])
1726 })
1727 .await
1728 }
1729
1730 async fn call_actor_method(
1731 &mut self,
1732 cx: &Context<Self>,
1733 params: ActorMethodParams,
1734 ) -> anyhow::Result<()> {
1735 let seq = params.call.seq;
1736 let mutates = params.call.mutates.clone();
1737 self.try_define(cx, seq, params.results, &mutates, async |self| {
1738 let result = self.call_actor(cx, params.call).await?;
1739 let result = Python::with_gil(|py| {
1740 PyTree::<RValue>::extract_bound(&result.into_bound(py))
1741 .map_err(SerializablePyErr::from_fn(py))
1742 })?;
1743 Ok(result.into_leaves())
1744 })
1745 .await
1746 }
1747
1748 async fn set_value(
1749 &mut self,
1750 cx: &Context<Self>,
1751 seq: Seq,
1752 results: Vec<Option<Ref>>,
1753 pipe: PortHandle<PipeMessage>,
1754 ) -> Result<()> {
1755 if let Some((recording, _)) = self.get_defining_recording() {
1756 recording
1757 .messages
1758 .push(StreamMessage::SetValue { seq, results, pipe });
1759 return Ok(());
1760 }
1761
1762 self.try_define(cx, seq, results, &vec![], async |self| {
1763 let (tx, rx) = cx.open_once_port();
1764 pipe.send(PipeMessage::RecvValue(tx))
1765 .map_err(anyhow::Error::from)
1766 .map_err(CallFunctionError::from)?;
1767 let value = rx.recv().await.map_err(anyhow::Error::from)?;
1768 Ok(value.into_leaves())
1769 })
1770 .await
1771 }
1772
1773 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1774 if self.active_recording.is_some() {
1775 bail!("different recording already active");
1776 }
1777 match self.recordings.entry(recording) {
1778 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1779 Entry::Vacant(entry) => entry.insert(Recording::new()),
1780 };
1781 self.active_recording = Some(RecordingState::Defining {
1782 recording,
1783 defined_borrows: HashSet::new(),
1784 });
1785 Ok(())
1786 }
1787
1788 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1789 match self.active_recording {
1790 Some(RecordingState::Defining {
1791 recording: active_recording,
1792 ref defined_borrows,
1793 }) if active_recording == recording => {
1794 ensure!(
1795 defined_borrows.is_empty(),
1796 "all borrows created within recording must be dropped within recording"
1797 );
1798 self.active_recording = None;
1799 }
1800 _ => bail!("cannot finalize recording that isn't active"),
1801 }
1802 Ok(())
1803 }
1804
1805 async fn recording_formal(
1806 &mut self,
1807 _cx: &Context<Self>,
1808 result: Ref,
1809 argument_index: usize,
1810 ) -> Result<()> {
1811 match self.get_defining_recording() {
1812 Some((recording, _)) => {
1813 recording.messages.push(StreamMessage::RecordingFormal {
1814 result,
1815 argument_index,
1816 });
1817 }
1818 None => bail!("recording_formal called outside of recording"),
1819 };
1820 Ok(())
1821 }
1822
1823 async fn recording_result(
1824 &mut self,
1825 _cx: &Context<Self>,
1826 result: Ref,
1827 output_index: usize,
1828 ) -> Result<()> {
1829 match self.get_defining_recording() {
1830 Some((recording, _)) => {
1831 recording.messages.push(StreamMessage::RecordingResult {
1832 result,
1833 output_index,
1834 });
1835 }
1836 None => bail!("recording_result called outside of recording"),
1837 };
1838 Ok(())
1839 }
1840
1841 async fn call_recording(
1842 &mut self,
1843 cx: &Context<Self>,
1844 seq: Seq,
1845 recording: Ref,
1846 results: Vec<Ref>,
1847 actuals: Vec<Ref>,
1848 ) -> Result<()> {
1849 if self.active_recording.is_some() {
1850 bail!("cannot call recording while another recording is active");
1851 }
1852
1853 let messages = match self.recordings.get(&recording) {
1854 Some(recording) => recording
1855 .messages
1856 .iter()
1857 .map(|message| message.clone_for_recording())
1858 .collect::<Vec<_>>(),
1859 None => bail!("recording {:?} not found", recording),
1860 };
1861
1862 self.active_recording = Some(RecordingState::Running);
1863
1864 let mut error: Option<Arc<SeqError>> = None;
1869 let mut all_defined_refs = HashSet::new();
1872 let mut all_mutated_refs = HashSet::new();
1875 let mut formal_to_actual_refs = HashMap::new();
1881 self.last_seq_error = None;
1883 for message in messages.into_iter() {
1884 let defined_refs = message.get_defined_refs();
1885 all_defined_refs.extend(defined_refs.clone());
1886
1887 let mutated_refs_with_formals = message.get_mutated_refs();
1888 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1889 match formal_to_actual_refs.get(ref_) {
1890 Some(actual_ref) => Some(*actual_ref),
1891 None => {
1892 if all_defined_refs.contains(ref_) {
1893 None
1894 } else {
1895 Some(*ref_)
1896 }
1897 }
1898 }
1899 }));
1900
1901 match message {
1902 StreamMessage::RecordingFormal {
1903 result: formal_ref,
1904 argument_index,
1905 } => match actuals.get(argument_index) {
1906 None => bail!("recording_formal called with too few arguments"),
1907 Some(actual_ref) => {
1908 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1909 self.define_ref(formal_ref, *actual_ref)?;
1910 }
1911 },
1912 StreamMessage::RecordingResult {
1913 result: result_ref,
1914 output_index,
1915 } => match results.get(output_index) {
1916 None => bail!("recording_result called with too few results"),
1917 Some(actual_result_ref) => {
1918 self.define_ref(*actual_result_ref, result_ref)?;
1919 }
1920 },
1921 StreamMessage::DeleteRefs(ref refs) => {
1922 for ref_ in refs {
1923 all_defined_refs.remove(ref_);
1924 }
1925 StreamMessageHandler::handle(self, cx, message).await?;
1926 }
1927 StreamMessage::CallFunction { .. } if error.is_some() => {
1928 let error = error.clone().unwrap();
1934 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1935 self.env.insert(*ref_, Err(error.clone()));
1936 }
1937 }
1938 StreamMessage::BorrowLastUse { ref result, .. } => {
1939 all_defined_refs.remove(result);
1940 StreamMessageHandler::handle(self, cx, message).await?;
1941 }
1942 StreamMessage::Reduce {
1943 local_tensor,
1944 ref out,
1945 ..
1946 } => {
1947 if error.is_none() {
1951 let inputs_to_check = [Some(local_tensor), out.clone()]
1952 .iter()
1953 .filter_map(|r| *r)
1954 .collect::<Vec<_>>();
1955 error = self.get_first_error(inputs_to_check.as_slice())?;
1956 }
1957 StreamMessageHandler::handle(self, cx, message).await?;
1958 }
1959 StreamMessage::SendTensor {
1960 ref tensor,
1961 ref to_rank,
1962 ..
1963 } => {
1964 if to_rank.is_some() && error.is_none() {
1969 error = self.get_first_error(&[*tensor])?;
1970 }
1971 StreamMessageHandler::handle(self, cx, message).await?;
1972 }
1973 _ => {
1974 StreamMessageHandler::handle(self, cx, message).await?;
1975 }
1976 };
1977
1978 match (&error, self.last_seq_error.take()) {
1986 (None, Some(seq_err)) => {
1987 self.controller_actor
1989 .remote_function_failed(
1990 cx,
1991 seq,
1992 WorkerError {
1993 backtrace: format!("recording failed: {}", &seq_err),
1994 worker_actor_id: cx.self_id().clone(),
1995 },
1996 )
1997 .await?;
1998 error = Some(seq_err)
1999 }
2000 _ => {}
2001 }
2002 }
2007
2008 StreamMessageHandler::handle(
2012 self,
2013 cx,
2014 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
2015 )
2016 .await?;
2017
2018 if error.is_some() {
2021 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
2022 self.env.insert(*ref_, Err(error.clone().unwrap()));
2023 }
2024 }
2025
2026 self.active_recording = None;
2027 Ok(())
2028 }
2029
2030 async fn set_ref_unit_tests_only(
2031 &mut self,
2032 _cx: &Context<Self>,
2033 reference: Ref,
2034 value: WireValue,
2035 ) -> Result<()> {
2036 self.env
2037 .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
2038 Ok(())
2039 }
2040
2041 async fn set_tensor_ref_unit_tests_only(
2042 &mut self,
2043 _cx: &Context<Self>,
2044 reference: Ref,
2045 tensor_result: TensorCellResult,
2046 ) -> Result<()> {
2047 match tensor_result {
2048 Ok(tensor_cell) => {
2049 self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
2050 }
2051 Err(err) => {
2052 self.env.insert(reference, Err(err));
2053 }
2054 }
2055 Ok(())
2056 }
2057
2058 async fn get_ref_unit_tests_only(
2059 &mut self,
2060 _cx: &Context<Self>,
2061 reference: Ref,
2062 ) -> Result<Option<Result<WireValue, String>>> {
2063 fn rvalue_to_wire(
2065 value: Result<RValue, Arc<SeqError>>,
2066 ) -> Result<WireValue, Arc<SeqError>> {
2067 Ok(match value? {
2068 RValue::Int(val) => WireValue::Int(val),
2069 RValue::IntList(val) => WireValue::IntList(val),
2070 RValue::Double(val) => WireValue::Double(val),
2071 RValue::Bool(val) => WireValue::Bool(val),
2072 RValue::String(val) => WireValue::String(val),
2073 RValue::Layout(val) => WireValue::Layout(val),
2074 RValue::Device(val) => WireValue::Device(val),
2075 RValue::ScalarType(val) => WireValue::ScalarType(val),
2076 RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2077 RValue::None => WireValue::None(()),
2078 other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2079 })
2080 }
2081 Ok(self
2082 .env
2083 .get(&reference)
2084 .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2085 }
2086
2087 async fn get_tensor_ref_unit_tests_only(
2088 &mut self,
2089 _cx: &Context<Self>,
2090 reference: Ref,
2091 ) -> Result<Option<TensorCellResult>> {
2092 match self.env.get(&reference) {
2093 Some(Ok(rvalue)) => match rvalue {
2094 RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2095 other => bail!("expected tensor, got {:?}", other),
2096 },
2097 Some(Err(err)) => Ok(Some(Err(err.clone()))),
2098 None => Ok(None),
2099 }
2100 }
2101}
2102
2103#[cfg(test)]
2104mod tests {
2105 use hyperactor::actor::ActorStatus;
2106 use hyperactor::context;
2107 use hyperactor::supervision::ActorSupervisionEvent;
2108 use monarch_messages::controller::ControllerMessage;
2109 use monarch_messages::worker::StreamCreationMode;
2110 use monarch_types::PickledPyObject;
2111 use pyo3::IntoPyObjectExt;
2112 use timed_test::async_timed_test;
2113 use torch_sys::factory_float_tensor;
2114 use torch_sys::testing::allclose;
2115 use torch_sys_cuda::nccl::UniqueId;
2116
2117 use super::*;
2118 use crate::comm::CommParams;
2119 use crate::test_util;
2120
2121 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2122 Arc::new(SeqError {
2123 seq: 0.into(),
2124 error: err,
2125 })
2126 }
2127
2128 struct TestSetup {
2129 proc: Proc,
2130 stream_actor: ActorHandle<StreamActor>,
2131 client: Instance<()>,
2132 #[allow(dead_code)]
2135 supervision_rx: PortReceiver<ActorSupervisionEvent>,
2136 controller_rx: PortReceiver<ControllerMessage>,
2137 controller_actor: ActorRef<ControllerActor>,
2138 next_ref: Ref,
2139 }
2140
2141 impl TestSetup {
2142 async fn new() -> Result<Self> {
2143 Self::new_with_world_size(1).await
2144 }
2145
2146 async fn new_with_world_size(world_size: usize) -> Result<Self> {
2147 test_util::test_setup()?;
2148
2149 let proc = Proc::local();
2150 let (_, controller_actor, controller_rx) =
2151 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2152 let (client, _handle) = proc.instance("client")?;
2153 let (supervision_tx, supervision_rx) = client.open_port();
2154 proc.set_supervision_coordinator(supervision_tx)?;
2155 let stream_actor = proc
2156 .spawn::<StreamActor>(
2157 "stream",
2158 StreamParams {
2159 world_size,
2160 rank: 0,
2161 creation_mode: StreamCreationMode::UseDefaultStream,
2162 id: 0.into(),
2163 device: Some(CudaDevice::new(0.into())),
2164 controller_actor: controller_actor.clone(),
2165 respond_with_python_message: false,
2166 },
2167 )
2168 .await?;
2169
2170 Ok(Self {
2171 proc,
2172 stream_actor,
2173 client,
2174 supervision_rx,
2175 controller_rx,
2176 controller_actor,
2177 next_ref: 0.into(),
2178 })
2179 }
2180
2181 fn next_ref(&mut self) -> Ref {
2182 let ref_ = self.next_ref;
2183 self.next_ref = Ref {
2184 id: self.next_ref.id + 1,
2185 };
2186 ref_
2187 }
2188
2189 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2190 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2191 self.stream_actor
2192 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2193 .await
2194 }
2195
2196 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2197 let actual = self
2198 .stream_actor
2199 .get_tensor_ref_unit_tests_only(&self.client, reference)
2200 .await
2201 .unwrap()
2202 .unwrap()
2203 .unwrap();
2204
2205 let result = allclose(
2206 &factory_float_tensor(data, "cpu".try_into().unwrap()),
2207 &actual.borrow(),
2208 )
2209 .unwrap();
2210 result
2212 }
2213
2214 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2215 let result_error = self
2216 .stream_actor
2217 .get_tensor_ref_unit_tests_only(&self.client, reference)
2218 .await
2219 .unwrap()
2220 .unwrap()
2221 .unwrap_err();
2222
2223 assert!(Arc::ptr_eq(&result_error, &error));
2224 }
2225 }
2226
2227 async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2228 loop {
2229 let status = proc
2230 .ledger_snapshot()
2231 .roots
2232 .get(actor_id)
2233 .unwrap()
2234 .status
2235 .clone();
2236 if let ActorStatus::Failed(msg) = status {
2237 assert!(msg.contains(&expected_msg));
2238 break;
2239 } else {
2240 tokio::task::yield_now().await;
2241 }
2242 }
2243 }
2244
2245 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2246 for ref_ in refs {
2247 assert!(
2248 test_setup
2249 .stream_actor
2250 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2251 .await
2252 .unwrap()
2253 .is_none()
2254 );
2255 }
2256 }
2257
2258 async fn fetch_result(
2259 cx: &impl context::Actor,
2260 stream_actor: ActorHandle<StreamActor>,
2261 seq: Seq,
2262 reference: Ref,
2263 ) {
2264 let ref_to_send = Python::with_gil(|py| {
2265 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2266 });
2267
2268 stream_actor
2269 .send_value(
2270 cx,
2271 seq,
2272 stream_actor.actor_id().clone(),
2273 Vec::new(),
2274 None,
2275 vec![WireValue::PyObject(ref_to_send)],
2276 HashMap::new(),
2277 HashMap::new(),
2278 None,
2279 )
2280 .await
2281 .unwrap()
2282 }
2283
2284 async fn check_fetch_result_error(
2285 cx: &impl context::Actor,
2286 stream_actor: ActorHandle<StreamActor>,
2287 seq: Seq,
2288 reference: Ref,
2289 controller_rx: &mut PortReceiver<ControllerMessage>,
2290 expected_backtrace: &str,
2291 ) {
2292 fetch_result(cx, stream_actor, seq, reference).await;
2293
2294 let controller_msg = controller_rx.recv().await.unwrap();
2295 match controller_msg {
2296 ControllerMessage::FetchResult {
2297 seq: actual_seq,
2298 value: Err(err),
2299 } => {
2300 assert_eq!(actual_seq, seq);
2301 assert!(
2302 err.backtrace.contains(expected_backtrace),
2303 "backtrace did not contain {:?}: {:?}",
2304 expected_backtrace,
2305 err.backtrace
2306 );
2307 }
2308 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2309 };
2310 }
2311
2312 async fn check_fetch_result_value(
2313 cx: &impl context::Actor,
2314 stream_actor: ActorHandle<StreamActor>,
2315 seq: Seq,
2316 reference: Ref,
2317 controller_rx: &mut PortReceiver<ControllerMessage>,
2318 ) {
2319 fetch_result(cx, stream_actor, seq, reference).await;
2320
2321 let controller_msg = controller_rx.recv().await.unwrap();
2322 match controller_msg {
2323 ControllerMessage::FetchResult {
2324 value: Ok(_),
2325 seq: actual_seq,
2326 } => assert_eq!(seq, actual_seq),
2327 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2328 };
2329 }
2330
2331 #[async_timed_test(timeout_secs = 60)]
2332 async fn test_define_recording_other_recording_active() -> Result<()> {
2333 let test_setup = TestSetup::new().await?;
2334 test_setup
2335 .stream_actor
2336 .define_recording(&test_setup.client, 0.into())
2337 .await?;
2338 test_setup
2339 .stream_actor
2340 .define_recording(&test_setup.client, 1.into())
2341 .await?;
2342 assert_actor_failed_with_msg(
2343 &test_setup.proc,
2344 test_setup.stream_actor.actor_id(),
2345 "different recording already active".into(),
2346 )
2347 .await;
2348 Ok(())
2349 }
2350
2351 #[async_timed_test(timeout_secs = 60)]
2352 async fn test_define_recording_already_defined() -> Result<()> {
2353 let test_setup = TestSetup::new().await?;
2354 test_setup
2355 .stream_actor
2356 .define_recording(&test_setup.client, 0.into())
2357 .await?;
2358 test_setup
2359 .stream_actor
2360 .finalize_recording(&test_setup.client, 0.into())
2361 .await?;
2362 test_setup
2363 .stream_actor
2364 .define_recording(&test_setup.client, 0.into())
2365 .await?;
2366 assert_actor_failed_with_msg(
2367 &test_setup.proc,
2368 test_setup.stream_actor.actor_id(),
2369 "already defined".into(),
2370 )
2371 .await;
2372 Ok(())
2373 }
2374
2375 #[async_timed_test(timeout_secs = 60)]
2376 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2377 let test_setup = TestSetup::new().await?;
2378 test_setup
2379 .stream_actor
2380 .define_recording(&test_setup.client, 0.into())
2381 .await?;
2382 test_setup
2383 .stream_actor
2384 .finalize_recording(&test_setup.client, 1.into())
2385 .await?;
2386 assert_actor_failed_with_msg(
2387 &test_setup.proc,
2388 test_setup.stream_actor.actor_id(),
2389 "cannot finalize recording that isn't active".into(),
2390 )
2391 .await;
2392 Ok(())
2393 }
2394
2395 #[async_timed_test(timeout_secs = 60)]
2396 async fn test_recording_formal_outside_recording() -> Result<()> {
2397 let test_setup = TestSetup::new().await?;
2398 test_setup
2399 .stream_actor
2400 .recording_formal(&test_setup.client, 0.into(), 0)
2401 .await?;
2402 assert_actor_failed_with_msg(
2403 &test_setup.proc,
2404 test_setup.stream_actor.actor_id(),
2405 "recording_formal called outside of recording".into(),
2406 )
2407 .await;
2408 Ok(())
2409 }
2410
2411 #[async_timed_test(timeout_secs = 60)]
2412 async fn test_recording_result_outside_recording() -> Result<()> {
2413 let test_setup = TestSetup::new().await?;
2414 test_setup
2415 .stream_actor
2416 .recording_result(&test_setup.client, 0.into(), 0)
2417 .await?;
2418 assert_actor_failed_with_msg(
2419 &test_setup.proc,
2420 test_setup.stream_actor.actor_id(),
2421 "recording_result called outside of recording".into(),
2422 )
2423 .await;
2424 Ok(())
2425 }
2426
2427 #[async_timed_test(timeout_secs = 60)]
2428 async fn test_call_recording_other_recording_active() -> Result<()> {
2429 let test_setup = TestSetup::new().await?;
2430 test_setup
2431 .stream_actor
2432 .define_recording(&test_setup.client, 0.into())
2433 .await?;
2434 test_setup
2435 .stream_actor
2436 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2437 .await?;
2438 assert_actor_failed_with_msg(
2439 &test_setup.proc,
2440 test_setup.stream_actor.actor_id(),
2441 "cannot call recording while another recording is active".into(),
2442 )
2443 .await;
2444 Ok(())
2445 }
2446
2447 #[async_timed_test(timeout_secs = 60)]
2448 async fn test_call_recording_not_found() -> Result<()> {
2449 let test_setup = TestSetup::new().await?;
2450 test_setup
2451 .stream_actor
2452 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2453 .await?;
2454 assert_actor_failed_with_msg(
2455 &test_setup.proc,
2456 test_setup.stream_actor.actor_id(),
2457 "not found".into(),
2458 )
2459 .await;
2460 Ok(())
2461 }
2462
2463 #[async_timed_test(timeout_secs = 60)]
2464 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2465 let test_setup = TestSetup::new().await?;
2466
2467 test_setup
2468 .stream_actor
2469 .define_recording(&test_setup.client, 0.into())
2470 .await?;
2471
2472 test_setup
2473 .stream_actor
2474 .recording_formal(&test_setup.client, 1.into(), 0)
2475 .await?;
2476
2477 test_setup
2478 .stream_actor
2479 .finalize_recording(&test_setup.client, 0.into())
2480 .await?;
2481
2482 test_setup
2483 .stream_actor
2484 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2485 .await?;
2486
2487 assert_actor_failed_with_msg(
2488 &test_setup.proc,
2489 test_setup.stream_actor.actor_id(),
2490 "recording_formal called with too few arguments".into(),
2491 )
2492 .await;
2493 Ok(())
2494 }
2495
2496 #[async_timed_test(timeout_secs = 60)]
2497 async fn test_recording_result_too_few_results() -> Result<()> {
2498 let test_setup = TestSetup::new().await?;
2499
2500 test_setup
2501 .stream_actor
2502 .define_recording(&test_setup.client, 0.into())
2503 .await?;
2504
2505 test_setup
2506 .stream_actor
2507 .recording_result(&test_setup.client, 1.into(), 0)
2508 .await?;
2509
2510 test_setup
2511 .stream_actor
2512 .finalize_recording(&test_setup.client, 0.into())
2513 .await?;
2514
2515 test_setup
2516 .stream_actor
2517 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2518 .await?;
2519
2520 assert_actor_failed_with_msg(
2521 &test_setup.proc,
2522 test_setup.stream_actor.actor_id(),
2523 "recording_result called with too few results".into(),
2524 )
2525 .await;
2526 Ok(())
2527 }
2528
2529 #[async_timed_test(timeout_secs = 60)]
2530 async fn test_basic_call_recording() -> Result<()> {
2531 let mut test_setup = TestSetup::new().await?;
2532
2533 test_setup
2537 .stream_actor
2538 .define_recording(&test_setup.client, 0.into())
2539 .await?;
2540
2541 let formal0_ref = 1.into();
2542 let formal0_index = 1;
2543 test_setup
2544 .stream_actor
2545 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2546 .await?;
2547
2548 let formal1_ref = 2.into();
2549 let formal1_index = 0;
2550 test_setup
2551 .stream_actor
2552 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2553 .await?;
2554
2555 let result0_ref = formal0_ref;
2556 let result0_index = 0;
2557 test_setup
2558 .stream_actor
2559 .recording_result(&test_setup.client, result0_ref, result0_index)
2560 .await?;
2561
2562 let result1_ref = formal1_ref;
2563 let result1_index = 1;
2564 test_setup
2565 .stream_actor
2566 .recording_result(&test_setup.client, result1_ref, result1_index)
2567 .await?;
2568
2569 test_setup
2570 .stream_actor
2571 .finalize_recording(&test_setup.client, 0.into())
2572 .await?;
2573
2574 let actual0_ref = 3.into();
2575 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2576
2577 let actual1_ref = 4.into();
2578 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2579
2580 let actual_result0_ref = 5.into();
2583 let actual_result1_ref = 6.into();
2584 test_setup
2585 .stream_actor
2586 .call_recording(
2587 &test_setup.client,
2588 0.into(),
2589 0.into(),
2590 vec![actual_result0_ref, actual_result1_ref],
2591 vec![actual0_ref, actual1_ref],
2592 )
2593 .await?;
2594
2595 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2597 assert!(
2598 test_setup
2599 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2600 .await
2601 );
2602
2603 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2606 Ok(())
2607 }
2608
2609 #[async_timed_test(timeout_secs = 60)]
2610 async fn test_request_status_in_recording() -> Result<()> {
2611 let test_setup = TestSetup::new().await?;
2612 test_setup
2613 .stream_actor
2614 .define_recording(&test_setup.client, 0.into())
2615 .await?;
2616 test_setup
2617 .stream_actor
2618 .request_status(&test_setup.client)
2619 .await
2620 .expect_err("request_status should have failed");
2621 assert_actor_failed_with_msg(
2622 &test_setup.proc,
2623 test_setup.stream_actor.actor_id(),
2624 "request_status not allowed in recording".into(),
2625 )
2626 .await;
2627 Ok(())
2628 }
2629
2630 #[async_timed_test(timeout_secs = 60)]
2631 async fn test_init_comm_in_recording() -> Result<()> {
2632 let test_setup = TestSetup::new().await?;
2633 test_setup
2634 .stream_actor
2635 .define_recording(&test_setup.client, 0.into())
2636 .await?;
2637
2638 let dummy_comm = test_setup
2639 .proc
2640 .spawn::<NcclCommActor>(
2641 "comm",
2642 CommParams::New {
2643 device: CudaDevice::new(0.into()),
2644 unique_id: UniqueId::new()?,
2645 world_size: 1,
2646 rank: 0,
2647 },
2648 )
2649 .await?;
2650
2651 test_setup
2652 .stream_actor
2653 .init_comm(&test_setup.client, dummy_comm)
2654 .await?;
2655 assert_actor_failed_with_msg(
2656 &test_setup.proc,
2657 test_setup.stream_actor.actor_id(),
2658 "init_comm not allowed in recording".into(),
2659 )
2660 .await;
2661 Ok(())
2662 }
2663
2664 #[async_timed_test(timeout_secs = 60)]
2665 async fn test_call_function_in_recording() -> Result<()> {
2666 let mut test_setup = TestSetup::new().await?;
2667
2668 test_setup
2675 .stream_actor
2676 .define_recording(&test_setup.client, 0.into())
2677 .await?;
2678
2679 let formal0_ref = test_setup.next_ref();
2680 let formal0_index = 0;
2681 test_setup
2682 .stream_actor
2683 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2684 .await?;
2685
2686 let formal1_ref = test_setup.next_ref();
2687 let formal1_index = 1;
2688 test_setup
2689 .stream_actor
2690 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2691 .await?;
2692
2693 let captured_ref = test_setup.next_ref();
2694 let result_captured_ref = test_setup.next_ref();
2695 let add_one_function =
2696 ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2697 let add_tensors_function =
2698 ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2699
2700 let add_result_ref_0 = test_setup.next_ref();
2701 test_setup
2702 .stream_actor
2703 .call_function(
2704 &test_setup.client,
2705 CallFunctionParams {
2706 seq: 100.into(),
2707 function: add_tensors_function.clone(),
2708 args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2709 kwargs: HashMap::new(),
2710 results: vec![Some(add_result_ref_0)],
2711 mutates: vec![],
2712 stream: 0.into(),
2713 remote_process_groups: Vec::new(),
2714 },
2715 HashMap::new(),
2716 HashMap::new(),
2717 )
2718 .await?;
2719
2720 test_setup
2721 .stream_actor
2722 .call_function(
2723 &test_setup.client,
2724 CallFunctionParams {
2725 seq: 101.into(),
2726 function: add_one_function,
2727 args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2728 kwargs: HashMap::new(),
2729 results: vec![Some(result_captured_ref)],
2730 mutates: vec![captured_ref],
2731 stream: 0.into(),
2732 remote_process_groups: Vec::new(),
2733 },
2734 HashMap::new(),
2735 HashMap::new(),
2736 )
2737 .await?;
2738
2739 let add_result_ref_1 = test_setup.next_ref();
2740 test_setup
2741 .stream_actor
2742 .call_function(
2743 &test_setup.client,
2744 CallFunctionParams {
2745 seq: 102.into(),
2746 function: add_tensors_function,
2747 args: vec![
2748 WireValue::Ref(add_result_ref_0),
2749 WireValue::Ref(captured_ref),
2750 ],
2751 kwargs: HashMap::new(),
2752 results: vec![Some(add_result_ref_1)],
2753 mutates: vec![],
2754 stream: 0.into(),
2755 remote_process_groups: Vec::new(),
2756 },
2757 HashMap::new(),
2758 HashMap::new(),
2759 )
2760 .await?;
2761
2762 test_setup
2763 .stream_actor
2764 .recording_result(&test_setup.client, add_result_ref_1, 0)
2765 .await?;
2766
2767 test_setup
2768 .stream_actor
2769 .delete_refs(
2770 &test_setup.client,
2771 vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2772 )
2773 .await?;
2774
2775 test_setup
2776 .stream_actor
2777 .finalize_recording(&test_setup.client, 0.into())
2778 .await?;
2779
2780 let actual0_ref = test_setup.next_ref();
2781 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2782
2783 let actual1_ref = test_setup.next_ref();
2784 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2785
2786 test_setup
2787 .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2788 .await?;
2789
2790 let actual_result_ref = test_setup.next_ref();
2791 test_setup
2792 .stream_actor
2793 .call_recording(
2794 &test_setup.client,
2795 0.into(),
2796 0.into(),
2797 vec![actual_result_ref],
2798 vec![actual0_ref, actual1_ref],
2799 )
2800 .await?;
2801
2802 assert!(
2803 test_setup
2804 .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2805 .await
2806 );
2807
2808 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2810
2811 let actual_result_ref = test_setup.next_ref();
2812 test_setup
2813 .stream_actor
2814 .call_recording(
2815 &test_setup.client,
2816 1.into(),
2817 0.into(),
2818 vec![actual_result_ref],
2819 vec![actual0_ref, actual1_ref],
2820 )
2821 .await?;
2822
2823 for ref_ in [actual0_ref, actual1_ref] {
2825 let _ = test_setup
2826 .stream_actor
2827 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2828 .await?
2829 .unwrap()
2830 .unwrap();
2831 }
2832
2833 for ref_ in [captured_ref, actual_result_ref] {
2834 let result_error = test_setup
2835 .stream_actor
2836 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2837 .await?
2838 .unwrap()
2839 .unwrap_err();
2840 let error_str = result_error.to_string();
2842 assert!(
2843 error_str.contains("torch operator error"),
2844 "Error should contain 'torch operator failed': {}",
2845 error_str
2846 );
2847 }
2848
2849 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2850 match controller_msg {
2851 ControllerMessage::RemoteFunctionFailed { seq, error } => {
2852 assert_eq!(seq, 1.into());
2853 assert!(
2854 error.backtrace.contains("torch operator error"),
2855 "Unexpected WorkerError: {:?}",
2856 error
2857 );
2858 }
2859 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2860 };
2861
2862 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2864
2865 let actual_result_ref = test_setup.next_ref();
2869 test_setup
2870 .stream_actor
2871 .call_recording(
2872 &test_setup.client,
2873 2.into(),
2874 0.into(),
2875 vec![actual_result_ref],
2876 vec![actual0_ref, actual1_ref],
2877 )
2878 .await?;
2879
2880 for ref_ in [actual0_ref, actual1_ref] {
2882 let _ = test_setup
2883 .stream_actor
2884 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2885 .await?
2886 .unwrap()
2887 .unwrap();
2888 }
2889
2890 for ref_ in [captured_ref, actual_result_ref] {
2891 let result_error = test_setup
2892 .stream_actor
2893 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2894 .await?
2895 .unwrap()
2896 .unwrap_err();
2897 let error_str = result_error.to_string();
2899 assert!(
2900 error_str.contains("torch operator error"),
2901 "Error should contain input error: {}",
2902 error_str
2903 );
2904 }
2905
2906 check_fetch_result_error(
2910 &test_setup.client,
2911 test_setup.stream_actor.clone(),
2912 3.into(),
2913 captured_ref,
2914 &mut test_setup.controller_rx,
2915 "torch operator error",
2916 )
2917 .await;
2918
2919 Ok(())
2920 }
2921
2922 #[async_timed_test(timeout_secs = 60)]
2923 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2924 let mut test_setup = TestSetup::new().await?;
2925 test_setup
2926 .stream_actor
2927 .define_recording(&test_setup.client, 0.into())
2928 .await?;
2929
2930 let borrow_id = 1;
2931 let tensor_ref = test_setup.next_ref();
2932 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2933
2934 test_setup
2935 .stream_actor
2936 .borrow_create(
2937 &test_setup.client,
2938 borrow_id,
2939 tensor_ref,
2940 first_use_sender.clone(),
2941 )
2942 .await?;
2943
2944 test_setup
2945 .stream_actor
2946 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2947 .await?;
2948
2949 assert_actor_failed_with_msg(
2950 &test_setup.proc,
2951 test_setup.stream_actor.actor_id(),
2952 "duplicate borrow create in recording".into(),
2953 )
2954 .await;
2955
2956 Ok(())
2957 }
2958
2959 #[async_timed_test(timeout_secs = 60)]
2960 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2961 let test_setup = TestSetup::new().await?;
2962 test_setup
2963 .stream_actor
2964 .define_recording(&test_setup.client, 0.into())
2965 .await?;
2966
2967 let borrow_id = 1;
2968 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2969
2970 test_setup
2971 .stream_actor
2972 .borrow_drop(
2973 &test_setup.client,
2974 borrow_id,
2975 Arc::new(Mutex::new(last_use_receiver)),
2976 )
2977 .await?;
2978
2979 assert_actor_failed_with_msg(
2980 &test_setup.proc,
2981 test_setup.stream_actor.actor_id(),
2982 "borrow drop for borrow not defined in recording".into(),
2983 )
2984 .await;
2985
2986 Ok(())
2987 }
2988
2989 #[async_timed_test(timeout_secs = 60)]
2990 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2991 let mut test_setup = TestSetup::new().await?;
2992 test_setup
2993 .stream_actor
2994 .define_recording(&test_setup.client, 0.into())
2995 .await?;
2996
2997 let borrow_id = 1;
2998 let tensor_ref = test_setup.next_ref();
2999 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
3000
3001 test_setup
3002 .stream_actor
3003 .borrow_create(
3004 &test_setup.client,
3005 borrow_id,
3006 tensor_ref,
3007 first_use_sender.clone(),
3008 )
3009 .await?;
3010
3011 test_setup
3013 .stream_actor
3014 .finalize_recording(&test_setup.client, 0.into())
3015 .await?;
3016
3017 assert_actor_failed_with_msg(
3018 &test_setup.proc,
3019 test_setup.stream_actor.actor_id(),
3020 "all borrows created within recording must be dropped within recording".into(),
3021 )
3022 .await;
3023
3024 Ok(())
3025 }
3026
3027 #[async_timed_test(timeout_secs = 60)]
3028 async fn test_borrow_in_recording() -> Result<()> {
3029 let mut test_setup = TestSetup::new().await?;
3030
3031 let borrower_stream = test_setup
3032 .proc
3033 .spawn::<StreamActor>(
3034 "stream1",
3035 StreamParams {
3036 world_size: 1,
3037 rank: 0,
3038 creation_mode: StreamCreationMode::CreateNewStream,
3039 id: 1.into(),
3040 device: Some(CudaDevice::new(0.into())),
3041 controller_actor: test_setup.controller_actor.clone(),
3042 respond_with_python_message: false,
3043 },
3044 )
3045 .await?;
3046
3047 let lender_stream = test_setup.stream_actor.clone();
3048
3049 let borrow_id = 1;
3050 let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
3051 let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
3052
3053 lender_stream
3055 .define_recording(&test_setup.client, 0.into())
3056 .await?;
3057
3058 let formal_ref = test_setup.next_ref();
3059 lender_stream
3060 .recording_formal(&test_setup.client, formal_ref, 0)
3061 .await?;
3062
3063 lender_stream
3064 .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3065 .await?;
3066
3067 lender_stream
3068 .borrow_drop(
3069 &test_setup.client,
3070 borrow_id,
3071 Arc::new(Mutex::new(last_use_receiver)),
3072 )
3073 .await?;
3074
3075 lender_stream
3076 .finalize_recording(&test_setup.client, 0.into())
3077 .await?;
3078
3079 let borrower_tensor_ref = test_setup.next_ref();
3080 let borrower_tensor = TensorCell::new(factory_float_tensor(
3081 &[1.0, 2.0, 3.0],
3082 "cuda".try_into().unwrap(),
3083 ));
3084
3085 borrower_stream
3086 .set_tensor_ref_unit_tests_only(
3087 &test_setup.client,
3088 borrower_tensor_ref,
3089 Ok(borrower_tensor.clone()),
3090 )
3091 .await?;
3092
3093 borrower_stream
3095 .define_recording(&test_setup.client, 0.into())
3096 .await?;
3097
3098 let borrowed_ref = test_setup.next_ref();
3099
3100 borrower_stream
3101 .borrow_first_use(
3102 &test_setup.client,
3103 borrow_id,
3104 borrowed_ref,
3105 Arc::new(Mutex::new(first_use_receiver)),
3106 )
3107 .await?;
3108
3109 let result_ref = test_setup.next_ref();
3110 borrower_stream
3111 .call_function(
3112 &test_setup.client,
3113 CallFunctionParams {
3114 seq: 100.into(),
3115 function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3116 args: vec![
3117 WireValue::Ref(borrowed_ref),
3118 WireValue::Ref(borrower_tensor_ref),
3119 ],
3120 kwargs: HashMap::new(),
3121 results: vec![Some(result_ref)],
3122 mutates: vec![],
3123 stream: 1.into(),
3124 remote_process_groups: Vec::new(),
3125 },
3126 HashMap::new(),
3127 HashMap::new(),
3128 )
3129 .await?;
3130
3131 borrower_stream
3132 .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3133 .await?;
3134
3135 borrower_stream
3136 .recording_result(&test_setup.client, result_ref, 0)
3137 .await?;
3138
3139 borrower_stream
3140 .finalize_recording(&test_setup.client, 0.into())
3141 .await?;
3142
3143 let input_tensor_ref = test_setup.next_ref();
3145 test_setup
3146 .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3147 .await?;
3148
3149 let result_tensor_ref = test_setup.next_ref();
3150
3151 let lender_future = lender_stream.call_recording(
3152 &test_setup.client,
3153 0.into(),
3154 0.into(),
3155 vec![],
3156 vec![input_tensor_ref],
3157 );
3158
3159 let borrower_future = borrower_stream.call_recording(
3160 &test_setup.client,
3161 0.into(),
3162 0.into(),
3163 vec![result_tensor_ref],
3164 vec![],
3165 );
3166
3167 tokio::try_join!(lender_future, borrower_future)?;
3168
3169 let result_tensor = borrower_stream
3170 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3171 .await?
3172 .unwrap()
3173 .unwrap();
3174
3175 let expected_tensor = TensorCell::new(factory_float_tensor(
3176 &[5.0, 7.0, 9.0],
3177 "cpu".try_into().unwrap(),
3178 ));
3179 assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3180
3181 let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3183 &[1.0, 2.0],
3184 "cuda".try_into().unwrap(),
3185 ));
3186 borrower_stream
3187 .set_tensor_ref_unit_tests_only(
3188 &test_setup.client,
3189 borrower_tensor_ref,
3190 Ok(invalid_borrower_tensor.clone()),
3191 )
3192 .await?;
3193
3194 let lender_future = lender_stream.call_recording(
3196 &test_setup.client,
3197 1.into(),
3198 0.into(),
3199 vec![],
3200 vec![input_tensor_ref],
3201 );
3202
3203 let borrower_future = borrower_stream.call_recording(
3204 &test_setup.client,
3205 1.into(),
3206 0.into(),
3207 vec![result_tensor_ref],
3208 vec![],
3209 );
3210
3211 tokio::try_join!(lender_future, borrower_future)?;
3212
3213 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3215 match controller_msg {
3216 ControllerMessage::RemoteFunctionFailed { seq, error } => {
3217 assert_eq!(seq, 1.into());
3218 assert!(
3219 error.backtrace.contains("recording failed"),
3220 "Unexpected WorkerError: {:?}",
3221 error
3222 );
3223 assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3224 }
3225 _ => panic!("Unexpected controller message: {:?}", controller_msg),
3226 };
3227
3228 check_fetch_result_value(
3230 &test_setup.client,
3231 lender_stream.clone(),
3232 2.into(),
3233 input_tensor_ref,
3234 &mut test_setup.controller_rx,
3235 )
3236 .await;
3237
3238 let input_error = fake_seq_error(anyhow!("input error"));
3240 lender_stream
3241 .set_tensor_ref_unit_tests_only(
3242 &test_setup.client,
3243 input_tensor_ref,
3244 Err(input_error.clone()),
3245 )
3246 .await?;
3247
3248 let lender_future = lender_stream.call_recording(
3249 &test_setup.client,
3250 3.into(),
3251 0.into(),
3252 vec![],
3253 vec![input_tensor_ref],
3254 );
3255
3256 let borrower_future = borrower_stream.call_recording(
3257 &test_setup.client,
3258 3.into(),
3259 0.into(),
3260 vec![result_tensor_ref],
3261 vec![],
3262 );
3263
3264 tokio::try_join!(lender_future, borrower_future)?;
3265
3266 let result_error = borrower_stream
3268 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3269 .await?
3270 .unwrap()
3271 .unwrap_err();
3272
3273 let error_str = result_error.to_string();
3275 assert!(
3276 error_str.contains("input error"),
3277 "Error should contain input error: {}",
3278 error_str
3279 );
3280
3281 let input_error_str = input_error.to_string();
3284 assert!(
3285 error_str.contains(&input_error_str),
3286 "Error should contain the original error: {}",
3287 error_str
3288 );
3289
3290 check_fetch_result_error(
3292 &test_setup.client,
3293 lender_stream,
3294 4.into(),
3295 input_tensor_ref,
3296 &mut test_setup.controller_rx,
3297 "input error",
3298 )
3299 .await;
3300
3301 check_fetch_result_error(
3303 &test_setup.client,
3304 borrower_stream,
3305 5.into(),
3306 result_tensor_ref,
3307 &mut test_setup.controller_rx,
3308 "input error",
3309 )
3310 .await;
3311
3312 Ok(())
3313 }
3314
3315 #[async_timed_test(timeout_secs = 60)]
3316 async fn test_reduce_in_recording() -> Result<()> {
3317 let mut test_setup = TestSetup::new().await?;
3318 let recording_ref = test_setup.next_ref();
3319
3320 let comm = Arc::new(
3321 test_setup
3322 .proc
3323 .spawn::<NcclCommActor>(
3324 "comm",
3325 CommParams::New {
3326 device: CudaDevice::new(0.into()),
3327 unique_id: UniqueId::new()?,
3328 world_size: 1,
3329 rank: 0,
3330 },
3331 )
3332 .await?,
3333 );
3334
3335 let factory = Factory {
3336 size: vec![3],
3337 dtype: torch_sys::ScalarType::Float,
3338 layout: torch_sys::Layout::Strided,
3339 device: "cuda".try_into().unwrap(),
3340 };
3341
3342 let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3343
3344 test_setup
3345 .stream_actor
3346 .define_recording(&test_setup.client, recording_ref)
3347 .await?;
3348
3349 let formal_tensor_ref_0 = test_setup.next_ref();
3350 let formal_tensor_ref_1 = test_setup.next_ref();
3351 let formal_tensor_ref_2 = test_setup.next_ref();
3352
3353 test_setup
3354 .stream_actor
3355 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3356 .await?;
3357 test_setup
3358 .stream_actor
3359 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3360 .await?;
3361 test_setup
3362 .stream_actor
3363 .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3364 .await?;
3365
3366 let intermediate_tensor_ref_0 = test_setup.next_ref();
3367
3368 test_setup
3370 .stream_actor
3371 .reduce(
3372 &test_setup.client,
3373 comm.clone(),
3374 1,
3375 intermediate_tensor_ref_0,
3376 formal_tensor_ref_0,
3377 factory.clone(),
3378 reduction.clone(),
3379 false,
3380 true,
3381 None,
3382 )
3383 .await?;
3384
3385 let intermediate_tensor_ref_1 = test_setup.next_ref();
3387 test_setup
3388 .stream_actor
3389 .reduce(
3390 &test_setup.client,
3391 comm.clone(),
3392 1,
3393 intermediate_tensor_ref_1,
3394 formal_tensor_ref_1,
3395 factory.clone(),
3396 reduction.clone(),
3397 false,
3398 false,
3399 None,
3400 )
3401 .await?;
3402
3403 let intermediate_tensor_ref_2 = test_setup.next_ref();
3404
3405 test_setup
3407 .stream_actor
3408 .reduce(
3409 &test_setup.client,
3410 comm.clone(),
3411 1,
3412 intermediate_tensor_ref_2,
3413 intermediate_tensor_ref_1,
3414 factory.clone(),
3415 reduction.clone(),
3416 false,
3417 false,
3418 Some(formal_tensor_ref_2),
3419 )
3420 .await?;
3421
3422 test_setup
3423 .stream_actor
3424 .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3425 .await?;
3426
3427 test_setup
3428 .stream_actor
3429 .finalize_recording(&test_setup.client, recording_ref)
3430 .await?;
3431
3432 let input_tensor_ref_0 = test_setup.next_ref();
3433 let input_tensor_ref_1 = test_setup.next_ref();
3434 let input_tensor_ref_2 = test_setup.next_ref();
3435
3436 test_setup
3437 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3438 .await?;
3439
3440 test_setup
3441 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3442 .await?;
3443
3444 test_setup
3445 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3446 .await?;
3447
3448 let output_ref = test_setup.next_ref();
3449
3450 test_setup
3451 .stream_actor
3452 .call_recording(
3453 &test_setup.client,
3454 0.into(),
3455 recording_ref,
3456 vec![output_ref],
3457 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3458 )
3459 .await?;
3460
3461 assert!(
3463 test_setup
3464 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3465 .await
3466 );
3467 for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3469 assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3470 }
3471
3472 let input_error = fake_seq_error(anyhow!("input error"));
3474 test_setup
3475 .stream_actor
3476 .set_tensor_ref_unit_tests_only(
3477 &test_setup.client,
3478 input_tensor_ref_0,
3479 Err(input_error.clone()),
3480 )
3481 .await?;
3482
3483 test_setup
3484 .stream_actor
3485 .call_recording(
3486 &test_setup.client,
3487 1.into(),
3488 recording_ref,
3489 vec![output_ref],
3490 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3491 )
3492 .await?;
3493
3494 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3496 test_setup
3497 .validate_dependent_error(ref_, input_error.clone())
3498 .await;
3499 }
3500
3501 assert!(
3503 test_setup
3504 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3505 .await
3506 );
3507
3508 check_fetch_result_value(
3510 &test_setup.client,
3511 test_setup.stream_actor.clone(),
3512 2.into(),
3513 input_tensor_ref_1,
3514 &mut test_setup.controller_rx,
3515 )
3516 .await;
3517
3518 test_setup
3520 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3521 .await?;
3522 test_setup
3523 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3524 .await?;
3525
3526 test_setup
3528 .stream_actor
3529 .set_tensor_ref_unit_tests_only(
3530 &test_setup.client,
3531 input_tensor_ref_1,
3532 Err(input_error.clone()),
3533 )
3534 .await?;
3535
3536 test_setup
3537 .stream_actor
3538 .call_recording(
3539 &test_setup.client,
3540 3.into(),
3541 recording_ref,
3542 vec![output_ref],
3543 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3544 )
3545 .await?;
3546
3547 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3550 test_setup
3551 .validate_dependent_error(ref_, input_error.clone())
3552 .await;
3553 }
3554
3555 check_fetch_result_error(
3557 &test_setup.client,
3558 test_setup.stream_actor.clone(),
3559 4.into(),
3560 input_tensor_ref_1,
3561 &mut test_setup.controller_rx,
3562 "input error",
3563 )
3564 .await;
3565
3566 test_setup
3568 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3569 .await?;
3570 test_setup
3571 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3572 .await?;
3573
3574 test_setup
3576 .stream_actor
3577 .set_tensor_ref_unit_tests_only(
3578 &test_setup.client,
3579 input_tensor_ref_2,
3580 Err(input_error.clone()),
3581 )
3582 .await?;
3583
3584 test_setup
3585 .stream_actor
3586 .call_recording(
3587 &test_setup.client,
3588 5.into(),
3589 recording_ref,
3590 vec![output_ref],
3591 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3592 )
3593 .await?;
3594
3595 assert!(
3597 test_setup
3598 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3599 .await
3600 );
3601
3602 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3605 test_setup
3606 .validate_dependent_error(ref_, input_error.clone())
3607 .await;
3608 }
3609
3610 check_fetch_result_value(
3612 &test_setup.client,
3613 test_setup.stream_actor.clone(),
3614 6.into(),
3615 input_tensor_ref_1,
3616 &mut test_setup.controller_rx,
3617 )
3618 .await;
3619
3620 Ok(())
3621 }
3622
3623 #[async_timed_test(timeout_secs = 60)]
3624 async fn test_send_tensor_in_recording() -> Result<()> {
3625 let mut test_setup = TestSetup::new_with_world_size(2).await?;
3626 let recording_ref = test_setup.next_ref();
3627
3628 let unique_id = UniqueId::new()?;
3629 let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3630 "comm0",
3631 CommParams::New {
3632 device: CudaDevice::new(0.into()),
3633 unique_id: unique_id.clone(),
3634 world_size: 2,
3635 rank: 0,
3636 },
3637 );
3638 let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3639 "comm1",
3640 CommParams::New {
3641 device: CudaDevice::new(1.into()),
3642 unique_id,
3643 world_size: 2,
3644 rank: 1,
3645 },
3646 );
3647 let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3648 let comm0 = Arc::new(comm0);
3649 let comm1 = Arc::new(comm1);
3650
3651 let factory = Factory {
3652 size: vec![3],
3653 dtype: torch_sys::ScalarType::Float,
3654 layout: torch_sys::Layout::Strided,
3655 device: "cuda".try_into().unwrap(),
3656 };
3657
3658 let send_stream = test_setup.stream_actor.clone();
3659 let recv_stream = test_setup
3660 .proc
3661 .spawn::<StreamActor>(
3662 "recv_stream",
3663 StreamParams {
3664 world_size: 2,
3665 rank: 1,
3666 creation_mode: StreamCreationMode::CreateNewStream,
3667 id: 1.into(),
3668 device: Some(CudaDevice::new(1.into())),
3669 controller_actor: test_setup.controller_actor.clone(),
3670 respond_with_python_message: false,
3671 },
3672 )
3673 .await?;
3674
3675 send_stream
3676 .define_recording(&test_setup.client, recording_ref)
3677 .await?;
3678 recv_stream
3679 .define_recording(&test_setup.client, recording_ref)
3680 .await?;
3681
3682 let formal_tensor_ref_0 = test_setup.next_ref();
3683 let formal_tensor_ref_1 = test_setup.next_ref();
3684
3685 send_stream
3686 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3687 .await?;
3688 send_stream
3689 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3690 .await?;
3691
3692 let _ref = test_setup.next_ref();
3693 send_stream
3694 .send_tensor(
3695 &test_setup.client,
3696 _ref,
3697 None,
3698 Some(1),
3699 formal_tensor_ref_0,
3700 factory.clone(),
3701 comm0.clone(),
3702 )
3703 .await?;
3704
3705 let result_ref_0 = test_setup.next_ref();
3706 let _ref = test_setup.next_ref();
3707 recv_stream
3708 .send_tensor(
3709 &test_setup.client,
3710 result_ref_0,
3711 Some(0),
3712 None,
3713 _ref,
3714 factory.clone(),
3715 comm1,
3716 )
3717 .await?;
3718
3719 let result_ref_1 = test_setup.next_ref();
3720 send_stream
3721 .send_tensor(
3722 &test_setup.client,
3723 result_ref_1,
3724 Some(0),
3725 Some(0),
3726 formal_tensor_ref_1,
3727 factory.clone(),
3728 comm0,
3729 )
3730 .await?;
3731
3732 send_stream
3733 .recording_result(&test_setup.client, result_ref_1, 0)
3734 .await?;
3735 recv_stream
3736 .recording_result(&test_setup.client, result_ref_0, 0)
3737 .await?;
3738
3739 send_stream
3740 .finalize_recording(&test_setup.client, recording_ref)
3741 .await?;
3742 recv_stream
3743 .finalize_recording(&test_setup.client, recording_ref)
3744 .await?;
3745
3746 let input_tensor_ref_0 = test_setup.next_ref();
3747 let input_tensor_ref_1 = test_setup.next_ref();
3748 test_setup
3749 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3750 .await?;
3751 test_setup
3752 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3753 .await?;
3754
3755 let actual_result_ref_0 = test_setup.next_ref();
3756 let actual_result_ref_1 = test_setup.next_ref();
3757 let send_fut = send_stream.call_recording(
3758 &test_setup.client,
3759 0.into(),
3760 recording_ref,
3761 vec![actual_result_ref_1],
3762 vec![input_tensor_ref_0, input_tensor_ref_1],
3763 );
3764 let recv_fut = recv_stream.call_recording(
3765 &test_setup.client,
3766 0.into(),
3767 recording_ref,
3768 vec![actual_result_ref_0],
3769 vec![],
3770 );
3771 tokio::try_join!(send_fut, recv_fut)?;
3772
3773 assert!(
3774 test_setup
3775 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3776 .await
3777 );
3778 assert!(
3779 test_setup
3780 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3781 .await
3782 );
3783 assert!(
3784 test_setup
3785 .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3786 .await
3787 );
3788
3789 let actual_result_0 = recv_stream
3790 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3791 .await
3792 .unwrap()
3793 .unwrap()
3794 .unwrap();
3795 assert!(allclose(
3796 &actual_result_0.borrow(),
3797 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3798 )?);
3799
3800 check_fetch_result_value(
3802 &test_setup.client,
3803 send_stream.clone(),
3804 1.into(),
3805 actual_result_ref_1,
3806 &mut test_setup.controller_rx,
3807 )
3808 .await;
3809 check_fetch_result_value(
3810 &test_setup.client,
3811 recv_stream.clone(),
3812 2.into(),
3813 actual_result_ref_0,
3814 &mut test_setup.controller_rx,
3815 )
3816 .await;
3817
3818 let input_error = fake_seq_error(anyhow!("input error"));
3819 send_stream
3820 .set_tensor_ref_unit_tests_only(
3821 &test_setup.client,
3822 input_tensor_ref_0,
3823 Err(input_error.clone()),
3824 )
3825 .await?;
3826
3827 let send_fut = send_stream.call_recording(
3828 &test_setup.client,
3829 3.into(),
3830 recording_ref,
3831 vec![actual_result_ref_1],
3832 vec![input_tensor_ref_0, input_tensor_ref_1],
3833 );
3834 let recv_fut = recv_stream.call_recording(
3835 &test_setup.client,
3836 3.into(),
3837 recording_ref,
3838 vec![actual_result_ref_0],
3839 vec![],
3840 );
3841 tokio::try_join!(send_fut, recv_fut)?;
3842
3843 let _ = recv_stream
3845 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3846 .await
3847 .unwrap()
3848 .unwrap()
3849 .unwrap();
3850
3851 test_setup
3852 .validate_dependent_error(actual_result_ref_1, input_error.clone())
3853 .await;
3854
3855 assert!(
3857 test_setup
3858 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3859 .await
3860 );
3861
3862 check_fetch_result_error(
3864 &test_setup.client,
3865 send_stream.clone(),
3866 4.into(),
3867 actual_result_ref_1,
3868 &mut test_setup.controller_rx,
3869 "input error",
3870 )
3871 .await;
3872 check_fetch_result_value(
3873 &test_setup.client,
3874 recv_stream.clone(),
3875 5.into(),
3876 actual_result_ref_0,
3877 &mut test_setup.controller_rx,
3878 )
3879 .await;
3880
3881 test_setup
3882 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3883 .await?;
3884 send_stream
3885 .set_tensor_ref_unit_tests_only(
3886 &test_setup.client,
3887 input_tensor_ref_1,
3888 Err(input_error.clone()),
3889 )
3890 .await?;
3891
3892 let send_fut = send_stream.call_recording(
3893 &test_setup.client,
3894 6.into(),
3895 recording_ref,
3896 vec![actual_result_ref_1],
3897 vec![input_tensor_ref_0, input_tensor_ref_1],
3898 );
3899 let recv_fut = recv_stream.call_recording(
3900 &test_setup.client,
3901 6.into(),
3902 recording_ref,
3903 vec![actual_result_ref_0],
3904 vec![],
3905 );
3906 tokio::try_join!(send_fut, recv_fut)?;
3907
3908 let actual_result_0 = recv_stream
3909 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3910 .await
3911 .unwrap()
3912 .unwrap()
3913 .unwrap();
3914 assert!(allclose(
3915 &actual_result_0.borrow(),
3916 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3917 )?);
3918
3919 assert!(
3920 test_setup
3921 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3922 .await
3923 );
3924
3925 test_setup
3926 .validate_dependent_error(actual_result_ref_1, input_error)
3927 .await;
3928
3929 check_fetch_result_error(
3931 &test_setup.client,
3932 send_stream.clone(),
3933 7.into(),
3934 actual_result_ref_1,
3935 &mut test_setup.controller_rx,
3936 "input error",
3937 )
3938 .await;
3939 check_fetch_result_value(
3940 &test_setup.client,
3941 recv_stream.clone(),
3942 8.into(),
3943 actual_result_ref_0,
3944 &mut test_setup.controller_rx,
3945 )
3946 .await;
3947
3948 Ok(())
3949 }
3950
3951 #[async_timed_test(timeout_secs = 60)]
3952 async fn test_set_value_in_recording_valid_pipe() -> Result<()> {
3953 let mut test_setup = TestSetup::new().await?;
3954
3955 let (pipe_tx, mut pipe_rx) = test_setup.client.open_port();
3956
3957 let recording_ref = test_setup.next_ref();
3958 test_setup
3959 .stream_actor
3960 .define_recording(&test_setup.client, recording_ref)
3961 .await?;
3962
3963 let result_ref_0 = test_setup.next_ref();
3964
3965 test_setup
3966 .stream_actor
3967 .set_value(
3968 &test_setup.client,
3969 0.into(),
3970 vec![Some(result_ref_0)],
3971 pipe_tx,
3972 )
3973 .await?;
3974
3975 test_setup
3976 .stream_actor
3977 .recording_result(&test_setup.client, result_ref_0, 0)
3978 .await?;
3979
3980 test_setup
3981 .stream_actor
3982 .finalize_recording(&test_setup.client, recording_ref)
3983 .await?;
3984
3985 let real_result_ref = test_setup.next_ref();
3986 let recording_fut = test_setup.stream_actor.call_recording(
3987 &test_setup.client,
3988 0.into(),
3989 recording_ref,
3990 vec![real_result_ref],
3991 vec![],
3992 );
3993
3994 let pipe_fut = async {
3995 let msg = pipe_rx.recv().await.unwrap();
3996 match msg {
3997 PipeMessage::RecvValue(tx) => {
3998 tx.send(PyTree::from(RValue::Tensor(TensorCell::new(
3999 factory_float_tensor(&[1.0, 2.0, 3.0], "cuda".try_into().unwrap()),
4000 ))))
4001 .unwrap();
4002 }
4003 _ => panic!("Unexpected message"),
4004 }
4005 Ok(())
4006 };
4007
4008 tokio::try_join!(recording_fut, pipe_fut)?;
4009
4010 assert!(test_setup.allclose(real_result_ref, &[1.0, 2.0, 3.0]).await);
4011
4012 drop(pipe_rx);
4014
4015 let real_result_ref = test_setup.next_ref();
4016 test_setup
4017 .stream_actor
4018 .call_recording(
4019 &test_setup.client,
4020 1.into(),
4021 recording_ref,
4022 vec![real_result_ref],
4023 vec![],
4024 )
4025 .await?;
4026
4027 let real_result_err = test_setup
4028 .stream_actor
4029 .get_tensor_ref_unit_tests_only(&test_setup.client, real_result_ref)
4030 .await?
4031 .unwrap()
4032 .unwrap_err();
4033 let error_str = real_result_err.to_string();
4035 assert!(
4036 error_str.contains("send error"),
4037 "Error should contain 'send error': {}",
4038 error_str
4039 );
4040
4041 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
4042 match controller_msg {
4043 ControllerMessage::RemoteFunctionFailed { seq, error } => {
4044 assert_eq!(seq, 1.into());
4045 assert!(
4046 error.backtrace.contains("send error"),
4047 "Unexpected WorkerError: {:?}",
4048 error
4049 );
4050 }
4051 _ => panic!("Unexpected controller message: {:?}", controller_msg),
4052 };
4053
4054 Ok(())
4055 }
4056}