1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Context as _;
19use anyhow::Result;
20use anyhow::anyhow;
21use anyhow::bail;
22use anyhow::ensure;
23use async_trait::async_trait;
24use hyperactor as reference;
25use hyperactor::Actor;
26use hyperactor::Context;
27use hyperactor::Endpoint as _;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::handle;
34use hyperactor::id::Label;
35use hyperactor::mailbox::OncePortHandle;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::proc::Proc;
38use monarch_hyperactor::actor::PythonMessage;
39use monarch_hyperactor::actor::PythonMessageKind;
40use monarch_hyperactor::local_state_broker::BrokerId;
41use monarch_hyperactor::local_state_broker::LocalState;
42use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
43use monarch_hyperactor::pickle::pickle;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::ArgsKwargs;
50use monarch_messages::worker::CallFunctionError;
51use monarch_messages::worker::CallFunctionParams;
52use monarch_messages::worker::SeqError;
53use monarch_messages::worker::StreamRef;
54use monarch_types::PyTree;
55use monarch_types::SerializablePyErr;
56use monarch_types::TryIntoPyObjectUnsafe;
57use pyo3::prelude::*;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys_cuda::cuda::Event;
62use torch_sys_cuda::cuda::Stream;
63use torch_sys2::CloneUnsafe;
64use torch_sys2::CudaDevice;
65use torch_sys2::TensorCell;
66use torch_sys2::deep_clone;
67use torch_sys2::factory_empty;
68use torch_sys2::factory_zeros;
69use tracing_subscriber::fmt::Subscriber;
70use typeuri::Named;
71
72use crate::ControllerActor;
73use crate::DeviceMesh;
74use crate::Factory;
75use crate::Reduction;
76use crate::Ref;
77use crate::ResolvableFunction;
78use crate::StreamCreationMode;
79use crate::WireValue;
80use crate::comm::CommMessage;
81use crate::comm::CommMessageClient;
82use crate::comm::NcclCommActor;
83
84pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
85
86thread_local! {
88 pub static CONTROLLER_ACTOR_REF: OnceCell<reference::ActorRef<ControllerActor>> = const { OnceCell::new() };
89 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
90 pub static ROOT_ACTOR_ID: OnceCell<reference::ActorId> = const { OnceCell::new() };
91}
92
93fn pickle_python_result(
94 py: Python<'_>,
95 result: Bound<'_, PyAny>,
96 worker_rank: usize,
97) -> Result<PythonMessage, anyhow::Error> {
98 let mut state = pickle(py, result.unbind(), false, false)
99 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
100 let inner = state
101 .take_inner()
102 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
103 Ok(PythonMessage::new_from_buf(
104 PythonMessageKind::Result {
105 rank: Some(worker_rank),
106 },
107 inner.take_buffer(),
108 ))
109}
110
111#[derive(Debug)]
112struct Recording {
113 messages: Vec<StreamMessage>,
114}
115
116impl Recording {
117 fn new() -> Self {
118 Self {
119 messages: Vec::new(),
120 }
121 }
122}
123
124#[derive(Debug, PartialEq)]
125enum RecordingState {
126 Defining {
127 recording: Ref,
128 defined_borrows: HashSet<u64>,
131 },
132 Running,
133}
134
135#[derive(Handler, HandleClient, Debug, Named)]
138pub enum StreamMessage {
139 CallFunction(
140 CallFunctionParams,
141 HashMap<Ref, DeviceMesh>,
142 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
143 ),
144
145 BorrowCreate {
146 borrow: u64,
148 tensor: Ref,
150 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
153 },
154
155 BorrowFirstUse {
156 borrow: u64,
158 result: Ref,
160 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
163 },
164
165 BorrowLastUse {
166 borrow: u64,
168 result: Ref,
170 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
172 },
173
174 BorrowDrop {
175 borrow: u64,
176 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
178 },
179
180 DeleteRefs(Vec<Ref>),
181
182 RequestStatus(#[reply] OncePortHandle<()>),
183
184 InitComm(ActorHandle<NcclCommActor>),
185
186 Reduce {
187 comm: Arc<ActorHandle<NcclCommActor>>,
188 dim_size: i64,
189 result: Ref,
190 local_tensor: Ref,
191 factory: Factory,
192 reduction: Reduction,
193 scatter: bool,
194 in_place: bool,
195 out: Option<Ref>,
196 },
197
198 SendTensor {
199 result: Ref,
200 from_rank: Option<usize>,
201 to_rank: Option<usize>,
202 tensor: Ref,
203 factory: Factory,
204 comm: Option<Arc<ActorHandle<NcclCommActor>>>,
205 },
206
207 SendValue {
208 seq: Seq,
209 worker_actor_id: reference::ActorAddr,
210 mutates: Vec<Ref>,
211 function: Option<ResolvableFunction>,
212 args_kwargs: ArgsKwargs,
213 device_meshes: HashMap<Ref, DeviceMesh>,
214 },
215
216 DefineRecording {
217 recording: Ref,
218 },
219
220 FinalizeRecording {
221 recording: Ref,
222 },
223
224 CallRecording {
225 seq: Seq,
226 recording: Ref,
227 results: Vec<Ref>,
228 actuals: Vec<Ref>,
229 },
230
231 RecordingFormal {
232 result: Ref,
233 argument_index: usize,
234 },
235
236 RecordingResult {
237 result: Ref,
238 output_index: usize,
239 },
240
241 SetRefUnitTestsOnly(Ref, WireValue),
242
243 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
244
245 GetRefUnitTestsOnly(
246 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
248 ),
249
250 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
251
252 SendResultOfActorCall(ActorCallParams),
253 CallActorMethod(ActorMethodParams),
254}
255
256impl StreamMessage {
257 fn clone_for_recording(&self) -> Self {
258 match self {
259 StreamMessage::RecordingFormal {
260 result,
261 argument_index,
262 } => StreamMessage::RecordingFormal {
263 result: *result,
264 argument_index: *argument_index,
265 },
266 StreamMessage::RecordingResult {
267 result,
268 output_index,
269 } => StreamMessage::RecordingResult {
270 result: *result,
271 output_index: *output_index,
272 },
273 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
274 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
275 StreamMessage::CallFunction(
276 params.clone(),
277 device_meshes.clone(),
278 remote_process_groups.clone(),
279 )
280 }
281 StreamMessage::BorrowCreate {
282 borrow,
283 tensor,
284 first_use_sender,
285 } => StreamMessage::BorrowCreate {
286 borrow: *borrow,
287 tensor: *tensor,
288 first_use_sender: first_use_sender.clone(),
289 },
290 StreamMessage::BorrowFirstUse {
291 borrow,
292 result,
293 first_use_receiver,
294 } => StreamMessage::BorrowFirstUse {
295 borrow: *borrow,
296 result: *result,
297 first_use_receiver: first_use_receiver.clone(),
298 },
299 StreamMessage::BorrowLastUse {
300 borrow,
301 result,
302 last_use_sender,
303 } => StreamMessage::BorrowLastUse {
304 borrow: *borrow,
305 result: *result,
306 last_use_sender: last_use_sender.clone(),
307 },
308 StreamMessage::BorrowDrop {
309 borrow,
310 last_use_receiver,
311 } => StreamMessage::BorrowDrop {
312 borrow: *borrow,
313 last_use_receiver: last_use_receiver.clone(),
314 },
315 StreamMessage::Reduce {
316 comm,
317 dim_size,
318 result,
319 local_tensor,
320 factory,
321 reduction,
322 scatter,
323 in_place,
324 out,
325 } => StreamMessage::Reduce {
326 comm: comm.clone(),
327 dim_size: *dim_size,
328 result: *result,
329 local_tensor: *local_tensor,
330 factory: factory.clone(),
331 reduction: reduction.clone(),
332 scatter: *scatter,
333 in_place: *in_place,
334 out: out.clone(),
335 },
336 StreamMessage::SendTensor {
337 result,
338 from_rank,
339 to_rank,
340 tensor,
341 factory,
342 comm,
343 } => StreamMessage::SendTensor {
344 result: *result,
345 from_rank: *from_rank,
346 to_rank: *to_rank,
347 tensor: *tensor,
348 factory: factory.clone(),
349 comm: comm.clone(),
350 },
351 other => panic!(
352 "StreamMessage variant not supported in recording: {:?}",
353 other
354 ),
355 }
356 }
357
358 fn get_defined_refs(&self) -> HashSet<Ref> {
360 match self {
361 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
362 StreamMessage::CallFunction(params, ..) => {
363 params.results.iter().filter_map(|&ref_| ref_).collect()
364 }
365 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
366 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
367 StreamMessage::SendTensor {
368 result, from_rank, ..
369 } => {
370 if from_rank.is_some() {
371 HashSet::from([*result])
372 } else {
373 HashSet::new()
374 }
375 }
376 _ => HashSet::new(),
378 }
379 }
380
381 fn get_mutated_refs(&self) -> HashSet<Ref> {
383 match self {
384 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
385 StreamMessage::Reduce {
386 out,
387 in_place,
388 local_tensor,
389 ..
390 } => {
391 if *in_place {
392 HashSet::from([*local_tensor])
393 } else if let Some(out) = out {
394 HashSet::from([*out])
395 } else {
396 HashSet::new()
397 }
398 }
399 _ => HashSet::new(),
401 }
402 }
403}
404
405#[derive(Debug)]
414pub struct StreamActor {
415 _world_size: usize,
416 rank: usize,
417 env: HashMap<Ref, Result<Py<PyAny>, Arc<SeqError>>>,
421 creation_mode: StreamCreationMode,
423 cuda_stream: OnceLock<Option<Stream>>,
429 device: Option<CudaDevice>,
431 comm: Option<ActorHandle<NcclCommActor>>,
433 controller_actor: reference::ActorRef<ControllerActor>,
435 remote_process_groups: HashMap<Ref, Py<PyAny>>,
436 recordings: HashMap<Ref, Recording>,
437 active_recording: Option<RecordingState>,
438 respond_with_python_message: bool,
439 last_seq_error: Option<Arc<SeqError>>,
440}
441
442#[derive(Debug, Clone)]
444pub struct StreamParams {
445 pub world_size: usize,
446 pub rank: usize,
447 pub creation_mode: StreamCreationMode,
449 pub id: StreamRef,
451 pub device: Option<CudaDevice>,
454 pub controller_actor: reference::ActorRef<ControllerActor>,
456 pub respond_with_python_message: bool,
457}
458
459impl StreamActor {
460 pub fn new(
461 StreamParams {
462 world_size,
463 rank,
464 id: _,
465 device,
466 controller_actor,
467 creation_mode,
468 respond_with_python_message,
469 }: StreamParams,
470 ) -> Self {
471 Self {
472 _world_size: world_size,
473 rank,
474 env: HashMap::new(),
475 creation_mode,
476 cuda_stream: OnceLock::new(),
477 device,
478 comm: None,
479 controller_actor,
480 remote_process_groups: HashMap::new(),
481 recordings: HashMap::new(),
482 active_recording: None,
483 respond_with_python_message,
484 last_seq_error: None,
485 }
486 }
487}
488
489#[async_trait]
490impl Actor for StreamActor {
491 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
492 CONTROLLER_ACTOR_REF.with(
496 |controller_actor_ref: &OnceCell<reference::ActorRef<ControllerActor>>| {
497 controller_actor_ref.set(self.controller_actor.clone()).ok()
498 },
499 );
500 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
501 ROOT_ACTOR_ID.with(|root_actor_id: &OnceCell<reference::ActorId>| {
502 let root_label = cx
503 .self_addr()
504 .label()
505 .cloned()
506 .unwrap_or_else(|| Label::new("stream").unwrap());
507 root_actor_id
508 .set(reference::ActorId::singleton(
509 root_label,
510 cx.self_addr().proc_addr().id().clone(),
511 ))
512 .ok()
513 });
514 if let Some(stream) = self.cuda_stream() {
516 Stream::set_current_stream(stream);
517 }
518 Ok(())
519 }
520
521 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
527 where
528 F: Future + Send + 'static,
529 F::Output: Send + 'static,
530 {
531 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
532 let builder = std::thread::Builder::new().name("worker-stream".to_string());
540 let _thread_handle = builder.spawn(move || {
541 let rt = tokio::runtime::Builder::new_multi_thread()
545 .worker_threads(1)
546 .enable_all()
547 .build()
548 .unwrap();
549 let result = rt.block_on(async {
550 tokio::task::block_in_place(|| {
551 Python::attach(|py| {
556 py.detach(|| {
557 let result = Handle::current().block_on(future);
558 if join_tx.send(result).is_err() {
559 panic!("could not send join result")
560 }
561 })
562 })
563 })
564 });
565 rt.shutdown_timeout(Duration::from_weeks(1));
566 result
567 });
568
569 tokio::spawn(async move { join_rx.await.unwrap() })
572 }
573}
574
575#[derive(Debug)]
577enum PyArg {
578 Object(Py<PyAny>),
579}
580
581impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
583 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
584 match self {
585 PyArg::Object(obj) => Ok(obj.clone_ref(py).into_bound(py)),
586 }
587 }
588}
589
590impl StreamActor {
591 fn tensor_to_pyobject(tensor_cell: TensorCell) -> Py<PyAny> {
592 Python::attach(|py| {
593 let tensor = unsafe {
598 tensor_cell.get_unchecked().clone_unsafe()
600 };
601 tensor.into_pyobject(py).unwrap().unbind()
602 })
603 }
604
605 fn pyobject_to_tensor(py: Python<'_>, pyobj: &Py<PyAny>) -> PyResult<TensorCell> {
609 use torch_sys2::Tensor;
610 let tensor = pyobj.bind(py).extract::<Tensor>()?;
611 Ok(TensorCell::new(tensor))
613 }
614
615 fn cuda_stream(&self) -> Option<&Stream> {
616 self.cuda_stream
617 .get_or_init(|| {
618 self.device.map(|device| match self.creation_mode {
619 StreamCreationMode::UseDefaultStream => {
620 Stream::get_current_stream_on_device(device)
621 }
622 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
623 })
624 })
625 .as_ref()
626 }
627
628 fn ref_to_pyobject(&self, ref_: &Ref) -> Result<Py<PyAny>, CallFunctionError> {
629 let pyobject = self
630 .env
631 .get(ref_)
632 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
633 match pyobject {
634 Ok(val) => Ok(val.clone()),
635 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
636 }
637 }
638
639 async fn report_seq_error(
640 &mut self,
641 cx: &Context<'_, Self>,
642 seq: Seq,
643 error: CallFunctionError,
644 ) -> Result<Arc<SeqError>, anyhow::Error> {
645 match error {
646 CallFunctionError::DependentError(root) => Ok(root),
647 CallFunctionError::Error(e) => {
648 if self.active_recording.is_none() {
649 let worker_error = WorkerError {
650 backtrace: format!("{e}"),
651 worker_actor_id: cx.self_addr().clone(),
652 };
653 tracing::info!("Propagating remote function error to client: {worker_error}");
654 self.controller_actor
655 .remote_function_failed(cx, seq, worker_error)
656 .await?
657 }
658 let err = Arc::new(SeqError { seq, error: e });
659 self.last_seq_error = Some(err.clone());
660 Ok(err)
661 }
662 }
663 }
664
665 async fn try_define<F>(
666 &mut self,
667 cx: &Context<'_, Self>,
668 seq: Seq,
669 result_refs: Vec<Option<Ref>>,
670 mutates: &Vec<Ref>,
671 f: F,
672 ) -> Result<()>
673 where
674 F: AsyncFnOnce(&mut Self) -> Result<Vec<Py<PyAny>>, CallFunctionError>,
675 {
676 let actual_results = f(self).await;
677 let op_results = actual_results.and_then(|actual_results| {
680 if result_refs.len() == actual_results.len() {
681 Ok(actual_results
682 .into_iter()
683 .zip(result_refs.iter())
684 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
685 .collect::<Vec<(Ref, Py<PyAny>)>>())
686 } else {
687 Err(CallFunctionError::UnexpectedNumberOfReturns(
688 result_refs.len(),
689 actual_results.len(),
690 ))
691 }
692 });
693
694 match op_results {
697 Ok(op_results) => {
698 for (ref_, pyobject) in op_results.into_iter() {
699 let prev = self.env.insert(ref_, Ok(pyobject));
700 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
701 }
702 }
703 Err(err) => {
704 let err = self.report_seq_error(cx, seq, err).await?;
705 for ref_ in result_refs {
706 match ref_ {
707 Some(ref_) => {
708 let prev = self.env.insert(ref_, Err(err.clone()));
709 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
710 }
711 None => {}
712 }
713 }
714 for ref_ in mutates {
715 self.env.insert(*ref_, Err(err.clone()));
716 }
717 }
718 }
719 Ok(())
720 }
721
722 fn call_python_fn<'py>(
723 &mut self,
724 py: Python<'py>,
725 _cx: &Context<Self>,
726 function: Option<ResolvableFunction>,
727 args_kwargs: ArgsKwargs,
728 _mutates: &[Ref],
729 device_meshes: HashMap<Ref, DeviceMesh>,
730 remote_process_groups: HashMap<
731 Ref,
732 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
733 >,
734 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
735 let (args_tuple, kwargs_dict) = args_kwargs
736 .to_python(py)
737 .map_err(|e| CallFunctionError::Error(e.into()))?;
738 let function = function
739 .map(|function| {
740 function.resolve(py).map_err(|e| {
741 CallFunctionError::InvalidRemoteFunction(format!(
742 "failed to resolve function {}: {}",
743 function,
744 SerializablePyErr::from(py, &e)
745 ))
746 })
747 })
748 .transpose()?;
749
750 let remote_process_groups = remote_process_groups
751 .into_iter()
752 .map(|(gref, (_mesh, _dims, _comm))| {
753 let group = match self.remote_process_groups.entry(gref) {
754 Entry::Occupied(ent) => ent.get().clone_ref(py),
755 Entry::Vacant(_ent) => {
756 panic!("no longer implemented");
757 }
758 };
759 PyResult::Ok((gref, group))
760 })
761 .collect::<Result<HashMap<_, _>, _>>()
762 .map_err(SerializablePyErr::from_fn(py))?;
763
764 let resolve = |val: Bound<'py, PyAny>| {
765 val.extract::<PyTree<Py<PyAny>>>()
766 .map_err(SerializablePyErr::from_fn(py))?
767 .try_into_map(|obj| {
768 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
769 if let Some(mesh) = device_meshes.get(&ref_) {
770 PyArg::Object(
771 Py::new(py, mesh.clone())
772 .map_err(SerializablePyErr::from_fn(py))?
773 .into(),
774 )
775 } else if let Some(pg) = remote_process_groups.get(&ref_) {
776 PyArg::Object(pg.clone_ref(py))
777 } else {
778 let pyobj = self.ref_to_pyobject(&ref_)?;
779 PyArg::Object(pyobj)
780 }
781 } else {
782 PyArg::Object(obj)
783 })
784 })
785 };
786
787 let py_args: Vec<PyTree<PyArg>> = args_tuple
789 .iter()
790 .map(&resolve)
791 .collect::<Result<_, CallFunctionError>>()?;
792
793 let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
794 .iter()
795 .map(|(k, v)| {
796 let key = k
797 .extract::<String>()
798 .map_err(SerializablePyErr::from_fn(py))?;
799 let value = resolve(v)?;
800 Ok((key, value))
801 })
802 .collect::<Result<_, CallFunctionError>>()?;
803
804 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
807 let result: Bound<'_, PyAny> =
808 tracing::subscriber::with_default(scoped_subscriber, || {
809 let args = unsafe { py_args.try_to_object_unsafe(py) }
818 .map_err(SerializablePyErr::from_fn(py))?;
819 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
821 .map_err(SerializablePyErr::from_fn(py))?;
822
823 if let Some(function) = function {
824 function
825 .call(args, Some(kwargs))
826 .map_err(SerializablePyErr::from_fn(py))
827 } else {
828 Ok(args.get_item(0).unwrap())
829 }
830 })?;
831 Ok(result)
832 }
833
834 fn call_python_fn_pytree(
835 &mut self,
836 cx: &hyperactor::Context<Self>,
837 function: ResolvableFunction,
838 args_kwargs: ArgsKwargs,
839 mutates: &[Ref],
840 device_meshes: HashMap<Ref, DeviceMesh>,
841 remote_process_groups: HashMap<
842 Ref,
843 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
844 >,
845 ) -> Result<PyTree<Py<PyAny>>, CallFunctionError> {
846 Python::attach(|py| {
847 let result = self.call_python_fn(
848 py,
849 cx,
850 Some(function),
851 args_kwargs,
852 mutates,
853 device_meshes,
854 remote_process_groups,
855 )?;
856 Ok(PyTree::<Py<PyAny>>::extract_bound(&result)
857 .map_err(SerializablePyErr::from_fn(py))?)
858 })
859 }
860 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
868 let pyobject = self
869 .env
870 .get(&ref_)
871 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
872
873 match pyobject {
874 Ok(val) => Python::attach(|py| {
875 Self::pyobject_to_tensor(py, val)
876 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
877 }),
878 Err(_) => {
879 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
880 Ok(TensorCell::new(t))
881 }
882 }
883 }
884
885 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
886 self.active_recording
887 .as_mut()
888 .and_then(|state| match state {
889 RecordingState::Defining {
890 recording,
891 defined_borrows,
892 } => {
893 match self.recordings.get_mut(recording) {
894 Some(recording) => Some((recording, defined_borrows)),
895 None => panic!("recording not found: {:?}", recording),
897 }
898 }
899 RecordingState::Running => None,
900 })
901 }
902
903 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
904 for ref_ in refs {
905 let rvalue_or_err = self
906 .env
907 .get(ref_)
908 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
909 if let Err(err) = rvalue_or_err {
910 return Ok(Some(err.clone()));
911 }
912 }
913 Ok(None)
914 }
915 async fn send_value_python_message(
916 &mut self,
917 cx: &hyperactor::Context<'_, Self>,
918 seq: Seq,
919 mutates: Vec<Ref>,
920 function: Option<ResolvableFunction>,
921 args_kwargs: ArgsKwargs,
922 device_meshes: HashMap<Ref, DeviceMesh>,
923 ) -> Result<()> {
924 let rank = self.rank;
925 self.try_define(cx, seq, vec![], &vec![], async |self_| {
926 let python_message =
927 Python::attach(|py| -> Result<PythonMessage, CallFunctionError> {
928 let python_result = tokio::task::block_in_place(|| {
929 self_.call_python_fn(
930 py,
931 cx,
932 function,
933 args_kwargs,
934 &mutates,
935 device_meshes,
936 HashMap::new(),
937 )
938 })?;
939 pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
940 })?;
941 let ser = wirevalue::Any::serialize(&python_message).unwrap();
942 self_
943 .controller_actor
944 .fetch_result(cx, seq, Ok(ser))
945 .await?;
946 Ok(vec![])
947 })
948 .await
949 }
950 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
951 let rvalue = self
952 .env
953 .get(&src)
954 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
955 self.env.insert(dest, Python::attach(|_py| rvalue.clone()));
956 Ok(())
957 }
958 async fn call_actor(
959 &mut self,
960 cx: &Context<'_, Self>,
961 params: ActorCallParams,
962 ) -> Result<Py<PyAny>, CallFunctionError> {
963 let local_state: Result<Vec<Py<PyAny>>> = Python::attach(|_py| {
964 params
965 .local_state
966 .into_iter()
967 .map(|elem| {
968 let pyobj = self.ref_to_pyobject(&elem)?;
969 Ok(pyobj.into_any())
970 })
971 .collect()
972 });
973
974 let (send, recv) = cx.open_once_port();
975 let state = LocalState {
976 response_port: send,
977 state: local_state?,
978 };
979 let x: u64 = params.seq.into();
980 let message = LocalStateBrokerMessage::Set(x as usize, state);
981
982 let broker = BrokerId::new(params.broker_id).resolve(cx).await;
983 broker.post(cx, message);
984 let result = recv
985 .recv()
986 .await
987 .map_err(|e| CallFunctionError::Error(e.into()))?;
988
989 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
990 }
991}
992
993#[async_trait]
994#[handle(StreamMessage)]
995impl StreamMessageHandler for StreamActor {
996 async fn call_function(
997 &mut self,
998 cx: &Context<Self>,
999 params: CallFunctionParams,
1000 device_meshes: HashMap<Ref, DeviceMesh>,
1001 remote_process_groups: HashMap<
1002 Ref,
1003 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1004 >,
1005 ) -> Result<()> {
1006 if let Some((recording, _)) = self.get_defining_recording() {
1007 recording.messages.push(StreamMessage::CallFunction(
1008 params,
1009 device_meshes,
1010 remote_process_groups,
1011 ));
1012 return Ok(());
1013 }
1014
1015 params.function.panic_if_requested();
1016 self.try_define(
1017 cx,
1018 params.seq,
1019 params.results,
1020 ¶ms.mutates,
1021 async |self| {
1022 tokio::task::block_in_place(|| {
1023 self.call_python_fn_pytree(
1024 cx,
1025 params.function,
1026 params.args_kwargs,
1027 ¶ms.mutates,
1028 device_meshes,
1029 remote_process_groups,
1030 )
1031 .map(|results| results.into_leaves())
1032 })
1033 },
1034 )
1035 .await?;
1036 Ok(())
1037 }
1038
1039 async fn borrow_create(
1040 &mut self,
1041 cx: &Context<Self>,
1042 borrow: u64,
1043 tensor: Ref,
1044 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1045 ) -> Result<()> {
1046 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1047 recording.messages.push(StreamMessage::BorrowCreate {
1048 borrow,
1049 tensor,
1050 first_use_sender,
1051 });
1052 ensure!(
1053 defined_borrows.insert(borrow),
1054 "duplicate borrow create in recording"
1055 );
1056 return Ok(());
1057 }
1058
1059 let pyobj_result = self
1060 .env
1061 .get(&tensor)
1062 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1063
1064 let result = match pyobj_result {
1065 Ok(pyobj) => Python::attach(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1066 Err(e) => Err(e.clone()),
1067 };
1068
1069 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1070 first_use_sender.post(cx, (event, result));
1071 Ok(())
1072 }
1073
1074 async fn borrow_first_use(
1075 &mut self,
1076 _cx: &Context<Self>,
1077 borrow: u64,
1078 result: Ref,
1079 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1080 ) -> Result<()> {
1081 if let Some((recording, _)) = self.get_defining_recording() {
1082 recording.messages.push(StreamMessage::BorrowFirstUse {
1083 borrow,
1084 result,
1085 first_use_receiver: first_use_receiver.clone(),
1086 });
1087 return Ok(());
1088 }
1089
1090 let (first_use_event, cell) =
1091 first_use_receiver
1092 .lock()
1093 .await
1094 .recv()
1095 .await
1096 .map_err(|err| {
1097 anyhow!(
1098 "failed receiving first use event for borrow {:?}: {:?}",
1099 borrow,
1100 err
1101 )
1102 })?;
1103
1104 if let Some(stream) = self.cuda_stream() {
1105 stream.wait_event(
1106 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1107 );
1108 }
1109 match cell {
1110 Ok(cell) => {
1111 let pyobj = Self::tensor_to_pyobject(cell);
1112 self.env.insert(result, Ok(pyobj));
1113 }
1114 Err(err) => {
1115 self.env.insert(result, Err(err.clone()));
1116 }
1117 }
1118 Ok(())
1119 }
1120
1121 async fn borrow_last_use(
1122 &mut self,
1123 cx: &Context<Self>,
1124 borrow: u64,
1125 result: Ref,
1126 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1127 ) -> Result<()> {
1128 if let Some((recording, _)) = self.get_defining_recording() {
1129 recording.messages.push(StreamMessage::BorrowLastUse {
1130 borrow,
1131 result,
1132 last_use_sender,
1133 });
1134 return Ok(());
1135 }
1136
1137 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1138 let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1139 "Invalid reference for borrow_last_use: {result:#?}"
1140 ))?;
1141 let tensor = match pyobj_or_err {
1142 Ok(pyobj) => Ok(Python::attach(|py| {
1143 Self::pyobject_to_tensor(py, &pyobj).unwrap()
1144 })),
1145 Err(e) => Err(e),
1146 };
1147
1148 last_use_sender.post(cx, (event, tensor));
1149 Ok(())
1150 }
1151
1152 async fn borrow_drop(
1153 &mut self,
1154 _cx: &Context<Self>,
1155 borrow: u64,
1156 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1157 ) -> Result<()> {
1158 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1159 recording.messages.push(StreamMessage::BorrowDrop {
1160 borrow,
1161 last_use_receiver: last_use_receiver.clone(),
1162 });
1163 ensure!(
1164 defined_borrows.remove(&borrow),
1165 "borrow drop for borrow not defined in recording"
1166 );
1167 return Ok(());
1168 }
1169
1170 let (last_use_event, _cell) =
1174 last_use_receiver.lock().await.recv().await.map_err(|err| {
1175 anyhow!(
1176 "failed receiving last use event for borrow {:?}: {:?}",
1177 borrow,
1178 err
1179 )
1180 })?;
1181
1182 if let Some(stream) = self.cuda_stream() {
1183 stream.wait_event(
1184 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1185 );
1186 }
1187 Ok(())
1189 }
1190
1191 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1192 if let Some((recording, _)) = self.get_defining_recording() {
1193 recording.messages.push(StreamMessage::DeleteRefs(refs));
1194 return Ok(());
1195 }
1196
1197 for ref_ in refs.iter() {
1198 self.env.remove(ref_);
1199 }
1200 Ok(())
1201 }
1202
1203 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1204 if self.get_defining_recording().is_some() {
1205 bail!("request_status not allowed in recording");
1206 }
1207
1208 Ok(())
1209 }
1210
1211 async fn init_comm(
1212 &mut self,
1213 _cx: &Context<Self>,
1214 comm: ActorHandle<NcclCommActor>,
1215 ) -> Result<()> {
1216 if self.get_defining_recording().is_some() {
1217 bail!("init_comm not allowed in recording");
1218 }
1219
1220 self.comm = Some(comm);
1221 Ok(())
1222 }
1223
1224 async fn reduce(
1225 &mut self,
1226 cx: &Context<Self>,
1227 comm: Arc<ActorHandle<NcclCommActor>>,
1228 dim_size: i64,
1229 result: Ref,
1230 local_tensor: Ref,
1231 factory: Factory,
1232 reduction: Reduction,
1233 scatter: bool,
1234 in_place: bool,
1235 out: Option<Ref>,
1236 ) -> Result<()> {
1237 if let Some((recording, _)) = self.get_defining_recording() {
1238 recording.messages.push(StreamMessage::Reduce {
1239 comm,
1240 dim_size,
1241 result,
1242 local_tensor,
1243 factory,
1244 reduction,
1245 scatter,
1246 in_place,
1247 out,
1248 });
1249 return Ok(());
1250 }
1251
1252 let stream = self
1253 .cuda_stream()
1254 .expect("reductions not yet supported for non-CUDA workers")
1255 .clone();
1256 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1257 let out_cell = out
1258 .map(|out| self.get_or_fake_on_err(out, &factory))
1259 .transpose()?;
1260 let output_cell = match reduction {
1261 Reduction::Stack => {
1262 if scatter {
1263 let output_cell = if in_place {
1264 input_cell.clone()
1265 } else {
1266 out_cell.unwrap_or({
1267 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1268 let cloned = deep_clone(&borrow);
1269 TensorCell::new(cloned)
1270 })
1271 };
1272 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1273 .await?;
1274 output_cell
1275 } else {
1276 ensure!(
1277 !in_place,
1278 "in-place, non-scatter not supported for stack reduce"
1279 );
1280
1281 let output_cell = out_cell.unwrap_or({
1282 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1284 let output =
1285 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1286 TensorCell::new(output)
1287 });
1288
1289 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1290 .await?;
1291 output_cell
1292 }
1293 }
1294 Reduction::ReduceOp(op) => {
1295 if scatter {
1296 ensure!(!in_place, "in-place, scatter not supported for reduce");
1297
1298 let output_cell = out_cell.unwrap_or({
1299 let output = factory_empty(
1300 &factory.size[1..],
1301 factory.dtype,
1302 factory.layout,
1303 factory.device,
1304 );
1305 TensorCell::new(output)
1306 });
1307 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1308 .await?;
1309 output_cell
1310 } else {
1311 let output_cell = if in_place {
1312 input_cell.clone()
1313 } else {
1314 out_cell.map_or(
1315 {
1316 let borrow =
1317 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1318 let cloned = deep_clone(&borrow);
1319 Ok(TensorCell::new(cloned))
1320 },
1321 |out_cell| -> Result<_, anyhow::Error> {
1322 let mut out_borrow =
1323 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1324 let in_borrow =
1325 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1326 out_borrow.copy_(&in_borrow);
1327 drop(out_borrow);
1328 Ok(out_cell)
1329 },
1330 )?
1331 };
1332
1333 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1334 output_cell
1335 }
1336 }
1337 };
1338
1339 let pyobj = Self::tensor_to_pyobject(output_cell);
1340 self.env.insert(result, Ok(pyobj));
1341 Ok(())
1342 }
1343
1344 async fn send_tensor(
1345 &mut self,
1346 cx: &Context<Self>,
1347 result: Ref,
1348 from_rank: Option<usize>,
1349 to_rank: Option<usize>,
1350 tensor: Ref,
1351 factory: Factory,
1352 comm: Option<Arc<ActorHandle<NcclCommActor>>>,
1353 ) -> Result<()> {
1354 if let Some((recording, _)) = self.get_defining_recording() {
1355 recording.messages.push(StreamMessage::SendTensor {
1356 result,
1357 from_rank,
1358 to_rank,
1359 tensor,
1360 factory,
1361 comm,
1362 });
1363 return Ok(());
1364 }
1365
1366 if to_rank.is_none() && from_rank.is_none() {
1367 bail!("tried to send tensor without a to/from rank");
1368 }
1369
1370 if from_rank == to_rank {
1372 let input_cell: &std::result::Result<Py<PyAny>, Arc<SeqError>> = self
1373 .env
1374 .get(&tensor)
1375 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1376 let output_cell: Result<Py<PyAny>, Arc<SeqError>> = match input_cell {
1377 Ok(pyobj) => {
1378 Python::attach(|py| -> Result<Py<PyAny>, Arc<SeqError>> {
1379 let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1380 let borrow = input_tensor.try_borrow().unwrap();
1385 let cloned = deep_clone(&borrow);
1386 let cloned_cell = TensorCell::new(cloned);
1387 Ok(Self::tensor_to_pyobject(cloned_cell))
1388 })
1389 }
1390 Err(err) => Err(err.clone()),
1391 };
1392 self.env.insert(result, output_cell);
1393 return Ok(());
1394 }
1395
1396 let comm = comm.context("send_tensor requires backend comm")?;
1397
1398 let mut messages = Vec::new();
1399
1400 if let Some(to_rank) = to_rank {
1401 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1402 messages.push(CommMessage::Send(
1403 input_cell,
1404 to_rank.try_into().unwrap(),
1405 self.cuda_stream()
1406 .expect("tried to send_tensor on non-cuda stream")
1407 .clone(),
1408 cx.open_once_port().0,
1409 ));
1410 }
1411
1412 if let Some(from_rank) = from_rank {
1413 let output_cell = TensorCell::new(factory_empty(
1414 &factory.size,
1415 factory.dtype,
1416 factory.layout,
1417 factory.device,
1418 ));
1419 messages.push(CommMessage::Recv(
1420 output_cell.clone(),
1421 from_rank.try_into().unwrap(),
1422 self.cuda_stream()
1423 .expect("tried to send_tensor on non-cuda stream")
1424 .clone(),
1425 cx.open_once_port().0,
1426 ));
1427 let pyobj = Self::tensor_to_pyobject(output_cell);
1428 self.env.insert(result, Ok(pyobj));
1429 }
1430
1431 comm.group(
1432 cx,
1433 messages,
1434 self.cuda_stream()
1435 .expect("tried to send_tensor on non-cuda stream")
1436 .clone(),
1437 )
1438 .await?;
1439 Ok(())
1440 }
1441
1442 async fn send_value(
1443 &mut self,
1444 cx: &Context<Self>,
1445 seq: Seq,
1446 worker_actor_id: reference::ActorAddr,
1447 mutates: Vec<Ref>,
1448 function: Option<ResolvableFunction>,
1449 args_kwargs: ArgsKwargs,
1450 device_meshes: HashMap<Ref, DeviceMesh>,
1451 ) -> Result<()> {
1452 if self.respond_with_python_message {
1453 return self
1454 .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1455 .await;
1456 }
1457
1458 let result = if let Some(function) = function {
1459 tokio::task::block_in_place(|| {
1461 self.call_python_fn_pytree(
1462 cx,
1463 function,
1464 args_kwargs,
1465 &mutates,
1466 device_meshes,
1467 HashMap::new(),
1468 )
1469 })
1470 } else {
1471 Python::attach(|py| {
1474 let (args, kwargs) = args_kwargs
1475 .to_python(py)
1476 .map_err(|e| CallFunctionError::Error(e.into()))?;
1477 match (args.len(), kwargs.len()) {
1478 (1, 0) => {
1479 let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1480 arg.extract::<PyTree<Py<PyAny>>>()
1481 .map_err(SerializablePyErr::from_fn(py))?
1482 .try_into_map(|obj| {
1483 let bound_obj = obj.bind(py);
1484 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1485 self.ref_to_pyobject(&ref_)
1486 } else {
1487 Ok(obj)
1488 }
1489 })
1490 }
1491 _ => Err(CallFunctionError::TooManyArgsForValue(
1492 format!("args with {} elements", args.len()),
1493 format!("kwargs with {} elements", kwargs.len()),
1494 )),
1495 }
1496 })
1497 };
1498
1499 let value = match result {
1500 Ok(pyobject) => Ok(pyobject),
1501 Err(err) => {
1502 let err = self.report_seq_error(cx, seq, err).await?;
1503 for ref_ in mutates {
1504 self.env.insert(ref_, Err(err.clone()));
1505 }
1506 Err(WorkerError {
1507 backtrace: format!("{:?}", err),
1508 worker_actor_id,
1509 })
1510 }
1511 };
1512
1513 let result = match value {
1517 Ok(_value) => {
1518 unreachable!(
1520 "send_value should return early when respond_with_python_message is true"
1521 )
1522 }
1523 Err(e) => Err(e),
1524 };
1525 self.controller_actor.fetch_result(cx, seq, result).await?;
1526
1527 Ok(())
1528 }
1529
1530 async fn send_result_of_actor_call(
1531 &mut self,
1532 cx: &Context<Self>,
1533 params: ActorCallParams,
1534 ) -> anyhow::Result<()> {
1535 let seq = params.seq;
1536 let mutates = params.mutates.clone();
1537 self.try_define(cx, seq, vec![], &mutates, async |self| {
1538 let value = self.call_actor(cx, params).await?;
1539 let result =
1540 Python::attach(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1541 let result = wirevalue::Any::serialize(&result).unwrap();
1542 self.controller_actor
1543 .fetch_result(cx, seq, Ok(result))
1544 .await?;
1545 Ok(vec![])
1546 })
1547 .await
1548 }
1549
1550 async fn call_actor_method(
1551 &mut self,
1552 cx: &Context<Self>,
1553 params: ActorMethodParams,
1554 ) -> anyhow::Result<()> {
1555 let seq = params.call.seq;
1556 let mutates = params.call.mutates.clone();
1557 self.try_define(cx, seq, params.results, &mutates, async |self| {
1558 let result = self.call_actor(cx, params.call).await?;
1559 let result = Python::attach(|py| {
1560 PyTree::<Py<PyAny>>::extract_bound(&result.into_bound(py))
1561 .map_err(SerializablePyErr::from_fn(py))
1562 })?;
1563 Ok(result.into_leaves())
1564 })
1565 .await
1566 }
1567
1568 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1569 if self.active_recording.is_some() {
1570 bail!("different recording already active");
1571 }
1572 match self.recordings.entry(recording) {
1573 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1574 Entry::Vacant(entry) => entry.insert(Recording::new()),
1575 };
1576 self.active_recording = Some(RecordingState::Defining {
1577 recording,
1578 defined_borrows: HashSet::new(),
1579 });
1580 Ok(())
1581 }
1582
1583 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1584 match self.active_recording {
1585 Some(RecordingState::Defining {
1586 recording: active_recording,
1587 ref defined_borrows,
1588 }) if active_recording == recording => {
1589 ensure!(
1590 defined_borrows.is_empty(),
1591 "all borrows created within recording must be dropped within recording"
1592 );
1593 self.active_recording = None;
1594 }
1595 _ => bail!("cannot finalize recording that isn't active"),
1596 }
1597 Ok(())
1598 }
1599
1600 async fn recording_formal(
1601 &mut self,
1602 _cx: &Context<Self>,
1603 result: Ref,
1604 argument_index: usize,
1605 ) -> Result<()> {
1606 match self.get_defining_recording() {
1607 Some((recording, _)) => {
1608 recording.messages.push(StreamMessage::RecordingFormal {
1609 result,
1610 argument_index,
1611 });
1612 }
1613 None => bail!("recording_formal called outside of recording"),
1614 };
1615 Ok(())
1616 }
1617
1618 async fn recording_result(
1619 &mut self,
1620 _cx: &Context<Self>,
1621 result: Ref,
1622 output_index: usize,
1623 ) -> Result<()> {
1624 match self.get_defining_recording() {
1625 Some((recording, _)) => {
1626 recording.messages.push(StreamMessage::RecordingResult {
1627 result,
1628 output_index,
1629 });
1630 }
1631 None => bail!("recording_result called outside of recording"),
1632 };
1633 Ok(())
1634 }
1635
1636 async fn call_recording(
1637 &mut self,
1638 cx: &Context<Self>,
1639 seq: Seq,
1640 recording: Ref,
1641 results: Vec<Ref>,
1642 actuals: Vec<Ref>,
1643 ) -> Result<()> {
1644 if self.active_recording.is_some() {
1645 bail!("cannot call recording while another recording is active");
1646 }
1647
1648 let messages = match self.recordings.get(&recording) {
1649 Some(recording) => recording
1650 .messages
1651 .iter()
1652 .map(|message| message.clone_for_recording())
1653 .collect::<Vec<_>>(),
1654 None => bail!("recording {:?} not found", recording),
1655 };
1656
1657 self.active_recording = Some(RecordingState::Running);
1658
1659 let mut error: Option<Arc<SeqError>> = None;
1664 let mut all_defined_refs = HashSet::new();
1667 let mut all_mutated_refs = HashSet::new();
1670 let mut formal_to_actual_refs = HashMap::new();
1676 self.last_seq_error = None;
1678 for message in messages.into_iter() {
1679 let defined_refs = message.get_defined_refs();
1680 all_defined_refs.extend(defined_refs.clone());
1681
1682 let mutated_refs_with_formals = message.get_mutated_refs();
1683 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1684 match formal_to_actual_refs.get(ref_) {
1685 Some(actual_ref) => Some(*actual_ref),
1686 None => {
1687 if all_defined_refs.contains(ref_) {
1688 None
1689 } else {
1690 Some(*ref_)
1691 }
1692 }
1693 }
1694 }));
1695
1696 match message {
1697 StreamMessage::RecordingFormal {
1698 result: formal_ref,
1699 argument_index,
1700 } => match actuals.get(argument_index) {
1701 None => bail!("recording_formal called with too few arguments"),
1702 Some(actual_ref) => {
1703 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1704 self.define_ref(formal_ref, *actual_ref)?;
1705 }
1706 },
1707 StreamMessage::RecordingResult {
1708 result: result_ref,
1709 output_index,
1710 } => match results.get(output_index) {
1711 None => bail!("recording_result called with too few results"),
1712 Some(actual_result_ref) => {
1713 self.define_ref(*actual_result_ref, result_ref)?;
1714 }
1715 },
1716 StreamMessage::DeleteRefs(ref refs) => {
1717 for ref_ in refs {
1718 all_defined_refs.remove(ref_);
1719 }
1720 StreamMessageHandler::handle(self, cx, message).await?;
1721 }
1722 StreamMessage::CallFunction { .. } if error.is_some() => {
1723 let error = error.clone().unwrap();
1729 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1730 self.env.insert(*ref_, Err(error.clone()));
1731 }
1732 }
1733 StreamMessage::BorrowLastUse { ref result, .. } => {
1734 all_defined_refs.remove(result);
1735 StreamMessageHandler::handle(self, cx, message).await?;
1736 }
1737 StreamMessage::Reduce {
1738 local_tensor,
1739 ref out,
1740 ..
1741 } => {
1742 if error.is_none() {
1746 let inputs_to_check = [Some(local_tensor), out.clone()]
1747 .iter()
1748 .filter_map(|r| *r)
1749 .collect::<Vec<_>>();
1750 error = self.get_first_error(inputs_to_check.as_slice())?;
1751 }
1752 StreamMessageHandler::handle(self, cx, message).await?;
1753 }
1754 StreamMessage::SendTensor {
1755 ref tensor,
1756 ref to_rank,
1757 ..
1758 } => {
1759 if to_rank.is_some() && error.is_none() {
1764 error = self.get_first_error(&[*tensor])?;
1765 }
1766 StreamMessageHandler::handle(self, cx, message).await?;
1767 }
1768 _ => {
1769 StreamMessageHandler::handle(self, cx, message).await?;
1770 }
1771 };
1772
1773 match (&error, self.last_seq_error.take()) {
1781 (None, Some(seq_err)) => {
1782 self.controller_actor
1784 .remote_function_failed(
1785 cx,
1786 seq,
1787 WorkerError {
1788 backtrace: format!("recording failed: {}", &seq_err),
1789 worker_actor_id: cx.self_addr().clone(),
1790 },
1791 )
1792 .await?;
1793 error = Some(seq_err)
1794 }
1795 _ => {}
1796 }
1797 }
1802
1803 StreamMessageHandler::handle(
1807 self,
1808 cx,
1809 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1810 )
1811 .await?;
1812
1813 if error.is_some() {
1816 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1817 self.env.insert(*ref_, Err(error.clone().unwrap()));
1818 }
1819 }
1820
1821 self.active_recording = None;
1822 Ok(())
1823 }
1824
1825 async fn set_ref_unit_tests_only(
1826 &mut self,
1827 _cx: &Context<Self>,
1828 reference: Ref,
1829 value: WireValue,
1830 ) -> Result<()> {
1831 let pyobj =
1832 Python::attach(|py| -> PyResult<Py<PyAny>> { Ok(value.into_pyobject(py)?.unbind()) })?;
1833 self.env.insert(reference, Ok(pyobj));
1834 Ok(())
1835 }
1836
1837 async fn set_tensor_ref_unit_tests_only(
1838 &mut self,
1839 _cx: &Context<Self>,
1840 reference: Ref,
1841 tensor_result: TensorCellResult,
1842 ) -> Result<()> {
1843 match tensor_result {
1844 Ok(tensor_cell) => {
1845 let pyobj = Self::tensor_to_pyobject(tensor_cell);
1846 self.env.insert(reference, Ok(pyobj));
1847 }
1848 Err(err) => {
1849 self.env.insert(reference, Err(err));
1850 }
1851 }
1852 Ok(())
1853 }
1854
1855 async fn get_ref_unit_tests_only(
1856 &mut self,
1857 _cx: &Context<Self>,
1858 reference: Ref,
1859 ) -> Result<Option<Result<WireValue, String>>> {
1860 use pyo3::types::PyBool;
1861 use pyo3::types::PyFloat;
1862 use pyo3::types::PyInt;
1863 use pyo3::types::PyList;
1864 use pyo3::types::PyNone;
1865 use pyo3::types::PyString;
1866 fn pyobject_to_wire(
1868 value: Result<Py<PyAny>, Arc<SeqError>>,
1869 ) -> Result<WireValue, Arc<SeqError>> {
1870 let pyobj = value?;
1871 Python::attach(|py| {
1872 let bound = pyobj.bind(py);
1873 if bound.is_instance_of::<PyBool>() {
1875 Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1876 } else if bound.is_instance_of::<PyInt>() {
1877 Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1878 } else if bound.is_instance_of::<PyList>() {
1879 if let Ok(val) = bound.extract::<Vec<i64>>() {
1880 Ok(WireValue::IntList(val))
1881 } else {
1882 Ok(WireValue::String(format!(
1883 "unsupported list type: {:?}",
1884 bound
1885 )))
1886 }
1887 } else if bound.is_instance_of::<PyFloat>() {
1888 Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1889 } else if bound.is_instance_of::<PyString>() {
1890 Ok(WireValue::String(bound.extract::<String>().unwrap()))
1891 } else if bound.is_instance_of::<PyNone>() {
1892 Ok(WireValue::None(()))
1893 } else {
1894 Ok(WireValue::String(format!(
1895 "unsupported pyobject type: {:?}",
1896 bound
1897 )))
1898 }
1899 })
1900 }
1901 Ok(self.env.get(&reference).map(|pyobj| {
1902 pyobject_to_wire(Python::attach(|_py| pyobj.clone())).map_err(|err| err.to_string())
1903 }))
1904 }
1905
1906 async fn get_tensor_ref_unit_tests_only(
1907 &mut self,
1908 _cx: &Context<Self>,
1909 reference: Ref,
1910 ) -> Result<Option<TensorCellResult>> {
1911 match self.env.get(&reference) {
1912 Some(Ok(pyobj)) => Python::attach(|py| match Self::pyobject_to_tensor(py, pyobj) {
1913 Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1914 Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1915 }),
1916 Some(Err(err)) => Ok(Some(Err(err.clone()))),
1917 None => Ok(None),
1918 }
1919 }
1920}
1921
1922#[cfg(all(test, fbcode_build))]
1923mod tests {
1924 use hyperactor::actor::ActorStatus;
1925 use hyperactor::context;
1926 use hyperactor::supervision::ActorSupervisionEvent;
1927 use monarch_messages::controller::ControllerMessage;
1928 use monarch_messages::worker::StreamCreationMode;
1929 use monarch_types::PickledPyObject;
1930 use monarch_types::UniqueId;
1931 use pyo3::IntoPyObjectExt;
1932 use timed_test::async_timed_test;
1933 use tokio::sync::watch;
1934 use torch_sys_cuda::nccl::UniqueIdExt;
1935 use torch_sys2::factory_float_tensor;
1936 use torch_sys2::testing::allclose;
1937
1938 use super::*;
1939 use crate::comm::CommParams;
1940 use crate::test_util;
1941
1942 #[allow(dead_code)]
1943 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1944 Arc::new(SeqError {
1945 seq: 0.into(),
1946 error: err,
1947 })
1948 }
1949
1950 struct TestSetup {
1951 proc: Proc,
1952 stream_actor: ActorHandle<StreamActor>,
1953 client: Instance<()>,
1954 #[allow(dead_code)]
1957 supervision_rx: PortReceiver<ActorSupervisionEvent>,
1958 #[allow(dead_code)]
1959 controller_rx: PortReceiver<ControllerMessage>,
1960 #[allow(dead_code)]
1961 controller_actor: reference::ActorRef<ControllerActor>,
1962 next_ref: Ref,
1963 }
1964
1965 impl TestSetup {
1966 async fn new() -> Result<Self> {
1967 Self::new_with_world_size(1).await
1968 }
1969
1970 async fn new_with_world_size(world_size: usize) -> Result<Self> {
1971 test_util::test_setup()?;
1972
1973 let proc = Proc::isolated();
1974 let (_, controller_actor, controller_rx) =
1975 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1976 let (client, _handle) = proc.client("client")?;
1977 let (supervision_tx, supervision_rx) = client.open_port();
1978 proc.set_supervision_coordinator(supervision_tx)?;
1979 let stream_actor = proc.spawn(
1980 "stream",
1981 StreamActor::new(StreamParams {
1982 world_size,
1983 rank: 0,
1984 creation_mode: StreamCreationMode::UseDefaultStream,
1985 id: 0.into(),
1986 device: Some(CudaDevice::new(0.into())),
1987 controller_actor: controller_actor.clone(),
1988 respond_with_python_message: false,
1989 }),
1990 )?;
1991
1992 Ok(Self {
1993 proc,
1994 stream_actor,
1995 client,
1996 supervision_rx,
1997 controller_rx,
1998 controller_actor,
1999 next_ref: 0.into(),
2000 })
2001 }
2002
2003 fn next_ref(&mut self) -> Ref {
2004 let ref_ = self.next_ref;
2005 self.next_ref = Ref {
2006 id: self.next_ref.id + 1,
2007 };
2008 ref_
2009 }
2010
2011 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2012 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2013 self.stream_actor
2014 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2015 .await
2016 }
2017
2018 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2019 let actual = self
2020 .stream_actor
2021 .get_tensor_ref_unit_tests_only(&self.client, reference)
2022 .await
2023 .unwrap()
2024 .unwrap()
2025 .unwrap();
2026
2027 allclose(
2029 &factory_float_tensor(data, "cpu".parse().unwrap()),
2030 &actual.borrow(),
2031 )
2032 .unwrap()
2033 }
2034
2035 #[allow(dead_code)]
2036 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2037 let result_error = self
2038 .stream_actor
2039 .get_tensor_ref_unit_tests_only(&self.client, reference)
2040 .await
2041 .unwrap()
2042 .unwrap()
2043 .unwrap_err();
2044
2045 assert!(Arc::ptr_eq(&result_error, &error));
2046 }
2047 }
2048
2049 async fn assert_actor_failed_with_msg(
2050 status_rx: &mut watch::Receiver<ActorStatus>,
2051 expected_msg: String,
2052 ) {
2053 status_rx
2054 .wait_for(|s| matches!(s, ActorStatus::Failed(_)))
2055 .await
2056 .unwrap();
2057 let status = status_rx.borrow().clone();
2058 if let ActorStatus::Failed(msg) = status {
2059 assert!(msg.to_string().contains(&expected_msg));
2060 } else {
2061 panic!("expected ActorStatus::Failed, got {:?}", status);
2062 }
2063 }
2064
2065 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2066 for ref_ in refs {
2067 assert!(
2068 test_setup
2069 .stream_actor
2070 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2071 .await
2072 .unwrap()
2073 .is_none()
2074 );
2075 }
2076 }
2077
2078 #[allow(dead_code)]
2079 async fn fetch_result(
2080 cx: &impl context::Actor,
2081 stream_actor: ActorHandle<StreamActor>,
2082 seq: Seq,
2083 reference: Ref,
2084 ) {
2085 let ref_to_send = Python::attach(|py| {
2086 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2087 });
2088
2089 stream_actor
2090 .send_value(
2091 cx,
2092 seq,
2093 stream_actor.actor_addr().clone(),
2094 Vec::new(),
2095 None,
2096 ArgsKwargs::from_wire_values(
2097 vec![WireValue::PyObject(ref_to_send)],
2098 HashMap::new(),
2099 )
2100 .unwrap(),
2101 HashMap::new(),
2102 )
2103 .await
2104 .unwrap()
2105 }
2106
2107 #[allow(dead_code)]
2108 async fn check_fetch_result_error(
2109 cx: &impl context::Actor,
2110 stream_actor: ActorHandle<StreamActor>,
2111 seq: Seq,
2112 reference: Ref,
2113 controller_rx: &mut PortReceiver<ControllerMessage>,
2114 expected_backtrace: &str,
2115 ) {
2116 fetch_result(cx, stream_actor, seq, reference).await;
2117
2118 let controller_msg = controller_rx.recv().await.unwrap();
2119 match controller_msg {
2120 ControllerMessage::FetchResult {
2121 seq: actual_seq,
2122 value: Err(err),
2123 } => {
2124 assert_eq!(actual_seq, seq);
2125 assert!(
2126 err.backtrace.contains(expected_backtrace),
2127 "backtrace did not contain {:?}: {:?}",
2128 expected_backtrace,
2129 err.backtrace
2130 );
2131 }
2132 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2133 };
2134 }
2135
2136 #[allow(dead_code)]
2137 async fn check_fetch_result_value(
2138 cx: &impl context::Actor,
2139 stream_actor: ActorHandle<StreamActor>,
2140 seq: Seq,
2141 reference: Ref,
2142 controller_rx: &mut PortReceiver<ControllerMessage>,
2143 ) {
2144 fetch_result(cx, stream_actor, seq, reference).await;
2145
2146 let controller_msg = controller_rx.recv().await.unwrap();
2147 match controller_msg {
2148 ControllerMessage::FetchResult {
2149 value: Ok(_),
2150 seq: actual_seq,
2151 } => assert_eq!(seq, actual_seq),
2152 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2153 };
2154 }
2155
2156 #[async_timed_test(timeout_secs = 60)]
2157 async fn test_define_recording_other_recording_active() -> Result<()> {
2158 let test_setup = TestSetup::new().await?;
2159 test_setup
2160 .stream_actor
2161 .define_recording(&test_setup.client, 0.into())
2162 .await?;
2163 test_setup
2164 .stream_actor
2165 .define_recording(&test_setup.client, 1.into())
2166 .await?;
2167 assert_actor_failed_with_msg(
2168 &mut test_setup.stream_actor.status(),
2169 "different recording already active".into(),
2170 )
2171 .await;
2172 Ok(())
2173 }
2174
2175 #[async_timed_test(timeout_secs = 60)]
2176 async fn test_define_recording_already_defined() -> Result<()> {
2177 let test_setup = TestSetup::new().await?;
2178 test_setup
2179 .stream_actor
2180 .define_recording(&test_setup.client, 0.into())
2181 .await?;
2182 test_setup
2183 .stream_actor
2184 .finalize_recording(&test_setup.client, 0.into())
2185 .await?;
2186 test_setup
2187 .stream_actor
2188 .define_recording(&test_setup.client, 0.into())
2189 .await?;
2190 assert_actor_failed_with_msg(
2191 &mut test_setup.stream_actor.status(),
2192 "already defined".into(),
2193 )
2194 .await;
2195 Ok(())
2196 }
2197
2198 #[async_timed_test(timeout_secs = 60)]
2199 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2200 let test_setup = TestSetup::new().await?;
2201 test_setup
2202 .stream_actor
2203 .define_recording(&test_setup.client, 0.into())
2204 .await?;
2205 test_setup
2206 .stream_actor
2207 .finalize_recording(&test_setup.client, 1.into())
2208 .await?;
2209 assert_actor_failed_with_msg(
2210 &mut test_setup.stream_actor.status(),
2211 "cannot finalize recording that isn't active".into(),
2212 )
2213 .await;
2214 Ok(())
2215 }
2216
2217 #[async_timed_test(timeout_secs = 60)]
2218 async fn test_recording_formal_outside_recording() -> Result<()> {
2219 let test_setup = TestSetup::new().await?;
2220 test_setup
2221 .stream_actor
2222 .recording_formal(&test_setup.client, 0.into(), 0)
2223 .await?;
2224 assert_actor_failed_with_msg(
2225 &mut test_setup.stream_actor.status(),
2226 "recording_formal called outside of recording".into(),
2227 )
2228 .await;
2229 Ok(())
2230 }
2231
2232 #[async_timed_test(timeout_secs = 60)]
2233 async fn test_recording_result_outside_recording() -> Result<()> {
2234 let test_setup = TestSetup::new().await?;
2235 test_setup
2236 .stream_actor
2237 .recording_result(&test_setup.client, 0.into(), 0)
2238 .await?;
2239 assert_actor_failed_with_msg(
2240 &mut test_setup.stream_actor.status(),
2241 "recording_result called outside of recording".into(),
2242 )
2243 .await;
2244 Ok(())
2245 }
2246
2247 #[async_timed_test(timeout_secs = 60)]
2248 async fn test_call_recording_other_recording_active() -> Result<()> {
2249 let test_setup = TestSetup::new().await?;
2250 test_setup
2251 .stream_actor
2252 .define_recording(&test_setup.client, 0.into())
2253 .await?;
2254 test_setup
2255 .stream_actor
2256 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2257 .await?;
2258 assert_actor_failed_with_msg(
2259 &mut test_setup.stream_actor.status(),
2260 "cannot call recording while another recording is active".into(),
2261 )
2262 .await;
2263 Ok(())
2264 }
2265
2266 #[async_timed_test(timeout_secs = 60)]
2267 async fn test_call_recording_not_found() -> Result<()> {
2268 let test_setup = TestSetup::new().await?;
2269 test_setup
2270 .stream_actor
2271 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2272 .await?;
2273 assert_actor_failed_with_msg(&mut test_setup.stream_actor.status(), "not found".into())
2274 .await;
2275 Ok(())
2276 }
2277
2278 #[async_timed_test(timeout_secs = 60)]
2279 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2280 let test_setup = TestSetup::new().await?;
2281
2282 test_setup
2283 .stream_actor
2284 .define_recording(&test_setup.client, 0.into())
2285 .await?;
2286
2287 test_setup
2288 .stream_actor
2289 .recording_formal(&test_setup.client, 1.into(), 0)
2290 .await?;
2291
2292 test_setup
2293 .stream_actor
2294 .finalize_recording(&test_setup.client, 0.into())
2295 .await?;
2296
2297 test_setup
2298 .stream_actor
2299 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2300 .await?;
2301
2302 assert_actor_failed_with_msg(
2303 &mut test_setup.stream_actor.status(),
2304 "recording_formal called with too few arguments".into(),
2305 )
2306 .await;
2307 Ok(())
2308 }
2309
2310 #[async_timed_test(timeout_secs = 60)]
2311 async fn test_recording_result_too_few_results() -> Result<()> {
2312 let test_setup = TestSetup::new().await?;
2313
2314 test_setup
2315 .stream_actor
2316 .define_recording(&test_setup.client, 0.into())
2317 .await?;
2318
2319 test_setup
2320 .stream_actor
2321 .recording_result(&test_setup.client, 1.into(), 0)
2322 .await?;
2323
2324 test_setup
2325 .stream_actor
2326 .finalize_recording(&test_setup.client, 0.into())
2327 .await?;
2328
2329 test_setup
2330 .stream_actor
2331 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2332 .await?;
2333
2334 assert_actor_failed_with_msg(
2335 &mut test_setup.stream_actor.status(),
2336 "recording_result called with too few results".into(),
2337 )
2338 .await;
2339 Ok(())
2340 }
2341
2342 #[async_timed_test(timeout_secs = 60)]
2343 async fn test_basic_call_recording() -> Result<()> {
2344 let mut test_setup = TestSetup::new().await?;
2345
2346 test_setup
2350 .stream_actor
2351 .define_recording(&test_setup.client, 0.into())
2352 .await?;
2353
2354 let formal0_ref = 1.into();
2355 let formal0_index = 1;
2356 test_setup
2357 .stream_actor
2358 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2359 .await?;
2360
2361 let formal1_ref = 2.into();
2362 let formal1_index = 0;
2363 test_setup
2364 .stream_actor
2365 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2366 .await?;
2367
2368 let result0_ref = formal0_ref;
2369 let result0_index = 0;
2370 test_setup
2371 .stream_actor
2372 .recording_result(&test_setup.client, result0_ref, result0_index)
2373 .await?;
2374
2375 let result1_ref = formal1_ref;
2376 let result1_index = 1;
2377 test_setup
2378 .stream_actor
2379 .recording_result(&test_setup.client, result1_ref, result1_index)
2380 .await?;
2381
2382 test_setup
2383 .stream_actor
2384 .finalize_recording(&test_setup.client, 0.into())
2385 .await?;
2386
2387 let actual0_ref = 3.into();
2388 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2389
2390 let actual1_ref = 4.into();
2391 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2392
2393 let actual_result0_ref = 5.into();
2396 let actual_result1_ref = 6.into();
2397 test_setup
2398 .stream_actor
2399 .call_recording(
2400 &test_setup.client,
2401 0.into(),
2402 0.into(),
2403 vec![actual_result0_ref, actual_result1_ref],
2404 vec![actual0_ref, actual1_ref],
2405 )
2406 .await?;
2407
2408 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2410 assert!(
2411 test_setup
2412 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2413 .await
2414 );
2415
2416 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2419 Ok(())
2420 }
2421
2422 #[async_timed_test(timeout_secs = 60)]
2423 async fn test_request_status_in_recording() -> Result<()> {
2424 let test_setup = TestSetup::new().await?;
2425 test_setup
2426 .stream_actor
2427 .define_recording(&test_setup.client, 0.into())
2428 .await?;
2429 test_setup
2430 .stream_actor
2431 .request_status(&test_setup.client)
2432 .await
2433 .expect_err("request_status should have failed");
2434 assert_actor_failed_with_msg(
2435 &mut test_setup.stream_actor.status(),
2436 "request_status not allowed in recording".into(),
2437 )
2438 .await;
2439 Ok(())
2440 }
2441
2442 #[async_timed_test(timeout_secs = 60)]
2443 async fn test_init_comm_in_recording() -> Result<()> {
2444 let test_setup = TestSetup::new().await?;
2445 test_setup
2446 .stream_actor
2447 .define_recording(&test_setup.client, 0.into())
2448 .await?;
2449
2450 let dummy_comm = test_setup.proc.spawn(
2451 "comm",
2452 NcclCommActor::new(CommParams::New {
2453 device: CudaDevice::new(0.into()),
2454 unique_id: UniqueId::new_nccl()?,
2455 world_size: 1,
2456 rank: 0,
2457 })
2458 .await
2459 .unwrap(),
2460 )?;
2461
2462 test_setup
2463 .stream_actor
2464 .init_comm(&test_setup.client, dummy_comm)
2465 .await?;
2466 assert_actor_failed_with_msg(
2467 &mut test_setup.stream_actor.status(),
2468 "init_comm not allowed in recording".into(),
2469 )
2470 .await;
2471 Ok(())
2472 }
2473
2474 #[async_timed_test(timeout_secs = 60)]
2475 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2476 let mut test_setup = TestSetup::new().await?;
2477 test_setup
2478 .stream_actor
2479 .define_recording(&test_setup.client, 0.into())
2480 .await?;
2481
2482 let borrow_id = 1;
2483 let tensor_ref = test_setup.next_ref();
2484 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2485
2486 test_setup
2487 .stream_actor
2488 .borrow_create(
2489 &test_setup.client,
2490 borrow_id,
2491 tensor_ref,
2492 first_use_sender.clone(),
2493 )
2494 .await?;
2495
2496 test_setup
2497 .stream_actor
2498 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2499 .await?;
2500
2501 assert_actor_failed_with_msg(
2502 &mut test_setup.stream_actor.status(),
2503 "duplicate borrow create in recording".into(),
2504 )
2505 .await;
2506
2507 Ok(())
2508 }
2509
2510 #[async_timed_test(timeout_secs = 60)]
2511 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2512 let test_setup = TestSetup::new().await?;
2513 test_setup
2514 .stream_actor
2515 .define_recording(&test_setup.client, 0.into())
2516 .await?;
2517
2518 let borrow_id = 1;
2519 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2520
2521 test_setup
2522 .stream_actor
2523 .borrow_drop(
2524 &test_setup.client,
2525 borrow_id,
2526 Arc::new(Mutex::new(last_use_receiver)),
2527 )
2528 .await?;
2529
2530 assert_actor_failed_with_msg(
2531 &mut test_setup.stream_actor.status(),
2532 "borrow drop for borrow not defined in recording".into(),
2533 )
2534 .await;
2535
2536 Ok(())
2537 }
2538
2539 #[async_timed_test(timeout_secs = 60)]
2540 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2541 let mut test_setup = TestSetup::new().await?;
2542 test_setup
2543 .stream_actor
2544 .define_recording(&test_setup.client, 0.into())
2545 .await?;
2546
2547 let borrow_id = 1;
2548 let tensor_ref = test_setup.next_ref();
2549 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2550
2551 test_setup
2552 .stream_actor
2553 .borrow_create(
2554 &test_setup.client,
2555 borrow_id,
2556 tensor_ref,
2557 first_use_sender.clone(),
2558 )
2559 .await?;
2560
2561 test_setup
2563 .stream_actor
2564 .finalize_recording(&test_setup.client, 0.into())
2565 .await?;
2566
2567 assert_actor_failed_with_msg(
2568 &mut test_setup.stream_actor.status(),
2569 "all borrows created within recording must be dropped within recording".into(),
2570 )
2571 .await;
2572
2573 Ok(())
2574 }
2575}