1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::Context;
25use hyperactor::HandleClient;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::PortHandle;
29use hyperactor::actor::ActorHandle;
30use hyperactor::handle;
31use hyperactor::mailbox::OncePortHandle;
32use hyperactor::mailbox::PortReceiver;
33use hyperactor::proc::Proc;
34use hyperactor::reference;
35use monarch_hyperactor::actor::PythonMessage;
36use monarch_hyperactor::actor::PythonMessageKind;
37use monarch_hyperactor::local_state_broker::BrokerId;
38use monarch_hyperactor::local_state_broker::LocalState;
39use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
40use monarch_hyperactor::pickle::pickle;
41use monarch_messages::controller::ControllerMessageClient;
42use monarch_messages::controller::Seq;
43use monarch_messages::controller::WorkerError;
44use monarch_messages::worker::ActorCallParams;
45use monarch_messages::worker::ActorMethodParams;
46use monarch_messages::worker::ArgsKwargs;
47use monarch_messages::worker::CallFunctionError;
48use monarch_messages::worker::CallFunctionParams;
49use monarch_messages::worker::SeqError;
50use monarch_messages::worker::StreamRef;
51use monarch_types::PyTree;
52use monarch_types::SerializablePyErr;
53use monarch_types::TryIntoPyObjectUnsafe;
54use pyo3::prelude::*;
55use tokio::runtime::Handle;
56use tokio::sync::Mutex;
57use tokio::task::JoinHandle;
58use torch_sys_cuda::cuda::Event;
59use torch_sys_cuda::cuda::Stream;
60use torch_sys2::CloneUnsafe;
61use torch_sys2::CudaDevice;
62use torch_sys2::TensorCell;
63use torch_sys2::deep_clone;
64use torch_sys2::factory_empty;
65use torch_sys2::factory_zeros;
66use tracing_subscriber::fmt::Subscriber;
67use typeuri::Named;
68
69use crate::ControllerActor;
70use crate::DeviceMesh;
71use crate::Factory;
72use crate::Reduction;
73use crate::Ref;
74use crate::ResolvableFunction;
75use crate::StreamCreationMode;
76use crate::WireValue;
77use crate::comm::CommMessage;
78use crate::comm::CommMessageClient;
79use crate::comm::NcclCommActor;
80
81pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
82
83thread_local! {
85 pub static CONTROLLER_ACTOR_REF: OnceCell<reference::ActorRef<ControllerActor>> = const { OnceCell::new() };
86 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
87 pub static ROOT_ACTOR_ID: OnceCell<reference::ActorId> = const { OnceCell::new() };
88}
89
90fn pickle_python_result(
91 py: Python<'_>,
92 result: Bound<'_, PyAny>,
93 worker_rank: usize,
94) -> Result<PythonMessage, anyhow::Error> {
95 let mut state = pickle(py, result.unbind(), false, false)
96 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
97 let inner = state
98 .take_inner()
99 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
100 Ok(PythonMessage::new_from_buf(
101 PythonMessageKind::Result {
102 rank: Some(worker_rank),
103 },
104 inner.take_buffer(),
105 ))
106}
107
108#[derive(Debug)]
109struct Recording {
110 messages: Vec<StreamMessage>,
111}
112
113impl Recording {
114 fn new() -> Self {
115 Self {
116 messages: Vec::new(),
117 }
118 }
119}
120
121#[derive(Debug, PartialEq)]
122enum RecordingState {
123 Defining {
124 recording: Ref,
125 defined_borrows: HashSet<u64>,
128 },
129 Running,
130}
131
132#[derive(Handler, HandleClient, Debug, Named)]
135pub enum StreamMessage {
136 CallFunction(
137 CallFunctionParams,
138 HashMap<Ref, DeviceMesh>,
139 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
140 ),
141
142 BorrowCreate {
143 borrow: u64,
145 tensor: Ref,
147 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
150 },
151
152 BorrowFirstUse {
153 borrow: u64,
155 result: Ref,
157 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
160 },
161
162 BorrowLastUse {
163 borrow: u64,
165 result: Ref,
167 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
169 },
170
171 BorrowDrop {
172 borrow: u64,
173 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
175 },
176
177 DeleteRefs(Vec<Ref>),
178
179 RequestStatus(#[reply] OncePortHandle<()>),
180
181 InitComm(ActorHandle<NcclCommActor>),
182
183 Reduce {
184 comm: Arc<ActorHandle<NcclCommActor>>,
185 dim_size: i64,
186 result: Ref,
187 local_tensor: Ref,
188 factory: Factory,
189 reduction: Reduction,
190 scatter: bool,
191 in_place: bool,
192 out: Option<Ref>,
193 },
194
195 SendTensor {
196 result: Ref,
197 from_rank: Option<usize>,
198 to_rank: Option<usize>,
199 tensor: Ref,
200 factory: Factory,
201 comm: Arc<ActorHandle<NcclCommActor>>,
202 },
203
204 SendValue {
205 seq: Seq,
206 worker_actor_id: reference::ActorId,
207 mutates: Vec<Ref>,
208 function: Option<ResolvableFunction>,
209 args_kwargs: ArgsKwargs,
210 device_meshes: HashMap<Ref, DeviceMesh>,
211 },
212
213 DefineRecording {
214 recording: Ref,
215 },
216
217 FinalizeRecording {
218 recording: Ref,
219 },
220
221 CallRecording {
222 seq: Seq,
223 recording: Ref,
224 results: Vec<Ref>,
225 actuals: Vec<Ref>,
226 },
227
228 RecordingFormal {
229 result: Ref,
230 argument_index: usize,
231 },
232
233 RecordingResult {
234 result: Ref,
235 output_index: usize,
236 },
237
238 SetRefUnitTestsOnly(Ref, WireValue),
239
240 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
241
242 GetRefUnitTestsOnly(
243 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
245 ),
246
247 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
248
249 SendResultOfActorCall(ActorCallParams),
250 CallActorMethod(ActorMethodParams),
251}
252
253impl StreamMessage {
254 fn clone_for_recording(&self) -> Self {
255 match self {
256 StreamMessage::RecordingFormal {
257 result,
258 argument_index,
259 } => StreamMessage::RecordingFormal {
260 result: *result,
261 argument_index: *argument_index,
262 },
263 StreamMessage::RecordingResult {
264 result,
265 output_index,
266 } => StreamMessage::RecordingResult {
267 result: *result,
268 output_index: *output_index,
269 },
270 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
271 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
272 StreamMessage::CallFunction(
273 params.clone(),
274 device_meshes.clone(),
275 remote_process_groups.clone(),
276 )
277 }
278 StreamMessage::BorrowCreate {
279 borrow,
280 tensor,
281 first_use_sender,
282 } => StreamMessage::BorrowCreate {
283 borrow: *borrow,
284 tensor: *tensor,
285 first_use_sender: first_use_sender.clone(),
286 },
287 StreamMessage::BorrowFirstUse {
288 borrow,
289 result,
290 first_use_receiver,
291 } => StreamMessage::BorrowFirstUse {
292 borrow: *borrow,
293 result: *result,
294 first_use_receiver: first_use_receiver.clone(),
295 },
296 StreamMessage::BorrowLastUse {
297 borrow,
298 result,
299 last_use_sender,
300 } => StreamMessage::BorrowLastUse {
301 borrow: *borrow,
302 result: *result,
303 last_use_sender: last_use_sender.clone(),
304 },
305 StreamMessage::BorrowDrop {
306 borrow,
307 last_use_receiver,
308 } => StreamMessage::BorrowDrop {
309 borrow: *borrow,
310 last_use_receiver: last_use_receiver.clone(),
311 },
312 StreamMessage::Reduce {
313 comm,
314 dim_size,
315 result,
316 local_tensor,
317 factory,
318 reduction,
319 scatter,
320 in_place,
321 out,
322 } => StreamMessage::Reduce {
323 comm: comm.clone(),
324 dim_size: *dim_size,
325 result: *result,
326 local_tensor: *local_tensor,
327 factory: factory.clone(),
328 reduction: reduction.clone(),
329 scatter: *scatter,
330 in_place: *in_place,
331 out: out.clone(),
332 },
333 StreamMessage::SendTensor {
334 result,
335 from_rank,
336 to_rank,
337 tensor,
338 factory,
339 comm,
340 } => StreamMessage::SendTensor {
341 result: *result,
342 from_rank: *from_rank,
343 to_rank: *to_rank,
344 tensor: *tensor,
345 factory: factory.clone(),
346 comm: comm.clone(),
347 },
348 other => panic!(
349 "StreamMessage variant not supported in recording: {:?}",
350 other
351 ),
352 }
353 }
354
355 fn get_defined_refs(&self) -> HashSet<Ref> {
357 match self {
358 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
359 StreamMessage::CallFunction(params, ..) => {
360 params.results.iter().filter_map(|&ref_| ref_).collect()
361 }
362 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
363 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
364 StreamMessage::SendTensor {
365 result, from_rank, ..
366 } => {
367 if from_rank.is_some() {
368 HashSet::from([*result])
369 } else {
370 HashSet::new()
371 }
372 }
373 _ => HashSet::new(),
375 }
376 }
377
378 fn get_mutated_refs(&self) -> HashSet<Ref> {
380 match self {
381 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
382 StreamMessage::Reduce {
383 out,
384 in_place,
385 local_tensor,
386 ..
387 } => {
388 if *in_place {
389 HashSet::from([*local_tensor])
390 } else if let Some(out) = out {
391 HashSet::from([*out])
392 } else {
393 HashSet::new()
394 }
395 }
396 _ => HashSet::new(),
398 }
399 }
400}
401
402#[derive(Debug)]
411pub struct StreamActor {
412 _world_size: usize,
413 rank: usize,
414 env: HashMap<Ref, Result<Py<PyAny>, Arc<SeqError>>>,
418 creation_mode: StreamCreationMode,
420 cuda_stream: OnceLock<Option<Stream>>,
426 device: Option<CudaDevice>,
428 comm: Option<ActorHandle<NcclCommActor>>,
430 controller_actor: reference::ActorRef<ControllerActor>,
432 remote_process_groups: HashMap<Ref, Py<PyAny>>,
433 recordings: HashMap<Ref, Recording>,
434 active_recording: Option<RecordingState>,
435 respond_with_python_message: bool,
436 last_seq_error: Option<Arc<SeqError>>,
437}
438
439#[derive(Debug, Clone)]
441pub struct StreamParams {
442 pub world_size: usize,
443 pub rank: usize,
444 pub creation_mode: StreamCreationMode,
446 pub id: StreamRef,
448 pub device: Option<CudaDevice>,
451 pub controller_actor: reference::ActorRef<ControllerActor>,
453 pub respond_with_python_message: bool,
454}
455
456impl StreamActor {
457 pub fn new(
458 StreamParams {
459 world_size,
460 rank,
461 id: _,
462 device,
463 controller_actor,
464 creation_mode,
465 respond_with_python_message,
466 }: StreamParams,
467 ) -> Self {
468 Self {
469 _world_size: world_size,
470 rank,
471 env: HashMap::new(),
472 creation_mode,
473 cuda_stream: OnceLock::new(),
474 device,
475 comm: None,
476 controller_actor,
477 remote_process_groups: HashMap::new(),
478 recordings: HashMap::new(),
479 active_recording: None,
480 respond_with_python_message,
481 last_seq_error: None,
482 }
483 }
484}
485
486#[async_trait]
487impl Actor for StreamActor {
488 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
489 CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
493 controller_actor_ref.set(self.controller_actor.clone()).ok()
494 });
495 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
496 ROOT_ACTOR_ID.with(|root_actor_id| {
497 root_actor_id
498 .set(reference::ActorId::root(
499 cx.self_id().proc_id().clone(),
500 cx.self_id().name().to_string(),
501 ))
502 .ok()
503 });
504 if let Some(stream) = self.cuda_stream() {
506 Stream::set_current_stream(stream);
507 }
508 Ok(())
509 }
510
511 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
517 where
518 F: Future + Send + 'static,
519 F::Output: Send + 'static,
520 {
521 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
522 let builder = std::thread::Builder::new().name("worker-stream".to_string());
530 let _thread_handle = builder.spawn(move || {
531 let rt = tokio::runtime::Builder::new_multi_thread()
535 .worker_threads(1)
536 .enable_all()
537 .build()
538 .unwrap();
539 let result = rt.block_on(async {
540 tokio::task::block_in_place(|| {
541 Python::attach(|py| {
546 py.detach(|| {
547 let result = Handle::current().block_on(future);
548 if join_tx.send(result).is_err() {
549 panic!("could not send join result")
550 }
551 })
552 })
553 })
554 });
555 rt.shutdown_timeout(Duration::from_weeks(1));
556 result
557 });
558
559 tokio::spawn(async move { join_rx.await.unwrap() })
562 }
563}
564
565#[derive(Debug)]
567enum PyArg {
568 Object(Py<PyAny>),
569}
570
571impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
573 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
574 match self {
575 PyArg::Object(obj) => Ok(obj.clone_ref(py).into_bound(py)),
576 }
577 }
578}
579
580impl StreamActor {
581 fn tensor_to_pyobject(tensor_cell: TensorCell) -> Py<PyAny> {
582 Python::attach(|py| {
583 let tensor = unsafe {
588 tensor_cell.get_unchecked().clone_unsafe()
590 };
591 tensor.into_pyobject(py).unwrap().unbind()
592 })
593 }
594
595 fn pyobject_to_tensor(py: Python<'_>, pyobj: &Py<PyAny>) -> PyResult<TensorCell> {
599 use torch_sys2::Tensor;
600 let tensor = pyobj.bind(py).extract::<Tensor>()?;
601 Ok(TensorCell::new(tensor))
603 }
604
605 fn cuda_stream(&self) -> Option<&Stream> {
606 self.cuda_stream
607 .get_or_init(|| {
608 self.device.map(|device| match self.creation_mode {
609 StreamCreationMode::UseDefaultStream => {
610 Stream::get_current_stream_on_device(device)
611 }
612 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
613 })
614 })
615 .as_ref()
616 }
617
618 fn ref_to_pyobject(&self, ref_: &Ref) -> Result<Py<PyAny>, CallFunctionError> {
619 let pyobject = self
620 .env
621 .get(ref_)
622 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
623 match pyobject {
624 Ok(val) => Ok(val.clone()),
625 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
626 }
627 }
628
629 async fn report_seq_error(
630 &mut self,
631 cx: &Context<'_, Self>,
632 seq: Seq,
633 error: CallFunctionError,
634 ) -> Result<Arc<SeqError>, anyhow::Error> {
635 match error {
636 CallFunctionError::DependentError(root) => Ok(root),
637 CallFunctionError::Error(e) => {
638 if self.active_recording.is_none() {
639 let worker_error = WorkerError {
640 backtrace: format!("{e}"),
641 worker_actor_id: cx.self_id().clone(),
642 };
643 tracing::info!("Propagating remote function error to client: {worker_error}");
644 self.controller_actor
645 .remote_function_failed(cx, seq, worker_error)
646 .await?
647 }
648 let err = Arc::new(SeqError { seq, error: e });
649 self.last_seq_error = Some(err.clone());
650 Ok(err)
651 }
652 }
653 }
654
655 async fn try_define<F>(
656 &mut self,
657 cx: &Context<'_, Self>,
658 seq: Seq,
659 result_refs: Vec<Option<Ref>>,
660 mutates: &Vec<Ref>,
661 f: F,
662 ) -> Result<()>
663 where
664 F: AsyncFnOnce(&mut Self) -> Result<Vec<Py<PyAny>>, CallFunctionError>,
665 {
666 let actual_results = f(self).await;
667 let op_results = actual_results.and_then(|actual_results| {
670 if result_refs.len() == actual_results.len() {
671 Ok(actual_results
672 .into_iter()
673 .zip(result_refs.iter())
674 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
675 .collect::<Vec<(Ref, Py<PyAny>)>>())
676 } else {
677 Err(CallFunctionError::UnexpectedNumberOfReturns(
678 result_refs.len(),
679 actual_results.len(),
680 ))
681 }
682 });
683
684 match op_results {
687 Ok(op_results) => {
688 for (ref_, pyobject) in op_results.into_iter() {
689 let prev = self.env.insert(ref_, Ok(pyobject));
690 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
691 }
692 }
693 Err(err) => {
694 let err = self.report_seq_error(cx, seq, err).await?;
695 for ref_ in result_refs {
696 match ref_ {
697 Some(ref_) => {
698 let prev = self.env.insert(ref_, Err(err.clone()));
699 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
700 }
701 None => {}
702 }
703 }
704 for ref_ in mutates {
705 self.env.insert(*ref_, Err(err.clone()));
706 }
707 }
708 }
709 Ok(())
710 }
711
712 fn call_python_fn<'py>(
713 &mut self,
714 py: Python<'py>,
715 _cx: &Context<Self>,
716 function: Option<ResolvableFunction>,
717 args_kwargs: ArgsKwargs,
718 _mutates: &[Ref],
719 device_meshes: HashMap<Ref, DeviceMesh>,
720 remote_process_groups: HashMap<
721 Ref,
722 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
723 >,
724 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
725 let (args_tuple, kwargs_dict) = args_kwargs
726 .to_python(py)
727 .map_err(|e| CallFunctionError::Error(e.into()))?;
728 let function = function
729 .map(|function| {
730 function.resolve(py).map_err(|e| {
731 CallFunctionError::InvalidRemoteFunction(format!(
732 "failed to resolve function {}: {}",
733 function,
734 SerializablePyErr::from(py, &e)
735 ))
736 })
737 })
738 .transpose()?;
739
740 let remote_process_groups = remote_process_groups
741 .into_iter()
742 .map(|(gref, (_mesh, _dims, _comm))| {
743 let group = match self.remote_process_groups.entry(gref) {
744 Entry::Occupied(ent) => ent.get().clone_ref(py),
745 Entry::Vacant(_ent) => {
746 panic!("no longer implemented");
747 }
748 };
749 PyResult::Ok((gref, group))
750 })
751 .collect::<Result<HashMap<_, _>, _>>()
752 .map_err(SerializablePyErr::from_fn(py))?;
753
754 let resolve = |val: Bound<'py, PyAny>| {
755 val.extract::<PyTree<Py<PyAny>>>()
756 .map_err(SerializablePyErr::from_fn(py))?
757 .try_into_map(|obj| {
758 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
759 if let Some(mesh) = device_meshes.get(&ref_) {
760 PyArg::Object(
761 Py::new(py, mesh.clone())
762 .map_err(SerializablePyErr::from_fn(py))?
763 .into(),
764 )
765 } else if let Some(pg) = remote_process_groups.get(&ref_) {
766 PyArg::Object(pg.clone_ref(py))
767 } else {
768 let pyobj = self.ref_to_pyobject(&ref_)?;
769 PyArg::Object(pyobj)
770 }
771 } else {
772 PyArg::Object(obj)
773 })
774 })
775 };
776
777 let py_args: Vec<PyTree<PyArg>> = args_tuple
779 .iter()
780 .map(&resolve)
781 .collect::<Result<_, CallFunctionError>>()?;
782
783 let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
784 .iter()
785 .map(|(k, v)| {
786 let key = k
787 .extract::<String>()
788 .map_err(SerializablePyErr::from_fn(py))?;
789 let value = resolve(v)?;
790 Ok((key, value))
791 })
792 .collect::<Result<_, CallFunctionError>>()?;
793
794 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
797 let result: Bound<'_, PyAny> =
798 tracing::subscriber::with_default(scoped_subscriber, || {
799 let args = unsafe { py_args.try_to_object_unsafe(py) }
808 .map_err(SerializablePyErr::from_fn(py))?;
809 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
811 .map_err(SerializablePyErr::from_fn(py))?;
812
813 if let Some(function) = function {
814 function
815 .call(args, Some(kwargs))
816 .map_err(SerializablePyErr::from_fn(py))
817 } else {
818 Ok(args.get_item(0).unwrap())
819 }
820 })?;
821 Ok(result)
822 }
823
824 fn call_python_fn_pytree(
825 &mut self,
826 cx: &hyperactor::Context<Self>,
827 function: ResolvableFunction,
828 args_kwargs: ArgsKwargs,
829 mutates: &[Ref],
830 device_meshes: HashMap<Ref, DeviceMesh>,
831 remote_process_groups: HashMap<
832 Ref,
833 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
834 >,
835 ) -> Result<PyTree<Py<PyAny>>, CallFunctionError> {
836 Python::attach(|py| {
837 let result = self.call_python_fn(
838 py,
839 cx,
840 Some(function),
841 args_kwargs,
842 mutates,
843 device_meshes,
844 remote_process_groups,
845 )?;
846 Ok(PyTree::<Py<PyAny>>::extract_bound(&result)
847 .map_err(SerializablePyErr::from_fn(py))?)
848 })
849 }
850 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
858 let pyobject = self
859 .env
860 .get(&ref_)
861 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
862
863 match pyobject {
864 Ok(val) => Python::attach(|py| {
865 Self::pyobject_to_tensor(py, val)
866 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
867 }),
868 Err(_) => {
869 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
870 Ok(TensorCell::new(t))
871 }
872 }
873 }
874
875 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
876 self.active_recording
877 .as_mut()
878 .and_then(|state| match state {
879 RecordingState::Defining {
880 recording,
881 defined_borrows,
882 } => {
883 match self.recordings.get_mut(recording) {
884 Some(recording) => Some((recording, defined_borrows)),
885 None => panic!("recording not found: {:?}", recording),
887 }
888 }
889 RecordingState::Running => None,
890 })
891 }
892
893 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
894 for ref_ in refs {
895 let rvalue_or_err = self
896 .env
897 .get(ref_)
898 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
899 if let Err(err) = rvalue_or_err {
900 return Ok(Some(err.clone()));
901 }
902 }
903 Ok(None)
904 }
905 async fn send_value_python_message(
906 &mut self,
907 cx: &hyperactor::Context<'_, Self>,
908 seq: Seq,
909 mutates: Vec<Ref>,
910 function: Option<ResolvableFunction>,
911 args_kwargs: ArgsKwargs,
912 device_meshes: HashMap<Ref, DeviceMesh>,
913 ) -> Result<()> {
914 let rank = self.rank;
915 self.try_define(cx, seq, vec![], &vec![], async |self_| {
916 let python_message =
917 Python::attach(|py| -> Result<PythonMessage, CallFunctionError> {
918 let python_result = tokio::task::block_in_place(|| {
919 self_.call_python_fn(
920 py,
921 cx,
922 function,
923 args_kwargs,
924 &mutates,
925 device_meshes,
926 HashMap::new(),
927 )
928 })?;
929 pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
930 })?;
931 let ser = wirevalue::Any::serialize(&python_message).unwrap();
932 self_
933 .controller_actor
934 .fetch_result(cx, seq, Ok(ser))
935 .await?;
936 Ok(vec![])
937 })
938 .await
939 }
940 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
941 let rvalue = self
942 .env
943 .get(&src)
944 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
945 self.env.insert(dest, Python::attach(|_py| rvalue.clone()));
946 Ok(())
947 }
948 async fn call_actor(
949 &mut self,
950 cx: &Context<'_, Self>,
951 params: ActorCallParams,
952 ) -> Result<Py<PyAny>, CallFunctionError> {
953 let local_state: Result<Vec<Py<PyAny>>> = Python::attach(|_py| {
954 params
955 .local_state
956 .into_iter()
957 .map(|elem| {
958 let pyobj = self.ref_to_pyobject(&elem)?;
959 Ok(pyobj.into_any())
960 })
961 .collect()
962 });
963
964 let (send, recv) = cx.open_once_port();
965 let state = LocalState {
966 response_port: send,
967 state: local_state?,
968 };
969 let x: u64 = params.seq.into();
970 let message = LocalStateBrokerMessage::Set(x as usize, state);
971
972 let broker = BrokerId::new(params.broker_id).resolve(cx).await;
973 broker
974 .send(cx, message)
975 .map_err(|e| CallFunctionError::Error(e.into()))?;
976 let result = recv
977 .recv()
978 .await
979 .map_err(|e| CallFunctionError::Error(e.into()))?;
980
981 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
982 }
983}
984
985#[async_trait]
986#[handle(StreamMessage)]
987impl StreamMessageHandler for StreamActor {
988 async fn call_function(
989 &mut self,
990 cx: &Context<Self>,
991 params: CallFunctionParams,
992 device_meshes: HashMap<Ref, DeviceMesh>,
993 remote_process_groups: HashMap<
994 Ref,
995 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
996 >,
997 ) -> Result<()> {
998 if let Some((recording, _)) = self.get_defining_recording() {
999 recording.messages.push(StreamMessage::CallFunction(
1000 params,
1001 device_meshes,
1002 remote_process_groups,
1003 ));
1004 return Ok(());
1005 }
1006
1007 params.function.panic_if_requested();
1008 self.try_define(
1009 cx,
1010 params.seq,
1011 params.results,
1012 ¶ms.mutates,
1013 async |self| {
1014 tokio::task::block_in_place(|| {
1015 self.call_python_fn_pytree(
1016 cx,
1017 params.function,
1018 params.args_kwargs,
1019 ¶ms.mutates,
1020 device_meshes,
1021 remote_process_groups,
1022 )
1023 .map(|results| results.into_leaves())
1024 })
1025 },
1026 )
1027 .await?;
1028 Ok(())
1029 }
1030
1031 async fn borrow_create(
1032 &mut self,
1033 cx: &Context<Self>,
1034 borrow: u64,
1035 tensor: Ref,
1036 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1037 ) -> Result<()> {
1038 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1039 recording.messages.push(StreamMessage::BorrowCreate {
1040 borrow,
1041 tensor,
1042 first_use_sender,
1043 });
1044 ensure!(
1045 defined_borrows.insert(borrow),
1046 "duplicate borrow create in recording"
1047 );
1048 return Ok(());
1049 }
1050
1051 let pyobj_result = self
1052 .env
1053 .get(&tensor)
1054 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1055
1056 let result = match pyobj_result {
1057 Ok(pyobj) => Python::attach(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1058 Err(e) => Err(e.clone()),
1059 };
1060
1061 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1062 first_use_sender.send(cx, (event, result)).map_err(|err| {
1063 anyhow!(
1064 "failed sending first use event for borrow {:?}: {:?}",
1065 borrow,
1066 err
1067 )
1068 })
1069 }
1070
1071 async fn borrow_first_use(
1072 &mut self,
1073 _cx: &Context<Self>,
1074 borrow: u64,
1075 result: Ref,
1076 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1077 ) -> Result<()> {
1078 if let Some((recording, _)) = self.get_defining_recording() {
1079 recording.messages.push(StreamMessage::BorrowFirstUse {
1080 borrow,
1081 result,
1082 first_use_receiver: first_use_receiver.clone(),
1083 });
1084 return Ok(());
1085 }
1086
1087 let (first_use_event, cell) =
1088 first_use_receiver
1089 .lock()
1090 .await
1091 .recv()
1092 .await
1093 .map_err(|err| {
1094 anyhow!(
1095 "failed receiving first use event for borrow {:?}: {:?}",
1096 borrow,
1097 err
1098 )
1099 })?;
1100
1101 if let Some(stream) = self.cuda_stream() {
1102 stream.wait_event(
1103 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1104 );
1105 }
1106 match cell {
1107 Ok(cell) => {
1108 let pyobj = Self::tensor_to_pyobject(cell);
1109 self.env.insert(result, Ok(pyobj));
1110 }
1111 Err(err) => {
1112 self.env.insert(result, Err(err.clone()));
1113 }
1114 }
1115 Ok(())
1116 }
1117
1118 async fn borrow_last_use(
1119 &mut self,
1120 cx: &Context<Self>,
1121 borrow: u64,
1122 result: Ref,
1123 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1124 ) -> Result<()> {
1125 if let Some((recording, _)) = self.get_defining_recording() {
1126 recording.messages.push(StreamMessage::BorrowLastUse {
1127 borrow,
1128 result,
1129 last_use_sender,
1130 });
1131 return Ok(());
1132 }
1133
1134 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1135 let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1136 "Invalid reference for borrow_last_use: {result:#?}"
1137 ))?;
1138 let tensor = match pyobj_or_err {
1139 Ok(pyobj) => Ok(Python::attach(|py| {
1140 Self::pyobject_to_tensor(py, &pyobj).unwrap()
1141 })),
1142 Err(e) => Err(e),
1143 };
1144
1145 last_use_sender.send(cx, (event, tensor)).map_err(|err| {
1146 anyhow!(
1147 "failed sending last use event for borrow {:?}: {:?}",
1148 borrow,
1149 err
1150 )
1151 })
1152 }
1153
1154 async fn borrow_drop(
1155 &mut self,
1156 _cx: &Context<Self>,
1157 borrow: u64,
1158 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1159 ) -> Result<()> {
1160 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1161 recording.messages.push(StreamMessage::BorrowDrop {
1162 borrow,
1163 last_use_receiver: last_use_receiver.clone(),
1164 });
1165 ensure!(
1166 defined_borrows.remove(&borrow),
1167 "borrow drop for borrow not defined in recording"
1168 );
1169 return Ok(());
1170 }
1171
1172 let (last_use_event, _cell) =
1176 last_use_receiver.lock().await.recv().await.map_err(|err| {
1177 anyhow!(
1178 "failed receiving last use event for borrow {:?}: {:?}",
1179 borrow,
1180 err
1181 )
1182 })?;
1183
1184 if let Some(stream) = self.cuda_stream() {
1185 stream.wait_event(
1186 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1187 );
1188 }
1189 Ok(())
1191 }
1192
1193 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1194 if let Some((recording, _)) = self.get_defining_recording() {
1195 recording.messages.push(StreamMessage::DeleteRefs(refs));
1196 return Ok(());
1197 }
1198
1199 for ref_ in refs.iter() {
1200 self.env.remove(ref_);
1201 }
1202 Ok(())
1203 }
1204
1205 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1206 if self.get_defining_recording().is_some() {
1207 bail!("request_status not allowed in recording");
1208 }
1209
1210 Ok(())
1211 }
1212
1213 async fn init_comm(
1214 &mut self,
1215 _cx: &Context<Self>,
1216 comm: ActorHandle<NcclCommActor>,
1217 ) -> Result<()> {
1218 if self.get_defining_recording().is_some() {
1219 bail!("init_comm not allowed in recording");
1220 }
1221
1222 self.comm = Some(comm);
1223 Ok(())
1224 }
1225
1226 async fn reduce(
1227 &mut self,
1228 cx: &Context<Self>,
1229 comm: Arc<ActorHandle<NcclCommActor>>,
1230 dim_size: i64,
1231 result: Ref,
1232 local_tensor: Ref,
1233 factory: Factory,
1234 reduction: Reduction,
1235 scatter: bool,
1236 in_place: bool,
1237 out: Option<Ref>,
1238 ) -> Result<()> {
1239 if let Some((recording, _)) = self.get_defining_recording() {
1240 recording.messages.push(StreamMessage::Reduce {
1241 comm,
1242 dim_size,
1243 result,
1244 local_tensor,
1245 factory,
1246 reduction,
1247 scatter,
1248 in_place,
1249 out,
1250 });
1251 return Ok(());
1252 }
1253
1254 let stream = self
1255 .cuda_stream()
1256 .expect("reductions not yet supported for non-CUDA workers")
1257 .clone();
1258 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1259 let out_cell = out
1260 .map(|out| self.get_or_fake_on_err(out, &factory))
1261 .transpose()?;
1262 let output_cell = match reduction {
1263 Reduction::Stack => {
1264 if scatter {
1265 let output_cell = if in_place {
1266 input_cell.clone()
1267 } else {
1268 out_cell.unwrap_or({
1269 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1270 let cloned = deep_clone(&borrow);
1271 TensorCell::new(cloned)
1272 })
1273 };
1274 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1275 .await?;
1276 output_cell
1277 } else {
1278 ensure!(
1279 !in_place,
1280 "in-place, non-scatter not supported for stack reduce"
1281 );
1282
1283 let output_cell = out_cell.unwrap_or({
1284 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1286 let output =
1287 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1288 TensorCell::new(output)
1289 });
1290
1291 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1292 .await?;
1293 output_cell
1294 }
1295 }
1296 Reduction::ReduceOp(op) => {
1297 if scatter {
1298 ensure!(!in_place, "in-place, scatter not supported for reduce");
1299
1300 let output_cell = out_cell.unwrap_or({
1301 let output = factory_empty(
1302 &factory.size[1..],
1303 factory.dtype,
1304 factory.layout,
1305 factory.device,
1306 );
1307 TensorCell::new(output)
1308 });
1309 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1310 .await?;
1311 output_cell
1312 } else {
1313 let output_cell = if in_place {
1314 input_cell.clone()
1315 } else {
1316 out_cell.map_or(
1317 {
1318 let borrow =
1319 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1320 let cloned = deep_clone(&borrow);
1321 Ok(TensorCell::new(cloned))
1322 },
1323 |out_cell| -> Result<_, anyhow::Error> {
1324 let mut out_borrow =
1325 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1326 let in_borrow =
1327 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1328 out_borrow.copy_(&in_borrow);
1329 drop(out_borrow);
1330 Ok(out_cell)
1331 },
1332 )?
1333 };
1334
1335 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1336 output_cell
1337 }
1338 }
1339 };
1340
1341 let pyobj = Self::tensor_to_pyobject(output_cell);
1342 self.env.insert(result, Ok(pyobj));
1343 Ok(())
1344 }
1345
1346 async fn send_tensor(
1347 &mut self,
1348 cx: &Context<Self>,
1349 result: Ref,
1350 from_rank: Option<usize>,
1351 to_rank: Option<usize>,
1352 tensor: Ref,
1353 factory: Factory,
1354 comm: Arc<ActorHandle<NcclCommActor>>,
1355 ) -> Result<()> {
1356 if let Some((recording, _)) = self.get_defining_recording() {
1357 recording.messages.push(StreamMessage::SendTensor {
1358 result,
1359 from_rank,
1360 to_rank,
1361 tensor,
1362 factory,
1363 comm,
1364 });
1365 return Ok(());
1366 }
1367
1368 if to_rank.is_none() && from_rank.is_none() {
1369 bail!("tried to send tensor without a to/from rank");
1370 }
1371
1372 if from_rank == to_rank {
1374 let input_cell: &std::result::Result<Py<PyAny>, Arc<SeqError>> = self
1375 .env
1376 .get(&tensor)
1377 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1378 let output_cell: Result<Py<PyAny>, Arc<SeqError>> = match input_cell {
1379 Ok(pyobj) => {
1380 Python::attach(|py| -> Result<Py<PyAny>, Arc<SeqError>> {
1381 let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1382 let borrow = input_tensor.try_borrow().unwrap();
1387 let cloned = deep_clone(&borrow);
1388 let cloned_cell = TensorCell::new(cloned);
1389 Ok(Self::tensor_to_pyobject(cloned_cell))
1390 })
1391 }
1392 Err(err) => Err(err.clone()),
1393 };
1394 self.env.insert(result, output_cell);
1395 return Ok(());
1396 }
1397
1398 let mut messages = Vec::new();
1399
1400 if let Some(to_rank) = to_rank {
1401 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1402 messages.push(CommMessage::Send(
1403 input_cell,
1404 to_rank.try_into().unwrap(),
1405 self.cuda_stream()
1406 .expect("tried to send_tensor on non-cuda stream")
1407 .clone(),
1408 cx.open_once_port().0,
1409 ));
1410 }
1411
1412 if let Some(from_rank) = from_rank {
1413 let output_cell = TensorCell::new(factory_empty(
1414 &factory.size,
1415 factory.dtype,
1416 factory.layout,
1417 factory.device,
1418 ));
1419 messages.push(CommMessage::Recv(
1420 output_cell.clone(),
1421 from_rank.try_into().unwrap(),
1422 self.cuda_stream()
1423 .expect("tried to send_tensor on non-cuda stream")
1424 .clone(),
1425 cx.open_once_port().0,
1426 ));
1427 let pyobj = Self::tensor_to_pyobject(output_cell);
1428 self.env.insert(result, Ok(pyobj));
1429 }
1430
1431 comm.group(
1432 cx,
1433 messages,
1434 self.cuda_stream()
1435 .expect("tried to send_tensor on non-cuda stream")
1436 .clone(),
1437 )
1438 .await?;
1439 Ok(())
1440 }
1441
1442 async fn send_value(
1443 &mut self,
1444 cx: &Context<Self>,
1445 seq: Seq,
1446 worker_actor_id: reference::ActorId,
1447 mutates: Vec<Ref>,
1448 function: Option<ResolvableFunction>,
1449 args_kwargs: ArgsKwargs,
1450 device_meshes: HashMap<Ref, DeviceMesh>,
1451 ) -> Result<()> {
1452 if self.respond_with_python_message {
1453 return self
1454 .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1455 .await;
1456 }
1457
1458 let result = if let Some(function) = function {
1459 tokio::task::block_in_place(|| {
1461 self.call_python_fn_pytree(
1462 cx,
1463 function,
1464 args_kwargs,
1465 &mutates,
1466 device_meshes,
1467 HashMap::new(),
1468 )
1469 })
1470 } else {
1471 Python::attach(|py| {
1474 let (args, kwargs) = args_kwargs
1475 .to_python(py)
1476 .map_err(|e| CallFunctionError::Error(e.into()))?;
1477 match (args.len(), kwargs.len()) {
1478 (1, 0) => {
1479 let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1480 arg.extract::<PyTree<Py<PyAny>>>()
1481 .map_err(SerializablePyErr::from_fn(py))?
1482 .try_into_map(|obj| {
1483 let bound_obj = obj.bind(py);
1484 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1485 self.ref_to_pyobject(&ref_)
1486 } else {
1487 Ok(obj)
1488 }
1489 })
1490 }
1491 _ => Err(CallFunctionError::TooManyArgsForValue(
1492 format!("args with {} elements", args.len()),
1493 format!("kwargs with {} elements", kwargs.len()),
1494 )),
1495 }
1496 })
1497 };
1498
1499 let value = match result {
1500 Ok(pyobject) => Ok(pyobject),
1501 Err(err) => {
1502 let err = self.report_seq_error(cx, seq, err).await?;
1503 for ref_ in mutates {
1504 self.env.insert(ref_, Err(err.clone()));
1505 }
1506 Err(WorkerError {
1507 backtrace: format!("{:?}", err),
1508 worker_actor_id,
1509 })
1510 }
1511 };
1512
1513 let result = match value {
1517 Ok(_value) => {
1518 unreachable!(
1520 "send_value should return early when respond_with_python_message is true"
1521 )
1522 }
1523 Err(e) => Err(e),
1524 };
1525 self.controller_actor.fetch_result(cx, seq, result).await?;
1526
1527 Ok(())
1528 }
1529
1530 async fn send_result_of_actor_call(
1531 &mut self,
1532 cx: &Context<Self>,
1533 params: ActorCallParams,
1534 ) -> anyhow::Result<()> {
1535 let seq = params.seq;
1536 let mutates = params.mutates.clone();
1537 self.try_define(cx, seq, vec![], &mutates, async |self| {
1538 let value = self.call_actor(cx, params).await?;
1539 let result =
1540 Python::attach(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1541 let result = wirevalue::Any::serialize(&result).unwrap();
1542 self.controller_actor
1543 .fetch_result(cx, seq, Ok(result))
1544 .await?;
1545 Ok(vec![])
1546 })
1547 .await
1548 }
1549
1550 async fn call_actor_method(
1551 &mut self,
1552 cx: &Context<Self>,
1553 params: ActorMethodParams,
1554 ) -> anyhow::Result<()> {
1555 let seq = params.call.seq;
1556 let mutates = params.call.mutates.clone();
1557 self.try_define(cx, seq, params.results, &mutates, async |self| {
1558 let result = self.call_actor(cx, params.call).await?;
1559 let result = Python::attach(|py| {
1560 PyTree::<Py<PyAny>>::extract_bound(&result.into_bound(py))
1561 .map_err(SerializablePyErr::from_fn(py))
1562 })?;
1563 Ok(result.into_leaves())
1564 })
1565 .await
1566 }
1567
1568 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1569 if self.active_recording.is_some() {
1570 bail!("different recording already active");
1571 }
1572 match self.recordings.entry(recording) {
1573 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1574 Entry::Vacant(entry) => entry.insert(Recording::new()),
1575 };
1576 self.active_recording = Some(RecordingState::Defining {
1577 recording,
1578 defined_borrows: HashSet::new(),
1579 });
1580 Ok(())
1581 }
1582
1583 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1584 match self.active_recording {
1585 Some(RecordingState::Defining {
1586 recording: active_recording,
1587 ref defined_borrows,
1588 }) if active_recording == recording => {
1589 ensure!(
1590 defined_borrows.is_empty(),
1591 "all borrows created within recording must be dropped within recording"
1592 );
1593 self.active_recording = None;
1594 }
1595 _ => bail!("cannot finalize recording that isn't active"),
1596 }
1597 Ok(())
1598 }
1599
1600 async fn recording_formal(
1601 &mut self,
1602 _cx: &Context<Self>,
1603 result: Ref,
1604 argument_index: usize,
1605 ) -> Result<()> {
1606 match self.get_defining_recording() {
1607 Some((recording, _)) => {
1608 recording.messages.push(StreamMessage::RecordingFormal {
1609 result,
1610 argument_index,
1611 });
1612 }
1613 None => bail!("recording_formal called outside of recording"),
1614 };
1615 Ok(())
1616 }
1617
1618 async fn recording_result(
1619 &mut self,
1620 _cx: &Context<Self>,
1621 result: Ref,
1622 output_index: usize,
1623 ) -> Result<()> {
1624 match self.get_defining_recording() {
1625 Some((recording, _)) => {
1626 recording.messages.push(StreamMessage::RecordingResult {
1627 result,
1628 output_index,
1629 });
1630 }
1631 None => bail!("recording_result called outside of recording"),
1632 };
1633 Ok(())
1634 }
1635
1636 async fn call_recording(
1637 &mut self,
1638 cx: &Context<Self>,
1639 seq: Seq,
1640 recording: Ref,
1641 results: Vec<Ref>,
1642 actuals: Vec<Ref>,
1643 ) -> Result<()> {
1644 if self.active_recording.is_some() {
1645 bail!("cannot call recording while another recording is active");
1646 }
1647
1648 let messages = match self.recordings.get(&recording) {
1649 Some(recording) => recording
1650 .messages
1651 .iter()
1652 .map(|message| message.clone_for_recording())
1653 .collect::<Vec<_>>(),
1654 None => bail!("recording {:?} not found", recording),
1655 };
1656
1657 self.active_recording = Some(RecordingState::Running);
1658
1659 let mut error: Option<Arc<SeqError>> = None;
1664 let mut all_defined_refs = HashSet::new();
1667 let mut all_mutated_refs = HashSet::new();
1670 let mut formal_to_actual_refs = HashMap::new();
1676 self.last_seq_error = None;
1678 for message in messages.into_iter() {
1679 let defined_refs = message.get_defined_refs();
1680 all_defined_refs.extend(defined_refs.clone());
1681
1682 let mutated_refs_with_formals = message.get_mutated_refs();
1683 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1684 match formal_to_actual_refs.get(ref_) {
1685 Some(actual_ref) => Some(*actual_ref),
1686 None => {
1687 if all_defined_refs.contains(ref_) {
1688 None
1689 } else {
1690 Some(*ref_)
1691 }
1692 }
1693 }
1694 }));
1695
1696 match message {
1697 StreamMessage::RecordingFormal {
1698 result: formal_ref,
1699 argument_index,
1700 } => match actuals.get(argument_index) {
1701 None => bail!("recording_formal called with too few arguments"),
1702 Some(actual_ref) => {
1703 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1704 self.define_ref(formal_ref, *actual_ref)?;
1705 }
1706 },
1707 StreamMessage::RecordingResult {
1708 result: result_ref,
1709 output_index,
1710 } => match results.get(output_index) {
1711 None => bail!("recording_result called with too few results"),
1712 Some(actual_result_ref) => {
1713 self.define_ref(*actual_result_ref, result_ref)?;
1714 }
1715 },
1716 StreamMessage::DeleteRefs(ref refs) => {
1717 for ref_ in refs {
1718 all_defined_refs.remove(ref_);
1719 }
1720 StreamMessageHandler::handle(self, cx, message).await?;
1721 }
1722 StreamMessage::CallFunction { .. } if error.is_some() => {
1723 let error = error.clone().unwrap();
1729 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1730 self.env.insert(*ref_, Err(error.clone()));
1731 }
1732 }
1733 StreamMessage::BorrowLastUse { ref result, .. } => {
1734 all_defined_refs.remove(result);
1735 StreamMessageHandler::handle(self, cx, message).await?;
1736 }
1737 StreamMessage::Reduce {
1738 local_tensor,
1739 ref out,
1740 ..
1741 } => {
1742 if error.is_none() {
1746 let inputs_to_check = [Some(local_tensor), out.clone()]
1747 .iter()
1748 .filter_map(|r| *r)
1749 .collect::<Vec<_>>();
1750 error = self.get_first_error(inputs_to_check.as_slice())?;
1751 }
1752 StreamMessageHandler::handle(self, cx, message).await?;
1753 }
1754 StreamMessage::SendTensor {
1755 ref tensor,
1756 ref to_rank,
1757 ..
1758 } => {
1759 if to_rank.is_some() && error.is_none() {
1764 error = self.get_first_error(&[*tensor])?;
1765 }
1766 StreamMessageHandler::handle(self, cx, message).await?;
1767 }
1768 _ => {
1769 StreamMessageHandler::handle(self, cx, message).await?;
1770 }
1771 };
1772
1773 match (&error, self.last_seq_error.take()) {
1781 (None, Some(seq_err)) => {
1782 self.controller_actor
1784 .remote_function_failed(
1785 cx,
1786 seq,
1787 WorkerError {
1788 backtrace: format!("recording failed: {}", &seq_err),
1789 worker_actor_id: cx.self_id().clone(),
1790 },
1791 )
1792 .await?;
1793 error = Some(seq_err)
1794 }
1795 _ => {}
1796 }
1797 }
1802
1803 StreamMessageHandler::handle(
1807 self,
1808 cx,
1809 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1810 )
1811 .await?;
1812
1813 if error.is_some() {
1816 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1817 self.env.insert(*ref_, Err(error.clone().unwrap()));
1818 }
1819 }
1820
1821 self.active_recording = None;
1822 Ok(())
1823 }
1824
1825 async fn set_ref_unit_tests_only(
1826 &mut self,
1827 _cx: &Context<Self>,
1828 reference: Ref,
1829 value: WireValue,
1830 ) -> Result<()> {
1831 let pyobj =
1832 Python::attach(|py| -> PyResult<Py<PyAny>> { Ok(value.into_pyobject(py)?.unbind()) })?;
1833 self.env.insert(reference, Ok(pyobj));
1834 Ok(())
1835 }
1836
1837 async fn set_tensor_ref_unit_tests_only(
1838 &mut self,
1839 _cx: &Context<Self>,
1840 reference: Ref,
1841 tensor_result: TensorCellResult,
1842 ) -> Result<()> {
1843 match tensor_result {
1844 Ok(tensor_cell) => {
1845 let pyobj = Self::tensor_to_pyobject(tensor_cell);
1846 self.env.insert(reference, Ok(pyobj));
1847 }
1848 Err(err) => {
1849 self.env.insert(reference, Err(err));
1850 }
1851 }
1852 Ok(())
1853 }
1854
1855 async fn get_ref_unit_tests_only(
1856 &mut self,
1857 _cx: &Context<Self>,
1858 reference: Ref,
1859 ) -> Result<Option<Result<WireValue, String>>> {
1860 use pyo3::types::PyBool;
1861 use pyo3::types::PyFloat;
1862 use pyo3::types::PyInt;
1863 use pyo3::types::PyList;
1864 use pyo3::types::PyNone;
1865 use pyo3::types::PyString;
1866 fn pyobject_to_wire(
1868 value: Result<Py<PyAny>, Arc<SeqError>>,
1869 ) -> Result<WireValue, Arc<SeqError>> {
1870 let pyobj = value?;
1871 Python::attach(|py| {
1872 let bound = pyobj.bind(py);
1873 if bound.is_instance_of::<PyBool>() {
1875 Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1876 } else if bound.is_instance_of::<PyInt>() {
1877 Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1878 } else if bound.is_instance_of::<PyList>() {
1879 if let Ok(val) = bound.extract::<Vec<i64>>() {
1880 Ok(WireValue::IntList(val))
1881 } else {
1882 Ok(WireValue::String(format!(
1883 "unsupported list type: {:?}",
1884 bound
1885 )))
1886 }
1887 } else if bound.is_instance_of::<PyFloat>() {
1888 Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1889 } else if bound.is_instance_of::<PyString>() {
1890 Ok(WireValue::String(bound.extract::<String>().unwrap()))
1891 } else if bound.is_instance_of::<PyNone>() {
1892 Ok(WireValue::None(()))
1893 } else {
1894 Ok(WireValue::String(format!(
1895 "unsupported pyobject type: {:?}",
1896 bound
1897 )))
1898 }
1899 })
1900 }
1901 Ok(self.env.get(&reference).map(|pyobj| {
1902 pyobject_to_wire(Python::attach(|_py| pyobj.clone())).map_err(|err| err.to_string())
1903 }))
1904 }
1905
1906 async fn get_tensor_ref_unit_tests_only(
1907 &mut self,
1908 _cx: &Context<Self>,
1909 reference: Ref,
1910 ) -> Result<Option<TensorCellResult>> {
1911 match self.env.get(&reference) {
1912 Some(Ok(pyobj)) => Python::attach(|py| match Self::pyobject_to_tensor(py, pyobj) {
1913 Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1914 Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1915 }),
1916 Some(Err(err)) => Ok(Some(Err(err.clone()))),
1917 None => Ok(None),
1918 }
1919 }
1920}
1921
1922#[cfg(test)]
1923mod tests {
1924 use hyperactor::actor::ActorStatus;
1925 use hyperactor::context;
1926 use hyperactor::supervision::ActorSupervisionEvent;
1927 use monarch_messages::controller::ControllerMessage;
1928 use monarch_messages::worker::StreamCreationMode;
1929 use monarch_types::PickledPyObject;
1930 use pyo3::IntoPyObjectExt;
1931 use timed_test::async_timed_test;
1932 use tokio::sync::watch;
1933 use torch_sys_cuda::nccl::UniqueId;
1934 use torch_sys2::factory_float_tensor;
1935 use torch_sys2::testing::allclose;
1936
1937 use super::*;
1938 use crate::comm::CommParams;
1939 use crate::test_util;
1940
1941 #[allow(dead_code)]
1942 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1943 Arc::new(SeqError {
1944 seq: 0.into(),
1945 error: err,
1946 })
1947 }
1948
1949 struct TestSetup {
1950 proc: Proc,
1951 stream_actor: ActorHandle<StreamActor>,
1952 client: Instance<()>,
1953 #[allow(dead_code)]
1956 supervision_rx: PortReceiver<ActorSupervisionEvent>,
1957 #[allow(dead_code)]
1958 controller_rx: PortReceiver<ControllerMessage>,
1959 #[allow(dead_code)]
1960 controller_actor: reference::ActorRef<ControllerActor>,
1961 next_ref: Ref,
1962 }
1963
1964 impl TestSetup {
1965 async fn new() -> Result<Self> {
1966 Self::new_with_world_size(1).await
1967 }
1968
1969 async fn new_with_world_size(world_size: usize) -> Result<Self> {
1970 test_util::test_setup()?;
1971
1972 let proc = Proc::local();
1973 let (_, controller_actor, controller_rx) =
1974 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1975 let (client, _handle) = proc.instance("client")?;
1976 let (supervision_tx, supervision_rx) = client.open_port();
1977 proc.set_supervision_coordinator(supervision_tx)?;
1978 let stream_actor = proc.spawn(
1979 "stream",
1980 StreamActor::new(StreamParams {
1981 world_size,
1982 rank: 0,
1983 creation_mode: StreamCreationMode::UseDefaultStream,
1984 id: 0.into(),
1985 device: Some(CudaDevice::new(0.into())),
1986 controller_actor: controller_actor.clone(),
1987 respond_with_python_message: false,
1988 }),
1989 )?;
1990
1991 Ok(Self {
1992 proc,
1993 stream_actor,
1994 client,
1995 supervision_rx,
1996 controller_rx,
1997 controller_actor,
1998 next_ref: 0.into(),
1999 })
2000 }
2001
2002 fn next_ref(&mut self) -> Ref {
2003 let ref_ = self.next_ref;
2004 self.next_ref = Ref {
2005 id: self.next_ref.id + 1,
2006 };
2007 ref_
2008 }
2009
2010 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2011 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2012 self.stream_actor
2013 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2014 .await
2015 }
2016
2017 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2018 let actual = self
2019 .stream_actor
2020 .get_tensor_ref_unit_tests_only(&self.client, reference)
2021 .await
2022 .unwrap()
2023 .unwrap()
2024 .unwrap();
2025
2026 allclose(
2028 &factory_float_tensor(data, "cpu".parse().unwrap()),
2029 &actual.borrow(),
2030 )
2031 .unwrap()
2032 }
2033
2034 #[allow(dead_code)]
2035 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2036 let result_error = self
2037 .stream_actor
2038 .get_tensor_ref_unit_tests_only(&self.client, reference)
2039 .await
2040 .unwrap()
2041 .unwrap()
2042 .unwrap_err();
2043
2044 assert!(Arc::ptr_eq(&result_error, &error));
2045 }
2046 }
2047
2048 async fn assert_actor_failed_with_msg(
2049 status_rx: &mut watch::Receiver<ActorStatus>,
2050 expected_msg: String,
2051 ) {
2052 status_rx
2053 .wait_for(|s| matches!(s, ActorStatus::Failed(_)))
2054 .await
2055 .unwrap();
2056 let status = status_rx.borrow().clone();
2057 if let ActorStatus::Failed(msg) = status {
2058 assert!(msg.to_string().contains(&expected_msg));
2059 } else {
2060 panic!("expected ActorStatus::Failed, got {:?}", status);
2061 }
2062 }
2063
2064 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2065 for ref_ in refs {
2066 assert!(
2067 test_setup
2068 .stream_actor
2069 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2070 .await
2071 .unwrap()
2072 .is_none()
2073 );
2074 }
2075 }
2076
2077 #[allow(dead_code)]
2078 async fn fetch_result(
2079 cx: &impl context::Actor,
2080 stream_actor: ActorHandle<StreamActor>,
2081 seq: Seq,
2082 reference: Ref,
2083 ) {
2084 let ref_to_send = Python::attach(|py| {
2085 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2086 });
2087
2088 stream_actor
2089 .send_value(
2090 cx,
2091 seq,
2092 stream_actor.actor_id().clone(),
2093 Vec::new(),
2094 None,
2095 ArgsKwargs::from_wire_values(
2096 vec![WireValue::PyObject(ref_to_send)],
2097 HashMap::new(),
2098 )
2099 .unwrap(),
2100 HashMap::new(),
2101 )
2102 .await
2103 .unwrap()
2104 }
2105
2106 #[allow(dead_code)]
2107 async fn check_fetch_result_error(
2108 cx: &impl context::Actor,
2109 stream_actor: ActorHandle<StreamActor>,
2110 seq: Seq,
2111 reference: Ref,
2112 controller_rx: &mut PortReceiver<ControllerMessage>,
2113 expected_backtrace: &str,
2114 ) {
2115 fetch_result(cx, stream_actor, seq, reference).await;
2116
2117 let controller_msg = controller_rx.recv().await.unwrap();
2118 match controller_msg {
2119 ControllerMessage::FetchResult {
2120 seq: actual_seq,
2121 value: Err(err),
2122 } => {
2123 assert_eq!(actual_seq, seq);
2124 assert!(
2125 err.backtrace.contains(expected_backtrace),
2126 "backtrace did not contain {:?}: {:?}",
2127 expected_backtrace,
2128 err.backtrace
2129 );
2130 }
2131 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2132 };
2133 }
2134
2135 #[allow(dead_code)]
2136 async fn check_fetch_result_value(
2137 cx: &impl context::Actor,
2138 stream_actor: ActorHandle<StreamActor>,
2139 seq: Seq,
2140 reference: Ref,
2141 controller_rx: &mut PortReceiver<ControllerMessage>,
2142 ) {
2143 fetch_result(cx, stream_actor, seq, reference).await;
2144
2145 let controller_msg = controller_rx.recv().await.unwrap();
2146 match controller_msg {
2147 ControllerMessage::FetchResult {
2148 value: Ok(_),
2149 seq: actual_seq,
2150 } => assert_eq!(seq, actual_seq),
2151 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2152 };
2153 }
2154
2155 #[async_timed_test(timeout_secs = 60)]
2156 async fn test_define_recording_other_recording_active() -> Result<()> {
2157 let test_setup = TestSetup::new().await?;
2158 test_setup
2159 .stream_actor
2160 .define_recording(&test_setup.client, 0.into())
2161 .await?;
2162 test_setup
2163 .stream_actor
2164 .define_recording(&test_setup.client, 1.into())
2165 .await?;
2166 assert_actor_failed_with_msg(
2167 &mut test_setup.stream_actor.status(),
2168 "different recording already active".into(),
2169 )
2170 .await;
2171 Ok(())
2172 }
2173
2174 #[async_timed_test(timeout_secs = 60)]
2175 async fn test_define_recording_already_defined() -> Result<()> {
2176 let test_setup = TestSetup::new().await?;
2177 test_setup
2178 .stream_actor
2179 .define_recording(&test_setup.client, 0.into())
2180 .await?;
2181 test_setup
2182 .stream_actor
2183 .finalize_recording(&test_setup.client, 0.into())
2184 .await?;
2185 test_setup
2186 .stream_actor
2187 .define_recording(&test_setup.client, 0.into())
2188 .await?;
2189 assert_actor_failed_with_msg(
2190 &mut test_setup.stream_actor.status(),
2191 "already defined".into(),
2192 )
2193 .await;
2194 Ok(())
2195 }
2196
2197 #[async_timed_test(timeout_secs = 60)]
2198 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2199 let test_setup = TestSetup::new().await?;
2200 test_setup
2201 .stream_actor
2202 .define_recording(&test_setup.client, 0.into())
2203 .await?;
2204 test_setup
2205 .stream_actor
2206 .finalize_recording(&test_setup.client, 1.into())
2207 .await?;
2208 assert_actor_failed_with_msg(
2209 &mut test_setup.stream_actor.status(),
2210 "cannot finalize recording that isn't active".into(),
2211 )
2212 .await;
2213 Ok(())
2214 }
2215
2216 #[async_timed_test(timeout_secs = 60)]
2217 async fn test_recording_formal_outside_recording() -> Result<()> {
2218 let test_setup = TestSetup::new().await?;
2219 test_setup
2220 .stream_actor
2221 .recording_formal(&test_setup.client, 0.into(), 0)
2222 .await?;
2223 assert_actor_failed_with_msg(
2224 &mut test_setup.stream_actor.status(),
2225 "recording_formal called outside of recording".into(),
2226 )
2227 .await;
2228 Ok(())
2229 }
2230
2231 #[async_timed_test(timeout_secs = 60)]
2232 async fn test_recording_result_outside_recording() -> Result<()> {
2233 let test_setup = TestSetup::new().await?;
2234 test_setup
2235 .stream_actor
2236 .recording_result(&test_setup.client, 0.into(), 0)
2237 .await?;
2238 assert_actor_failed_with_msg(
2239 &mut test_setup.stream_actor.status(),
2240 "recording_result called outside of recording".into(),
2241 )
2242 .await;
2243 Ok(())
2244 }
2245
2246 #[async_timed_test(timeout_secs = 60)]
2247 async fn test_call_recording_other_recording_active() -> Result<()> {
2248 let test_setup = TestSetup::new().await?;
2249 test_setup
2250 .stream_actor
2251 .define_recording(&test_setup.client, 0.into())
2252 .await?;
2253 test_setup
2254 .stream_actor
2255 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2256 .await?;
2257 assert_actor_failed_with_msg(
2258 &mut test_setup.stream_actor.status(),
2259 "cannot call recording while another recording is active".into(),
2260 )
2261 .await;
2262 Ok(())
2263 }
2264
2265 #[async_timed_test(timeout_secs = 60)]
2266 async fn test_call_recording_not_found() -> Result<()> {
2267 let test_setup = TestSetup::new().await?;
2268 test_setup
2269 .stream_actor
2270 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2271 .await?;
2272 assert_actor_failed_with_msg(&mut test_setup.stream_actor.status(), "not found".into())
2273 .await;
2274 Ok(())
2275 }
2276
2277 #[async_timed_test(timeout_secs = 60)]
2278 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2279 let test_setup = TestSetup::new().await?;
2280
2281 test_setup
2282 .stream_actor
2283 .define_recording(&test_setup.client, 0.into())
2284 .await?;
2285
2286 test_setup
2287 .stream_actor
2288 .recording_formal(&test_setup.client, 1.into(), 0)
2289 .await?;
2290
2291 test_setup
2292 .stream_actor
2293 .finalize_recording(&test_setup.client, 0.into())
2294 .await?;
2295
2296 test_setup
2297 .stream_actor
2298 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2299 .await?;
2300
2301 assert_actor_failed_with_msg(
2302 &mut test_setup.stream_actor.status(),
2303 "recording_formal called with too few arguments".into(),
2304 )
2305 .await;
2306 Ok(())
2307 }
2308
2309 #[async_timed_test(timeout_secs = 60)]
2310 async fn test_recording_result_too_few_results() -> Result<()> {
2311 let test_setup = TestSetup::new().await?;
2312
2313 test_setup
2314 .stream_actor
2315 .define_recording(&test_setup.client, 0.into())
2316 .await?;
2317
2318 test_setup
2319 .stream_actor
2320 .recording_result(&test_setup.client, 1.into(), 0)
2321 .await?;
2322
2323 test_setup
2324 .stream_actor
2325 .finalize_recording(&test_setup.client, 0.into())
2326 .await?;
2327
2328 test_setup
2329 .stream_actor
2330 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2331 .await?;
2332
2333 assert_actor_failed_with_msg(
2334 &mut test_setup.stream_actor.status(),
2335 "recording_result called with too few results".into(),
2336 )
2337 .await;
2338 Ok(())
2339 }
2340
2341 #[async_timed_test(timeout_secs = 60)]
2342 async fn test_basic_call_recording() -> Result<()> {
2343 let mut test_setup = TestSetup::new().await?;
2344
2345 test_setup
2349 .stream_actor
2350 .define_recording(&test_setup.client, 0.into())
2351 .await?;
2352
2353 let formal0_ref = 1.into();
2354 let formal0_index = 1;
2355 test_setup
2356 .stream_actor
2357 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2358 .await?;
2359
2360 let formal1_ref = 2.into();
2361 let formal1_index = 0;
2362 test_setup
2363 .stream_actor
2364 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2365 .await?;
2366
2367 let result0_ref = formal0_ref;
2368 let result0_index = 0;
2369 test_setup
2370 .stream_actor
2371 .recording_result(&test_setup.client, result0_ref, result0_index)
2372 .await?;
2373
2374 let result1_ref = formal1_ref;
2375 let result1_index = 1;
2376 test_setup
2377 .stream_actor
2378 .recording_result(&test_setup.client, result1_ref, result1_index)
2379 .await?;
2380
2381 test_setup
2382 .stream_actor
2383 .finalize_recording(&test_setup.client, 0.into())
2384 .await?;
2385
2386 let actual0_ref = 3.into();
2387 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2388
2389 let actual1_ref = 4.into();
2390 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2391
2392 let actual_result0_ref = 5.into();
2395 let actual_result1_ref = 6.into();
2396 test_setup
2397 .stream_actor
2398 .call_recording(
2399 &test_setup.client,
2400 0.into(),
2401 0.into(),
2402 vec![actual_result0_ref, actual_result1_ref],
2403 vec![actual0_ref, actual1_ref],
2404 )
2405 .await?;
2406
2407 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2409 assert!(
2410 test_setup
2411 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2412 .await
2413 );
2414
2415 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2418 Ok(())
2419 }
2420
2421 #[async_timed_test(timeout_secs = 60)]
2422 async fn test_request_status_in_recording() -> Result<()> {
2423 let test_setup = TestSetup::new().await?;
2424 test_setup
2425 .stream_actor
2426 .define_recording(&test_setup.client, 0.into())
2427 .await?;
2428 test_setup
2429 .stream_actor
2430 .request_status(&test_setup.client)
2431 .await
2432 .expect_err("request_status should have failed");
2433 assert_actor_failed_with_msg(
2434 &mut test_setup.stream_actor.status(),
2435 "request_status not allowed in recording".into(),
2436 )
2437 .await;
2438 Ok(())
2439 }
2440
2441 #[async_timed_test(timeout_secs = 60)]
2442 async fn test_init_comm_in_recording() -> Result<()> {
2443 let test_setup = TestSetup::new().await?;
2444 test_setup
2445 .stream_actor
2446 .define_recording(&test_setup.client, 0.into())
2447 .await?;
2448
2449 let dummy_comm = test_setup.proc.spawn(
2450 "comm",
2451 NcclCommActor::new(CommParams::New {
2452 device: CudaDevice::new(0.into()),
2453 unique_id: UniqueId::new()?,
2454 world_size: 1,
2455 rank: 0,
2456 })
2457 .await
2458 .unwrap(),
2459 )?;
2460
2461 test_setup
2462 .stream_actor
2463 .init_comm(&test_setup.client, dummy_comm)
2464 .await?;
2465 assert_actor_failed_with_msg(
2466 &mut test_setup.stream_actor.status(),
2467 "init_comm not allowed in recording".into(),
2468 )
2469 .await;
2470 Ok(())
2471 }
2472
2473 #[async_timed_test(timeout_secs = 60)]
2474 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2475 let mut test_setup = TestSetup::new().await?;
2476 test_setup
2477 .stream_actor
2478 .define_recording(&test_setup.client, 0.into())
2479 .await?;
2480
2481 let borrow_id = 1;
2482 let tensor_ref = test_setup.next_ref();
2483 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2484
2485 test_setup
2486 .stream_actor
2487 .borrow_create(
2488 &test_setup.client,
2489 borrow_id,
2490 tensor_ref,
2491 first_use_sender.clone(),
2492 )
2493 .await?;
2494
2495 test_setup
2496 .stream_actor
2497 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2498 .await?;
2499
2500 assert_actor_failed_with_msg(
2501 &mut test_setup.stream_actor.status(),
2502 "duplicate borrow create in recording".into(),
2503 )
2504 .await;
2505
2506 Ok(())
2507 }
2508
2509 #[async_timed_test(timeout_secs = 60)]
2510 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2511 let test_setup = TestSetup::new().await?;
2512 test_setup
2513 .stream_actor
2514 .define_recording(&test_setup.client, 0.into())
2515 .await?;
2516
2517 let borrow_id = 1;
2518 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2519
2520 test_setup
2521 .stream_actor
2522 .borrow_drop(
2523 &test_setup.client,
2524 borrow_id,
2525 Arc::new(Mutex::new(last_use_receiver)),
2526 )
2527 .await?;
2528
2529 assert_actor_failed_with_msg(
2530 &mut test_setup.stream_actor.status(),
2531 "borrow drop for borrow not defined in recording".into(),
2532 )
2533 .await;
2534
2535 Ok(())
2536 }
2537
2538 #[async_timed_test(timeout_secs = 60)]
2539 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2540 let mut test_setup = TestSetup::new().await?;
2541 test_setup
2542 .stream_actor
2543 .define_recording(&test_setup.client, 0.into())
2544 .await?;
2545
2546 let borrow_id = 1;
2547 let tensor_ref = test_setup.next_ref();
2548 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2549
2550 test_setup
2551 .stream_actor
2552 .borrow_create(
2553 &test_setup.client,
2554 borrow_id,
2555 tensor_ref,
2556 first_use_sender.clone(),
2557 )
2558 .await?;
2559
2560 test_setup
2562 .stream_actor
2563 .finalize_recording(&test_setup.client, 0.into())
2564 .await?;
2565
2566 assert_actor_failed_with_msg(
2567 &mut test_setup.stream_actor.status(),
2568 "all borrows created within recording must be dropped within recording".into(),
2569 )
2570 .await;
2571
2572 Ok(())
2573 }
2574}