1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::Mailbox;
36use hyperactor::mailbox::OncePortHandle;
37use hyperactor::mailbox::PortReceiver;
38use hyperactor::proc::Proc;
39use monarch_hyperactor::actor::PythonMessage;
40use monarch_hyperactor::actor::PythonMessageKind;
41use monarch_hyperactor::local_state_broker::BrokerId;
42use monarch_hyperactor::local_state_broker::LocalState;
43use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::CallFunctionError;
50use monarch_messages::worker::CallFunctionParams;
51use monarch_messages::worker::SeqError;
52use monarch_messages::worker::StreamRef;
53use monarch_types::PyTree;
54use monarch_types::SerializablePyErr;
55use monarch_types::TryIntoPyObjectUnsafe;
56use pyo3::prelude::*;
57use pyo3::types::PyTuple;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys::BorrowType;
62use torch_sys::CudaDevice;
63use torch_sys::MultiBorrow;
64use torch_sys::RValue;
65use torch_sys::TensorCell;
66use torch_sys::deep_clone;
67use torch_sys::factory_empty;
68use torch_sys::factory_zeros;
69use torch_sys_cuda::cuda::Event;
70use torch_sys_cuda::cuda::Stream;
71use tracing_subscriber::fmt::Subscriber;
72
73use crate::ControllerActor;
74use crate::DeviceMesh;
75use crate::Factory;
76use crate::Reduction;
77use crate::Ref;
78use crate::ResolvableFunction;
79use crate::StreamCreationMode;
80use crate::WireValue;
81use crate::comm::CommBackend;
82use crate::comm::CommMessage;
83use crate::comm::CommMessageClient;
84use crate::comm::NcclCommActor;
85use crate::pipe::PipeMessage;
86
87pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
88
89thread_local! {
91 pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
92 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
93 pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
94}
95
96fn pickle_python_result(
97 py: Python<'_>,
98 result: Bound<'_, PyAny>,
99 worker_actor_id: ActorId,
100) -> Result<PythonMessage, anyhow::Error> {
101 let pickle = py
102 .import("monarch._src.actor.actor_mesh")
103 .unwrap()
104 .getattr("_pickle")
105 .unwrap();
106 let data: Vec<u8> = pickle
107 .call1((result,))
108 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
109 .extract()
110 .unwrap();
111 Ok(PythonMessage::new_from_buf(
112 PythonMessageKind::Result {
113 rank: Some(worker_actor_id.rank()),
114 },
115 data,
116 ))
117}
118
119#[derive(Debug)]
120struct Recording {
121 messages: Vec<StreamMessage>,
122}
123
124impl Recording {
125 fn new() -> Self {
126 Self {
127 messages: Vec::new(),
128 }
129 }
130}
131
132#[derive(Debug, PartialEq)]
133enum RecordingState {
134 Defining {
135 recording: Ref,
136 defined_borrows: HashSet<u64>,
139 },
140 Running,
141}
142
143#[derive(Handler, HandleClient, Debug, Named)]
146pub enum StreamMessage {
147 CallFunction(
148 CallFunctionParams,
149 HashMap<Ref, DeviceMesh>,
150 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
151 ),
152
153 BorrowCreate {
154 borrow: u64,
156 tensor: Ref,
158 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
161 },
162
163 BorrowFirstUse {
164 borrow: u64,
166 result: Ref,
168 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
171 },
172
173 BorrowLastUse {
174 borrow: u64,
176 result: Ref,
178 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
180 },
181
182 BorrowDrop {
183 borrow: u64,
184 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
186 },
187
188 DeleteRefs(Vec<Ref>),
189
190 RequestStatus(#[reply] OncePortHandle<()>),
191
192 InitComm(ActorHandle<NcclCommActor>),
193
194 Reduce {
195 comm: Arc<ActorHandle<NcclCommActor>>,
196 dim_size: i64,
197 result: Ref,
198 local_tensor: Ref,
199 factory: Factory,
200 reduction: Reduction,
201 scatter: bool,
202 in_place: bool,
203 out: Option<Ref>,
204 },
205
206 SendTensor {
207 result: Ref,
208 from_rank: Option<usize>,
209 to_rank: Option<usize>,
210 tensor: Ref,
211 factory: Factory,
212 comm: Arc<ActorHandle<NcclCommActor>>,
213 },
214
215 SendValue {
216 seq: Seq,
217 worker_actor_id: ActorId,
218 mutates: Vec<Ref>,
219 function: Option<ResolvableFunction>,
220 args: Vec<WireValue>,
221 kwargs: HashMap<String, WireValue>,
222 device_meshes: HashMap<Ref, DeviceMesh>,
223 pipe: Option<PortHandle<PipeMessage>>,
224 },
225
226 SetValue {
227 seq: Seq,
228 results: Vec<Option<Ref>>,
229 pipe: PortHandle<PipeMessage>,
230 },
231
232 DefineRecording {
233 recording: Ref,
234 },
235
236 FinalizeRecording {
237 recording: Ref,
238 },
239
240 CallRecording {
241 seq: Seq,
242 recording: Ref,
243 results: Vec<Ref>,
244 actuals: Vec<Ref>,
245 },
246
247 RecordingFormal {
248 result: Ref,
249 argument_index: usize,
250 },
251
252 RecordingResult {
253 result: Ref,
254 output_index: usize,
255 },
256
257 SetRefUnitTestsOnly(Ref, WireValue),
258
259 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
260
261 GetRefUnitTestsOnly(
262 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
264 ),
265
266 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
267
268 SendResultOfActorCall(ActorId, ActorCallParams),
269 CallActorMethod(ActorMethodParams),
270}
271
272impl StreamMessage {
273 fn clone_for_recording(&self) -> Self {
274 match self {
275 StreamMessage::RecordingFormal {
276 result,
277 argument_index,
278 } => StreamMessage::RecordingFormal {
279 result: *result,
280 argument_index: *argument_index,
281 },
282 StreamMessage::RecordingResult {
283 result,
284 output_index,
285 } => StreamMessage::RecordingResult {
286 result: *result,
287 output_index: *output_index,
288 },
289 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
290 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
291 StreamMessage::CallFunction(
292 params.clone(),
293 device_meshes.clone(),
294 remote_process_groups.clone(),
295 )
296 }
297 StreamMessage::BorrowCreate {
298 borrow,
299 tensor,
300 first_use_sender,
301 } => StreamMessage::BorrowCreate {
302 borrow: *borrow,
303 tensor: *tensor,
304 first_use_sender: first_use_sender.clone(),
305 },
306 StreamMessage::BorrowFirstUse {
307 borrow,
308 result,
309 first_use_receiver,
310 } => StreamMessage::BorrowFirstUse {
311 borrow: *borrow,
312 result: *result,
313 first_use_receiver: first_use_receiver.clone(),
314 },
315 StreamMessage::BorrowLastUse {
316 borrow,
317 result,
318 last_use_sender,
319 } => StreamMessage::BorrowLastUse {
320 borrow: *borrow,
321 result: *result,
322 last_use_sender: last_use_sender.clone(),
323 },
324 StreamMessage::BorrowDrop {
325 borrow,
326 last_use_receiver,
327 } => StreamMessage::BorrowDrop {
328 borrow: *borrow,
329 last_use_receiver: last_use_receiver.clone(),
330 },
331 StreamMessage::Reduce {
332 comm,
333 dim_size,
334 result,
335 local_tensor,
336 factory,
337 reduction,
338 scatter,
339 in_place,
340 out,
341 } => StreamMessage::Reduce {
342 comm: comm.clone(),
343 dim_size: *dim_size,
344 result: *result,
345 local_tensor: *local_tensor,
346 factory: factory.clone(),
347 reduction: reduction.clone(),
348 scatter: *scatter,
349 in_place: *in_place,
350 out: out.clone(),
351 },
352 StreamMessage::SendTensor {
353 result,
354 from_rank,
355 to_rank,
356 tensor,
357 factory,
358 comm,
359 } => StreamMessage::SendTensor {
360 result: *result,
361 from_rank: *from_rank,
362 to_rank: *to_rank,
363 tensor: *tensor,
364 factory: factory.clone(),
365 comm: comm.clone(),
366 },
367 StreamMessage::SetValue { seq, results, pipe } => StreamMessage::SetValue {
368 seq: seq.clone(),
369 results: results.clone(),
370 pipe: pipe.clone(),
371 },
372 other => panic!(
373 "StreamMessage variant not supported in recording: {:?}",
374 other
375 ),
376 }
377 }
378
379 fn get_defined_refs(&self) -> HashSet<Ref> {
381 match self {
382 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
383 StreamMessage::CallFunction(params, ..) => {
384 params.results.iter().filter_map(|&ref_| ref_).collect()
385 }
386 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
387 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
388 StreamMessage::SendTensor {
389 result, from_rank, ..
390 } => {
391 if from_rank.is_some() {
392 HashSet::from([*result])
393 } else {
394 HashSet::new()
395 }
396 }
397 StreamMessage::SetValue { results, .. } => {
398 results.iter().filter_map(|&ref_| ref_).collect()
399 }
400 _ => HashSet::new(),
402 }
403 }
404
405 fn get_mutated_refs(&self) -> HashSet<Ref> {
407 match self {
408 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
409 StreamMessage::Reduce {
410 out,
411 in_place,
412 local_tensor,
413 ..
414 } => {
415 if *in_place {
416 HashSet::from([*local_tensor])
417 } else if let Some(out) = out {
418 HashSet::from([*out])
419 } else {
420 HashSet::new()
421 }
422 }
423 _ => HashSet::new(),
425 }
426 }
427}
428
429#[derive(Debug)]
438pub struct StreamActor {
439 world_size: usize,
440 rank: usize,
441 env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
445 creation_mode: StreamCreationMode,
447 cuda_stream: OnceLock<Option<Stream>>,
453 device: Option<CudaDevice>,
455 comm: Option<ActorHandle<NcclCommActor>>,
457 controller_actor: ActorRef<ControllerActor>,
459 remote_process_groups: HashMap<Ref, PyObject>,
460 recordings: HashMap<Ref, Recording>,
461 active_recording: Option<RecordingState>,
462 respond_with_python_message: bool,
463 last_seq_error: Option<Arc<SeqError>>,
464}
465
466#[derive(Debug, Clone)]
468pub struct StreamParams {
469 pub world_size: usize,
470 pub rank: usize,
471 pub creation_mode: StreamCreationMode,
473 pub id: StreamRef,
475 pub device: Option<CudaDevice>,
478 pub controller_actor: ActorRef<ControllerActor>,
480 pub respond_with_python_message: bool,
481}
482
483#[async_trait]
484impl Actor for StreamActor {
485 type Params = StreamParams;
486 async fn new(
487 StreamParams {
488 world_size,
489 rank,
490 id: _,
491 device,
492 controller_actor,
493 creation_mode,
494 respond_with_python_message,
495 }: Self::Params,
496 ) -> Result<Self> {
497 Ok(Self {
498 world_size,
499 rank,
500 env: HashMap::new(),
501 creation_mode,
502 cuda_stream: OnceLock::new(),
503 device,
504 comm: None,
505 controller_actor,
506 remote_process_groups: HashMap::new(),
507 recordings: HashMap::new(),
508 active_recording: None,
509 respond_with_python_message,
510 last_seq_error: None,
511 })
512 }
513
514 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
515 CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
519 controller_actor_ref.set(self.controller_actor.clone()).ok()
520 });
521 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
522 ROOT_ACTOR_ID.with(|root_actor_id| {
523 root_actor_id
524 .set(ActorId::root(
525 cx.self_id().proc_id().clone(),
526 cx.self_id().name().to_string(),
527 ))
528 .ok()
529 });
530 if let Some(stream) = self.cuda_stream() {
532 Stream::set_current_stream(stream);
533 }
534 Ok(())
535 }
536
537 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
543 where
544 F: Future + Send + 'static,
545 F::Output: Send + 'static,
546 {
547 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
548 let builder = std::thread::Builder::new().name("worker-stream".to_string());
556 let _thread_handle = builder.spawn(move || {
557 let rt = tokio::runtime::Builder::new_multi_thread()
561 .worker_threads(1)
562 .enable_io()
563 .build()
564 .unwrap();
565 let result = rt.block_on(async {
566 tokio::task::block_in_place(|| {
567 Python::with_gil(|py| {
572 py.allow_threads(|| {
573 let result = Handle::current().block_on(future);
574 if join_tx.send(result).is_err() {
575 panic!("could not send join result")
576 }
577 })
578 })
579 })
580 });
581 rt.shutdown_timeout(Duration::from_weeks(1));
582 result
583 });
584
585 tokio::spawn(async move { join_rx.await.unwrap() })
588 }
589}
590
591#[derive(Debug)]
593enum PyArg<'a> {
594 RValue(RValue),
595 DeviceMesh(&'a DeviceMesh),
596 PyObject(PyObject),
597}
598
599impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
601 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
602 match self {
603 PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
606 PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
607 PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
608 }
609 }
610}
611
612impl StreamActor {
613 fn cuda_stream(&self) -> Option<&Stream> {
614 self.cuda_stream
615 .get_or_init(|| {
616 self.device.map(|device| match self.creation_mode {
617 StreamCreationMode::UseDefaultStream => {
618 Stream::get_current_stream_on_device(device)
619 }
620 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
621 })
622 })
623 .as_ref()
624 }
625
626 fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
627 let rvalue = self
628 .env
629 .get(ref_)
630 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
631 match rvalue {
632 Ok(val) => Ok(val.clone()),
633 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
634 }
635 }
636
637 fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
638 let ret = match value {
639 WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
640 WireValue::RefList(val) => {
642 let mut ret = Vec::with_capacity(val.len());
643 for v in val {
644 match self.ref_to_rvalue(&v) {
645 Ok(RValue::Tensor(t)) => ret.push(t),
646 Err(err) => {
647 return Err(err);
648 }
649 Ok(val) => {
650 return Err(CallFunctionError::UnsupportedArgType(
651 "wire_to_rvalue".into(),
652 format!("RefList([{:?}])", val),
653 ));
654 }
655 }
656 }
657 RValue::TensorList(ret)
658 }
659 WireValue::Int(val) => RValue::Int(val),
660 WireValue::IntList(val) => RValue::IntList(val),
661 WireValue::Double(val) => RValue::Double(val),
662 WireValue::Bool(val) => RValue::Bool(val),
663 WireValue::String(val) => RValue::String(val),
664 WireValue::Device(val) => RValue::Device(val),
665 WireValue::Layout(val) => RValue::Layout(val),
666 WireValue::ScalarType(val) => RValue::ScalarType(val),
667 WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
668 WireValue::PyObject(val) => RValue::PyObject(val),
669 WireValue::None(()) => RValue::None,
670 WireValue::IValue(val) => RValue::Opaque(val.into()),
671 };
672 Ok(ret)
673 }
674
675 async fn report_seq_error(
676 &mut self,
677 cx: &Context<'_, Self>,
678 seq: Seq,
679 error: CallFunctionError,
680 ) -> Result<Arc<SeqError>, anyhow::Error> {
681 match error {
682 CallFunctionError::DependentError(root) => Ok(root),
683 CallFunctionError::Error(e) => {
684 if self.active_recording.is_none() {
685 let worker_error = WorkerError {
686 backtrace: format!("{e}"),
687 worker_actor_id: cx.self_id().clone(),
688 };
689 tracing::info!("Propagating remote function error to client: {worker_error}");
690 self.controller_actor
691 .remote_function_failed(cx, seq, worker_error)
692 .await?
693 }
694 let err = Arc::new(SeqError { seq, error: e });
695 self.last_seq_error = Some(err.clone());
696 Ok(err)
697 }
698 }
699 }
700
701 async fn try_define<F>(
702 &mut self,
703 cx: &Context<'_, Self>,
704 seq: Seq,
705 result_refs: Vec<Option<Ref>>,
706 mutates: &Vec<Ref>,
707 f: F,
708 ) -> Result<()>
709 where
710 F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
711 {
712 let actual_results = f(self).await;
713 let op_results = actual_results.and_then(|actual_results| {
716 if result_refs.len() == actual_results.len() {
717 Ok(actual_results
718 .into_iter()
719 .zip(result_refs.iter())
720 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
721 .collect::<Vec<(Ref, RValue)>>())
722 } else {
723 Err(CallFunctionError::UnexpectedNumberOfReturns(
724 result_refs.len(),
725 actual_results.len(),
726 ))
727 }
728 });
729
730 match op_results {
733 Ok(op_results) => {
734 for (ref_, rvalue) in op_results.into_iter() {
735 let prev = self.env.insert(ref_, Ok(rvalue));
736 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
737 }
738 }
739 Err(err) => {
740 let err = self.report_seq_error(cx, seq, err).await?;
741 for ref_ in result_refs {
742 match ref_ {
743 Some(ref_) => {
744 let prev = self.env.insert(ref_, Err(err.clone()));
745 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
746 }
747 None => {}
748 }
749 }
750 for ref_ in mutates {
751 self.env.insert(*ref_, Err(err.clone()));
752 }
753 }
754 }
755 Ok(())
756 }
757
758 fn call_torch_op(
759 &self,
760 op: String,
761 overload: String,
762 args: Vec<WireValue>,
763 kwargs: HashMap<String, WireValue>,
764 ) -> Result<Vec<RValue>, CallFunctionError> {
765 let args = args
766 .into_iter()
767 .map(|arg| self.wire_to_rvalue(arg))
768 .collect::<Result<Vec<_>, _>>()?;
769 let kwargs = kwargs
770 .into_iter()
771 .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
772 .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
773
774 let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
775
776 Ok(if results.is_empty() {
780 vec![RValue::None]
781 } else {
782 results
783 })
784 }
785
786 fn call_python_fn<'py>(
787 &mut self,
788 py: Python<'py>,
789 cx: &Context<Self>,
790 function: Option<ResolvableFunction>,
791 args: Vec<WireValue>,
792 kwargs: HashMap<String, WireValue>,
793 mutates: &[Ref],
794 device_meshes: HashMap<Ref, DeviceMesh>,
795 remote_process_groups: HashMap<
796 Ref,
797 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
798 >,
799 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
800 let function = function
801 .map(|function| {
802 function.resolve(py).map_err(|e| {
803 CallFunctionError::InvalidRemoteFunction(format!(
804 "failed to resolve function {}: {}",
805 function,
806 SerializablePyErr::from(py, &e)
807 ))
808 })
809 })
810 .transpose()?;
811
812 let remote_process_groups = remote_process_groups
813 .into_iter()
814 .map(|(gref, (mesh, dims, comm))| {
815 let group = match self.remote_process_groups.entry(gref) {
816 Entry::Occupied(ent) => ent.get().clone_ref(py),
817 Entry::Vacant(ent) => {
818 torch_sys::backend::ensure_init_process_group(
821 py,
822 self.world_size,
823 self.rank,
824 )?;
825
826 let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
829 let group_size = ranks.len();
830 let backend = CommBackend::new(
831 comm,
832 Mailbox::new_detached(cx.self_id().clone()),
833 self.rank,
834 group_size,
835 self.world_size,
836 );
837 ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
838 .clone_ref(py)
839 }
840 };
841 PyResult::Ok((gref, group))
842 })
843 .collect::<Result<HashMap<_, _>, _>>()
844 .map_err(SerializablePyErr::from_fn(py))?;
845
846 let mut multiborrow = MultiBorrow::new();
850
851 let resolve = |val: WireValue| {
852 val.into_py_object()
853 .map_err(|e| {
854 CallFunctionError::UnsupportedArgType(
855 format!("{:?}", function),
856 format!("{:?}", e),
857 )
858 })?
859 .unpickle(py)
860 .map_err(SerializablePyErr::from_fn(py))?
861 .extract::<PyTree<PyObject>>()
862 .map_err(SerializablePyErr::from_fn(py))?
863 .try_into_map(|obj| {
864 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
865 if let Some(mesh) = device_meshes.get(&ref_) {
866 PyArg::DeviceMesh(mesh)
867 } else if let Some(pg) = remote_process_groups.get(&ref_) {
868 PyArg::PyObject(pg.clone_ref(py))
869 } else {
870 let rval = self.ref_to_rvalue(&ref_)?;
871 PyArg::RValue(rval)
872 }
873 } else {
874 PyArg::PyObject(obj)
875 })
876 })
877 };
878
879 let py_args: Vec<PyTree<PyArg>> = args
881 .into_iter()
882 .map(resolve)
883 .collect::<Result<_, CallFunctionError>>()?;
884 let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
885 .into_iter()
886 .map(|(k, object)| Ok((k, resolve(object)?)))
887 .collect::<Result<_, CallFunctionError>>()?;
888
889 py_args
891 .iter()
892 .chain(py_kwargs.values())
893 .flat_map(|o| o.iter())
894 .for_each(|arg| {
895 if let PyArg::RValue(rval) = arg {
896 multiborrow.add(rval, BorrowType::Shared);
897 }
898 });
899
900 let mutates: Vec<_> = mutates
902 .iter()
903 .map(|r| self.ref_to_rvalue(r))
904 .collect::<Result<_, CallFunctionError>>()?;
905 mutates
906 .iter()
907 .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
908
909 let _borrow = multiborrow.borrow()?;
911
912 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
915 let result: Bound<'_, PyAny> =
916 tracing::subscriber::with_default(scoped_subscriber, || {
917 let args = unsafe { py_args.try_to_object_unsafe(py) }
926 .map_err(SerializablePyErr::from_fn(py))?;
927 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
929 .map_err(SerializablePyErr::from_fn(py))?;
930
931 if let Some(function) = function {
932 function
933 .call(args, Some(kwargs))
934 .map_err(SerializablePyErr::from_fn(py))
935 } else {
936 Ok(args.get_item(0).unwrap())
937 }
938 })?;
939 Ok(result)
940 }
941
942 fn call_python_fn_pytree(
943 &mut self,
944 cx: &hyperactor::Context<Self>,
945 function: ResolvableFunction,
946 args: Vec<WireValue>,
947 kwargs: HashMap<String, WireValue>,
948 mutates: &[Ref],
949 device_meshes: HashMap<Ref, DeviceMesh>,
950 remote_process_groups: HashMap<
951 Ref,
952 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
953 >,
954 ) -> Result<PyTree<RValue>, CallFunctionError> {
955 Python::with_gil(|py| {
956 let result = self.call_python_fn(
957 py,
958 cx,
959 Some(function),
960 args,
961 kwargs,
962 mutates,
963 device_meshes,
964 remote_process_groups,
965 )?;
966 Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
967 })
968 }
969 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
977 let rvalue = self
978 .env
979 .get(&ref_)
980 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
981
982 match rvalue {
983 Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
984 Err(_) => {
985 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
986 Ok(TensorCell::new(t))
987 }
988 }
989 }
990
991 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
992 self.active_recording
993 .as_mut()
994 .and_then(|state| match state {
995 RecordingState::Defining {
996 recording,
997 defined_borrows,
998 } => {
999 match self.recordings.get_mut(recording) {
1000 Some(recording) => Some((recording, defined_borrows)),
1001 None => panic!("recording not found: {:?}", recording),
1003 }
1004 }
1005 RecordingState::Running => None,
1006 })
1007 }
1008
1009 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
1010 for ref_ in refs {
1011 let rvalue_or_err = self
1012 .env
1013 .get(ref_)
1014 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1015 if let Err(err) = rvalue_or_err {
1016 return Ok(Some(err.clone()));
1017 }
1018 }
1019 Ok(None)
1020 }
1021 async fn send_value_python_message(
1022 &mut self,
1023 cx: &hyperactor::Context<'_, Self>,
1024 seq: Seq,
1025 worker_actor_id: ActorId,
1026 mutates: Vec<Ref>,
1027 function: Option<ResolvableFunction>,
1028 args: Vec<WireValue>,
1029 kwargs: HashMap<String, WireValue>,
1030 device_meshes: HashMap<Ref, DeviceMesh>,
1031 ) -> Result<()> {
1032 self.try_define(cx, seq, vec![], &vec![], async |self_| {
1033 let python_message =
1034 Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1035 let python_result = tokio::task::block_in_place(|| {
1036 self_.call_python_fn(
1037 py,
1038 cx,
1039 function,
1040 args,
1041 kwargs,
1042 &mutates,
1043 device_meshes,
1044 HashMap::new(),
1045 )
1046 })?;
1047 pickle_python_result(py, python_result, worker_actor_id)
1048 .map_err(CallFunctionError::Error)
1049 })?;
1050 let ser = Serialized::serialize(&python_message).unwrap();
1051 self_
1052 .controller_actor
1053 .fetch_result(cx, seq, Ok(ser))
1054 .await?;
1055 Ok(vec![])
1056 })
1057 .await
1058 }
1059 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1060 let rvalue = self
1061 .env
1062 .get(&src)
1063 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1064 self.env.insert(dest, rvalue.clone());
1065 Ok(())
1066 }
1067 async fn call_actor(
1068 &mut self,
1069 cx: &Context<'_, Self>,
1070 params: ActorCallParams,
1071 ) -> Result<PyObject, CallFunctionError> {
1072 let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1073 params
1074 .local_state
1075 .into_iter()
1076 .map(|elem| {
1077 unsafe {
1079 let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1080 Ok(x)
1081 }
1082 })
1083 .collect()
1084 });
1085
1086 let (send, recv) = cx.open_once_port();
1087 let state = LocalState {
1088 response_port: send,
1089 state: local_state?,
1090 };
1091 let x: u64 = params.seq.into();
1092 let message = LocalStateBrokerMessage::Set(x as usize, state);
1093
1094 let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1095 broker
1096 .send(message)
1097 .map_err(|e| CallFunctionError::Error(e.into()))?;
1098 let result = recv
1099 .recv()
1100 .await
1101 .map_err(|e| CallFunctionError::Error(e.into()))?;
1102
1103 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1104 }
1105}
1106
1107#[async_trait]
1108#[forward(StreamMessage)]
1109impl StreamMessageHandler for StreamActor {
1110 async fn call_function(
1111 &mut self,
1112 cx: &Context<Self>,
1113 params: CallFunctionParams,
1114 device_meshes: HashMap<Ref, DeviceMesh>,
1115 remote_process_groups: HashMap<
1116 Ref,
1117 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1118 >,
1119 ) -> Result<()> {
1120 if let Some((recording, _)) = self.get_defining_recording() {
1121 recording.messages.push(StreamMessage::CallFunction(
1122 params,
1123 device_meshes,
1124 remote_process_groups,
1125 ));
1126 return Ok(());
1127 }
1128
1129 params.function.panic_if_requested();
1130 self.try_define(
1131 cx,
1132 params.seq,
1133 params.results,
1134 ¶ms.mutates,
1135 async |self| {
1136 tokio::task::block_in_place(|| match params.function.as_torch_op() {
1137 Some((op, overload)) => {
1138 self.call_torch_op(op, overload, params.args, params.kwargs)
1139 }
1140 _ => self
1141 .call_python_fn_pytree(
1142 cx,
1143 params.function,
1144 params.args,
1145 params.kwargs,
1146 ¶ms.mutates,
1147 device_meshes,
1148 remote_process_groups,
1149 )
1150 .map(|results| results.into_leaves()),
1151 })
1152 },
1153 )
1154 .await?;
1155 Ok(())
1156 }
1157
1158 async fn borrow_create(
1159 &mut self,
1160 _cx: &Context<Self>,
1161 borrow: u64,
1162 tensor: Ref,
1163 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1164 ) -> Result<()> {
1165 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1166 recording.messages.push(StreamMessage::BorrowCreate {
1167 borrow,
1168 tensor,
1169 first_use_sender,
1170 });
1171 ensure!(
1172 defined_borrows.insert(borrow),
1173 "duplicate borrow create in recording"
1174 );
1175 return Ok(());
1176 }
1177
1178 let rvalue_result = self
1179 .env
1180 .get(&tensor)
1181 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1182
1183 let result = match rvalue_result {
1184 Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1185 Err(e) => Err(e.clone()),
1186 };
1187
1188 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1189 first_use_sender.send((event, result)).map_err(|err| {
1190 anyhow!(
1191 "failed sending first use event for borrow {:?}: {:?}",
1192 borrow,
1193 err
1194 )
1195 })
1196 }
1197
1198 async fn borrow_first_use(
1199 &mut self,
1200 _cx: &Context<Self>,
1201 borrow: u64,
1202 result: Ref,
1203 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1204 ) -> Result<()> {
1205 if let Some((recording, _)) = self.get_defining_recording() {
1206 recording.messages.push(StreamMessage::BorrowFirstUse {
1207 borrow,
1208 result,
1209 first_use_receiver: first_use_receiver.clone(),
1210 });
1211 return Ok(());
1212 }
1213
1214 let (first_use_event, cell) =
1215 first_use_receiver
1216 .lock()
1217 .await
1218 .recv()
1219 .await
1220 .map_err(|err| {
1221 anyhow!(
1222 "failed receiving first use event for borrow {:?}: {:?}",
1223 borrow,
1224 err
1225 )
1226 })?;
1227
1228 if let Some(stream) = self.cuda_stream() {
1229 stream.wait_event(
1230 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1231 );
1232 }
1233 match cell {
1234 Ok(cell) => {
1235 self.env.insert(result, Ok(cell.into()));
1236 }
1237 Err(err) => {
1238 self.env.insert(result, Err(err.clone()));
1239 }
1240 }
1241 Ok(())
1242 }
1243
1244 async fn borrow_last_use(
1245 &mut self,
1246 _cx: &Context<Self>,
1247 borrow: u64,
1248 result: Ref,
1249 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1250 ) -> Result<()> {
1251 if let Some((recording, _)) = self.get_defining_recording() {
1252 recording.messages.push(StreamMessage::BorrowLastUse {
1253 borrow,
1254 result,
1255 last_use_sender,
1256 });
1257 return Ok(());
1258 }
1259
1260 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1261 let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1262 "Invalid reference for borrow_last_use: {result:#?}"
1263 ))?;
1264 let tensor = match rvalue_or_err {
1265 Ok(RValue::Tensor(t)) => Ok(t),
1266 Err(e) => Err(e),
1267 _ => bail!("invalid rvalue type for borrow_last_use"),
1268 };
1269
1270 last_use_sender.send((event, tensor)).map_err(|err| {
1271 anyhow!(
1272 "failed sending last use event for borrow {:?}: {:?}",
1273 borrow,
1274 err
1275 )
1276 })
1277 }
1278
1279 async fn borrow_drop(
1280 &mut self,
1281 _cx: &Context<Self>,
1282 borrow: u64,
1283 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1284 ) -> Result<()> {
1285 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1286 recording.messages.push(StreamMessage::BorrowDrop {
1287 borrow,
1288 last_use_receiver: last_use_receiver.clone(),
1289 });
1290 ensure!(
1291 defined_borrows.remove(&borrow),
1292 "borrow drop for borrow not defined in recording"
1293 );
1294 return Ok(());
1295 }
1296
1297 let (last_use_event, _cell) =
1301 last_use_receiver.lock().await.recv().await.map_err(|err| {
1302 anyhow!(
1303 "failed receiving last use event for borrow {:?}: {:?}",
1304 borrow,
1305 err
1306 )
1307 })?;
1308
1309 if let Some(stream) = self.cuda_stream() {
1310 stream.wait_event(
1311 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1312 );
1313 }
1314 Ok(())
1316 }
1317
1318 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1319 if let Some((recording, _)) = self.get_defining_recording() {
1320 recording.messages.push(StreamMessage::DeleteRefs(refs));
1321 return Ok(());
1322 }
1323
1324 for ref_ in refs.iter() {
1325 self.env.remove(ref_);
1326 }
1327 Ok(())
1328 }
1329
1330 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1331 if self.get_defining_recording().is_some() {
1332 bail!("request_status not allowed in recording");
1333 }
1334
1335 Ok(())
1336 }
1337
1338 async fn init_comm(
1339 &mut self,
1340 _cx: &Context<Self>,
1341 comm: ActorHandle<NcclCommActor>,
1342 ) -> Result<()> {
1343 if self.get_defining_recording().is_some() {
1344 bail!("init_comm not allowed in recording");
1345 }
1346
1347 self.comm = Some(comm);
1348 Ok(())
1349 }
1350
1351 async fn reduce(
1352 &mut self,
1353 cx: &Context<Self>,
1354 comm: Arc<ActorHandle<NcclCommActor>>,
1355 dim_size: i64,
1356 result: Ref,
1357 local_tensor: Ref,
1358 factory: Factory,
1359 reduction: Reduction,
1360 scatter: bool,
1361 in_place: bool,
1362 out: Option<Ref>,
1363 ) -> Result<()> {
1364 if let Some((recording, _)) = self.get_defining_recording() {
1365 recording.messages.push(StreamMessage::Reduce {
1366 comm,
1367 dim_size,
1368 result,
1369 local_tensor,
1370 factory,
1371 reduction,
1372 scatter,
1373 in_place,
1374 out,
1375 });
1376 return Ok(());
1377 }
1378
1379 let stream = self
1380 .cuda_stream()
1381 .expect("reductions not yet supported for non-CUDA workers")
1382 .clone();
1383 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1384 let out_cell = out
1385 .map(|out| self.get_or_fake_on_err(out, &factory))
1386 .transpose()?;
1387 let output_cell = match reduction {
1388 Reduction::Stack => {
1389 if scatter {
1390 let output_cell = if in_place {
1391 input_cell.clone()
1392 } else {
1393 out_cell.unwrap_or({
1394 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1395 let cloned = deep_clone(&borrow);
1396 TensorCell::new(cloned)
1397 })
1398 };
1399 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1400 .await?;
1401 output_cell
1402 } else {
1403 ensure!(
1404 !in_place,
1405 "in-place, non-scatter not supported for stack reduce"
1406 );
1407
1408 let output_cell = out_cell.unwrap_or({
1409 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1411 let output =
1412 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1413 TensorCell::new(output)
1414 });
1415
1416 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1417 .await?;
1418 output_cell
1419 }
1420 }
1421 Reduction::ReduceOp(op) => {
1422 if scatter {
1423 ensure!(!in_place, "in-place, scatter not supported for reduce");
1424
1425 let output_cell = out_cell.unwrap_or({
1426 let output = factory_empty(
1427 &factory.size[1..],
1428 factory.dtype,
1429 factory.layout,
1430 factory.device,
1431 );
1432 TensorCell::new(output)
1433 });
1434 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1435 .await?;
1436 output_cell
1437 } else {
1438 let output_cell = if in_place {
1439 input_cell.clone()
1440 } else {
1441 out_cell.map_or(
1442 {
1443 let borrow =
1444 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1445 let cloned = deep_clone(&borrow);
1446 Ok(TensorCell::new(cloned))
1447 },
1448 |out_cell| -> Result<_, anyhow::Error> {
1449 let mut out_borrow =
1450 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1451 let in_borrow =
1452 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1453 out_borrow.copy_(&in_borrow);
1454 drop(out_borrow);
1455 Ok(out_cell)
1456 },
1457 )?
1458 };
1459
1460 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1461 output_cell
1462 }
1463 }
1464 };
1465
1466 self.env.insert(result, Ok(output_cell.into()));
1467 Ok(())
1468 }
1469
1470 async fn send_tensor(
1471 &mut self,
1472 cx: &Context<Self>,
1473 result: Ref,
1474 from_rank: Option<usize>,
1475 to_rank: Option<usize>,
1476 tensor: Ref,
1477 factory: Factory,
1478 comm: Arc<ActorHandle<NcclCommActor>>,
1479 ) -> Result<()> {
1480 if let Some((recording, _)) = self.get_defining_recording() {
1481 recording.messages.push(StreamMessage::SendTensor {
1482 result,
1483 from_rank,
1484 to_rank,
1485 tensor,
1486 factory,
1487 comm,
1488 });
1489 return Ok(());
1490 }
1491
1492 if to_rank.is_none() && from_rank.is_none() {
1493 bail!("tried to send tensor without a to/from rank");
1494 }
1495
1496 if from_rank == to_rank {
1498 let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1499 .env
1500 .get(&tensor)
1501 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1502 let output_cell = match input_cell {
1503 Ok(RValue::Tensor(input_cell)) => {
1504 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1509 let cloned = deep_clone(&borrow);
1510 Ok(RValue::Tensor(TensorCell::new(cloned)))
1511 }
1512 Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1513 Err(err) => Err(err.clone()),
1514 };
1515 self.env.insert(result, output_cell);
1516 return Ok(());
1517 }
1518
1519 let mut messages = Vec::new();
1520
1521 if let Some(to_rank) = to_rank {
1522 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1523 messages.push(CommMessage::Send(
1524 input_cell,
1525 to_rank.try_into().unwrap(),
1526 self.cuda_stream()
1527 .expect("tried to send_tensor on non-cuda stream")
1528 .clone(),
1529 cx.open_once_port().0,
1530 ));
1531 }
1532
1533 if let Some(from_rank) = from_rank {
1534 let output_cell = TensorCell::new(factory_empty(
1535 &factory.size,
1536 factory.dtype,
1537 factory.layout,
1538 factory.device,
1539 ));
1540 messages.push(CommMessage::Recv(
1541 output_cell.clone(),
1542 from_rank.try_into().unwrap(),
1543 self.cuda_stream()
1544 .expect("tried to send_tensor on non-cuda stream")
1545 .clone(),
1546 cx.open_once_port().0,
1547 ));
1548 self.env.insert(result, Ok(output_cell.into()));
1549 }
1550
1551 comm.group(
1552 cx,
1553 messages,
1554 self.cuda_stream()
1555 .expect("tried to send_tensor on non-cuda stream")
1556 .clone(),
1557 )
1558 .await?;
1559 Ok(())
1560 }
1561
1562 async fn send_value(
1563 &mut self,
1564 cx: &Context<Self>,
1565 seq: Seq,
1566 worker_actor_id: ActorId,
1567 mutates: Vec<Ref>,
1568 function: Option<ResolvableFunction>,
1569 args: Vec<WireValue>,
1570 kwargs: HashMap<String, WireValue>,
1571 device_meshes: HashMap<Ref, DeviceMesh>,
1572 pipe: Option<PortHandle<PipeMessage>>,
1573 ) -> Result<()> {
1574 if self.respond_with_python_message && pipe.is_none() {
1575 return self
1576 .send_value_python_message(
1577 cx,
1578 seq,
1579 worker_actor_id,
1580 mutates,
1581 function,
1582 args,
1583 kwargs,
1584 device_meshes,
1585 )
1586 .await;
1587 }
1588 let result = if let Some(function) = function {
1589 match function.as_torch_op() {
1591 Some((op, overload)) => {
1592 self.call_torch_op(op, overload, args, kwargs)
1593 .map(|rvalues| {
1594 if rvalues.len() == 1 {
1595 Ok(rvalues[0].clone().into())
1596 } else {
1597 Python::with_gil(|py| {
1599 Ok((|| {
1600 let py_rvalues = rvalues
1601 .into_iter()
1602 .map(|rvalue| unsafe {
1604 rvalue.try_to_object_unsafe(py)
1605 })
1606 .collect::<Result<Vec<_>, _>>()?;
1607 PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1608 })()
1609 .map_err(SerializablePyErr::from_fn(py))?)
1610 })
1611 }
1612 })?
1613 }
1614 _ => tokio::task::block_in_place(|| {
1617 self.call_python_fn_pytree(
1618 cx,
1619 function,
1620 args,
1621 kwargs,
1622 &mutates,
1623 device_meshes,
1624 HashMap::new(),
1625 )
1626 }),
1627 }
1628 } else {
1629 match (args.len(), kwargs.len()) {
1632 (1, 0) => Python::with_gil(|py| {
1633 let arg = args[0]
1634 .as_py_object()
1635 .ok_or_else(|| {
1636 CallFunctionError::UnsupportedArgType(
1637 "send_value".to_string(),
1638 "expected a PyObject as the first arg".to_string(),
1639 )
1640 })?
1641 .unpickle(py)
1642 .map_err(SerializablePyErr::from_fn(py))?;
1643 arg.extract::<PyTree<PyObject>>()
1644 .map_err(SerializablePyErr::from_fn(py))?
1645 .try_into_map(|obj| {
1646 let bound_obj = obj.bind(py);
1647 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1648 self.ref_to_rvalue(&ref_)
1649 } else {
1650 Ok(bound_obj
1651 .extract::<RValue>()
1652 .map_err(SerializablePyErr::from_fn(py))?)
1653 }
1654 })
1655 }),
1656 _ => Err(CallFunctionError::TooManyArgsForValue(
1657 format!("{:?}", args),
1658 format!("{:?}", kwargs),
1659 )),
1660 }
1661 };
1662
1663 let value = match result {
1664 Ok(rvalue) => {
1665 Ok(rvalue.into_map(|rval| match rval {
1670 RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1671 RValue::TensorList(tensors) => RValue::TensorList(
1672 tensors
1673 .into_iter()
1674 .map(|tensor| tensor.try_cpu().unwrap())
1675 .collect(),
1676 ),
1677 rval => rval,
1678 }))
1679 }
1680 Err(err) => {
1681 let err = self.report_seq_error(cx, seq, err).await?;
1682 for ref_ in mutates {
1683 self.env.insert(ref_, Err(err.clone()));
1684 }
1685 Err(WorkerError {
1686 backtrace: format!("{:?}", err),
1687 worker_actor_id,
1688 })
1689 }
1690 };
1691
1692 if let Some(pipe) = pipe {
1694 pipe.send(PipeMessage::SendValue(value))?;
1695 } else {
1696 let result = match value {
1697 Ok(value) => Ok(Serialized::serialize_anon(&value).map_err(anyhow::Error::from)?),
1698 Err(e) => Err(e),
1699 };
1700 self.controller_actor.fetch_result(cx, seq, result).await?;
1701 }
1702
1703 Ok(())
1704 }
1705
1706 async fn send_result_of_actor_call(
1707 &mut self,
1708 cx: &Context<Self>,
1709 worker_actor_id: ActorId,
1710 params: ActorCallParams,
1711 ) -> anyhow::Result<()> {
1712 let seq = params.seq;
1713 let mutates = params.mutates.clone();
1714 self.try_define(cx, seq, vec![], &mutates, async |self| {
1715 let value = self.call_actor(cx, params).await?;
1716 let result = Python::with_gil(|py| {
1717 pickle_python_result(py, value.into_bound(py), worker_actor_id)
1718 })?;
1719 let result = Serialized::serialize(&result).unwrap();
1720 self.controller_actor
1721 .fetch_result(cx, seq, Ok(result))
1722 .await?;
1723 Ok(vec![])
1724 })
1725 .await
1726 }
1727
1728 async fn call_actor_method(
1729 &mut self,
1730 cx: &Context<Self>,
1731 params: ActorMethodParams,
1732 ) -> anyhow::Result<()> {
1733 let seq = params.call.seq;
1734 let mutates = params.call.mutates.clone();
1735 self.try_define(cx, seq, params.results, &mutates, async |self| {
1736 let result = self.call_actor(cx, params.call).await?;
1737 let result = Python::with_gil(|py| {
1738 PyTree::<RValue>::extract_bound(&result.into_bound(py))
1739 .map_err(SerializablePyErr::from_fn(py))
1740 })?;
1741 Ok(result.into_leaves())
1742 })
1743 .await
1744 }
1745
1746 async fn set_value(
1747 &mut self,
1748 cx: &Context<Self>,
1749 seq: Seq,
1750 results: Vec<Option<Ref>>,
1751 pipe: PortHandle<PipeMessage>,
1752 ) -> Result<()> {
1753 if let Some((recording, _)) = self.get_defining_recording() {
1754 recording
1755 .messages
1756 .push(StreamMessage::SetValue { seq, results, pipe });
1757 return Ok(());
1758 }
1759
1760 self.try_define(cx, seq, results, &vec![], async |self| {
1761 let (tx, rx) = cx.open_once_port();
1762 pipe.send(PipeMessage::RecvValue(tx))
1763 .map_err(anyhow::Error::from)
1764 .map_err(CallFunctionError::from)?;
1765 let value = rx.recv().await.map_err(anyhow::Error::from)?;
1766 Ok(value.into_leaves())
1767 })
1768 .await
1769 }
1770
1771 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1772 if self.active_recording.is_some() {
1773 bail!("different recording already active");
1774 }
1775 match self.recordings.entry(recording) {
1776 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1777 Entry::Vacant(entry) => entry.insert(Recording::new()),
1778 };
1779 self.active_recording = Some(RecordingState::Defining {
1780 recording,
1781 defined_borrows: HashSet::new(),
1782 });
1783 Ok(())
1784 }
1785
1786 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1787 match self.active_recording {
1788 Some(RecordingState::Defining {
1789 recording: active_recording,
1790 ref defined_borrows,
1791 }) if active_recording == recording => {
1792 ensure!(
1793 defined_borrows.is_empty(),
1794 "all borrows created within recording must be dropped within recording"
1795 );
1796 self.active_recording = None;
1797 }
1798 _ => bail!("cannot finalize recording that isn't active"),
1799 }
1800 Ok(())
1801 }
1802
1803 async fn recording_formal(
1804 &mut self,
1805 _cx: &Context<Self>,
1806 result: Ref,
1807 argument_index: usize,
1808 ) -> Result<()> {
1809 match self.get_defining_recording() {
1810 Some((recording, _)) => {
1811 recording.messages.push(StreamMessage::RecordingFormal {
1812 result,
1813 argument_index,
1814 });
1815 }
1816 None => bail!("recording_formal called outside of recording"),
1817 };
1818 Ok(())
1819 }
1820
1821 async fn recording_result(
1822 &mut self,
1823 _cx: &Context<Self>,
1824 result: Ref,
1825 output_index: usize,
1826 ) -> Result<()> {
1827 match self.get_defining_recording() {
1828 Some((recording, _)) => {
1829 recording.messages.push(StreamMessage::RecordingResult {
1830 result,
1831 output_index,
1832 });
1833 }
1834 None => bail!("recording_result called outside of recording"),
1835 };
1836 Ok(())
1837 }
1838
1839 async fn call_recording(
1840 &mut self,
1841 cx: &Context<Self>,
1842 seq: Seq,
1843 recording: Ref,
1844 results: Vec<Ref>,
1845 actuals: Vec<Ref>,
1846 ) -> Result<()> {
1847 if self.active_recording.is_some() {
1848 bail!("cannot call recording while another recording is active");
1849 }
1850
1851 let messages = match self.recordings.get(&recording) {
1852 Some(recording) => recording
1853 .messages
1854 .iter()
1855 .map(|message| message.clone_for_recording())
1856 .collect::<Vec<_>>(),
1857 None => bail!("recording {:?} not found", recording),
1858 };
1859
1860 self.active_recording = Some(RecordingState::Running);
1861
1862 let mut error: Option<Arc<SeqError>> = None;
1867 let mut all_defined_refs = HashSet::new();
1870 let mut all_mutated_refs = HashSet::new();
1873 let mut formal_to_actual_refs = HashMap::new();
1879 self.last_seq_error = None;
1881 for message in messages.into_iter() {
1882 let defined_refs = message.get_defined_refs();
1883 all_defined_refs.extend(defined_refs.clone());
1884
1885 let mutated_refs_with_formals = message.get_mutated_refs();
1886 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1887 match formal_to_actual_refs.get(ref_) {
1888 Some(actual_ref) => Some(*actual_ref),
1889 None => {
1890 if all_defined_refs.contains(ref_) {
1891 None
1892 } else {
1893 Some(*ref_)
1894 }
1895 }
1896 }
1897 }));
1898
1899 match message {
1900 StreamMessage::RecordingFormal {
1901 result: formal_ref,
1902 argument_index,
1903 } => match actuals.get(argument_index) {
1904 None => bail!("recording_formal called with too few arguments"),
1905 Some(actual_ref) => {
1906 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1907 self.define_ref(formal_ref, *actual_ref)?;
1908 }
1909 },
1910 StreamMessage::RecordingResult {
1911 result: result_ref,
1912 output_index,
1913 } => match results.get(output_index) {
1914 None => bail!("recording_result called with too few results"),
1915 Some(actual_result_ref) => {
1916 self.define_ref(*actual_result_ref, result_ref)?;
1917 }
1918 },
1919 StreamMessage::DeleteRefs(ref refs) => {
1920 for ref_ in refs {
1921 all_defined_refs.remove(ref_);
1922 }
1923 StreamMessageHandler::handle(self, cx, message).await?;
1924 }
1925 StreamMessage::CallFunction { .. } if error.is_some() => {
1926 let error = error.clone().unwrap();
1932 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1933 self.env.insert(*ref_, Err(error.clone()));
1934 }
1935 }
1936 StreamMessage::BorrowLastUse { ref result, .. } => {
1937 all_defined_refs.remove(result);
1938 StreamMessageHandler::handle(self, cx, message).await?;
1939 }
1940 StreamMessage::Reduce {
1941 local_tensor,
1942 ref out,
1943 ..
1944 } => {
1945 if error.is_none() {
1949 let inputs_to_check = [Some(local_tensor), out.clone()]
1950 .iter()
1951 .filter_map(|r| *r)
1952 .collect::<Vec<_>>();
1953 error = self.get_first_error(inputs_to_check.as_slice())?;
1954 }
1955 StreamMessageHandler::handle(self, cx, message).await?;
1956 }
1957 StreamMessage::SendTensor {
1958 ref tensor,
1959 ref to_rank,
1960 ..
1961 } => {
1962 if to_rank.is_some() && error.is_none() {
1967 error = self.get_first_error(&[*tensor])?;
1968 }
1969 StreamMessageHandler::handle(self, cx, message).await?;
1970 }
1971 _ => {
1972 StreamMessageHandler::handle(self, cx, message).await?;
1973 }
1974 };
1975
1976 match (&error, self.last_seq_error.take()) {
1984 (None, Some(seq_err)) => {
1985 self.controller_actor
1987 .remote_function_failed(
1988 cx,
1989 seq,
1990 WorkerError {
1991 backtrace: format!("recording failed: {}", &seq_err),
1992 worker_actor_id: cx.self_id().clone(),
1993 },
1994 )
1995 .await?;
1996 error = Some(seq_err)
1997 }
1998 _ => {}
1999 }
2000 }
2005
2006 StreamMessageHandler::handle(
2010 self,
2011 cx,
2012 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
2013 )
2014 .await?;
2015
2016 if error.is_some() {
2019 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
2020 self.env.insert(*ref_, Err(error.clone().unwrap()));
2021 }
2022 }
2023
2024 self.active_recording = None;
2025 Ok(())
2026 }
2027
2028 async fn set_ref_unit_tests_only(
2029 &mut self,
2030 _cx: &Context<Self>,
2031 reference: Ref,
2032 value: WireValue,
2033 ) -> Result<()> {
2034 self.env
2035 .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
2036 Ok(())
2037 }
2038
2039 async fn set_tensor_ref_unit_tests_only(
2040 &mut self,
2041 _cx: &Context<Self>,
2042 reference: Ref,
2043 tensor_result: TensorCellResult,
2044 ) -> Result<()> {
2045 match tensor_result {
2046 Ok(tensor_cell) => {
2047 self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
2048 }
2049 Err(err) => {
2050 self.env.insert(reference, Err(err));
2051 }
2052 }
2053 Ok(())
2054 }
2055
2056 async fn get_ref_unit_tests_only(
2057 &mut self,
2058 _cx: &Context<Self>,
2059 reference: Ref,
2060 ) -> Result<Option<Result<WireValue, String>>> {
2061 fn rvalue_to_wire(
2063 value: Result<RValue, Arc<SeqError>>,
2064 ) -> Result<WireValue, Arc<SeqError>> {
2065 Ok(match value? {
2066 RValue::Int(val) => WireValue::Int(val),
2067 RValue::IntList(val) => WireValue::IntList(val),
2068 RValue::Double(val) => WireValue::Double(val),
2069 RValue::Bool(val) => WireValue::Bool(val),
2070 RValue::String(val) => WireValue::String(val),
2071 RValue::Layout(val) => WireValue::Layout(val),
2072 RValue::Device(val) => WireValue::Device(val),
2073 RValue::ScalarType(val) => WireValue::ScalarType(val),
2074 RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2075 RValue::None => WireValue::None(()),
2076 other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2077 })
2078 }
2079 Ok(self
2080 .env
2081 .get(&reference)
2082 .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2083 }
2084
2085 async fn get_tensor_ref_unit_tests_only(
2086 &mut self,
2087 _cx: &Context<Self>,
2088 reference: Ref,
2089 ) -> Result<Option<TensorCellResult>> {
2090 match self.env.get(&reference) {
2091 Some(Ok(rvalue)) => match rvalue {
2092 RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2093 other => bail!("expected tensor, got {:?}", other),
2094 },
2095 Some(Err(err)) => Ok(Some(Err(err.clone()))),
2096 None => Ok(None),
2097 }
2098 }
2099}
2100
2101#[cfg(test)]
2102mod tests {
2103 use hyperactor::actor::ActorStatus;
2104 use hyperactor::cap;
2105 use hyperactor::supervision::ActorSupervisionEvent;
2106 use monarch_messages::controller::ControllerMessage;
2107 use monarch_messages::worker::StreamCreationMode;
2108 use monarch_types::PickledPyObject;
2109 use pyo3::IntoPyObjectExt;
2110 use timed_test::async_timed_test;
2111 use torch_sys::factory_float_tensor;
2112 use torch_sys::testing::allclose;
2113 use torch_sys_cuda::nccl::UniqueId;
2114
2115 use super::*;
2116 use crate::comm::CommParams;
2117 use crate::test_util;
2118
2119 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2120 Arc::new(SeqError {
2121 seq: 0.into(),
2122 error: err,
2123 })
2124 }
2125
2126 struct TestSetup {
2127 proc: Proc,
2128 stream_actor: ActorHandle<StreamActor>,
2129 client: Mailbox,
2130 #[allow(dead_code)]
2133 supervision_rx: PortReceiver<ActorSupervisionEvent>,
2134 controller_rx: PortReceiver<ControllerMessage>,
2135 controller_actor: ActorRef<ControllerActor>,
2136 next_ref: Ref,
2137 }
2138
2139 impl TestSetup {
2140 async fn new() -> Result<Self> {
2141 Self::new_with_world_size(1).await
2142 }
2143
2144 async fn new_with_world_size(world_size: usize) -> Result<Self> {
2145 test_util::test_setup()?;
2146
2147 let proc = Proc::local();
2148 let (_, controller_actor, controller_rx) =
2149 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2150 let client = proc.attach("client")?;
2151 let (supervision_tx, supervision_rx) = client.open_port();
2152 proc.set_supervision_coordinator(supervision_tx)?;
2153 let stream_actor = proc
2154 .spawn::<StreamActor>(
2155 "stream",
2156 StreamParams {
2157 world_size,
2158 rank: 0,
2159 creation_mode: StreamCreationMode::UseDefaultStream,
2160 id: 0.into(),
2161 device: Some(CudaDevice::new(0.into())),
2162 controller_actor: controller_actor.clone(),
2163 respond_with_python_message: false,
2164 },
2165 )
2166 .await?;
2167
2168 Ok(Self {
2169 proc,
2170 stream_actor,
2171 client,
2172 supervision_rx,
2173 controller_rx,
2174 controller_actor,
2175 next_ref: 0.into(),
2176 })
2177 }
2178
2179 fn next_ref(&mut self) -> Ref {
2180 let ref_ = self.next_ref;
2181 self.next_ref = Ref {
2182 id: self.next_ref.id + 1,
2183 };
2184 ref_
2185 }
2186
2187 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2188 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2189 self.stream_actor
2190 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2191 .await
2192 }
2193
2194 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2195 let actual = self
2196 .stream_actor
2197 .get_tensor_ref_unit_tests_only(&self.client, reference)
2198 .await
2199 .unwrap()
2200 .unwrap()
2201 .unwrap();
2202 let x = allclose(
2203 &factory_float_tensor(data, "cpu".try_into().unwrap()),
2204 &actual.borrow(),
2205 )
2206 .unwrap();
2207 x
2208 }
2209
2210 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2211 let result_error = self
2212 .stream_actor
2213 .get_tensor_ref_unit_tests_only(&self.client, reference)
2214 .await
2215 .unwrap()
2216 .unwrap()
2217 .unwrap_err();
2218
2219 assert!(Arc::ptr_eq(&result_error, &error));
2220 }
2221 }
2222
2223 async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2224 loop {
2225 let status = proc
2226 .ledger_snapshot()
2227 .roots
2228 .get(actor_id)
2229 .unwrap()
2230 .status
2231 .clone();
2232 if let ActorStatus::Failed(msg) = status {
2233 assert!(msg.contains(&expected_msg));
2234 break;
2235 } else {
2236 tokio::task::yield_now().await;
2237 }
2238 }
2239 }
2240
2241 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2242 for ref_ in refs {
2243 assert!(
2244 test_setup
2245 .stream_actor
2246 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2247 .await
2248 .unwrap()
2249 .is_none()
2250 );
2251 }
2252 }
2253
2254 async fn fetch_result(
2255 caps: &impl cap::CanSend,
2256 stream_actor: ActorHandle<StreamActor>,
2257 seq: Seq,
2258 reference: Ref,
2259 ) {
2260 let ref_to_send = Python::with_gil(|py| {
2261 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2262 });
2263
2264 stream_actor
2265 .send_value(
2266 caps,
2267 seq,
2268 stream_actor.actor_id().clone(),
2269 Vec::new(),
2270 None,
2271 vec![WireValue::PyObject(ref_to_send)],
2272 HashMap::new(),
2273 HashMap::new(),
2274 None,
2275 )
2276 .await
2277 .unwrap()
2278 }
2279
2280 async fn check_fetch_result_error(
2281 caps: &impl cap::CanSend,
2282 stream_actor: ActorHandle<StreamActor>,
2283 seq: Seq,
2284 reference: Ref,
2285 controller_rx: &mut PortReceiver<ControllerMessage>,
2286 expected_backtrace: &str,
2287 ) {
2288 fetch_result(caps, stream_actor, seq, reference).await;
2289
2290 let controller_msg = controller_rx.recv().await.unwrap();
2291 match controller_msg {
2292 ControllerMessage::FetchResult {
2293 seq: actual_seq,
2294 value: Err(err),
2295 } => {
2296 assert_eq!(actual_seq, seq);
2297 assert!(
2298 err.backtrace.contains(expected_backtrace),
2299 "backtrace did not contain {:?}: {:?}",
2300 expected_backtrace,
2301 err.backtrace
2302 );
2303 }
2304 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2305 };
2306 }
2307
2308 async fn check_fetch_result_value(
2309 caps: &impl cap::CanSend,
2310 stream_actor: ActorHandle<StreamActor>,
2311 seq: Seq,
2312 reference: Ref,
2313 controller_rx: &mut PortReceiver<ControllerMessage>,
2314 ) {
2315 fetch_result(caps, stream_actor, seq, reference).await;
2316
2317 let controller_msg = controller_rx.recv().await.unwrap();
2318 match controller_msg {
2319 ControllerMessage::FetchResult {
2320 value: Ok(_),
2321 seq: actual_seq,
2322 } => assert_eq!(seq, actual_seq),
2323 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2324 };
2325 }
2326
2327 #[async_timed_test(timeout_secs = 60)]
2328 async fn test_define_recording_other_recording_active() -> Result<()> {
2329 let test_setup = TestSetup::new().await?;
2330 test_setup
2331 .stream_actor
2332 .define_recording(&test_setup.client, 0.into())
2333 .await?;
2334 test_setup
2335 .stream_actor
2336 .define_recording(&test_setup.client, 1.into())
2337 .await?;
2338 assert_actor_failed_with_msg(
2339 &test_setup.proc,
2340 test_setup.stream_actor.actor_id(),
2341 "different recording already active".into(),
2342 )
2343 .await;
2344 Ok(())
2345 }
2346
2347 #[async_timed_test(timeout_secs = 60)]
2348 async fn test_define_recording_already_defined() -> Result<()> {
2349 let test_setup = TestSetup::new().await?;
2350 test_setup
2351 .stream_actor
2352 .define_recording(&test_setup.client, 0.into())
2353 .await?;
2354 test_setup
2355 .stream_actor
2356 .finalize_recording(&test_setup.client, 0.into())
2357 .await?;
2358 test_setup
2359 .stream_actor
2360 .define_recording(&test_setup.client, 0.into())
2361 .await?;
2362 assert_actor_failed_with_msg(
2363 &test_setup.proc,
2364 test_setup.stream_actor.actor_id(),
2365 "already defined".into(),
2366 )
2367 .await;
2368 Ok(())
2369 }
2370
2371 #[async_timed_test(timeout_secs = 60)]
2372 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2373 let test_setup = TestSetup::new().await?;
2374 test_setup
2375 .stream_actor
2376 .define_recording(&test_setup.client, 0.into())
2377 .await?;
2378 test_setup
2379 .stream_actor
2380 .finalize_recording(&test_setup.client, 1.into())
2381 .await?;
2382 assert_actor_failed_with_msg(
2383 &test_setup.proc,
2384 test_setup.stream_actor.actor_id(),
2385 "cannot finalize recording that isn't active".into(),
2386 )
2387 .await;
2388 Ok(())
2389 }
2390
2391 #[async_timed_test(timeout_secs = 60)]
2392 async fn test_recording_formal_outside_recording() -> Result<()> {
2393 let test_setup = TestSetup::new().await?;
2394 test_setup
2395 .stream_actor
2396 .recording_formal(&test_setup.client, 0.into(), 0)
2397 .await?;
2398 assert_actor_failed_with_msg(
2399 &test_setup.proc,
2400 test_setup.stream_actor.actor_id(),
2401 "recording_formal called outside of recording".into(),
2402 )
2403 .await;
2404 Ok(())
2405 }
2406
2407 #[async_timed_test(timeout_secs = 60)]
2408 async fn test_recording_result_outside_recording() -> Result<()> {
2409 let test_setup = TestSetup::new().await?;
2410 test_setup
2411 .stream_actor
2412 .recording_result(&test_setup.client, 0.into(), 0)
2413 .await?;
2414 assert_actor_failed_with_msg(
2415 &test_setup.proc,
2416 test_setup.stream_actor.actor_id(),
2417 "recording_result called outside of recording".into(),
2418 )
2419 .await;
2420 Ok(())
2421 }
2422
2423 #[async_timed_test(timeout_secs = 60)]
2424 async fn test_call_recording_other_recording_active() -> Result<()> {
2425 let test_setup = TestSetup::new().await?;
2426 test_setup
2427 .stream_actor
2428 .define_recording(&test_setup.client, 0.into())
2429 .await?;
2430 test_setup
2431 .stream_actor
2432 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2433 .await?;
2434 assert_actor_failed_with_msg(
2435 &test_setup.proc,
2436 test_setup.stream_actor.actor_id(),
2437 "cannot call recording while another recording is active".into(),
2438 )
2439 .await;
2440 Ok(())
2441 }
2442
2443 #[async_timed_test(timeout_secs = 60)]
2444 async fn test_call_recording_not_found() -> Result<()> {
2445 let test_setup = TestSetup::new().await?;
2446 test_setup
2447 .stream_actor
2448 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2449 .await?;
2450 assert_actor_failed_with_msg(
2451 &test_setup.proc,
2452 test_setup.stream_actor.actor_id(),
2453 "not found".into(),
2454 )
2455 .await;
2456 Ok(())
2457 }
2458
2459 #[async_timed_test(timeout_secs = 60)]
2460 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2461 let test_setup = TestSetup::new().await?;
2462
2463 test_setup
2464 .stream_actor
2465 .define_recording(&test_setup.client, 0.into())
2466 .await?;
2467
2468 test_setup
2469 .stream_actor
2470 .recording_formal(&test_setup.client, 1.into(), 0)
2471 .await?;
2472
2473 test_setup
2474 .stream_actor
2475 .finalize_recording(&test_setup.client, 0.into())
2476 .await?;
2477
2478 test_setup
2479 .stream_actor
2480 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2481 .await?;
2482
2483 assert_actor_failed_with_msg(
2484 &test_setup.proc,
2485 test_setup.stream_actor.actor_id(),
2486 "recording_formal called with too few arguments".into(),
2487 )
2488 .await;
2489 Ok(())
2490 }
2491
2492 #[async_timed_test(timeout_secs = 60)]
2493 async fn test_recording_result_too_few_results() -> Result<()> {
2494 let test_setup = TestSetup::new().await?;
2495
2496 test_setup
2497 .stream_actor
2498 .define_recording(&test_setup.client, 0.into())
2499 .await?;
2500
2501 test_setup
2502 .stream_actor
2503 .recording_result(&test_setup.client, 1.into(), 0)
2504 .await?;
2505
2506 test_setup
2507 .stream_actor
2508 .finalize_recording(&test_setup.client, 0.into())
2509 .await?;
2510
2511 test_setup
2512 .stream_actor
2513 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2514 .await?;
2515
2516 assert_actor_failed_with_msg(
2517 &test_setup.proc,
2518 test_setup.stream_actor.actor_id(),
2519 "recording_result called with too few results".into(),
2520 )
2521 .await;
2522 Ok(())
2523 }
2524
2525 #[async_timed_test(timeout_secs = 60)]
2526 async fn test_basic_call_recording() -> Result<()> {
2527 let mut test_setup = TestSetup::new().await?;
2528
2529 test_setup
2533 .stream_actor
2534 .define_recording(&test_setup.client, 0.into())
2535 .await?;
2536
2537 let formal0_ref = 1.into();
2538 let formal0_index = 1;
2539 test_setup
2540 .stream_actor
2541 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2542 .await?;
2543
2544 let formal1_ref = 2.into();
2545 let formal1_index = 0;
2546 test_setup
2547 .stream_actor
2548 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2549 .await?;
2550
2551 let result0_ref = formal0_ref;
2552 let result0_index = 0;
2553 test_setup
2554 .stream_actor
2555 .recording_result(&test_setup.client, result0_ref, result0_index)
2556 .await?;
2557
2558 let result1_ref = formal1_ref;
2559 let result1_index = 1;
2560 test_setup
2561 .stream_actor
2562 .recording_result(&test_setup.client, result1_ref, result1_index)
2563 .await?;
2564
2565 test_setup
2566 .stream_actor
2567 .finalize_recording(&test_setup.client, 0.into())
2568 .await?;
2569
2570 let actual0_ref = 3.into();
2571 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2572
2573 let actual1_ref = 4.into();
2574 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2575
2576 let actual_result0_ref = 5.into();
2579 let actual_result1_ref = 6.into();
2580 test_setup
2581 .stream_actor
2582 .call_recording(
2583 &test_setup.client,
2584 0.into(),
2585 0.into(),
2586 vec![actual_result0_ref, actual_result1_ref],
2587 vec![actual0_ref, actual1_ref],
2588 )
2589 .await?;
2590
2591 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2593 assert!(
2594 test_setup
2595 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2596 .await
2597 );
2598
2599 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2602 Ok(())
2603 }
2604
2605 #[async_timed_test(timeout_secs = 60)]
2606 async fn test_request_status_in_recording() -> Result<()> {
2607 let test_setup = TestSetup::new().await?;
2608 test_setup
2609 .stream_actor
2610 .define_recording(&test_setup.client, 0.into())
2611 .await?;
2612 test_setup
2613 .stream_actor
2614 .request_status(&test_setup.client)
2615 .await
2616 .expect_err("request_status should have failed");
2617 assert_actor_failed_with_msg(
2618 &test_setup.proc,
2619 test_setup.stream_actor.actor_id(),
2620 "request_status not allowed in recording".into(),
2621 )
2622 .await;
2623 Ok(())
2624 }
2625
2626 #[async_timed_test(timeout_secs = 60)]
2627 async fn test_init_comm_in_recording() -> Result<()> {
2628 let test_setup = TestSetup::new().await?;
2629 test_setup
2630 .stream_actor
2631 .define_recording(&test_setup.client, 0.into())
2632 .await?;
2633
2634 let dummy_comm = test_setup
2635 .proc
2636 .spawn::<NcclCommActor>(
2637 "comm",
2638 CommParams::New {
2639 device: CudaDevice::new(0.into()),
2640 unique_id: UniqueId::new()?,
2641 world_size: 1,
2642 rank: 0,
2643 },
2644 )
2645 .await?;
2646
2647 test_setup
2648 .stream_actor
2649 .init_comm(&test_setup.client, dummy_comm)
2650 .await?;
2651 assert_actor_failed_with_msg(
2652 &test_setup.proc,
2653 test_setup.stream_actor.actor_id(),
2654 "init_comm not allowed in recording".into(),
2655 )
2656 .await;
2657 Ok(())
2658 }
2659
2660 #[async_timed_test(timeout_secs = 60)]
2661 async fn test_call_function_in_recording() -> Result<()> {
2662 let mut test_setup = TestSetup::new().await?;
2663
2664 test_setup
2671 .stream_actor
2672 .define_recording(&test_setup.client, 0.into())
2673 .await?;
2674
2675 let formal0_ref = test_setup.next_ref();
2676 let formal0_index = 0;
2677 test_setup
2678 .stream_actor
2679 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2680 .await?;
2681
2682 let formal1_ref = test_setup.next_ref();
2683 let formal1_index = 1;
2684 test_setup
2685 .stream_actor
2686 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2687 .await?;
2688
2689 let captured_ref = test_setup.next_ref();
2690 let result_captured_ref = test_setup.next_ref();
2691 let add_one_function =
2692 ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2693 let add_tensors_function =
2694 ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2695
2696 let add_result_ref_0 = test_setup.next_ref();
2697 test_setup
2698 .stream_actor
2699 .call_function(
2700 &test_setup.client,
2701 CallFunctionParams {
2702 seq: 100.into(),
2703 function: add_tensors_function.clone(),
2704 args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2705 kwargs: HashMap::new(),
2706 results: vec![Some(add_result_ref_0)],
2707 mutates: vec![],
2708 stream: 0.into(),
2709 remote_process_groups: Vec::new(),
2710 },
2711 HashMap::new(),
2712 HashMap::new(),
2713 )
2714 .await?;
2715
2716 test_setup
2717 .stream_actor
2718 .call_function(
2719 &test_setup.client,
2720 CallFunctionParams {
2721 seq: 101.into(),
2722 function: add_one_function,
2723 args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2724 kwargs: HashMap::new(),
2725 results: vec![Some(result_captured_ref)],
2726 mutates: vec![captured_ref],
2727 stream: 0.into(),
2728 remote_process_groups: Vec::new(),
2729 },
2730 HashMap::new(),
2731 HashMap::new(),
2732 )
2733 .await?;
2734
2735 let add_result_ref_1 = test_setup.next_ref();
2736 test_setup
2737 .stream_actor
2738 .call_function(
2739 &test_setup.client,
2740 CallFunctionParams {
2741 seq: 102.into(),
2742 function: add_tensors_function,
2743 args: vec![
2744 WireValue::Ref(add_result_ref_0),
2745 WireValue::Ref(captured_ref),
2746 ],
2747 kwargs: HashMap::new(),
2748 results: vec![Some(add_result_ref_1)],
2749 mutates: vec![],
2750 stream: 0.into(),
2751 remote_process_groups: Vec::new(),
2752 },
2753 HashMap::new(),
2754 HashMap::new(),
2755 )
2756 .await?;
2757
2758 test_setup
2759 .stream_actor
2760 .recording_result(&test_setup.client, add_result_ref_1, 0)
2761 .await?;
2762
2763 test_setup
2764 .stream_actor
2765 .delete_refs(
2766 &test_setup.client,
2767 vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2768 )
2769 .await?;
2770
2771 test_setup
2772 .stream_actor
2773 .finalize_recording(&test_setup.client, 0.into())
2774 .await?;
2775
2776 let actual0_ref = test_setup.next_ref();
2777 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2778
2779 let actual1_ref = test_setup.next_ref();
2780 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2781
2782 test_setup
2783 .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2784 .await?;
2785
2786 let actual_result_ref = test_setup.next_ref();
2787 test_setup
2788 .stream_actor
2789 .call_recording(
2790 &test_setup.client,
2791 0.into(),
2792 0.into(),
2793 vec![actual_result_ref],
2794 vec![actual0_ref, actual1_ref],
2795 )
2796 .await?;
2797
2798 assert!(
2799 test_setup
2800 .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2801 .await
2802 );
2803
2804 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2806
2807 let actual_result_ref = test_setup.next_ref();
2808 test_setup
2809 .stream_actor
2810 .call_recording(
2811 &test_setup.client,
2812 1.into(),
2813 0.into(),
2814 vec![actual_result_ref],
2815 vec![actual0_ref, actual1_ref],
2816 )
2817 .await?;
2818
2819 for ref_ in [actual0_ref, actual1_ref] {
2821 let _ = test_setup
2822 .stream_actor
2823 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2824 .await?
2825 .unwrap()
2826 .unwrap();
2827 }
2828
2829 for ref_ in [captured_ref, actual_result_ref] {
2830 let result_error = test_setup
2831 .stream_actor
2832 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2833 .await?
2834 .unwrap()
2835 .unwrap_err();
2836 let error_str = result_error.to_string();
2838 assert!(
2839 error_str.contains("torch operator error"),
2840 "Error should contain 'torch operator failed': {}",
2841 error_str
2842 );
2843 }
2844
2845 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2846 match controller_msg {
2847 ControllerMessage::RemoteFunctionFailed { seq, error } => {
2848 assert_eq!(seq, 1.into());
2849 assert!(
2850 error.backtrace.contains("torch operator error"),
2851 "Unexpected WorkerError: {:?}",
2852 error
2853 );
2854 }
2855 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2856 };
2857
2858 test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2860
2861 let actual_result_ref = test_setup.next_ref();
2865 test_setup
2866 .stream_actor
2867 .call_recording(
2868 &test_setup.client,
2869 2.into(),
2870 0.into(),
2871 vec![actual_result_ref],
2872 vec![actual0_ref, actual1_ref],
2873 )
2874 .await?;
2875
2876 for ref_ in [actual0_ref, actual1_ref] {
2878 let _ = test_setup
2879 .stream_actor
2880 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2881 .await?
2882 .unwrap()
2883 .unwrap();
2884 }
2885
2886 for ref_ in [captured_ref, actual_result_ref] {
2887 let result_error = test_setup
2888 .stream_actor
2889 .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2890 .await?
2891 .unwrap()
2892 .unwrap_err();
2893 let error_str = result_error.to_string();
2895 assert!(
2896 error_str.contains("torch operator error"),
2897 "Error should contain input error: {}",
2898 error_str
2899 );
2900 }
2901
2902 check_fetch_result_error(
2906 &test_setup.client,
2907 test_setup.stream_actor.clone(),
2908 3.into(),
2909 captured_ref,
2910 &mut test_setup.controller_rx,
2911 "torch operator error",
2912 )
2913 .await;
2914
2915 Ok(())
2916 }
2917
2918 #[async_timed_test(timeout_secs = 60)]
2919 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2920 let mut test_setup = TestSetup::new().await?;
2921 test_setup
2922 .stream_actor
2923 .define_recording(&test_setup.client, 0.into())
2924 .await?;
2925
2926 let borrow_id = 1;
2927 let tensor_ref = test_setup.next_ref();
2928 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2929
2930 test_setup
2931 .stream_actor
2932 .borrow_create(
2933 &test_setup.client,
2934 borrow_id,
2935 tensor_ref,
2936 first_use_sender.clone(),
2937 )
2938 .await?;
2939
2940 test_setup
2941 .stream_actor
2942 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2943 .await?;
2944
2945 assert_actor_failed_with_msg(
2946 &test_setup.proc,
2947 test_setup.stream_actor.actor_id(),
2948 "duplicate borrow create in recording".into(),
2949 )
2950 .await;
2951
2952 Ok(())
2953 }
2954
2955 #[async_timed_test(timeout_secs = 60)]
2956 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2957 let test_setup = TestSetup::new().await?;
2958 test_setup
2959 .stream_actor
2960 .define_recording(&test_setup.client, 0.into())
2961 .await?;
2962
2963 let borrow_id = 1;
2964 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2965
2966 test_setup
2967 .stream_actor
2968 .borrow_drop(
2969 &test_setup.client,
2970 borrow_id,
2971 Arc::new(Mutex::new(last_use_receiver)),
2972 )
2973 .await?;
2974
2975 assert_actor_failed_with_msg(
2976 &test_setup.proc,
2977 test_setup.stream_actor.actor_id(),
2978 "borrow drop for borrow not defined in recording".into(),
2979 )
2980 .await;
2981
2982 Ok(())
2983 }
2984
2985 #[async_timed_test(timeout_secs = 60)]
2986 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2987 let mut test_setup = TestSetup::new().await?;
2988 test_setup
2989 .stream_actor
2990 .define_recording(&test_setup.client, 0.into())
2991 .await?;
2992
2993 let borrow_id = 1;
2994 let tensor_ref = test_setup.next_ref();
2995 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2996
2997 test_setup
2998 .stream_actor
2999 .borrow_create(
3000 &test_setup.client,
3001 borrow_id,
3002 tensor_ref,
3003 first_use_sender.clone(),
3004 )
3005 .await?;
3006
3007 test_setup
3009 .stream_actor
3010 .finalize_recording(&test_setup.client, 0.into())
3011 .await?;
3012
3013 assert_actor_failed_with_msg(
3014 &test_setup.proc,
3015 test_setup.stream_actor.actor_id(),
3016 "all borrows created within recording must be dropped within recording".into(),
3017 )
3018 .await;
3019
3020 Ok(())
3021 }
3022
3023 #[async_timed_test(timeout_secs = 60)]
3024 async fn test_borrow_in_recording() -> Result<()> {
3025 let mut test_setup = TestSetup::new().await?;
3026
3027 let borrower_stream = test_setup
3028 .proc
3029 .spawn::<StreamActor>(
3030 "stream1",
3031 StreamParams {
3032 world_size: 1,
3033 rank: 0,
3034 creation_mode: StreamCreationMode::CreateNewStream,
3035 id: 1.into(),
3036 device: Some(CudaDevice::new(0.into())),
3037 controller_actor: test_setup.controller_actor.clone(),
3038 respond_with_python_message: false,
3039 },
3040 )
3041 .await?;
3042
3043 let lender_stream = test_setup.stream_actor.clone();
3044
3045 let borrow_id = 1;
3046 let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
3047 let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
3048
3049 lender_stream
3051 .define_recording(&test_setup.client, 0.into())
3052 .await?;
3053
3054 let formal_ref = test_setup.next_ref();
3055 lender_stream
3056 .recording_formal(&test_setup.client, formal_ref, 0)
3057 .await?;
3058
3059 lender_stream
3060 .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3061 .await?;
3062
3063 lender_stream
3064 .borrow_drop(
3065 &test_setup.client,
3066 borrow_id,
3067 Arc::new(Mutex::new(last_use_receiver)),
3068 )
3069 .await?;
3070
3071 lender_stream
3072 .finalize_recording(&test_setup.client, 0.into())
3073 .await?;
3074
3075 let borrower_tensor_ref = test_setup.next_ref();
3076 let borrower_tensor = TensorCell::new(factory_float_tensor(
3077 &[1.0, 2.0, 3.0],
3078 "cuda".try_into().unwrap(),
3079 ));
3080
3081 borrower_stream
3082 .set_tensor_ref_unit_tests_only(
3083 &test_setup.client,
3084 borrower_tensor_ref,
3085 Ok(borrower_tensor.clone()),
3086 )
3087 .await?;
3088
3089 borrower_stream
3091 .define_recording(&test_setup.client, 0.into())
3092 .await?;
3093
3094 let borrowed_ref = test_setup.next_ref();
3095
3096 borrower_stream
3097 .borrow_first_use(
3098 &test_setup.client,
3099 borrow_id,
3100 borrowed_ref,
3101 Arc::new(Mutex::new(first_use_receiver)),
3102 )
3103 .await?;
3104
3105 let result_ref = test_setup.next_ref();
3106 borrower_stream
3107 .call_function(
3108 &test_setup.client,
3109 CallFunctionParams {
3110 seq: 100.into(),
3111 function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3112 args: vec![
3113 WireValue::Ref(borrowed_ref),
3114 WireValue::Ref(borrower_tensor_ref),
3115 ],
3116 kwargs: HashMap::new(),
3117 results: vec![Some(result_ref)],
3118 mutates: vec![],
3119 stream: 1.into(),
3120 remote_process_groups: Vec::new(),
3121 },
3122 HashMap::new(),
3123 HashMap::new(),
3124 )
3125 .await?;
3126
3127 borrower_stream
3128 .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3129 .await?;
3130
3131 borrower_stream
3132 .recording_result(&test_setup.client, result_ref, 0)
3133 .await?;
3134
3135 borrower_stream
3136 .finalize_recording(&test_setup.client, 0.into())
3137 .await?;
3138
3139 let input_tensor_ref = test_setup.next_ref();
3141 test_setup
3142 .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3143 .await?;
3144
3145 let result_tensor_ref = test_setup.next_ref();
3146
3147 let lender_future = lender_stream.call_recording(
3148 &test_setup.client,
3149 0.into(),
3150 0.into(),
3151 vec![],
3152 vec![input_tensor_ref],
3153 );
3154
3155 let borrower_future = borrower_stream.call_recording(
3156 &test_setup.client,
3157 0.into(),
3158 0.into(),
3159 vec![result_tensor_ref],
3160 vec![],
3161 );
3162
3163 tokio::try_join!(lender_future, borrower_future)?;
3164
3165 let result_tensor = borrower_stream
3166 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3167 .await?
3168 .unwrap()
3169 .unwrap();
3170
3171 let expected_tensor = TensorCell::new(factory_float_tensor(
3172 &[5.0, 7.0, 9.0],
3173 "cpu".try_into().unwrap(),
3174 ));
3175 assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3176
3177 let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3179 &[1.0, 2.0],
3180 "cuda".try_into().unwrap(),
3181 ));
3182 borrower_stream
3183 .set_tensor_ref_unit_tests_only(
3184 &test_setup.client,
3185 borrower_tensor_ref,
3186 Ok(invalid_borrower_tensor.clone()),
3187 )
3188 .await?;
3189
3190 let lender_future = lender_stream.call_recording(
3192 &test_setup.client,
3193 1.into(),
3194 0.into(),
3195 vec![],
3196 vec![input_tensor_ref],
3197 );
3198
3199 let borrower_future = borrower_stream.call_recording(
3200 &test_setup.client,
3201 1.into(),
3202 0.into(),
3203 vec![result_tensor_ref],
3204 vec![],
3205 );
3206
3207 tokio::try_join!(lender_future, borrower_future)?;
3208
3209 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3211 match controller_msg {
3212 ControllerMessage::RemoteFunctionFailed { seq, error } => {
3213 assert_eq!(seq, 1.into());
3214 assert!(
3215 error.backtrace.contains("recording failed"),
3216 "Unexpected WorkerError: {:?}",
3217 error
3218 );
3219 assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3220 }
3221 _ => panic!("Unexpected controller message: {:?}", controller_msg),
3222 };
3223
3224 check_fetch_result_value(
3226 &test_setup.client,
3227 lender_stream.clone(),
3228 2.into(),
3229 input_tensor_ref,
3230 &mut test_setup.controller_rx,
3231 )
3232 .await;
3233
3234 let input_error = fake_seq_error(anyhow!("input error"));
3236 lender_stream
3237 .set_tensor_ref_unit_tests_only(
3238 &test_setup.client,
3239 input_tensor_ref,
3240 Err(input_error.clone()),
3241 )
3242 .await?;
3243
3244 let lender_future = lender_stream.call_recording(
3245 &test_setup.client,
3246 3.into(),
3247 0.into(),
3248 vec![],
3249 vec![input_tensor_ref],
3250 );
3251
3252 let borrower_future = borrower_stream.call_recording(
3253 &test_setup.client,
3254 3.into(),
3255 0.into(),
3256 vec![result_tensor_ref],
3257 vec![],
3258 );
3259
3260 tokio::try_join!(lender_future, borrower_future)?;
3261
3262 let result_error = borrower_stream
3264 .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3265 .await?
3266 .unwrap()
3267 .unwrap_err();
3268
3269 let error_str = result_error.to_string();
3271 assert!(
3272 error_str.contains("input error"),
3273 "Error should contain input error: {}",
3274 error_str
3275 );
3276
3277 let input_error_str = input_error.to_string();
3280 assert!(
3281 error_str.contains(&input_error_str),
3282 "Error should contain the original error: {}",
3283 error_str
3284 );
3285
3286 check_fetch_result_error(
3288 &test_setup.client,
3289 lender_stream,
3290 4.into(),
3291 input_tensor_ref,
3292 &mut test_setup.controller_rx,
3293 "input error",
3294 )
3295 .await;
3296
3297 check_fetch_result_error(
3299 &test_setup.client,
3300 borrower_stream,
3301 5.into(),
3302 result_tensor_ref,
3303 &mut test_setup.controller_rx,
3304 "input error",
3305 )
3306 .await;
3307
3308 Ok(())
3309 }
3310
3311 #[async_timed_test(timeout_secs = 60)]
3312 async fn test_reduce_in_recording() -> Result<()> {
3313 let mut test_setup = TestSetup::new().await?;
3314 let recording_ref = test_setup.next_ref();
3315
3316 let comm = Arc::new(
3317 test_setup
3318 .proc
3319 .spawn::<NcclCommActor>(
3320 "comm",
3321 CommParams::New {
3322 device: CudaDevice::new(0.into()),
3323 unique_id: UniqueId::new()?,
3324 world_size: 1,
3325 rank: 0,
3326 },
3327 )
3328 .await?,
3329 );
3330
3331 let factory = Factory {
3332 size: vec![3],
3333 dtype: torch_sys::ScalarType::Float,
3334 layout: torch_sys::Layout::Strided,
3335 device: "cuda".try_into().unwrap(),
3336 };
3337
3338 let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3339
3340 test_setup
3341 .stream_actor
3342 .define_recording(&test_setup.client, recording_ref)
3343 .await?;
3344
3345 let formal_tensor_ref_0 = test_setup.next_ref();
3346 let formal_tensor_ref_1 = test_setup.next_ref();
3347 let formal_tensor_ref_2 = test_setup.next_ref();
3348
3349 test_setup
3350 .stream_actor
3351 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3352 .await?;
3353 test_setup
3354 .stream_actor
3355 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3356 .await?;
3357 test_setup
3358 .stream_actor
3359 .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3360 .await?;
3361
3362 let intermediate_tensor_ref_0 = test_setup.next_ref();
3363
3364 test_setup
3366 .stream_actor
3367 .reduce(
3368 &test_setup.client,
3369 comm.clone(),
3370 1,
3371 intermediate_tensor_ref_0,
3372 formal_tensor_ref_0,
3373 factory.clone(),
3374 reduction.clone(),
3375 false,
3376 true,
3377 None,
3378 )
3379 .await?;
3380
3381 let intermediate_tensor_ref_1 = test_setup.next_ref();
3383 test_setup
3384 .stream_actor
3385 .reduce(
3386 &test_setup.client,
3387 comm.clone(),
3388 1,
3389 intermediate_tensor_ref_1,
3390 formal_tensor_ref_1,
3391 factory.clone(),
3392 reduction.clone(),
3393 false,
3394 false,
3395 None,
3396 )
3397 .await?;
3398
3399 let intermediate_tensor_ref_2 = test_setup.next_ref();
3400
3401 test_setup
3403 .stream_actor
3404 .reduce(
3405 &test_setup.client,
3406 comm.clone(),
3407 1,
3408 intermediate_tensor_ref_2,
3409 intermediate_tensor_ref_1,
3410 factory.clone(),
3411 reduction.clone(),
3412 false,
3413 false,
3414 Some(formal_tensor_ref_2),
3415 )
3416 .await?;
3417
3418 test_setup
3419 .stream_actor
3420 .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3421 .await?;
3422
3423 test_setup
3424 .stream_actor
3425 .finalize_recording(&test_setup.client, recording_ref)
3426 .await?;
3427
3428 let input_tensor_ref_0 = test_setup.next_ref();
3429 let input_tensor_ref_1 = test_setup.next_ref();
3430 let input_tensor_ref_2 = test_setup.next_ref();
3431
3432 test_setup
3433 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3434 .await?;
3435
3436 test_setup
3437 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3438 .await?;
3439
3440 test_setup
3441 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3442 .await?;
3443
3444 let output_ref = test_setup.next_ref();
3445
3446 test_setup
3447 .stream_actor
3448 .call_recording(
3449 &test_setup.client,
3450 0.into(),
3451 recording_ref,
3452 vec![output_ref],
3453 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3454 )
3455 .await?;
3456
3457 assert!(
3459 test_setup
3460 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3461 .await
3462 );
3463 for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3465 assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3466 }
3467
3468 let input_error = fake_seq_error(anyhow!("input error"));
3470 test_setup
3471 .stream_actor
3472 .set_tensor_ref_unit_tests_only(
3473 &test_setup.client,
3474 input_tensor_ref_0,
3475 Err(input_error.clone()),
3476 )
3477 .await?;
3478
3479 test_setup
3480 .stream_actor
3481 .call_recording(
3482 &test_setup.client,
3483 1.into(),
3484 recording_ref,
3485 vec![output_ref],
3486 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3487 )
3488 .await?;
3489
3490 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3492 test_setup
3493 .validate_dependent_error(ref_, input_error.clone())
3494 .await;
3495 }
3496
3497 assert!(
3499 test_setup
3500 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3501 .await
3502 );
3503
3504 check_fetch_result_value(
3506 &test_setup.client,
3507 test_setup.stream_actor.clone(),
3508 2.into(),
3509 input_tensor_ref_1,
3510 &mut test_setup.controller_rx,
3511 )
3512 .await;
3513
3514 test_setup
3516 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3517 .await?;
3518 test_setup
3519 .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3520 .await?;
3521
3522 test_setup
3524 .stream_actor
3525 .set_tensor_ref_unit_tests_only(
3526 &test_setup.client,
3527 input_tensor_ref_1,
3528 Err(input_error.clone()),
3529 )
3530 .await?;
3531
3532 test_setup
3533 .stream_actor
3534 .call_recording(
3535 &test_setup.client,
3536 3.into(),
3537 recording_ref,
3538 vec![output_ref],
3539 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3540 )
3541 .await?;
3542
3543 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3546 test_setup
3547 .validate_dependent_error(ref_, input_error.clone())
3548 .await;
3549 }
3550
3551 check_fetch_result_error(
3553 &test_setup.client,
3554 test_setup.stream_actor.clone(),
3555 4.into(),
3556 input_tensor_ref_1,
3557 &mut test_setup.controller_rx,
3558 "input error",
3559 )
3560 .await;
3561
3562 test_setup
3564 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3565 .await?;
3566 test_setup
3567 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3568 .await?;
3569
3570 test_setup
3572 .stream_actor
3573 .set_tensor_ref_unit_tests_only(
3574 &test_setup.client,
3575 input_tensor_ref_2,
3576 Err(input_error.clone()),
3577 )
3578 .await?;
3579
3580 test_setup
3581 .stream_actor
3582 .call_recording(
3583 &test_setup.client,
3584 5.into(),
3585 recording_ref,
3586 vec![output_ref],
3587 vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3588 )
3589 .await?;
3590
3591 assert!(
3593 test_setup
3594 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3595 .await
3596 );
3597
3598 for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3601 test_setup
3602 .validate_dependent_error(ref_, input_error.clone())
3603 .await;
3604 }
3605
3606 check_fetch_result_value(
3608 &test_setup.client,
3609 test_setup.stream_actor.clone(),
3610 6.into(),
3611 input_tensor_ref_1,
3612 &mut test_setup.controller_rx,
3613 )
3614 .await;
3615
3616 Ok(())
3617 }
3618
3619 #[async_timed_test(timeout_secs = 60)]
3620 async fn test_send_tensor_in_recording() -> Result<()> {
3621 let mut test_setup = TestSetup::new_with_world_size(2).await?;
3622 let recording_ref = test_setup.next_ref();
3623
3624 let unique_id = UniqueId::new()?;
3625 let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3626 "comm0",
3627 CommParams::New {
3628 device: CudaDevice::new(0.into()),
3629 unique_id: unique_id.clone(),
3630 world_size: 2,
3631 rank: 0,
3632 },
3633 );
3634 let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3635 "comm1",
3636 CommParams::New {
3637 device: CudaDevice::new(1.into()),
3638 unique_id,
3639 world_size: 2,
3640 rank: 1,
3641 },
3642 );
3643 let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3644 let comm0 = Arc::new(comm0);
3645 let comm1 = Arc::new(comm1);
3646
3647 let factory = Factory {
3648 size: vec![3],
3649 dtype: torch_sys::ScalarType::Float,
3650 layout: torch_sys::Layout::Strided,
3651 device: "cuda".try_into().unwrap(),
3652 };
3653
3654 let send_stream = test_setup.stream_actor.clone();
3655 let recv_stream = test_setup
3656 .proc
3657 .spawn::<StreamActor>(
3658 "recv_stream",
3659 StreamParams {
3660 world_size: 2,
3661 rank: 1,
3662 creation_mode: StreamCreationMode::CreateNewStream,
3663 id: 1.into(),
3664 device: Some(CudaDevice::new(1.into())),
3665 controller_actor: test_setup.controller_actor.clone(),
3666 respond_with_python_message: false,
3667 },
3668 )
3669 .await?;
3670
3671 send_stream
3672 .define_recording(&test_setup.client, recording_ref)
3673 .await?;
3674 recv_stream
3675 .define_recording(&test_setup.client, recording_ref)
3676 .await?;
3677
3678 let formal_tensor_ref_0 = test_setup.next_ref();
3679 let formal_tensor_ref_1 = test_setup.next_ref();
3680
3681 send_stream
3682 .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3683 .await?;
3684 send_stream
3685 .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3686 .await?;
3687
3688 let _ref = test_setup.next_ref();
3689 send_stream
3690 .send_tensor(
3691 &test_setup.client,
3692 _ref,
3693 None,
3694 Some(1),
3695 formal_tensor_ref_0,
3696 factory.clone(),
3697 comm0.clone(),
3698 )
3699 .await?;
3700
3701 let result_ref_0 = test_setup.next_ref();
3702 let _ref = test_setup.next_ref();
3703 recv_stream
3704 .send_tensor(
3705 &test_setup.client,
3706 result_ref_0,
3707 Some(0),
3708 None,
3709 _ref,
3710 factory.clone(),
3711 comm1,
3712 )
3713 .await?;
3714
3715 let result_ref_1 = test_setup.next_ref();
3716 send_stream
3717 .send_tensor(
3718 &test_setup.client,
3719 result_ref_1,
3720 Some(0),
3721 Some(0),
3722 formal_tensor_ref_1,
3723 factory.clone(),
3724 comm0,
3725 )
3726 .await?;
3727
3728 send_stream
3729 .recording_result(&test_setup.client, result_ref_1, 0)
3730 .await?;
3731 recv_stream
3732 .recording_result(&test_setup.client, result_ref_0, 0)
3733 .await?;
3734
3735 send_stream
3736 .finalize_recording(&test_setup.client, recording_ref)
3737 .await?;
3738 recv_stream
3739 .finalize_recording(&test_setup.client, recording_ref)
3740 .await?;
3741
3742 let input_tensor_ref_0 = test_setup.next_ref();
3743 let input_tensor_ref_1 = test_setup.next_ref();
3744 test_setup
3745 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3746 .await?;
3747 test_setup
3748 .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3749 .await?;
3750
3751 let actual_result_ref_0 = test_setup.next_ref();
3752 let actual_result_ref_1 = test_setup.next_ref();
3753 let send_fut = send_stream.call_recording(
3754 &test_setup.client,
3755 0.into(),
3756 recording_ref,
3757 vec![actual_result_ref_1],
3758 vec![input_tensor_ref_0, input_tensor_ref_1],
3759 );
3760 let recv_fut = recv_stream.call_recording(
3761 &test_setup.client,
3762 0.into(),
3763 recording_ref,
3764 vec![actual_result_ref_0],
3765 vec![],
3766 );
3767 tokio::try_join!(send_fut, recv_fut)?;
3768
3769 assert!(
3770 test_setup
3771 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3772 .await
3773 );
3774 assert!(
3775 test_setup
3776 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3777 .await
3778 );
3779 assert!(
3780 test_setup
3781 .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3782 .await
3783 );
3784
3785 let actual_result_0 = recv_stream
3786 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3787 .await
3788 .unwrap()
3789 .unwrap()
3790 .unwrap();
3791 assert!(allclose(
3792 &actual_result_0.borrow(),
3793 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3794 )?);
3795
3796 check_fetch_result_value(
3798 &test_setup.client,
3799 send_stream.clone(),
3800 1.into(),
3801 actual_result_ref_1,
3802 &mut test_setup.controller_rx,
3803 )
3804 .await;
3805 check_fetch_result_value(
3806 &test_setup.client,
3807 recv_stream.clone(),
3808 2.into(),
3809 actual_result_ref_0,
3810 &mut test_setup.controller_rx,
3811 )
3812 .await;
3813
3814 let input_error = fake_seq_error(anyhow!("input error"));
3815 send_stream
3816 .set_tensor_ref_unit_tests_only(
3817 &test_setup.client,
3818 input_tensor_ref_0,
3819 Err(input_error.clone()),
3820 )
3821 .await?;
3822
3823 let send_fut = send_stream.call_recording(
3824 &test_setup.client,
3825 3.into(),
3826 recording_ref,
3827 vec![actual_result_ref_1],
3828 vec![input_tensor_ref_0, input_tensor_ref_1],
3829 );
3830 let recv_fut = recv_stream.call_recording(
3831 &test_setup.client,
3832 3.into(),
3833 recording_ref,
3834 vec![actual_result_ref_0],
3835 vec![],
3836 );
3837 tokio::try_join!(send_fut, recv_fut)?;
3838
3839 let _ = recv_stream
3841 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3842 .await
3843 .unwrap()
3844 .unwrap()
3845 .unwrap();
3846
3847 test_setup
3848 .validate_dependent_error(actual_result_ref_1, input_error.clone())
3849 .await;
3850
3851 assert!(
3853 test_setup
3854 .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3855 .await
3856 );
3857
3858 check_fetch_result_error(
3860 &test_setup.client,
3861 send_stream.clone(),
3862 4.into(),
3863 actual_result_ref_1,
3864 &mut test_setup.controller_rx,
3865 "input error",
3866 )
3867 .await;
3868 check_fetch_result_value(
3869 &test_setup.client,
3870 recv_stream.clone(),
3871 5.into(),
3872 actual_result_ref_0,
3873 &mut test_setup.controller_rx,
3874 )
3875 .await;
3876
3877 test_setup
3878 .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3879 .await?;
3880 send_stream
3881 .set_tensor_ref_unit_tests_only(
3882 &test_setup.client,
3883 input_tensor_ref_1,
3884 Err(input_error.clone()),
3885 )
3886 .await?;
3887
3888 let send_fut = send_stream.call_recording(
3889 &test_setup.client,
3890 6.into(),
3891 recording_ref,
3892 vec![actual_result_ref_1],
3893 vec![input_tensor_ref_0, input_tensor_ref_1],
3894 );
3895 let recv_fut = recv_stream.call_recording(
3896 &test_setup.client,
3897 6.into(),
3898 recording_ref,
3899 vec![actual_result_ref_0],
3900 vec![],
3901 );
3902 tokio::try_join!(send_fut, recv_fut)?;
3903
3904 let actual_result_0 = recv_stream
3905 .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3906 .await
3907 .unwrap()
3908 .unwrap()
3909 .unwrap();
3910 assert!(allclose(
3911 &actual_result_0.borrow(),
3912 &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3913 )?);
3914
3915 assert!(
3916 test_setup
3917 .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3918 .await
3919 );
3920
3921 test_setup
3922 .validate_dependent_error(actual_result_ref_1, input_error)
3923 .await;
3924
3925 check_fetch_result_error(
3927 &test_setup.client,
3928 send_stream.clone(),
3929 7.into(),
3930 actual_result_ref_1,
3931 &mut test_setup.controller_rx,
3932 "input error",
3933 )
3934 .await;
3935 check_fetch_result_value(
3936 &test_setup.client,
3937 recv_stream.clone(),
3938 8.into(),
3939 actual_result_ref_0,
3940 &mut test_setup.controller_rx,
3941 )
3942 .await;
3943
3944 Ok(())
3945 }
3946
3947 #[async_timed_test(timeout_secs = 60)]
3948 async fn test_set_value_in_recording_valid_pipe() -> Result<()> {
3949 let mut test_setup = TestSetup::new().await?;
3950
3951 let (pipe_tx, mut pipe_rx) = test_setup.client.open_port();
3952
3953 let recording_ref = test_setup.next_ref();
3954 test_setup
3955 .stream_actor
3956 .define_recording(&test_setup.client, recording_ref)
3957 .await?;
3958
3959 let result_ref_0 = test_setup.next_ref();
3960
3961 test_setup
3962 .stream_actor
3963 .set_value(
3964 &test_setup.client,
3965 0.into(),
3966 vec![Some(result_ref_0)],
3967 pipe_tx,
3968 )
3969 .await?;
3970
3971 test_setup
3972 .stream_actor
3973 .recording_result(&test_setup.client, result_ref_0, 0)
3974 .await?;
3975
3976 test_setup
3977 .stream_actor
3978 .finalize_recording(&test_setup.client, recording_ref)
3979 .await?;
3980
3981 let real_result_ref = test_setup.next_ref();
3982 let recording_fut = test_setup.stream_actor.call_recording(
3983 &test_setup.client,
3984 0.into(),
3985 recording_ref,
3986 vec![real_result_ref],
3987 vec![],
3988 );
3989
3990 let pipe_fut = async {
3991 let msg = pipe_rx.recv().await.unwrap();
3992 match msg {
3993 PipeMessage::RecvValue(tx) => {
3994 tx.send(PyTree::from(RValue::Tensor(TensorCell::new(
3995 factory_float_tensor(&[1.0, 2.0, 3.0], "cuda".try_into().unwrap()),
3996 ))))
3997 .unwrap();
3998 }
3999 _ => panic!("Unexpected message"),
4000 }
4001 Ok(())
4002 };
4003
4004 tokio::try_join!(recording_fut, pipe_fut)?;
4005
4006 assert!(test_setup.allclose(real_result_ref, &[1.0, 2.0, 3.0]).await);
4007
4008 drop(pipe_rx);
4010
4011 let real_result_ref = test_setup.next_ref();
4012 test_setup
4013 .stream_actor
4014 .call_recording(
4015 &test_setup.client,
4016 1.into(),
4017 recording_ref,
4018 vec![real_result_ref],
4019 vec![],
4020 )
4021 .await?;
4022
4023 let real_result_err = test_setup
4024 .stream_actor
4025 .get_tensor_ref_unit_tests_only(&test_setup.client, real_result_ref)
4026 .await?
4027 .unwrap()
4028 .unwrap_err();
4029 let error_str = real_result_err.to_string();
4031 assert!(
4032 error_str.contains("send error"),
4033 "Error should contain 'send error': {}",
4034 error_str
4035 );
4036
4037 let controller_msg = test_setup.controller_rx.recv().await.unwrap();
4038 match controller_msg {
4039 ControllerMessage::RemoteFunctionFailed { seq, error } => {
4040 assert_eq!(seq, 1.into());
4041 assert!(
4042 error.backtrace.contains("send error"),
4043 "Unexpected WorkerError: {:?}",
4044 error
4045 );
4046 }
4047 _ => panic!("Unexpected controller message: {:?}", controller_msg),
4048 };
4049
4050 Ok(())
4051 }
4052}