1use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::PortHandle;
31use hyperactor::actor::ActorHandle;
32use hyperactor::forward;
33use hyperactor::mailbox::OncePortHandle;
34use hyperactor::mailbox::PortReceiver;
35use hyperactor::proc::Proc;
36use monarch_hyperactor::actor::PythonMessage;
37use monarch_hyperactor::actor::PythonMessageKind;
38use monarch_hyperactor::buffers::Buffer;
39use monarch_hyperactor::local_state_broker::BrokerId;
40use monarch_hyperactor::local_state_broker::LocalState;
41use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
42use monarch_messages::controller::ControllerMessageClient;
43use monarch_messages::controller::Seq;
44use monarch_messages::controller::WorkerError;
45use monarch_messages::worker::ActorCallParams;
46use monarch_messages::worker::ActorMethodParams;
47use monarch_messages::worker::ArgsKwargs;
48use monarch_messages::worker::CallFunctionError;
49use monarch_messages::worker::CallFunctionParams;
50use monarch_messages::worker::SeqError;
51use monarch_messages::worker::StreamRef;
52use monarch_types::PyTree;
53use monarch_types::SerializablePyErr;
54use monarch_types::TryIntoPyObjectUnsafe;
55use pyo3::prelude::*;
56use tokio::runtime::Handle;
57use tokio::sync::Mutex;
58use tokio::task::JoinHandle;
59use torch_sys_cuda::cuda::Event;
60use torch_sys_cuda::cuda::Stream;
61use torch_sys2::CloneUnsafe;
62use torch_sys2::CudaDevice;
63use torch_sys2::TensorCell;
64use torch_sys2::deep_clone;
65use torch_sys2::factory_empty;
66use torch_sys2::factory_zeros;
67use tracing_subscriber::fmt::Subscriber;
68use typeuri::Named;
69
70use crate::ControllerActor;
71use crate::DeviceMesh;
72use crate::Factory;
73use crate::Reduction;
74use crate::Ref;
75use crate::ResolvableFunction;
76use crate::StreamCreationMode;
77use crate::WireValue;
78use crate::comm::CommMessage;
79use crate::comm::CommMessageClient;
80use crate::comm::NcclCommActor;
81
82pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
83
84thread_local! {
86 pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
87 pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
88 pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
89}
90
91fn pickle_python_result(
92 py: Python<'_>,
93 result: Bound<'_, PyAny>,
94 worker_rank: usize,
95) -> Result<PythonMessage, anyhow::Error> {
96 let pickle = py
97 .import("monarch._src.actor.actor_mesh")
98 .unwrap()
99 .getattr("_pickle")
100 .unwrap();
101 let mut data: Buffer = pickle
102 .call1((result,))
103 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
104 .extract()
105 .unwrap();
106 Ok(PythonMessage::new_from_buf(
107 PythonMessageKind::Result {
108 rank: Some(worker_rank),
109 },
110 data.take_part(),
111 ))
112}
113
114#[derive(Debug)]
115struct Recording {
116 messages: Vec<StreamMessage>,
117}
118
119impl Recording {
120 fn new() -> Self {
121 Self {
122 messages: Vec::new(),
123 }
124 }
125}
126
127#[derive(Debug, PartialEq)]
128enum RecordingState {
129 Defining {
130 recording: Ref,
131 defined_borrows: HashSet<u64>,
134 },
135 Running,
136}
137
138#[derive(Handler, HandleClient, Debug, Named)]
141pub enum StreamMessage {
142 CallFunction(
143 CallFunctionParams,
144 HashMap<Ref, DeviceMesh>,
145 HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
146 ),
147
148 BorrowCreate {
149 borrow: u64,
151 tensor: Ref,
153 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
156 },
157
158 BorrowFirstUse {
159 borrow: u64,
161 result: Ref,
163 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
166 },
167
168 BorrowLastUse {
169 borrow: u64,
171 result: Ref,
173 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
175 },
176
177 BorrowDrop {
178 borrow: u64,
179 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
181 },
182
183 DeleteRefs(Vec<Ref>),
184
185 RequestStatus(#[reply] OncePortHandle<()>),
186
187 InitComm(ActorHandle<NcclCommActor>),
188
189 Reduce {
190 comm: Arc<ActorHandle<NcclCommActor>>,
191 dim_size: i64,
192 result: Ref,
193 local_tensor: Ref,
194 factory: Factory,
195 reduction: Reduction,
196 scatter: bool,
197 in_place: bool,
198 out: Option<Ref>,
199 },
200
201 SendTensor {
202 result: Ref,
203 from_rank: Option<usize>,
204 to_rank: Option<usize>,
205 tensor: Ref,
206 factory: Factory,
207 comm: Arc<ActorHandle<NcclCommActor>>,
208 },
209
210 SendValue {
211 seq: Seq,
212 worker_actor_id: ActorId,
213 mutates: Vec<Ref>,
214 function: Option<ResolvableFunction>,
215 args_kwargs: ArgsKwargs,
216 device_meshes: HashMap<Ref, DeviceMesh>,
217 },
218
219 DefineRecording {
220 recording: Ref,
221 },
222
223 FinalizeRecording {
224 recording: Ref,
225 },
226
227 CallRecording {
228 seq: Seq,
229 recording: Ref,
230 results: Vec<Ref>,
231 actuals: Vec<Ref>,
232 },
233
234 RecordingFormal {
235 result: Ref,
236 argument_index: usize,
237 },
238
239 RecordingResult {
240 result: Ref,
241 output_index: usize,
242 },
243
244 SetRefUnitTestsOnly(Ref, WireValue),
245
246 SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
247
248 GetRefUnitTestsOnly(
249 Ref, #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
251 ),
252
253 GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
254
255 SendResultOfActorCall(ActorCallParams),
256 CallActorMethod(ActorMethodParams),
257}
258
259impl StreamMessage {
260 fn clone_for_recording(&self) -> Self {
261 match self {
262 StreamMessage::RecordingFormal {
263 result,
264 argument_index,
265 } => StreamMessage::RecordingFormal {
266 result: *result,
267 argument_index: *argument_index,
268 },
269 StreamMessage::RecordingResult {
270 result,
271 output_index,
272 } => StreamMessage::RecordingResult {
273 result: *result,
274 output_index: *output_index,
275 },
276 StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
277 StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
278 StreamMessage::CallFunction(
279 params.clone(),
280 device_meshes.clone(),
281 remote_process_groups.clone(),
282 )
283 }
284 StreamMessage::BorrowCreate {
285 borrow,
286 tensor,
287 first_use_sender,
288 } => StreamMessage::BorrowCreate {
289 borrow: *borrow,
290 tensor: *tensor,
291 first_use_sender: first_use_sender.clone(),
292 },
293 StreamMessage::BorrowFirstUse {
294 borrow,
295 result,
296 first_use_receiver,
297 } => StreamMessage::BorrowFirstUse {
298 borrow: *borrow,
299 result: *result,
300 first_use_receiver: first_use_receiver.clone(),
301 },
302 StreamMessage::BorrowLastUse {
303 borrow,
304 result,
305 last_use_sender,
306 } => StreamMessage::BorrowLastUse {
307 borrow: *borrow,
308 result: *result,
309 last_use_sender: last_use_sender.clone(),
310 },
311 StreamMessage::BorrowDrop {
312 borrow,
313 last_use_receiver,
314 } => StreamMessage::BorrowDrop {
315 borrow: *borrow,
316 last_use_receiver: last_use_receiver.clone(),
317 },
318 StreamMessage::Reduce {
319 comm,
320 dim_size,
321 result,
322 local_tensor,
323 factory,
324 reduction,
325 scatter,
326 in_place,
327 out,
328 } => StreamMessage::Reduce {
329 comm: comm.clone(),
330 dim_size: *dim_size,
331 result: *result,
332 local_tensor: *local_tensor,
333 factory: factory.clone(),
334 reduction: reduction.clone(),
335 scatter: *scatter,
336 in_place: *in_place,
337 out: out.clone(),
338 },
339 StreamMessage::SendTensor {
340 result,
341 from_rank,
342 to_rank,
343 tensor,
344 factory,
345 comm,
346 } => StreamMessage::SendTensor {
347 result: *result,
348 from_rank: *from_rank,
349 to_rank: *to_rank,
350 tensor: *tensor,
351 factory: factory.clone(),
352 comm: comm.clone(),
353 },
354 other => panic!(
355 "StreamMessage variant not supported in recording: {:?}",
356 other
357 ),
358 }
359 }
360
361 fn get_defined_refs(&self) -> HashSet<Ref> {
363 match self {
364 StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
365 StreamMessage::CallFunction(params, ..) => {
366 params.results.iter().filter_map(|&ref_| ref_).collect()
367 }
368 StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
369 StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
370 StreamMessage::SendTensor {
371 result, from_rank, ..
372 } => {
373 if from_rank.is_some() {
374 HashSet::from([*result])
375 } else {
376 HashSet::new()
377 }
378 }
379 _ => HashSet::new(),
381 }
382 }
383
384 fn get_mutated_refs(&self) -> HashSet<Ref> {
386 match self {
387 StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
388 StreamMessage::Reduce {
389 out,
390 in_place,
391 local_tensor,
392 ..
393 } => {
394 if *in_place {
395 HashSet::from([*local_tensor])
396 } else if let Some(out) = out {
397 HashSet::from([*out])
398 } else {
399 HashSet::new()
400 }
401 }
402 _ => HashSet::new(),
404 }
405 }
406}
407
408#[derive(Debug)]
417pub struct StreamActor {
418 _world_size: usize,
419 rank: usize,
420 env: HashMap<Ref, Result<PyObject, Arc<SeqError>>>,
424 creation_mode: StreamCreationMode,
426 cuda_stream: OnceLock<Option<Stream>>,
432 device: Option<CudaDevice>,
434 comm: Option<ActorHandle<NcclCommActor>>,
436 controller_actor: ActorRef<ControllerActor>,
438 remote_process_groups: HashMap<Ref, PyObject>,
439 recordings: HashMap<Ref, Recording>,
440 active_recording: Option<RecordingState>,
441 respond_with_python_message: bool,
442 last_seq_error: Option<Arc<SeqError>>,
443}
444
445#[derive(Debug, Clone)]
447pub struct StreamParams {
448 pub world_size: usize,
449 pub rank: usize,
450 pub creation_mode: StreamCreationMode,
452 pub id: StreamRef,
454 pub device: Option<CudaDevice>,
457 pub controller_actor: ActorRef<ControllerActor>,
459 pub respond_with_python_message: bool,
460}
461
462impl StreamActor {
463 pub fn new(
464 StreamParams {
465 world_size,
466 rank,
467 id: _,
468 device,
469 controller_actor,
470 creation_mode,
471 respond_with_python_message,
472 }: StreamParams,
473 ) -> Self {
474 Self {
475 _world_size: world_size,
476 rank,
477 env: HashMap::new(),
478 creation_mode,
479 cuda_stream: OnceLock::new(),
480 device,
481 comm: None,
482 controller_actor,
483 remote_process_groups: HashMap::new(),
484 recordings: HashMap::new(),
485 active_recording: None,
486 respond_with_python_message,
487 last_seq_error: None,
488 }
489 }
490}
491
492#[async_trait]
493impl Actor for StreamActor {
494 async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
495 CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
499 controller_actor_ref.set(self.controller_actor.clone()).ok()
500 });
501 PROC.with(|proc| proc.set(cx.proc().clone()).ok());
502 ROOT_ACTOR_ID.with(|root_actor_id| {
503 root_actor_id
504 .set(ActorId::root(
505 cx.self_id().proc_id().clone(),
506 cx.self_id().name().to_string(),
507 ))
508 .ok()
509 });
510 if let Some(stream) = self.cuda_stream() {
512 Stream::set_current_stream(stream);
513 }
514 Ok(())
515 }
516
517 fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
523 where
524 F: Future + Send + 'static,
525 F::Output: Send + 'static,
526 {
527 let (join_tx, join_rx) = tokio::sync::oneshot::channel();
528 let builder = std::thread::Builder::new().name("worker-stream".to_string());
536 let _thread_handle = builder.spawn(move || {
537 let rt = tokio::runtime::Builder::new_multi_thread()
541 .worker_threads(1)
542 .enable_all()
543 .build()
544 .unwrap();
545 let result = rt.block_on(async {
546 tokio::task::block_in_place(|| {
547 Python::with_gil(|py| {
552 py.allow_threads(|| {
553 let result = Handle::current().block_on(future);
554 if join_tx.send(result).is_err() {
555 panic!("could not send join result")
556 }
557 })
558 })
559 })
560 });
561 rt.shutdown_timeout(Duration::from_weeks(1));
562 result
563 });
564
565 tokio::spawn(async move { join_rx.await.unwrap() })
568 }
569}
570
571#[derive(Debug)]
573enum PyArg {
574 PyObject(PyObject),
575}
576
577impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
579 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
580 match self {
581 PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
582 }
583 }
584}
585
586impl StreamActor {
587 fn tensor_to_pyobject(tensor_cell: TensorCell) -> PyObject {
588 Python::with_gil(|py| {
589 let tensor = unsafe {
594 tensor_cell.get_unchecked().clone_unsafe()
596 };
597 tensor.into_pyobject(py).unwrap().unbind()
598 })
599 }
600
601 fn pyobject_to_tensor(py: Python<'_>, pyobj: &PyObject) -> PyResult<TensorCell> {
605 use torch_sys2::Tensor;
606 let tensor = pyobj.bind(py).extract::<Tensor>()?;
607 Ok(TensorCell::new(tensor))
609 }
610
611 fn cuda_stream(&self) -> Option<&Stream> {
612 self.cuda_stream
613 .get_or_init(|| {
614 self.device.map(|device| match self.creation_mode {
615 StreamCreationMode::UseDefaultStream => {
616 Stream::get_current_stream_on_device(device)
617 }
618 StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
619 })
620 })
621 .as_ref()
622 }
623
624 fn ref_to_pyobject(&self, ref_: &Ref) -> Result<PyObject, CallFunctionError> {
625 let pyobject = self
626 .env
627 .get(ref_)
628 .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
629 match pyobject {
630 Ok(val) => Ok(val.clone()),
631 Err(err) => Err(CallFunctionError::DependentError(err.clone())),
632 }
633 }
634
635 async fn report_seq_error(
636 &mut self,
637 cx: &Context<'_, Self>,
638 seq: Seq,
639 error: CallFunctionError,
640 ) -> Result<Arc<SeqError>, anyhow::Error> {
641 match error {
642 CallFunctionError::DependentError(root) => Ok(root),
643 CallFunctionError::Error(e) => {
644 if self.active_recording.is_none() {
645 let worker_error = WorkerError {
646 backtrace: format!("{e}"),
647 worker_actor_id: cx.self_id().clone(),
648 };
649 tracing::info!("Propagating remote function error to client: {worker_error}");
650 self.controller_actor
651 .remote_function_failed(cx, seq, worker_error)
652 .await?
653 }
654 let err = Arc::new(SeqError { seq, error: e });
655 self.last_seq_error = Some(err.clone());
656 Ok(err)
657 }
658 }
659 }
660
661 async fn try_define<F>(
662 &mut self,
663 cx: &Context<'_, Self>,
664 seq: Seq,
665 result_refs: Vec<Option<Ref>>,
666 mutates: &Vec<Ref>,
667 f: F,
668 ) -> Result<()>
669 where
670 F: AsyncFnOnce(&mut Self) -> Result<Vec<PyObject>, CallFunctionError>,
671 {
672 let actual_results = f(self).await;
673 let op_results = actual_results.and_then(|actual_results| {
676 if result_refs.len() == actual_results.len() {
677 Ok(actual_results
678 .into_iter()
679 .zip(result_refs.iter())
680 .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
681 .collect::<Vec<(Ref, PyObject)>>())
682 } else {
683 Err(CallFunctionError::UnexpectedNumberOfReturns(
684 result_refs.len(),
685 actual_results.len(),
686 ))
687 }
688 });
689
690 match op_results {
693 Ok(op_results) => {
694 for (ref_, pyobject) in op_results.into_iter() {
695 let prev = self.env.insert(ref_, Ok(pyobject));
696 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
697 }
698 }
699 Err(err) => {
700 let err = self.report_seq_error(cx, seq, err).await?;
701 for ref_ in result_refs {
702 match ref_ {
703 Some(ref_) => {
704 let prev = self.env.insert(ref_, Err(err.clone()));
705 assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
706 }
707 None => {}
708 }
709 }
710 for ref_ in mutates {
711 self.env.insert(*ref_, Err(err.clone()));
712 }
713 }
714 }
715 Ok(())
716 }
717
718 fn call_python_fn<'py>(
719 &mut self,
720 py: Python<'py>,
721 _cx: &Context<Self>,
722 function: Option<ResolvableFunction>,
723 args_kwargs: ArgsKwargs,
724 _mutates: &[Ref],
725 device_meshes: HashMap<Ref, DeviceMesh>,
726 remote_process_groups: HashMap<
727 Ref,
728 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
729 >,
730 ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
731 let (args_tuple, kwargs_dict) = args_kwargs
732 .to_python(py)
733 .map_err(|e| CallFunctionError::Error(e.into()))?;
734 let function = function
735 .map(|function| {
736 function.resolve(py).map_err(|e| {
737 CallFunctionError::InvalidRemoteFunction(format!(
738 "failed to resolve function {}: {}",
739 function,
740 SerializablePyErr::from(py, &e)
741 ))
742 })
743 })
744 .transpose()?;
745
746 let remote_process_groups = remote_process_groups
747 .into_iter()
748 .map(|(gref, (_mesh, _dims, _comm))| {
749 let group = match self.remote_process_groups.entry(gref) {
750 Entry::Occupied(ent) => ent.get().clone_ref(py),
751 Entry::Vacant(_ent) => {
752 panic!("no longer implemented");
753 }
754 };
755 PyResult::Ok((gref, group))
756 })
757 .collect::<Result<HashMap<_, _>, _>>()
758 .map_err(SerializablePyErr::from_fn(py))?;
759
760 let resolve = |val: Bound<'py, PyAny>| {
761 val.extract::<PyTree<PyObject>>()
762 .map_err(SerializablePyErr::from_fn(py))?
763 .try_into_map(|obj| {
764 Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
765 if let Some(mesh) = device_meshes.get(&ref_) {
766 PyArg::PyObject(
767 Py::new(py, mesh.clone())
768 .map_err(SerializablePyErr::from_fn(py))?
769 .into(),
770 )
771 } else if let Some(pg) = remote_process_groups.get(&ref_) {
772 PyArg::PyObject(pg.clone_ref(py))
773 } else {
774 let pyobj = self.ref_to_pyobject(&ref_)?;
775 PyArg::PyObject(pyobj)
776 }
777 } else {
778 PyArg::PyObject(obj)
779 })
780 })
781 };
782
783 let py_args: Vec<PyTree<PyArg>> = args_tuple
785 .iter()
786 .map(&resolve)
787 .collect::<Result<_, CallFunctionError>>()?;
788
789 let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
790 .iter()
791 .map(|(k, v)| {
792 let key = k
793 .extract::<String>()
794 .map_err(SerializablePyErr::from_fn(py))?;
795 let value = resolve(v)?;
796 Ok((key, value))
797 })
798 .collect::<Result<_, CallFunctionError>>()?;
799
800 let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
803 let result: Bound<'_, PyAny> =
804 tracing::subscriber::with_default(scoped_subscriber, || {
805 let args = unsafe { py_args.try_to_object_unsafe(py) }
814 .map_err(SerializablePyErr::from_fn(py))?;
815 let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
817 .map_err(SerializablePyErr::from_fn(py))?;
818
819 if let Some(function) = function {
820 function
821 .call(args, Some(kwargs))
822 .map_err(SerializablePyErr::from_fn(py))
823 } else {
824 Ok(args.get_item(0).unwrap())
825 }
826 })?;
827 Ok(result)
828 }
829
830 fn call_python_fn_pytree(
831 &mut self,
832 cx: &hyperactor::Context<Self>,
833 function: ResolvableFunction,
834 args_kwargs: ArgsKwargs,
835 mutates: &[Ref],
836 device_meshes: HashMap<Ref, DeviceMesh>,
837 remote_process_groups: HashMap<
838 Ref,
839 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
840 >,
841 ) -> Result<PyTree<PyObject>, CallFunctionError> {
842 Python::with_gil(|py| {
843 let result = self.call_python_fn(
844 py,
845 cx,
846 Some(function),
847 args_kwargs,
848 mutates,
849 device_meshes,
850 remote_process_groups,
851 )?;
852 Ok(PyTree::<PyObject>::extract_bound(&result)
853 .map_err(SerializablePyErr::from_fn(py))?)
854 })
855 }
856 fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
864 let pyobject = self
865 .env
866 .get(&ref_)
867 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
868
869 match pyobject {
870 Ok(val) => Python::with_gil(|py| {
871 Self::pyobject_to_tensor(py, val)
872 .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
873 }),
874 Err(_) => {
875 let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
876 Ok(TensorCell::new(t))
877 }
878 }
879 }
880
881 fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
882 self.active_recording
883 .as_mut()
884 .and_then(|state| match state {
885 RecordingState::Defining {
886 recording,
887 defined_borrows,
888 } => {
889 match self.recordings.get_mut(recording) {
890 Some(recording) => Some((recording, defined_borrows)),
891 None => panic!("recording not found: {:?}", recording),
893 }
894 }
895 RecordingState::Running => None,
896 })
897 }
898
899 fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
900 for ref_ in refs {
901 let rvalue_or_err = self
902 .env
903 .get(ref_)
904 .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
905 if let Err(err) = rvalue_or_err {
906 return Ok(Some(err.clone()));
907 }
908 }
909 Ok(None)
910 }
911 async fn send_value_python_message(
912 &mut self,
913 cx: &hyperactor::Context<'_, Self>,
914 seq: Seq,
915 mutates: Vec<Ref>,
916 function: Option<ResolvableFunction>,
917 args_kwargs: ArgsKwargs,
918 device_meshes: HashMap<Ref, DeviceMesh>,
919 ) -> Result<()> {
920 let rank = self.rank;
921 self.try_define(cx, seq, vec![], &vec![], async |self_| {
922 let python_message =
923 Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
924 let python_result = tokio::task::block_in_place(|| {
925 self_.call_python_fn(
926 py,
927 cx,
928 function,
929 args_kwargs,
930 &mutates,
931 device_meshes,
932 HashMap::new(),
933 )
934 })?;
935 pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
936 })?;
937 let ser = wirevalue::Any::serialize(&python_message).unwrap();
938 self_
939 .controller_actor
940 .fetch_result(cx, seq, Ok(ser))
941 .await?;
942 Ok(vec![])
943 })
944 .await
945 }
946 fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
947 let rvalue = self
948 .env
949 .get(&src)
950 .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
951 self.env
952 .insert(dest, Python::with_gil(|_py| rvalue.clone()));
953 Ok(())
954 }
955 async fn call_actor(
956 &mut self,
957 cx: &Context<'_, Self>,
958 params: ActorCallParams,
959 ) -> Result<PyObject, CallFunctionError> {
960 let local_state: Result<Vec<PyObject>> = Python::with_gil(|_py| {
961 params
962 .local_state
963 .into_iter()
964 .map(|elem| {
965 let pyobj = self.ref_to_pyobject(&elem)?;
966 Ok(pyobj.into_any())
967 })
968 .collect()
969 });
970
971 let (send, recv) = cx.open_once_port();
972 let state = LocalState {
973 response_port: send,
974 state: local_state?,
975 };
976 let x: u64 = params.seq.into();
977 let message = LocalStateBrokerMessage::Set(x as usize, state);
978
979 let broker = BrokerId::new(params.broker_id).resolve(cx).await;
980 broker
981 .send(cx, message)
982 .map_err(|e| CallFunctionError::Error(e.into()))?;
983 let result = recv
984 .recv()
985 .await
986 .map_err(|e| CallFunctionError::Error(e.into()))?;
987
988 result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
989 }
990}
991
992#[async_trait]
993#[forward(StreamMessage)]
994impl StreamMessageHandler for StreamActor {
995 async fn call_function(
996 &mut self,
997 cx: &Context<Self>,
998 params: CallFunctionParams,
999 device_meshes: HashMap<Ref, DeviceMesh>,
1000 remote_process_groups: HashMap<
1001 Ref,
1002 (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1003 >,
1004 ) -> Result<()> {
1005 if let Some((recording, _)) = self.get_defining_recording() {
1006 recording.messages.push(StreamMessage::CallFunction(
1007 params,
1008 device_meshes,
1009 remote_process_groups,
1010 ));
1011 return Ok(());
1012 }
1013
1014 params.function.panic_if_requested();
1015 self.try_define(
1016 cx,
1017 params.seq,
1018 params.results,
1019 ¶ms.mutates,
1020 async |self| {
1021 tokio::task::block_in_place(|| {
1022 self.call_python_fn_pytree(
1023 cx,
1024 params.function,
1025 params.args_kwargs,
1026 ¶ms.mutates,
1027 device_meshes,
1028 remote_process_groups,
1029 )
1030 .map(|results| results.into_leaves())
1031 })
1032 },
1033 )
1034 .await?;
1035 Ok(())
1036 }
1037
1038 async fn borrow_create(
1039 &mut self,
1040 _cx: &Context<Self>,
1041 borrow: u64,
1042 tensor: Ref,
1043 first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1044 ) -> Result<()> {
1045 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1046 recording.messages.push(StreamMessage::BorrowCreate {
1047 borrow,
1048 tensor,
1049 first_use_sender,
1050 });
1051 ensure!(
1052 defined_borrows.insert(borrow),
1053 "duplicate borrow create in recording"
1054 );
1055 return Ok(());
1056 }
1057
1058 let pyobj_result = self
1059 .env
1060 .get(&tensor)
1061 .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1062
1063 let result = match pyobj_result {
1064 Ok(pyobj) => Python::with_gil(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1065 Err(e) => Err(e.clone()),
1066 };
1067
1068 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1069 first_use_sender.send((event, result)).map_err(|err| {
1070 anyhow!(
1071 "failed sending first use event for borrow {:?}: {:?}",
1072 borrow,
1073 err
1074 )
1075 })
1076 }
1077
1078 async fn borrow_first_use(
1079 &mut self,
1080 _cx: &Context<Self>,
1081 borrow: u64,
1082 result: Ref,
1083 first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1084 ) -> Result<()> {
1085 if let Some((recording, _)) = self.get_defining_recording() {
1086 recording.messages.push(StreamMessage::BorrowFirstUse {
1087 borrow,
1088 result,
1089 first_use_receiver: first_use_receiver.clone(),
1090 });
1091 return Ok(());
1092 }
1093
1094 let (first_use_event, cell) =
1095 first_use_receiver
1096 .lock()
1097 .await
1098 .recv()
1099 .await
1100 .map_err(|err| {
1101 anyhow!(
1102 "failed receiving first use event for borrow {:?}: {:?}",
1103 borrow,
1104 err
1105 )
1106 })?;
1107
1108 if let Some(stream) = self.cuda_stream() {
1109 stream.wait_event(
1110 &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1111 );
1112 }
1113 match cell {
1114 Ok(cell) => {
1115 let pyobj = Self::tensor_to_pyobject(cell);
1116 self.env.insert(result, Ok(pyobj));
1117 }
1118 Err(err) => {
1119 self.env.insert(result, Err(err.clone()));
1120 }
1121 }
1122 Ok(())
1123 }
1124
1125 async fn borrow_last_use(
1126 &mut self,
1127 _cx: &Context<Self>,
1128 borrow: u64,
1129 result: Ref,
1130 last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1131 ) -> Result<()> {
1132 if let Some((recording, _)) = self.get_defining_recording() {
1133 recording.messages.push(StreamMessage::BorrowLastUse {
1134 borrow,
1135 result,
1136 last_use_sender,
1137 });
1138 return Ok(());
1139 }
1140
1141 let event = self.cuda_stream().map(|stream| stream.record_event(None));
1142 let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1143 "Invalid reference for borrow_last_use: {result:#?}"
1144 ))?;
1145 let tensor = match pyobj_or_err {
1146 Ok(pyobj) => Ok(Python::with_gil(|py| {
1147 Self::pyobject_to_tensor(py, &pyobj).unwrap()
1148 })),
1149 Err(e) => Err(e),
1150 };
1151
1152 last_use_sender.send((event, tensor)).map_err(|err| {
1153 anyhow!(
1154 "failed sending last use event for borrow {:?}: {:?}",
1155 borrow,
1156 err
1157 )
1158 })
1159 }
1160
1161 async fn borrow_drop(
1162 &mut self,
1163 _cx: &Context<Self>,
1164 borrow: u64,
1165 last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1166 ) -> Result<()> {
1167 if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1168 recording.messages.push(StreamMessage::BorrowDrop {
1169 borrow,
1170 last_use_receiver: last_use_receiver.clone(),
1171 });
1172 ensure!(
1173 defined_borrows.remove(&borrow),
1174 "borrow drop for borrow not defined in recording"
1175 );
1176 return Ok(());
1177 }
1178
1179 let (last_use_event, _cell) =
1183 last_use_receiver.lock().await.recv().await.map_err(|err| {
1184 anyhow!(
1185 "failed receiving last use event for borrow {:?}: {:?}",
1186 borrow,
1187 err
1188 )
1189 })?;
1190
1191 if let Some(stream) = self.cuda_stream() {
1192 stream.wait_event(
1193 &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1194 );
1195 }
1196 Ok(())
1198 }
1199
1200 async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1201 if let Some((recording, _)) = self.get_defining_recording() {
1202 recording.messages.push(StreamMessage::DeleteRefs(refs));
1203 return Ok(());
1204 }
1205
1206 for ref_ in refs.iter() {
1207 self.env.remove(ref_);
1208 }
1209 Ok(())
1210 }
1211
1212 async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1213 if self.get_defining_recording().is_some() {
1214 bail!("request_status not allowed in recording");
1215 }
1216
1217 Ok(())
1218 }
1219
1220 async fn init_comm(
1221 &mut self,
1222 _cx: &Context<Self>,
1223 comm: ActorHandle<NcclCommActor>,
1224 ) -> Result<()> {
1225 if self.get_defining_recording().is_some() {
1226 bail!("init_comm not allowed in recording");
1227 }
1228
1229 self.comm = Some(comm);
1230 Ok(())
1231 }
1232
1233 async fn reduce(
1234 &mut self,
1235 cx: &Context<Self>,
1236 comm: Arc<ActorHandle<NcclCommActor>>,
1237 dim_size: i64,
1238 result: Ref,
1239 local_tensor: Ref,
1240 factory: Factory,
1241 reduction: Reduction,
1242 scatter: bool,
1243 in_place: bool,
1244 out: Option<Ref>,
1245 ) -> Result<()> {
1246 if let Some((recording, _)) = self.get_defining_recording() {
1247 recording.messages.push(StreamMessage::Reduce {
1248 comm,
1249 dim_size,
1250 result,
1251 local_tensor,
1252 factory,
1253 reduction,
1254 scatter,
1255 in_place,
1256 out,
1257 });
1258 return Ok(());
1259 }
1260
1261 let stream = self
1262 .cuda_stream()
1263 .expect("reductions not yet supported for non-CUDA workers")
1264 .clone();
1265 let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1266 let out_cell = out
1267 .map(|out| self.get_or_fake_on_err(out, &factory))
1268 .transpose()?;
1269 let output_cell = match reduction {
1270 Reduction::Stack => {
1271 if scatter {
1272 let output_cell = if in_place {
1273 input_cell.clone()
1274 } else {
1275 out_cell.unwrap_or({
1276 let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1277 let cloned = deep_clone(&borrow);
1278 TensorCell::new(cloned)
1279 })
1280 };
1281 comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1282 .await?;
1283 output_cell
1284 } else {
1285 ensure!(
1286 !in_place,
1287 "in-place, non-scatter not supported for stack reduce"
1288 );
1289
1290 let output_cell = out_cell.unwrap_or({
1291 let sizes = [&[dim_size][..], &factory.size[..]].concat();
1293 let output =
1294 factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1295 TensorCell::new(output)
1296 });
1297
1298 comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1299 .await?;
1300 output_cell
1301 }
1302 }
1303 Reduction::ReduceOp(op) => {
1304 if scatter {
1305 ensure!(!in_place, "in-place, scatter not supported for reduce");
1306
1307 let output_cell = out_cell.unwrap_or({
1308 let output = factory_empty(
1309 &factory.size[1..],
1310 factory.dtype,
1311 factory.layout,
1312 factory.device,
1313 );
1314 TensorCell::new(output)
1315 });
1316 comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1317 .await?;
1318 output_cell
1319 } else {
1320 let output_cell = if in_place {
1321 input_cell.clone()
1322 } else {
1323 out_cell.map_or(
1324 {
1325 let borrow =
1326 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1327 let cloned = deep_clone(&borrow);
1328 Ok(TensorCell::new(cloned))
1329 },
1330 |out_cell| -> Result<_, anyhow::Error> {
1331 let mut out_borrow =
1332 out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1333 let in_borrow =
1334 input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1335 out_borrow.copy_(&in_borrow);
1336 drop(out_borrow);
1337 Ok(out_cell)
1338 },
1339 )?
1340 };
1341
1342 comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1343 output_cell
1344 }
1345 }
1346 };
1347
1348 let pyobj = Self::tensor_to_pyobject(output_cell);
1349 self.env.insert(result, Ok(pyobj));
1350 Ok(())
1351 }
1352
1353 async fn send_tensor(
1354 &mut self,
1355 cx: &Context<Self>,
1356 result: Ref,
1357 from_rank: Option<usize>,
1358 to_rank: Option<usize>,
1359 tensor: Ref,
1360 factory: Factory,
1361 comm: Arc<ActorHandle<NcclCommActor>>,
1362 ) -> Result<()> {
1363 if let Some((recording, _)) = self.get_defining_recording() {
1364 recording.messages.push(StreamMessage::SendTensor {
1365 result,
1366 from_rank,
1367 to_rank,
1368 tensor,
1369 factory,
1370 comm,
1371 });
1372 return Ok(());
1373 }
1374
1375 if to_rank.is_none() && from_rank.is_none() {
1376 bail!("tried to send tensor without a to/from rank");
1377 }
1378
1379 if from_rank == to_rank {
1381 let input_cell: &std::result::Result<PyObject, Arc<SeqError>> =
1382 self.env
1383 .get(&tensor)
1384 .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1385 let output_cell: Result<PyObject, Arc<SeqError>> = match input_cell {
1386 Ok(pyobj) => {
1387 Python::with_gil(|py| -> Result<PyObject, Arc<SeqError>> {
1388 let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1389 let borrow = input_tensor.try_borrow().unwrap();
1394 let cloned = deep_clone(&borrow);
1395 let cloned_cell = TensorCell::new(cloned);
1396 Ok(Self::tensor_to_pyobject(cloned_cell))
1397 })
1398 }
1399 Err(err) => Err(err.clone()),
1400 };
1401 self.env.insert(result, output_cell);
1402 return Ok(());
1403 }
1404
1405 let mut messages = Vec::new();
1406
1407 if let Some(to_rank) = to_rank {
1408 let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1409 messages.push(CommMessage::Send(
1410 input_cell,
1411 to_rank.try_into().unwrap(),
1412 self.cuda_stream()
1413 .expect("tried to send_tensor on non-cuda stream")
1414 .clone(),
1415 cx.open_once_port().0,
1416 ));
1417 }
1418
1419 if let Some(from_rank) = from_rank {
1420 let output_cell = TensorCell::new(factory_empty(
1421 &factory.size,
1422 factory.dtype,
1423 factory.layout,
1424 factory.device,
1425 ));
1426 messages.push(CommMessage::Recv(
1427 output_cell.clone(),
1428 from_rank.try_into().unwrap(),
1429 self.cuda_stream()
1430 .expect("tried to send_tensor on non-cuda stream")
1431 .clone(),
1432 cx.open_once_port().0,
1433 ));
1434 let pyobj = Self::tensor_to_pyobject(output_cell);
1435 self.env.insert(result, Ok(pyobj));
1436 }
1437
1438 comm.group(
1439 cx,
1440 messages,
1441 self.cuda_stream()
1442 .expect("tried to send_tensor on non-cuda stream")
1443 .clone(),
1444 )
1445 .await?;
1446 Ok(())
1447 }
1448
1449 async fn send_value(
1450 &mut self,
1451 cx: &Context<Self>,
1452 seq: Seq,
1453 worker_actor_id: ActorId,
1454 mutates: Vec<Ref>,
1455 function: Option<ResolvableFunction>,
1456 args_kwargs: ArgsKwargs,
1457 device_meshes: HashMap<Ref, DeviceMesh>,
1458 ) -> Result<()> {
1459 if self.respond_with_python_message {
1460 return self
1461 .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1462 .await;
1463 }
1464
1465 let result = if let Some(function) = function {
1466 tokio::task::block_in_place(|| {
1468 self.call_python_fn_pytree(
1469 cx,
1470 function,
1471 args_kwargs,
1472 &mutates,
1473 device_meshes,
1474 HashMap::new(),
1475 )
1476 })
1477 } else {
1478 Python::with_gil(|py| {
1481 let (args, kwargs) = args_kwargs
1482 .to_python(py)
1483 .map_err(|e| CallFunctionError::Error(e.into()))?;
1484 match (args.len(), kwargs.len()) {
1485 (1, 0) => {
1486 let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1487 arg.extract::<PyTree<PyObject>>()
1488 .map_err(SerializablePyErr::from_fn(py))?
1489 .try_into_map(|obj| {
1490 let bound_obj = obj.bind(py);
1491 if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1492 self.ref_to_pyobject(&ref_)
1493 } else {
1494 Ok(obj)
1495 }
1496 })
1497 }
1498 _ => Err(CallFunctionError::TooManyArgsForValue(
1499 format!("args with {} elements", args.len()),
1500 format!("kwargs with {} elements", kwargs.len()),
1501 )),
1502 }
1503 })
1504 };
1505
1506 let value = match result {
1507 Ok(pyobject) => Ok(pyobject),
1508 Err(err) => {
1509 let err = self.report_seq_error(cx, seq, err).await?;
1510 for ref_ in mutates {
1511 self.env.insert(ref_, Err(err.clone()));
1512 }
1513 Err(WorkerError {
1514 backtrace: format!("{:?}", err),
1515 worker_actor_id,
1516 })
1517 }
1518 };
1519
1520 let result = match value {
1524 Ok(_value) => {
1525 unreachable!(
1527 "send_value should return early when respond_with_python_message is true"
1528 )
1529 }
1530 Err(e) => Err(e),
1531 };
1532 self.controller_actor.fetch_result(cx, seq, result).await?;
1533
1534 Ok(())
1535 }
1536
1537 async fn send_result_of_actor_call(
1538 &mut self,
1539 cx: &Context<Self>,
1540 params: ActorCallParams,
1541 ) -> anyhow::Result<()> {
1542 let seq = params.seq;
1543 let mutates = params.mutates.clone();
1544 self.try_define(cx, seq, vec![], &mutates, async |self| {
1545 let value = self.call_actor(cx, params).await?;
1546 let result =
1547 Python::with_gil(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1548 let result = wirevalue::Any::serialize(&result).unwrap();
1549 self.controller_actor
1550 .fetch_result(cx, seq, Ok(result))
1551 .await?;
1552 Ok(vec![])
1553 })
1554 .await
1555 }
1556
1557 async fn call_actor_method(
1558 &mut self,
1559 cx: &Context<Self>,
1560 params: ActorMethodParams,
1561 ) -> anyhow::Result<()> {
1562 let seq = params.call.seq;
1563 let mutates = params.call.mutates.clone();
1564 self.try_define(cx, seq, params.results, &mutates, async |self| {
1565 let result = self.call_actor(cx, params.call).await?;
1566 let result = Python::with_gil(|py| {
1567 PyTree::<PyObject>::extract_bound(&result.into_bound(py))
1568 .map_err(SerializablePyErr::from_fn(py))
1569 })?;
1570 Ok(result.into_leaves())
1571 })
1572 .await
1573 }
1574
1575 async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1576 if self.active_recording.is_some() {
1577 bail!("different recording already active");
1578 }
1579 match self.recordings.entry(recording) {
1580 Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1581 Entry::Vacant(entry) => entry.insert(Recording::new()),
1582 };
1583 self.active_recording = Some(RecordingState::Defining {
1584 recording,
1585 defined_borrows: HashSet::new(),
1586 });
1587 Ok(())
1588 }
1589
1590 async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1591 match self.active_recording {
1592 Some(RecordingState::Defining {
1593 recording: active_recording,
1594 ref defined_borrows,
1595 }) if active_recording == recording => {
1596 ensure!(
1597 defined_borrows.is_empty(),
1598 "all borrows created within recording must be dropped within recording"
1599 );
1600 self.active_recording = None;
1601 }
1602 _ => bail!("cannot finalize recording that isn't active"),
1603 }
1604 Ok(())
1605 }
1606
1607 async fn recording_formal(
1608 &mut self,
1609 _cx: &Context<Self>,
1610 result: Ref,
1611 argument_index: usize,
1612 ) -> Result<()> {
1613 match self.get_defining_recording() {
1614 Some((recording, _)) => {
1615 recording.messages.push(StreamMessage::RecordingFormal {
1616 result,
1617 argument_index,
1618 });
1619 }
1620 None => bail!("recording_formal called outside of recording"),
1621 };
1622 Ok(())
1623 }
1624
1625 async fn recording_result(
1626 &mut self,
1627 _cx: &Context<Self>,
1628 result: Ref,
1629 output_index: usize,
1630 ) -> Result<()> {
1631 match self.get_defining_recording() {
1632 Some((recording, _)) => {
1633 recording.messages.push(StreamMessage::RecordingResult {
1634 result,
1635 output_index,
1636 });
1637 }
1638 None => bail!("recording_result called outside of recording"),
1639 };
1640 Ok(())
1641 }
1642
1643 async fn call_recording(
1644 &mut self,
1645 cx: &Context<Self>,
1646 seq: Seq,
1647 recording: Ref,
1648 results: Vec<Ref>,
1649 actuals: Vec<Ref>,
1650 ) -> Result<()> {
1651 if self.active_recording.is_some() {
1652 bail!("cannot call recording while another recording is active");
1653 }
1654
1655 let messages = match self.recordings.get(&recording) {
1656 Some(recording) => recording
1657 .messages
1658 .iter()
1659 .map(|message| message.clone_for_recording())
1660 .collect::<Vec<_>>(),
1661 None => bail!("recording {:?} not found", recording),
1662 };
1663
1664 self.active_recording = Some(RecordingState::Running);
1665
1666 let mut error: Option<Arc<SeqError>> = None;
1671 let mut all_defined_refs = HashSet::new();
1674 let mut all_mutated_refs = HashSet::new();
1677 let mut formal_to_actual_refs = HashMap::new();
1683 self.last_seq_error = None;
1685 for message in messages.into_iter() {
1686 let defined_refs = message.get_defined_refs();
1687 all_defined_refs.extend(defined_refs.clone());
1688
1689 let mutated_refs_with_formals = message.get_mutated_refs();
1690 all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1691 match formal_to_actual_refs.get(ref_) {
1692 Some(actual_ref) => Some(*actual_ref),
1693 None => {
1694 if all_defined_refs.contains(ref_) {
1695 None
1696 } else {
1697 Some(*ref_)
1698 }
1699 }
1700 }
1701 }));
1702
1703 match message {
1704 StreamMessage::RecordingFormal {
1705 result: formal_ref,
1706 argument_index,
1707 } => match actuals.get(argument_index) {
1708 None => bail!("recording_formal called with too few arguments"),
1709 Some(actual_ref) => {
1710 formal_to_actual_refs.insert(formal_ref, *actual_ref);
1711 self.define_ref(formal_ref, *actual_ref)?;
1712 }
1713 },
1714 StreamMessage::RecordingResult {
1715 result: result_ref,
1716 output_index,
1717 } => match results.get(output_index) {
1718 None => bail!("recording_result called with too few results"),
1719 Some(actual_result_ref) => {
1720 self.define_ref(*actual_result_ref, result_ref)?;
1721 }
1722 },
1723 StreamMessage::DeleteRefs(ref refs) => {
1724 for ref_ in refs {
1725 all_defined_refs.remove(ref_);
1726 }
1727 StreamMessageHandler::handle(self, cx, message).await?;
1728 }
1729 StreamMessage::CallFunction { .. } if error.is_some() => {
1730 let error = error.clone().unwrap();
1736 for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1737 self.env.insert(*ref_, Err(error.clone()));
1738 }
1739 }
1740 StreamMessage::BorrowLastUse { ref result, .. } => {
1741 all_defined_refs.remove(result);
1742 StreamMessageHandler::handle(self, cx, message).await?;
1743 }
1744 StreamMessage::Reduce {
1745 local_tensor,
1746 ref out,
1747 ..
1748 } => {
1749 if error.is_none() {
1753 let inputs_to_check = [Some(local_tensor), out.clone()]
1754 .iter()
1755 .filter_map(|r| *r)
1756 .collect::<Vec<_>>();
1757 error = self.get_first_error(inputs_to_check.as_slice())?;
1758 }
1759 StreamMessageHandler::handle(self, cx, message).await?;
1760 }
1761 StreamMessage::SendTensor {
1762 ref tensor,
1763 ref to_rank,
1764 ..
1765 } => {
1766 if to_rank.is_some() && error.is_none() {
1771 error = self.get_first_error(&[*tensor])?;
1772 }
1773 StreamMessageHandler::handle(self, cx, message).await?;
1774 }
1775 _ => {
1776 StreamMessageHandler::handle(self, cx, message).await?;
1777 }
1778 };
1779
1780 match (&error, self.last_seq_error.take()) {
1788 (None, Some(seq_err)) => {
1789 self.controller_actor
1791 .remote_function_failed(
1792 cx,
1793 seq,
1794 WorkerError {
1795 backtrace: format!("recording failed: {}", &seq_err),
1796 worker_actor_id: cx.self_id().clone(),
1797 },
1798 )
1799 .await?;
1800 error = Some(seq_err)
1801 }
1802 _ => {}
1803 }
1804 }
1809
1810 StreamMessageHandler::handle(
1814 self,
1815 cx,
1816 StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1817 )
1818 .await?;
1819
1820 if error.is_some() {
1823 for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1824 self.env.insert(*ref_, Err(error.clone().unwrap()));
1825 }
1826 }
1827
1828 self.active_recording = None;
1829 Ok(())
1830 }
1831
1832 async fn set_ref_unit_tests_only(
1833 &mut self,
1834 _cx: &Context<Self>,
1835 reference: Ref,
1836 value: WireValue,
1837 ) -> Result<()> {
1838 let pyobj =
1839 Python::with_gil(|py| -> PyResult<PyObject> { Ok(value.into_pyobject(py)?.unbind()) })?;
1840 self.env.insert(reference, Ok(pyobj));
1841 Ok(())
1842 }
1843
1844 async fn set_tensor_ref_unit_tests_only(
1845 &mut self,
1846 _cx: &Context<Self>,
1847 reference: Ref,
1848 tensor_result: TensorCellResult,
1849 ) -> Result<()> {
1850 match tensor_result {
1851 Ok(tensor_cell) => {
1852 let pyobj = Self::tensor_to_pyobject(tensor_cell);
1853 self.env.insert(reference, Ok(pyobj));
1854 }
1855 Err(err) => {
1856 self.env.insert(reference, Err(err));
1857 }
1858 }
1859 Ok(())
1860 }
1861
1862 async fn get_ref_unit_tests_only(
1863 &mut self,
1864 _cx: &Context<Self>,
1865 reference: Ref,
1866 ) -> Result<Option<Result<WireValue, String>>> {
1867 use pyo3::types::PyBool;
1868 use pyo3::types::PyFloat;
1869 use pyo3::types::PyInt;
1870 use pyo3::types::PyList;
1871 use pyo3::types::PyNone;
1872 use pyo3::types::PyString;
1873 fn pyobject_to_wire(
1875 value: Result<PyObject, Arc<SeqError>>,
1876 ) -> Result<WireValue, Arc<SeqError>> {
1877 let pyobj = value?;
1878 Python::with_gil(|py| {
1879 let bound = pyobj.bind(py);
1880 if bound.is_instance_of::<PyBool>() {
1882 Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1883 } else if bound.is_instance_of::<PyInt>() {
1884 Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1885 } else if bound.is_instance_of::<PyList>() {
1886 if let Ok(val) = bound.extract::<Vec<i64>>() {
1887 Ok(WireValue::IntList(val))
1888 } else {
1889 Ok(WireValue::String(format!(
1890 "unsupported list type: {:?}",
1891 bound
1892 )))
1893 }
1894 } else if bound.is_instance_of::<PyFloat>() {
1895 Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1896 } else if bound.is_instance_of::<PyString>() {
1897 Ok(WireValue::String(bound.extract::<String>().unwrap()))
1898 } else if bound.is_instance_of::<PyNone>() {
1899 Ok(WireValue::None(()))
1900 } else {
1901 Ok(WireValue::String(format!(
1902 "unsupported pyobject type: {:?}",
1903 bound
1904 )))
1905 }
1906 })
1907 }
1908 Ok(self.env.get(&reference).map(|pyobj| {
1909 pyobject_to_wire(Python::with_gil(|_py| pyobj.clone())).map_err(|err| err.to_string())
1910 }))
1911 }
1912
1913 async fn get_tensor_ref_unit_tests_only(
1914 &mut self,
1915 _cx: &Context<Self>,
1916 reference: Ref,
1917 ) -> Result<Option<TensorCellResult>> {
1918 match self.env.get(&reference) {
1919 Some(Ok(pyobj)) => Python::with_gil(|py| match Self::pyobject_to_tensor(py, pyobj) {
1920 Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1921 Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1922 }),
1923 Some(Err(err)) => Ok(Some(Err(err.clone()))),
1924 None => Ok(None),
1925 }
1926 }
1927}
1928
1929#[cfg(test)]
1930mod tests {
1931 use hyperactor::actor::ActorStatus;
1932 use hyperactor::context;
1933 use hyperactor::supervision::ActorSupervisionEvent;
1934 use monarch_messages::controller::ControllerMessage;
1935 use monarch_messages::worker::StreamCreationMode;
1936 use monarch_types::PickledPyObject;
1937 use pyo3::IntoPyObjectExt;
1938 use timed_test::async_timed_test;
1939 use torch_sys_cuda::nccl::UniqueId;
1940 use torch_sys2::factory_float_tensor;
1941 use torch_sys2::testing::allclose;
1942
1943 use super::*;
1944 use crate::comm::CommParams;
1945 use crate::test_util;
1946
1947 #[allow(dead_code)]
1948 fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1949 Arc::new(SeqError {
1950 seq: 0.into(),
1951 error: err,
1952 })
1953 }
1954
1955 struct TestSetup {
1956 proc: Proc,
1957 stream_actor: ActorHandle<StreamActor>,
1958 client: Instance<()>,
1959 #[allow(dead_code)]
1962 supervision_rx: PortReceiver<ActorSupervisionEvent>,
1963 #[allow(dead_code)]
1964 controller_rx: PortReceiver<ControllerMessage>,
1965 #[allow(dead_code)]
1966 controller_actor: ActorRef<ControllerActor>,
1967 next_ref: Ref,
1968 }
1969
1970 impl TestSetup {
1971 async fn new() -> Result<Self> {
1972 Self::new_with_world_size(1).await
1973 }
1974
1975 async fn new_with_world_size(world_size: usize) -> Result<Self> {
1976 test_util::test_setup()?;
1977
1978 let proc = Proc::local();
1979 let (_, controller_actor, controller_rx) =
1980 proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1981 let (client, _handle) = proc.instance("client")?;
1982 let (supervision_tx, supervision_rx) = client.open_port();
1983 proc.set_supervision_coordinator(supervision_tx)?;
1984 let stream_actor = proc.spawn(
1985 "stream",
1986 StreamActor::new(StreamParams {
1987 world_size,
1988 rank: 0,
1989 creation_mode: StreamCreationMode::UseDefaultStream,
1990 id: 0.into(),
1991 device: Some(CudaDevice::new(0.into())),
1992 controller_actor: controller_actor.clone(),
1993 respond_with_python_message: false,
1994 }),
1995 )?;
1996
1997 Ok(Self {
1998 proc,
1999 stream_actor,
2000 client,
2001 supervision_rx,
2002 controller_rx,
2003 controller_actor,
2004 next_ref: 0.into(),
2005 })
2006 }
2007
2008 fn next_ref(&mut self) -> Ref {
2009 let ref_ = self.next_ref;
2010 self.next_ref = Ref {
2011 id: self.next_ref.id + 1,
2012 };
2013 ref_
2014 }
2015
2016 async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2017 let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2018 self.stream_actor
2019 .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2020 .await
2021 }
2022
2023 async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2024 let actual = self
2025 .stream_actor
2026 .get_tensor_ref_unit_tests_only(&self.client, reference)
2027 .await
2028 .unwrap()
2029 .unwrap()
2030 .unwrap();
2031
2032 allclose(
2034 &factory_float_tensor(data, "cpu".parse().unwrap()),
2035 &actual.borrow(),
2036 )
2037 .unwrap()
2038 }
2039
2040 #[allow(dead_code)]
2041 async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2042 let result_error = self
2043 .stream_actor
2044 .get_tensor_ref_unit_tests_only(&self.client, reference)
2045 .await
2046 .unwrap()
2047 .unwrap()
2048 .unwrap_err();
2049
2050 assert!(Arc::ptr_eq(&result_error, &error));
2051 }
2052 }
2053
2054 async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2055 loop {
2056 let status = proc
2057 .ledger_snapshot()
2058 .roots
2059 .get(actor_id)
2060 .unwrap()
2061 .status
2062 .clone();
2063 if let ActorStatus::Failed(msg) = status {
2064 assert!(msg.to_string().contains(&expected_msg));
2065 break;
2066 } else {
2067 tokio::task::yield_now().await;
2068 }
2069 }
2070 }
2071
2072 async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2073 for ref_ in refs {
2074 assert!(
2075 test_setup
2076 .stream_actor
2077 .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2078 .await
2079 .unwrap()
2080 .is_none()
2081 );
2082 }
2083 }
2084
2085 #[allow(dead_code)]
2086 async fn fetch_result(
2087 cx: &impl context::Actor,
2088 stream_actor: ActorHandle<StreamActor>,
2089 seq: Seq,
2090 reference: Ref,
2091 ) {
2092 let ref_to_send = Python::with_gil(|py| {
2093 PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2094 });
2095
2096 stream_actor
2097 .send_value(
2098 cx,
2099 seq,
2100 stream_actor.actor_id().clone(),
2101 Vec::new(),
2102 None,
2103 ArgsKwargs::from_wire_values(
2104 vec![WireValue::PyObject(ref_to_send)],
2105 HashMap::new(),
2106 )
2107 .unwrap(),
2108 HashMap::new(),
2109 )
2110 .await
2111 .unwrap()
2112 }
2113
2114 #[allow(dead_code)]
2115 async fn check_fetch_result_error(
2116 cx: &impl context::Actor,
2117 stream_actor: ActorHandle<StreamActor>,
2118 seq: Seq,
2119 reference: Ref,
2120 controller_rx: &mut PortReceiver<ControllerMessage>,
2121 expected_backtrace: &str,
2122 ) {
2123 fetch_result(cx, stream_actor, seq, reference).await;
2124
2125 let controller_msg = controller_rx.recv().await.unwrap();
2126 match controller_msg {
2127 ControllerMessage::FetchResult {
2128 seq: actual_seq,
2129 value: Err(err),
2130 } => {
2131 assert_eq!(actual_seq, seq);
2132 assert!(
2133 err.backtrace.contains(expected_backtrace),
2134 "backtrace did not contain {:?}: {:?}",
2135 expected_backtrace,
2136 err.backtrace
2137 );
2138 }
2139 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2140 };
2141 }
2142
2143 #[allow(dead_code)]
2144 async fn check_fetch_result_value(
2145 cx: &impl context::Actor,
2146 stream_actor: ActorHandle<StreamActor>,
2147 seq: Seq,
2148 reference: Ref,
2149 controller_rx: &mut PortReceiver<ControllerMessage>,
2150 ) {
2151 fetch_result(cx, stream_actor, seq, reference).await;
2152
2153 let controller_msg = controller_rx.recv().await.unwrap();
2154 match controller_msg {
2155 ControllerMessage::FetchResult {
2156 value: Ok(_),
2157 seq: actual_seq,
2158 } => assert_eq!(seq, actual_seq),
2159 _ => panic!("Unexpected controller message: {:?}", controller_msg),
2160 };
2161 }
2162
2163 #[async_timed_test(timeout_secs = 60)]
2164 async fn test_define_recording_other_recording_active() -> Result<()> {
2165 let test_setup = TestSetup::new().await?;
2166 test_setup
2167 .stream_actor
2168 .define_recording(&test_setup.client, 0.into())
2169 .await?;
2170 test_setup
2171 .stream_actor
2172 .define_recording(&test_setup.client, 1.into())
2173 .await?;
2174 assert_actor_failed_with_msg(
2175 &test_setup.proc,
2176 test_setup.stream_actor.actor_id(),
2177 "different recording already active".into(),
2178 )
2179 .await;
2180 Ok(())
2181 }
2182
2183 #[async_timed_test(timeout_secs = 60)]
2184 async fn test_define_recording_already_defined() -> Result<()> {
2185 let test_setup = TestSetup::new().await?;
2186 test_setup
2187 .stream_actor
2188 .define_recording(&test_setup.client, 0.into())
2189 .await?;
2190 test_setup
2191 .stream_actor
2192 .finalize_recording(&test_setup.client, 0.into())
2193 .await?;
2194 test_setup
2195 .stream_actor
2196 .define_recording(&test_setup.client, 0.into())
2197 .await?;
2198 assert_actor_failed_with_msg(
2199 &test_setup.proc,
2200 test_setup.stream_actor.actor_id(),
2201 "already defined".into(),
2202 )
2203 .await;
2204 Ok(())
2205 }
2206
2207 #[async_timed_test(timeout_secs = 60)]
2208 async fn test_finalize_recording_other_recording_active() -> Result<()> {
2209 let test_setup = TestSetup::new().await?;
2210 test_setup
2211 .stream_actor
2212 .define_recording(&test_setup.client, 0.into())
2213 .await?;
2214 test_setup
2215 .stream_actor
2216 .finalize_recording(&test_setup.client, 1.into())
2217 .await?;
2218 assert_actor_failed_with_msg(
2219 &test_setup.proc,
2220 test_setup.stream_actor.actor_id(),
2221 "cannot finalize recording that isn't active".into(),
2222 )
2223 .await;
2224 Ok(())
2225 }
2226
2227 #[async_timed_test(timeout_secs = 60)]
2228 async fn test_recording_formal_outside_recording() -> Result<()> {
2229 let test_setup = TestSetup::new().await?;
2230 test_setup
2231 .stream_actor
2232 .recording_formal(&test_setup.client, 0.into(), 0)
2233 .await?;
2234 assert_actor_failed_with_msg(
2235 &test_setup.proc,
2236 test_setup.stream_actor.actor_id(),
2237 "recording_formal called outside of recording".into(),
2238 )
2239 .await;
2240 Ok(())
2241 }
2242
2243 #[async_timed_test(timeout_secs = 60)]
2244 async fn test_recording_result_outside_recording() -> Result<()> {
2245 let test_setup = TestSetup::new().await?;
2246 test_setup
2247 .stream_actor
2248 .recording_result(&test_setup.client, 0.into(), 0)
2249 .await?;
2250 assert_actor_failed_with_msg(
2251 &test_setup.proc,
2252 test_setup.stream_actor.actor_id(),
2253 "recording_result called outside of recording".into(),
2254 )
2255 .await;
2256 Ok(())
2257 }
2258
2259 #[async_timed_test(timeout_secs = 60)]
2260 async fn test_call_recording_other_recording_active() -> Result<()> {
2261 let test_setup = TestSetup::new().await?;
2262 test_setup
2263 .stream_actor
2264 .define_recording(&test_setup.client, 0.into())
2265 .await?;
2266 test_setup
2267 .stream_actor
2268 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2269 .await?;
2270 assert_actor_failed_with_msg(
2271 &test_setup.proc,
2272 test_setup.stream_actor.actor_id(),
2273 "cannot call recording while another recording is active".into(),
2274 )
2275 .await;
2276 Ok(())
2277 }
2278
2279 #[async_timed_test(timeout_secs = 60)]
2280 async fn test_call_recording_not_found() -> Result<()> {
2281 let test_setup = TestSetup::new().await?;
2282 test_setup
2283 .stream_actor
2284 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2285 .await?;
2286 assert_actor_failed_with_msg(
2287 &test_setup.proc,
2288 test_setup.stream_actor.actor_id(),
2289 "not found".into(),
2290 )
2291 .await;
2292 Ok(())
2293 }
2294
2295 #[async_timed_test(timeout_secs = 60)]
2296 async fn test_recording_formal_too_few_arguments() -> Result<()> {
2297 let test_setup = TestSetup::new().await?;
2298
2299 test_setup
2300 .stream_actor
2301 .define_recording(&test_setup.client, 0.into())
2302 .await?;
2303
2304 test_setup
2305 .stream_actor
2306 .recording_formal(&test_setup.client, 1.into(), 0)
2307 .await?;
2308
2309 test_setup
2310 .stream_actor
2311 .finalize_recording(&test_setup.client, 0.into())
2312 .await?;
2313
2314 test_setup
2315 .stream_actor
2316 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2317 .await?;
2318
2319 assert_actor_failed_with_msg(
2320 &test_setup.proc,
2321 test_setup.stream_actor.actor_id(),
2322 "recording_formal called with too few arguments".into(),
2323 )
2324 .await;
2325 Ok(())
2326 }
2327
2328 #[async_timed_test(timeout_secs = 60)]
2329 async fn test_recording_result_too_few_results() -> Result<()> {
2330 let test_setup = TestSetup::new().await?;
2331
2332 test_setup
2333 .stream_actor
2334 .define_recording(&test_setup.client, 0.into())
2335 .await?;
2336
2337 test_setup
2338 .stream_actor
2339 .recording_result(&test_setup.client, 1.into(), 0)
2340 .await?;
2341
2342 test_setup
2343 .stream_actor
2344 .finalize_recording(&test_setup.client, 0.into())
2345 .await?;
2346
2347 test_setup
2348 .stream_actor
2349 .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2350 .await?;
2351
2352 assert_actor_failed_with_msg(
2353 &test_setup.proc,
2354 test_setup.stream_actor.actor_id(),
2355 "recording_result called with too few results".into(),
2356 )
2357 .await;
2358 Ok(())
2359 }
2360
2361 #[async_timed_test(timeout_secs = 60)]
2362 async fn test_basic_call_recording() -> Result<()> {
2363 let mut test_setup = TestSetup::new().await?;
2364
2365 test_setup
2369 .stream_actor
2370 .define_recording(&test_setup.client, 0.into())
2371 .await?;
2372
2373 let formal0_ref = 1.into();
2374 let formal0_index = 1;
2375 test_setup
2376 .stream_actor
2377 .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2378 .await?;
2379
2380 let formal1_ref = 2.into();
2381 let formal1_index = 0;
2382 test_setup
2383 .stream_actor
2384 .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2385 .await?;
2386
2387 let result0_ref = formal0_ref;
2388 let result0_index = 0;
2389 test_setup
2390 .stream_actor
2391 .recording_result(&test_setup.client, result0_ref, result0_index)
2392 .await?;
2393
2394 let result1_ref = formal1_ref;
2395 let result1_index = 1;
2396 test_setup
2397 .stream_actor
2398 .recording_result(&test_setup.client, result1_ref, result1_index)
2399 .await?;
2400
2401 test_setup
2402 .stream_actor
2403 .finalize_recording(&test_setup.client, 0.into())
2404 .await?;
2405
2406 let actual0_ref = 3.into();
2407 test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2408
2409 let actual1_ref = 4.into();
2410 test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2411
2412 let actual_result0_ref = 5.into();
2415 let actual_result1_ref = 6.into();
2416 test_setup
2417 .stream_actor
2418 .call_recording(
2419 &test_setup.client,
2420 0.into(),
2421 0.into(),
2422 vec![actual_result0_ref, actual_result1_ref],
2423 vec![actual0_ref, actual1_ref],
2424 )
2425 .await?;
2426
2427 assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2429 assert!(
2430 test_setup
2431 .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2432 .await
2433 );
2434
2435 assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2438 Ok(())
2439 }
2440
2441 #[async_timed_test(timeout_secs = 60)]
2442 async fn test_request_status_in_recording() -> Result<()> {
2443 let test_setup = TestSetup::new().await?;
2444 test_setup
2445 .stream_actor
2446 .define_recording(&test_setup.client, 0.into())
2447 .await?;
2448 test_setup
2449 .stream_actor
2450 .request_status(&test_setup.client)
2451 .await
2452 .expect_err("request_status should have failed");
2453 assert_actor_failed_with_msg(
2454 &test_setup.proc,
2455 test_setup.stream_actor.actor_id(),
2456 "request_status not allowed in recording".into(),
2457 )
2458 .await;
2459 Ok(())
2460 }
2461
2462 #[async_timed_test(timeout_secs = 60)]
2463 async fn test_init_comm_in_recording() -> Result<()> {
2464 let test_setup = TestSetup::new().await?;
2465 test_setup
2466 .stream_actor
2467 .define_recording(&test_setup.client, 0.into())
2468 .await?;
2469
2470 let dummy_comm = test_setup.proc.spawn(
2471 "comm",
2472 NcclCommActor::new(CommParams::New {
2473 device: CudaDevice::new(0.into()),
2474 unique_id: UniqueId::new()?,
2475 world_size: 1,
2476 rank: 0,
2477 })
2478 .await
2479 .unwrap(),
2480 )?;
2481
2482 test_setup
2483 .stream_actor
2484 .init_comm(&test_setup.client, dummy_comm)
2485 .await?;
2486 assert_actor_failed_with_msg(
2487 &test_setup.proc,
2488 test_setup.stream_actor.actor_id(),
2489 "init_comm not allowed in recording".into(),
2490 )
2491 .await;
2492 Ok(())
2493 }
2494
2495 #[async_timed_test(timeout_secs = 60)]
2496 async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2497 let mut test_setup = TestSetup::new().await?;
2498 test_setup
2499 .stream_actor
2500 .define_recording(&test_setup.client, 0.into())
2501 .await?;
2502
2503 let borrow_id = 1;
2504 let tensor_ref = test_setup.next_ref();
2505 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2506
2507 test_setup
2508 .stream_actor
2509 .borrow_create(
2510 &test_setup.client,
2511 borrow_id,
2512 tensor_ref,
2513 first_use_sender.clone(),
2514 )
2515 .await?;
2516
2517 test_setup
2518 .stream_actor
2519 .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2520 .await?;
2521
2522 assert_actor_failed_with_msg(
2523 &test_setup.proc,
2524 test_setup.stream_actor.actor_id(),
2525 "duplicate borrow create in recording".into(),
2526 )
2527 .await;
2528
2529 Ok(())
2530 }
2531
2532 #[async_timed_test(timeout_secs = 60)]
2533 async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2534 let test_setup = TestSetup::new().await?;
2535 test_setup
2536 .stream_actor
2537 .define_recording(&test_setup.client, 0.into())
2538 .await?;
2539
2540 let borrow_id = 1;
2541 let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2542
2543 test_setup
2544 .stream_actor
2545 .borrow_drop(
2546 &test_setup.client,
2547 borrow_id,
2548 Arc::new(Mutex::new(last_use_receiver)),
2549 )
2550 .await?;
2551
2552 assert_actor_failed_with_msg(
2553 &test_setup.proc,
2554 test_setup.stream_actor.actor_id(),
2555 "borrow drop for borrow not defined in recording".into(),
2556 )
2557 .await;
2558
2559 Ok(())
2560 }
2561
2562 #[async_timed_test(timeout_secs = 60)]
2563 async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2564 let mut test_setup = TestSetup::new().await?;
2565 test_setup
2566 .stream_actor
2567 .define_recording(&test_setup.client, 0.into())
2568 .await?;
2569
2570 let borrow_id = 1;
2571 let tensor_ref = test_setup.next_ref();
2572 let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2573
2574 test_setup
2575 .stream_actor
2576 .borrow_create(
2577 &test_setup.client,
2578 borrow_id,
2579 tensor_ref,
2580 first_use_sender.clone(),
2581 )
2582 .await?;
2583
2584 test_setup
2586 .stream_actor
2587 .finalize_recording(&test_setup.client, 0.into())
2588 .await?;
2589
2590 assert_actor_failed_with_msg(
2591 &test_setup.proc,
2592 test_setup.stream_actor.actor_id(),
2593 "all borrows created within recording must be dropped within recording".into(),
2594 )
2595 .await;
2596
2597 Ok(())
2598 }
2599}