monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::Context;
25use hyperactor::HandleClient;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::PortHandle;
29use hyperactor::actor::ActorHandle;
30use hyperactor::handle;
31use hyperactor::mailbox::OncePortHandle;
32use hyperactor::mailbox::PortReceiver;
33use hyperactor::proc::Proc;
34use hyperactor::reference;
35use monarch_hyperactor::actor::PythonMessage;
36use monarch_hyperactor::actor::PythonMessageKind;
37use monarch_hyperactor::local_state_broker::BrokerId;
38use monarch_hyperactor::local_state_broker::LocalState;
39use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
40use monarch_hyperactor::pickle::pickle;
41use monarch_messages::controller::ControllerMessageClient;
42use monarch_messages::controller::Seq;
43use monarch_messages::controller::WorkerError;
44use monarch_messages::worker::ActorCallParams;
45use monarch_messages::worker::ActorMethodParams;
46use monarch_messages::worker::ArgsKwargs;
47use monarch_messages::worker::CallFunctionError;
48use monarch_messages::worker::CallFunctionParams;
49use monarch_messages::worker::SeqError;
50use monarch_messages::worker::StreamRef;
51use monarch_types::PyTree;
52use monarch_types::SerializablePyErr;
53use monarch_types::TryIntoPyObjectUnsafe;
54use pyo3::prelude::*;
55use tokio::runtime::Handle;
56use tokio::sync::Mutex;
57use tokio::task::JoinHandle;
58use torch_sys_cuda::cuda::Event;
59use torch_sys_cuda::cuda::Stream;
60use torch_sys2::CloneUnsafe;
61use torch_sys2::CudaDevice;
62use torch_sys2::TensorCell;
63use torch_sys2::deep_clone;
64use torch_sys2::factory_empty;
65use torch_sys2::factory_zeros;
66use tracing_subscriber::fmt::Subscriber;
67use typeuri::Named;
68
69use crate::ControllerActor;
70use crate::DeviceMesh;
71use crate::Factory;
72use crate::Reduction;
73use crate::Ref;
74use crate::ResolvableFunction;
75use crate::StreamCreationMode;
76use crate::WireValue;
77use crate::comm::CommMessage;
78use crate::comm::CommMessageClient;
79use crate::comm::NcclCommActor;
80
81pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
82
83// These thread locals are accessed by the python runtime for debugging sessions.
84thread_local! {
85    pub static CONTROLLER_ACTOR_REF: OnceCell<reference::ActorRef<ControllerActor>> = const { OnceCell::new() };
86    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
87    pub static ROOT_ACTOR_ID: OnceCell<reference::ActorId> = const { OnceCell::new() };
88}
89
90fn pickle_python_result(
91    py: Python<'_>,
92    result: Bound<'_, PyAny>,
93    worker_rank: usize,
94) -> Result<PythonMessage, anyhow::Error> {
95    let mut state = pickle(py, result.unbind(), false, false)
96        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
97    let inner = state
98        .take_inner()
99        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
100    Ok(PythonMessage::new_from_buf(
101        PythonMessageKind::Result {
102            rank: Some(worker_rank),
103        },
104        inner.take_buffer(),
105    ))
106}
107
108#[derive(Debug)]
109struct Recording {
110    messages: Vec<StreamMessage>,
111}
112
113impl Recording {
114    fn new() -> Self {
115        Self {
116            messages: Vec::new(),
117        }
118    }
119}
120
121#[derive(Debug, PartialEq)]
122enum RecordingState {
123    Defining {
124        recording: Ref,
125        // Set of borrow ids used to track proper borrow usage inside
126        // a recording.
127        defined_borrows: HashSet<u64>,
128    },
129    Running,
130}
131
132/// Messages handled by the stream. Generally these are stream-local versions of
133/// [`crate::WorkerMessage`].
134#[derive(Handler, HandleClient, Debug, Named)]
135pub enum StreamMessage {
136    CallFunction(
137        CallFunctionParams,
138        HashMap<Ref, DeviceMesh>,
139        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
140    ),
141
142    BorrowCreate {
143        /// Id for the borrow.
144        borrow: u64,
145        /// Tensor to borrow.
146        tensor: Ref,
147        /// Port for sending the first use CUDA event + borrowed tensor to
148        /// the borrower.
149        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
150    },
151
152    BorrowFirstUse {
153        /// Id for the borrow.
154        borrow: u64,
155        /// Ref for storing the borrowed tensor.
156        result: Ref,
157        /// Port for receiving the first use CUDA event + borrowed tensor from
158        /// the provider stream.
159        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
160    },
161
162    BorrowLastUse {
163        /// Id for the borrow.
164        borrow: u64,
165        /// Ref for the borrowed tensor.
166        result: Ref,
167        /// Port for sending the last use CUDA event and borrowed tensor.
168        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
169    },
170
171    BorrowDrop {
172        borrow: u64,
173        /// Port for receiving the last use CUDA event and borrowed tensor.
174        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
175    },
176
177    DeleteRefs(Vec<Ref>),
178
179    RequestStatus(#[reply] OncePortHandle<()>),
180
181    InitComm(ActorHandle<NcclCommActor>),
182
183    Reduce {
184        comm: Arc<ActorHandle<NcclCommActor>>,
185        dim_size: i64,
186        result: Ref,
187        local_tensor: Ref,
188        factory: Factory,
189        reduction: Reduction,
190        scatter: bool,
191        in_place: bool,
192        out: Option<Ref>,
193    },
194
195    SendTensor {
196        result: Ref,
197        from_rank: Option<usize>,
198        to_rank: Option<usize>,
199        tensor: Ref,
200        factory: Factory,
201        comm: Arc<ActorHandle<NcclCommActor>>,
202    },
203
204    SendValue {
205        seq: Seq,
206        worker_actor_id: reference::ActorId,
207        mutates: Vec<Ref>,
208        function: Option<ResolvableFunction>,
209        args_kwargs: ArgsKwargs,
210        device_meshes: HashMap<Ref, DeviceMesh>,
211    },
212
213    DefineRecording {
214        recording: Ref,
215    },
216
217    FinalizeRecording {
218        recording: Ref,
219    },
220
221    CallRecording {
222        seq: Seq,
223        recording: Ref,
224        results: Vec<Ref>,
225        actuals: Vec<Ref>,
226    },
227
228    RecordingFormal {
229        result: Ref,
230        argument_index: usize,
231    },
232
233    RecordingResult {
234        result: Ref,
235        output_index: usize,
236    },
237
238    SetRefUnitTestsOnly(Ref, WireValue),
239
240    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
241
242    GetRefUnitTestsOnly(
243        Ref, // value
244        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
245    ),
246
247    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
248
249    SendResultOfActorCall(ActorCallParams),
250    CallActorMethod(ActorMethodParams),
251}
252
253impl StreamMessage {
254    fn clone_for_recording(&self) -> Self {
255        match self {
256            StreamMessage::RecordingFormal {
257                result,
258                argument_index,
259            } => StreamMessage::RecordingFormal {
260                result: *result,
261                argument_index: *argument_index,
262            },
263            StreamMessage::RecordingResult {
264                result,
265                output_index,
266            } => StreamMessage::RecordingResult {
267                result: *result,
268                output_index: *output_index,
269            },
270            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
271            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
272                StreamMessage::CallFunction(
273                    params.clone(),
274                    device_meshes.clone(),
275                    remote_process_groups.clone(),
276                )
277            }
278            StreamMessage::BorrowCreate {
279                borrow,
280                tensor,
281                first_use_sender,
282            } => StreamMessage::BorrowCreate {
283                borrow: *borrow,
284                tensor: *tensor,
285                first_use_sender: first_use_sender.clone(),
286            },
287            StreamMessage::BorrowFirstUse {
288                borrow,
289                result,
290                first_use_receiver,
291            } => StreamMessage::BorrowFirstUse {
292                borrow: *borrow,
293                result: *result,
294                first_use_receiver: first_use_receiver.clone(),
295            },
296            StreamMessage::BorrowLastUse {
297                borrow,
298                result,
299                last_use_sender,
300            } => StreamMessage::BorrowLastUse {
301                borrow: *borrow,
302                result: *result,
303                last_use_sender: last_use_sender.clone(),
304            },
305            StreamMessage::BorrowDrop {
306                borrow,
307                last_use_receiver,
308            } => StreamMessage::BorrowDrop {
309                borrow: *borrow,
310                last_use_receiver: last_use_receiver.clone(),
311            },
312            StreamMessage::Reduce {
313                comm,
314                dim_size,
315                result,
316                local_tensor,
317                factory,
318                reduction,
319                scatter,
320                in_place,
321                out,
322            } => StreamMessage::Reduce {
323                comm: comm.clone(),
324                dim_size: *dim_size,
325                result: *result,
326                local_tensor: *local_tensor,
327                factory: factory.clone(),
328                reduction: reduction.clone(),
329                scatter: *scatter,
330                in_place: *in_place,
331                out: out.clone(),
332            },
333            StreamMessage::SendTensor {
334                result,
335                from_rank,
336                to_rank,
337                tensor,
338                factory,
339                comm,
340            } => StreamMessage::SendTensor {
341                result: *result,
342                from_rank: *from_rank,
343                to_rank: *to_rank,
344                tensor: *tensor,
345                factory: factory.clone(),
346                comm: comm.clone(),
347            },
348            other => panic!(
349                "StreamMessage variant not supported in recording: {:?}",
350                other
351            ),
352        }
353    }
354
355    // Get the set of refs that this message defines.
356    fn get_defined_refs(&self) -> HashSet<Ref> {
357        match self {
358            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
359            StreamMessage::CallFunction(params, ..) => {
360                params.results.iter().filter_map(|&ref_| ref_).collect()
361            }
362            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
363            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
364            StreamMessage::SendTensor {
365                result, from_rank, ..
366            } => {
367                if from_rank.is_some() {
368                    HashSet::from([*result])
369                } else {
370                    HashSet::new()
371                }
372            }
373            // TODO(slurye): Add SendValue eventually.
374            _ => HashSet::new(),
375        }
376    }
377
378    // Get the set of refs that this message mutates.
379    fn get_mutated_refs(&self) -> HashSet<Ref> {
380        match self {
381            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
382            StreamMessage::Reduce {
383                out,
384                in_place,
385                local_tensor,
386                ..
387            } => {
388                if *in_place {
389                    HashSet::from([*local_tensor])
390                } else if let Some(out) = out {
391                    HashSet::from([*out])
392                } else {
393                    HashSet::new()
394                }
395            }
396            // TODO(slurye): Add SendValue eventually.
397            _ => HashSet::new(),
398        }
399    }
400}
401
402/// A stream represents a linear sequence of execution. Operations on different
403/// streams can execute concurrently.
404///
405/// For CUDA operators, streams will invoke the corresponding stream management
406/// APIs to perform synchronization.
407///
408/// For CPU operators, streams will just execute synchronously on their own OS
409/// thread.
410#[derive(Debug)]
411pub struct StreamActor {
412    _world_size: usize,
413    rank: usize,
414    /// Mapping of refs in the controller environment to TensorIndex in this
415    /// stream's local environment.
416    // TODO(agallagher): Use `ValueError` as the error type.
417    env: HashMap<Ref, Result<Py<PyAny>, Arc<SeqError>>>,
418    /// How to create the stream.
419    creation_mode: StreamCreationMode,
420    /// CUDA stream that this actor will enqueue operations on. None if "device"
421    /// is not a CUDA device.
422    /// NOTE: We lazily create the stream, so that we do it from the dedicated
423    /// Stream OS thread as, otherwise, we see deadlocks when done from
424    /// unexpected threads.
425    cuda_stream: OnceLock<Option<Stream>>,
426    /// Device this stream should be scheduled on.
427    device: Option<CudaDevice>,
428    /// Communicator for this stream. Optional as we lazily initialize it.
429    comm: Option<ActorHandle<NcclCommActor>>,
430    /// Actor ref of the controller that created this stream.
431    controller_actor: reference::ActorRef<ControllerActor>,
432    remote_process_groups: HashMap<Ref, Py<PyAny>>,
433    recordings: HashMap<Ref, Recording>,
434    active_recording: Option<RecordingState>,
435    respond_with_python_message: bool,
436    last_seq_error: Option<Arc<SeqError>>,
437}
438
439/// Parameters for creating a [`Stream`].
440#[derive(Debug, Clone)]
441pub struct StreamParams {
442    pub world_size: usize,
443    pub rank: usize,
444    /// Controls how the underlying CUDA stream is created.
445    pub creation_mode: StreamCreationMode,
446    /// Id of this stream in the worker actor's stream table.
447    pub id: StreamRef,
448    /// Device this stream should be scheduled on. If none, don't do stream
449    /// synchronization.
450    pub device: Option<CudaDevice>,
451    /// Actor ref of the controller that created this stream.
452    pub controller_actor: reference::ActorRef<ControllerActor>,
453    pub respond_with_python_message: bool,
454}
455
456impl StreamActor {
457    pub fn new(
458        StreamParams {
459            world_size,
460            rank,
461            id: _,
462            device,
463            controller_actor,
464            creation_mode,
465            respond_with_python_message,
466        }: StreamParams,
467    ) -> Self {
468        Self {
469            _world_size: world_size,
470            rank,
471            env: HashMap::new(),
472            creation_mode,
473            cuda_stream: OnceLock::new(),
474            device,
475            comm: None,
476            controller_actor,
477            remote_process_groups: HashMap::new(),
478            recordings: HashMap::new(),
479            active_recording: None,
480            respond_with_python_message,
481            last_seq_error: None,
482        }
483    }
484}
485
486#[async_trait]
487impl Actor for StreamActor {
488    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
489        // These thread locals are exposed via python functions, so we need to set them in the
490        // same thread that python will run in. That means we need to initialize them here in
491        // StreamActor::init instead of in StreamActor::new.
492        CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
493            controller_actor_ref.set(self.controller_actor.clone()).ok()
494        });
495        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
496        ROOT_ACTOR_ID.with(|root_actor_id| {
497            root_actor_id
498                .set(reference::ActorId::root(
499                    cx.self_id().proc_id().clone(),
500                    cx.self_id().name().to_string(),
501                ))
502                .ok()
503        });
504        // Set the current stream for this actor thread.
505        if let Some(stream) = self.cuda_stream() {
506            Stream::set_current_stream(stream);
507        }
508        Ok(())
509    }
510
511    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
512    /// dedicated OS thread. This is because:
513    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
514    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
515    ///     threads is problematic.
516    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
517    where
518        F: Future + Send + 'static,
519        F::Output: Send + 'static,
520    {
521        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
522        // It is important that we spawn a standalone thread for the work here,
523        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
524        // This is because the worker stream may call uninterruptible FFI code
525        // that can deadlock (CUDA, NCCL).
526        // If we use a tokio-managed blocking thread, then runtime teardown will
527        // try to wait for tasks on that thread to reach an await point, and
528        // hang forever.
529        let builder = std::thread::Builder::new().name("worker-stream".to_string());
530        let _thread_handle = builder.spawn(move || {
531            // Spawn a new thread with a single-threaded tokio runtime to run the
532            // actor loop.  We avoid the current-threaded runtime, so that we can
533            // use `block_in_place` for nested async-to-sync-to-async flows.
534            let rt = tokio::runtime::Builder::new_multi_thread()
535                .worker_threads(1)
536                .enable_all()
537                .build()
538                .unwrap();
539            let result = rt.block_on(async {
540                tokio::task::block_in_place(|| {
541                    // Allow e.g. destructing py objects on this thread, which
542                    // can happen at shutdown when the a stream actors env map
543                    // for rvalues is dropped (e.g. P1673311499).
544                    // https://github.com/PyO3/pyo3/discussions/3499
545                    Python::attach(|py| {
546                        py.detach(|| {
547                            let result = Handle::current().block_on(future);
548                            if join_tx.send(result).is_err() {
549                                panic!("could not send join result")
550                            }
551                        })
552                    })
553                })
554            });
555            rt.shutdown_timeout(Duration::from_weeks(1));
556            result
557        });
558
559        // In order to bridge the synchronous join handle with the async world,
560        // smuggle the result through a channel.
561        tokio::spawn(async move { join_rx.await.unwrap() })
562    }
563}
564
565/// The arguments we accept as inputs to Python function calls.
566#[derive(Debug)]
567enum PyArg {
568    Object(Py<PyAny>),
569}
570
571/// Serialize into a `Py<PyAny>`.
572impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
573    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
574        match self {
575            PyArg::Object(obj) => Ok(obj.clone_ref(py).into_bound(py)),
576        }
577    }
578}
579
580impl StreamActor {
581    fn tensor_to_pyobject(tensor_cell: TensorCell) -> Py<PyAny> {
582        Python::attach(|py| {
583            // SAFETY: Cloning a tensor was unsafe because we were tracking their references like
584            // Rust objects (single mutable reference or many immutable references). We are
585            // removing this functionality in upcoming patches, so we use the unsafe version here
586            // until that happens.
587            let tensor = unsafe {
588                // Get the owned tensor by calling clone_unsafe on the reference
589                tensor_cell.get_unchecked().clone_unsafe()
590            };
591            tensor.into_pyobject(py).unwrap().unbind()
592        })
593    }
594
595    /// Extract a TensorCell from a Py<PyAny>.
596    /// SAFETY: Uses new to create the TensorCell. Caller must ensure the Py<PyAny>
597    /// contains a valid tensor.
598    fn pyobject_to_tensor(py: Python<'_>, pyobj: &Py<PyAny>) -> PyResult<TensorCell> {
599        use torch_sys2::Tensor;
600        let tensor = pyobj.bind(py).extract::<Tensor>()?;
601        // Create a new TensorCell from the extracted tensor
602        Ok(TensorCell::new(tensor))
603    }
604
605    fn cuda_stream(&self) -> Option<&Stream> {
606        self.cuda_stream
607            .get_or_init(|| {
608                self.device.map(|device| match self.creation_mode {
609                    StreamCreationMode::UseDefaultStream => {
610                        Stream::get_current_stream_on_device(device)
611                    }
612                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
613                })
614            })
615            .as_ref()
616    }
617
618    fn ref_to_pyobject(&self, ref_: &Ref) -> Result<Py<PyAny>, CallFunctionError> {
619        let pyobject = self
620            .env
621            .get(ref_)
622            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
623        match pyobject {
624            Ok(val) => Ok(val.clone()),
625            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
626        }
627    }
628
629    async fn report_seq_error(
630        &mut self,
631        cx: &Context<'_, Self>,
632        seq: Seq,
633        error: CallFunctionError,
634    ) -> Result<Arc<SeqError>, anyhow::Error> {
635        match error {
636            CallFunctionError::DependentError(root) => Ok(root),
637            CallFunctionError::Error(e) => {
638                if self.active_recording.is_none() {
639                    let worker_error = WorkerError {
640                        backtrace: format!("{e}"),
641                        worker_actor_id: cx.self_id().clone(),
642                    };
643                    tracing::info!("Propagating remote function error to client: {worker_error}");
644                    self.controller_actor
645                        .remote_function_failed(cx, seq, worker_error)
646                        .await?
647                }
648                let err = Arc::new(SeqError { seq, error: e });
649                self.last_seq_error = Some(err.clone());
650                Ok(err)
651            }
652        }
653    }
654
655    async fn try_define<F>(
656        &mut self,
657        cx: &Context<'_, Self>,
658        seq: Seq,
659        result_refs: Vec<Option<Ref>>,
660        mutates: &Vec<Ref>,
661        f: F,
662    ) -> Result<()>
663    where
664        F: AsyncFnOnce(&mut Self) -> Result<Vec<Py<PyAny>>, CallFunctionError>,
665    {
666        let actual_results = f(self).await;
667        // Check if the expected number of returns is correct, otherwise convert
668        // into an error.
669        let op_results = actual_results.and_then(|actual_results| {
670            if result_refs.len() == actual_results.len() {
671                Ok(actual_results
672                    .into_iter()
673                    .zip(result_refs.iter())
674                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
675                    .collect::<Vec<(Ref, Py<PyAny>)>>())
676            } else {
677                Err(CallFunctionError::UnexpectedNumberOfReturns(
678                    result_refs.len(),
679                    actual_results.len(),
680                ))
681            }
682        });
683
684        // Propagate the results (either the actual values or an error) to the
685        // right entries in the global env mapping.
686        match op_results {
687            Ok(op_results) => {
688                for (ref_, pyobject) in op_results.into_iter() {
689                    let prev = self.env.insert(ref_, Ok(pyobject));
690                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
691                }
692            }
693            Err(err) => {
694                let err = self.report_seq_error(cx, seq, err).await?;
695                for ref_ in result_refs {
696                    match ref_ {
697                        Some(ref_) => {
698                            let prev = self.env.insert(ref_, Err(err.clone()));
699                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
700                        }
701                        None => {}
702                    }
703                }
704                for ref_ in mutates {
705                    self.env.insert(*ref_, Err(err.clone()));
706                }
707            }
708        }
709        Ok(())
710    }
711
712    fn call_python_fn<'py>(
713        &mut self,
714        py: Python<'py>,
715        _cx: &Context<Self>,
716        function: Option<ResolvableFunction>,
717        args_kwargs: ArgsKwargs,
718        _mutates: &[Ref],
719        device_meshes: HashMap<Ref, DeviceMesh>,
720        remote_process_groups: HashMap<
721            Ref,
722            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
723        >,
724    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
725        let (args_tuple, kwargs_dict) = args_kwargs
726            .to_python(py)
727            .map_err(|e| CallFunctionError::Error(e.into()))?;
728        let function = function
729            .map(|function| {
730                function.resolve(py).map_err(|e| {
731                    CallFunctionError::InvalidRemoteFunction(format!(
732                        "failed to resolve function {}: {}",
733                        function,
734                        SerializablePyErr::from(py, &e)
735                    ))
736                })
737            })
738            .transpose()?;
739
740        let remote_process_groups = remote_process_groups
741            .into_iter()
742            .map(|(gref, (_mesh, _dims, _comm))| {
743                let group = match self.remote_process_groups.entry(gref) {
744                    Entry::Occupied(ent) => ent.get().clone_ref(py),
745                    Entry::Vacant(_ent) => {
746                        panic!("no longer implemented");
747                    }
748                };
749                PyResult::Ok((gref, group))
750            })
751            .collect::<Result<HashMap<_, _>, _>>()
752            .map_err(SerializablePyErr::from_fn(py))?;
753
754        let resolve = |val: Bound<'py, PyAny>| {
755            val.extract::<PyTree<Py<PyAny>>>()
756                .map_err(SerializablePyErr::from_fn(py))?
757                .try_into_map(|obj| {
758                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
759                        if let Some(mesh) = device_meshes.get(&ref_) {
760                            PyArg::Object(
761                                Py::new(py, mesh.clone())
762                                    .map_err(SerializablePyErr::from_fn(py))?
763                                    .into(),
764                            )
765                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
766                            PyArg::Object(pg.clone_ref(py))
767                        } else {
768                            let pyobj = self.ref_to_pyobject(&ref_)?;
769                            PyArg::Object(pyobj)
770                        }
771                    } else {
772                        PyArg::Object(obj)
773                    })
774                })
775        };
776
777        // Resolve args and kwargs
778        let py_args: Vec<PyTree<PyArg>> = args_tuple
779            .iter()
780            .map(&resolve)
781            .collect::<Result<_, CallFunctionError>>()?;
782
783        let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
784            .iter()
785            .map(|(k, v)| {
786                let key = k
787                    .extract::<String>()
788                    .map_err(SerializablePyErr::from_fn(py))?;
789                let value = resolve(v)?;
790                Ok((key, value))
791            })
792            .collect::<Result<_, CallFunctionError>>()?;
793
794        // Call function.
795        // Use custom subscriber to route Worker messages to stdout.
796        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
797        let result: Bound<'_, PyAny> =
798            tracing::subscriber::with_default(scoped_subscriber, || {
799                // TODO(agallagher): The args/kwargs conversion traits generate
800                // the appropriate types here, but they get casted to `PyAny`.
801                // It'd be nice to make `TryToPy<PyAny>Unsafe` take a template
802                // arg for the converted py object to avoid this downcast.
803                // SAFETY: Tensor operations were unsafe because we were tracking their references
804                // like Rust objects (single mutable reference or many immutable references). We are
805                // removing this functionality in upcoming patches, so we use the unsafe version here
806                // until that happens.
807                let args = unsafe { py_args.try_to_object_unsafe(py) }
808                    .map_err(SerializablePyErr::from_fn(py))?;
809                // SAFETY: Same as above - reference tracking functionality is being removed.
810                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
811                    .map_err(SerializablePyErr::from_fn(py))?;
812
813                if let Some(function) = function {
814                    function
815                        .call(args, Some(kwargs))
816                        .map_err(SerializablePyErr::from_fn(py))
817                } else {
818                    Ok(args.get_item(0).unwrap())
819                }
820            })?;
821        Ok(result)
822    }
823
824    fn call_python_fn_pytree(
825        &mut self,
826        cx: &hyperactor::Context<Self>,
827        function: ResolvableFunction,
828        args_kwargs: ArgsKwargs,
829        mutates: &[Ref],
830        device_meshes: HashMap<Ref, DeviceMesh>,
831        remote_process_groups: HashMap<
832            Ref,
833            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
834        >,
835    ) -> Result<PyTree<Py<PyAny>>, CallFunctionError> {
836        Python::attach(|py| {
837            let result = self.call_python_fn(
838                py,
839                cx,
840                Some(function),
841                args_kwargs,
842                mutates,
843                device_meshes,
844                remote_process_groups,
845            )?;
846            Ok(PyTree::<Py<PyAny>>::extract_bound(&result)
847                .map_err(SerializablePyErr::from_fn(py))?)
848        })
849    }
850    /// Retrieve `ref_` or create a fake value with the provided factory if it
851    /// is an error. We use this for collective calls, where even if there was
852    /// an upstream failure, we still have participate in the collective to
853    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
854    /// value here of the correct shape; the controller will have been notified
855    /// of the upstream failure and will know to ignore everything dependent on
856    /// it.
857    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
858        let pyobject = self
859            .env
860            .get(&ref_)
861            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
862
863        match pyobject {
864            Ok(val) => Python::attach(|py| {
865                Self::pyobject_to_tensor(py, val)
866                    .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
867            }),
868            Err(_) => {
869                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
870                Ok(TensorCell::new(t))
871            }
872        }
873    }
874
875    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
876        self.active_recording
877            .as_mut()
878            .and_then(|state| match state {
879                RecordingState::Defining {
880                    recording,
881                    defined_borrows,
882                } => {
883                    match self.recordings.get_mut(recording) {
884                        Some(recording) => Some((recording, defined_borrows)),
885                        // Panic, because this would be a logic error in the program.
886                        None => panic!("recording not found: {:?}", recording),
887                    }
888                }
889                RecordingState::Running => None,
890            })
891    }
892
893    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
894        for ref_ in refs {
895            let rvalue_or_err = self
896                .env
897                .get(ref_)
898                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
899            if let Err(err) = rvalue_or_err {
900                return Ok(Some(err.clone()));
901            }
902        }
903        Ok(None)
904    }
905    async fn send_value_python_message(
906        &mut self,
907        cx: &hyperactor::Context<'_, Self>,
908        seq: Seq,
909        mutates: Vec<Ref>,
910        function: Option<ResolvableFunction>,
911        args_kwargs: ArgsKwargs,
912        device_meshes: HashMap<Ref, DeviceMesh>,
913    ) -> Result<()> {
914        let rank = self.rank;
915        self.try_define(cx, seq, vec![], &vec![], async |self_| {
916            let python_message =
917                Python::attach(|py| -> Result<PythonMessage, CallFunctionError> {
918                    let python_result = tokio::task::block_in_place(|| {
919                        self_.call_python_fn(
920                            py,
921                            cx,
922                            function,
923                            args_kwargs,
924                            &mutates,
925                            device_meshes,
926                            HashMap::new(),
927                        )
928                    })?;
929                    pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
930                })?;
931            let ser = wirevalue::Any::serialize(&python_message).unwrap();
932            self_
933                .controller_actor
934                .fetch_result(cx, seq, Ok(ser))
935                .await?;
936            Ok(vec![])
937        })
938        .await
939    }
940    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
941        let rvalue = self
942            .env
943            .get(&src)
944            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
945        self.env.insert(dest, Python::attach(|_py| rvalue.clone()));
946        Ok(())
947    }
948    async fn call_actor(
949        &mut self,
950        cx: &Context<'_, Self>,
951        params: ActorCallParams,
952    ) -> Result<Py<PyAny>, CallFunctionError> {
953        let local_state: Result<Vec<Py<PyAny>>> = Python::attach(|_py| {
954            params
955                .local_state
956                .into_iter()
957                .map(|elem| {
958                    let pyobj = self.ref_to_pyobject(&elem)?;
959                    Ok(pyobj.into_any())
960                })
961                .collect()
962        });
963
964        let (send, recv) = cx.open_once_port();
965        let state = LocalState {
966            response_port: send,
967            state: local_state?,
968        };
969        let x: u64 = params.seq.into();
970        let message = LocalStateBrokerMessage::Set(x as usize, state);
971
972        let broker = BrokerId::new(params.broker_id).resolve(cx).await;
973        broker
974            .send(cx, message)
975            .map_err(|e| CallFunctionError::Error(e.into()))?;
976        let result = recv
977            .recv()
978            .await
979            .map_err(|e| CallFunctionError::Error(e.into()))?;
980
981        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
982    }
983}
984
985#[async_trait]
986#[handle(StreamMessage)]
987impl StreamMessageHandler for StreamActor {
988    async fn call_function(
989        &mut self,
990        cx: &Context<Self>,
991        params: CallFunctionParams,
992        device_meshes: HashMap<Ref, DeviceMesh>,
993        remote_process_groups: HashMap<
994            Ref,
995            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
996        >,
997    ) -> Result<()> {
998        if let Some((recording, _)) = self.get_defining_recording() {
999            recording.messages.push(StreamMessage::CallFunction(
1000                params,
1001                device_meshes,
1002                remote_process_groups,
1003            ));
1004            return Ok(());
1005        }
1006
1007        params.function.panic_if_requested();
1008        self.try_define(
1009            cx,
1010            params.seq,
1011            params.results,
1012            &params.mutates,
1013            async |self| {
1014                tokio::task::block_in_place(|| {
1015                    self.call_python_fn_pytree(
1016                        cx,
1017                        params.function,
1018                        params.args_kwargs,
1019                        &params.mutates,
1020                        device_meshes,
1021                        remote_process_groups,
1022                    )
1023                    .map(|results| results.into_leaves())
1024                })
1025            },
1026        )
1027        .await?;
1028        Ok(())
1029    }
1030
1031    async fn borrow_create(
1032        &mut self,
1033        cx: &Context<Self>,
1034        borrow: u64,
1035        tensor: Ref,
1036        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1037    ) -> Result<()> {
1038        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1039            recording.messages.push(StreamMessage::BorrowCreate {
1040                borrow,
1041                tensor,
1042                first_use_sender,
1043            });
1044            ensure!(
1045                defined_borrows.insert(borrow),
1046                "duplicate borrow create in recording"
1047            );
1048            return Ok(());
1049        }
1050
1051        let pyobj_result = self
1052            .env
1053            .get(&tensor)
1054            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1055
1056        let result = match pyobj_result {
1057            Ok(pyobj) => Python::attach(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1058            Err(e) => Err(e.clone()),
1059        };
1060
1061        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1062        first_use_sender.send(cx, (event, result)).map_err(|err| {
1063            anyhow!(
1064                "failed sending first use event for borrow {:?}: {:?}",
1065                borrow,
1066                err
1067            )
1068        })
1069    }
1070
1071    async fn borrow_first_use(
1072        &mut self,
1073        _cx: &Context<Self>,
1074        borrow: u64,
1075        result: Ref,
1076        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1077    ) -> Result<()> {
1078        if let Some((recording, _)) = self.get_defining_recording() {
1079            recording.messages.push(StreamMessage::BorrowFirstUse {
1080                borrow,
1081                result,
1082                first_use_receiver: first_use_receiver.clone(),
1083            });
1084            return Ok(());
1085        }
1086
1087        let (first_use_event, cell) =
1088            first_use_receiver
1089                .lock()
1090                .await
1091                .recv()
1092                .await
1093                .map_err(|err| {
1094                    anyhow!(
1095                        "failed receiving first use event for borrow {:?}: {:?}",
1096                        borrow,
1097                        err
1098                    )
1099                })?;
1100
1101        if let Some(stream) = self.cuda_stream() {
1102            stream.wait_event(
1103                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1104            );
1105        }
1106        match cell {
1107            Ok(cell) => {
1108                let pyobj = Self::tensor_to_pyobject(cell);
1109                self.env.insert(result, Ok(pyobj));
1110            }
1111            Err(err) => {
1112                self.env.insert(result, Err(err.clone()));
1113            }
1114        }
1115        Ok(())
1116    }
1117
1118    async fn borrow_last_use(
1119        &mut self,
1120        cx: &Context<Self>,
1121        borrow: u64,
1122        result: Ref,
1123        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1124    ) -> Result<()> {
1125        if let Some((recording, _)) = self.get_defining_recording() {
1126            recording.messages.push(StreamMessage::BorrowLastUse {
1127                borrow,
1128                result,
1129                last_use_sender,
1130            });
1131            return Ok(());
1132        }
1133
1134        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1135        let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1136            "Invalid reference for borrow_last_use: {result:#?}"
1137        ))?;
1138        let tensor = match pyobj_or_err {
1139            Ok(pyobj) => Ok(Python::attach(|py| {
1140                Self::pyobject_to_tensor(py, &pyobj).unwrap()
1141            })),
1142            Err(e) => Err(e),
1143        };
1144
1145        last_use_sender.send(cx, (event, tensor)).map_err(|err| {
1146            anyhow!(
1147                "failed sending last use event for borrow {:?}: {:?}",
1148                borrow,
1149                err
1150            )
1151        })
1152    }
1153
1154    async fn borrow_drop(
1155        &mut self,
1156        _cx: &Context<Self>,
1157        borrow: u64,
1158        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1159    ) -> Result<()> {
1160        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1161            recording.messages.push(StreamMessage::BorrowDrop {
1162                borrow,
1163                last_use_receiver: last_use_receiver.clone(),
1164            });
1165            ensure!(
1166                defined_borrows.remove(&borrow),
1167                "borrow drop for borrow not defined in recording"
1168            );
1169            return Ok(());
1170        }
1171
1172        // The borrowed cell isn't used directly, but we still want to receive it here
1173        // so that the underlying tensor isn't dropped until after we synchronize the
1174        // CUDA streams.
1175        let (last_use_event, _cell) =
1176            last_use_receiver.lock().await.recv().await.map_err(|err| {
1177                anyhow!(
1178                    "failed receiving last use event for borrow {:?}: {:?}",
1179                    borrow,
1180                    err
1181                )
1182            })?;
1183
1184        if let Some(stream) = self.cuda_stream() {
1185            stream.wait_event(
1186                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1187            );
1188        }
1189        // let the cell drop.
1190        Ok(())
1191    }
1192
1193    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1194        if let Some((recording, _)) = self.get_defining_recording() {
1195            recording.messages.push(StreamMessage::DeleteRefs(refs));
1196            return Ok(());
1197        }
1198
1199        for ref_ in refs.iter() {
1200            self.env.remove(ref_);
1201        }
1202        Ok(())
1203    }
1204
1205    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1206        if self.get_defining_recording().is_some() {
1207            bail!("request_status not allowed in recording");
1208        }
1209
1210        Ok(())
1211    }
1212
1213    async fn init_comm(
1214        &mut self,
1215        _cx: &Context<Self>,
1216        comm: ActorHandle<NcclCommActor>,
1217    ) -> Result<()> {
1218        if self.get_defining_recording().is_some() {
1219            bail!("init_comm not allowed in recording");
1220        }
1221
1222        self.comm = Some(comm);
1223        Ok(())
1224    }
1225
1226    async fn reduce(
1227        &mut self,
1228        cx: &Context<Self>,
1229        comm: Arc<ActorHandle<NcclCommActor>>,
1230        dim_size: i64,
1231        result: Ref,
1232        local_tensor: Ref,
1233        factory: Factory,
1234        reduction: Reduction,
1235        scatter: bool,
1236        in_place: bool,
1237        out: Option<Ref>,
1238    ) -> Result<()> {
1239        if let Some((recording, _)) = self.get_defining_recording() {
1240            recording.messages.push(StreamMessage::Reduce {
1241                comm,
1242                dim_size,
1243                result,
1244                local_tensor,
1245                factory,
1246                reduction,
1247                scatter,
1248                in_place,
1249                out,
1250            });
1251            return Ok(());
1252        }
1253
1254        let stream = self
1255            .cuda_stream()
1256            .expect("reductions not yet supported for non-CUDA workers")
1257            .clone();
1258        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1259        let out_cell = out
1260            .map(|out| self.get_or_fake_on_err(out, &factory))
1261            .transpose()?;
1262        let output_cell = match reduction {
1263            Reduction::Stack => {
1264                if scatter {
1265                    let output_cell = if in_place {
1266                        input_cell.clone()
1267                    } else {
1268                        out_cell.unwrap_or({
1269                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1270                            let cloned = deep_clone(&borrow);
1271                            TensorCell::new(cloned)
1272                        })
1273                    };
1274                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1275                        .await?;
1276                    output_cell
1277                } else {
1278                    ensure!(
1279                        !in_place,
1280                        "in-place, non-scatter not supported for stack reduce"
1281                    );
1282
1283                    let output_cell = out_cell.unwrap_or({
1284                        // In Python, this would be [dim_size, *factory.sizes]
1285                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1286                        let output =
1287                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1288                        TensorCell::new(output)
1289                    });
1290
1291                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1292                        .await?;
1293                    output_cell
1294                }
1295            }
1296            Reduction::ReduceOp(op) => {
1297                if scatter {
1298                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1299
1300                    let output_cell = out_cell.unwrap_or({
1301                        let output = factory_empty(
1302                            &factory.size[1..],
1303                            factory.dtype,
1304                            factory.layout,
1305                            factory.device,
1306                        );
1307                        TensorCell::new(output)
1308                    });
1309                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1310                        .await?;
1311                    output_cell
1312                } else {
1313                    let output_cell = if in_place {
1314                        input_cell.clone()
1315                    } else {
1316                        out_cell.map_or(
1317                            {
1318                                let borrow =
1319                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1320                                let cloned = deep_clone(&borrow);
1321                                Ok(TensorCell::new(cloned))
1322                            },
1323                            |out_cell| -> Result<_, anyhow::Error> {
1324                                let mut out_borrow =
1325                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1326                                let in_borrow =
1327                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1328                                out_borrow.copy_(&in_borrow);
1329                                drop(out_borrow);
1330                                Ok(out_cell)
1331                            },
1332                        )?
1333                    };
1334
1335                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1336                    output_cell
1337                }
1338            }
1339        };
1340
1341        let pyobj = Self::tensor_to_pyobject(output_cell);
1342        self.env.insert(result, Ok(pyobj));
1343        Ok(())
1344    }
1345
1346    async fn send_tensor(
1347        &mut self,
1348        cx: &Context<Self>,
1349        result: Ref,
1350        from_rank: Option<usize>,
1351        to_rank: Option<usize>,
1352        tensor: Ref,
1353        factory: Factory,
1354        comm: Arc<ActorHandle<NcclCommActor>>,
1355    ) -> Result<()> {
1356        if let Some((recording, _)) = self.get_defining_recording() {
1357            recording.messages.push(StreamMessage::SendTensor {
1358                result,
1359                from_rank,
1360                to_rank,
1361                tensor,
1362                factory,
1363                comm,
1364            });
1365            return Ok(());
1366        }
1367
1368        if to_rank.is_none() && from_rank.is_none() {
1369            bail!("tried to send tensor without a to/from rank");
1370        }
1371
1372        // Value is local, so we do not have to actually send it.
1373        if from_rank == to_rank {
1374            let input_cell: &std::result::Result<Py<PyAny>, Arc<SeqError>> = self
1375                .env
1376                .get(&tensor)
1377                .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1378            let output_cell: Result<Py<PyAny>, Arc<SeqError>> = match input_cell {
1379                Ok(pyobj) => {
1380                    Python::attach(|py| -> Result<Py<PyAny>, Arc<SeqError>> {
1381                        let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1382                        // We create a defensive copy here to prevent mutations on
1383                        // the input tensor from affecting output tensor.
1384                        // Should we copy if input ref == output ref?
1385                        // Should we support copy-on-write to avoid unnecessary copy?
1386                        let borrow = input_tensor.try_borrow().unwrap();
1387                        let cloned = deep_clone(&borrow);
1388                        let cloned_cell = TensorCell::new(cloned);
1389                        Ok(Self::tensor_to_pyobject(cloned_cell))
1390                    })
1391                }
1392                Err(err) => Err(err.clone()),
1393            };
1394            self.env.insert(result, output_cell);
1395            return Ok(());
1396        }
1397
1398        let mut messages = Vec::new();
1399
1400        if let Some(to_rank) = to_rank {
1401            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1402            messages.push(CommMessage::Send(
1403                input_cell,
1404                to_rank.try_into().unwrap(),
1405                self.cuda_stream()
1406                    .expect("tried to send_tensor on non-cuda stream")
1407                    .clone(),
1408                cx.open_once_port().0,
1409            ));
1410        }
1411
1412        if let Some(from_rank) = from_rank {
1413            let output_cell = TensorCell::new(factory_empty(
1414                &factory.size,
1415                factory.dtype,
1416                factory.layout,
1417                factory.device,
1418            ));
1419            messages.push(CommMessage::Recv(
1420                output_cell.clone(),
1421                from_rank.try_into().unwrap(),
1422                self.cuda_stream()
1423                    .expect("tried to send_tensor on non-cuda stream")
1424                    .clone(),
1425                cx.open_once_port().0,
1426            ));
1427            let pyobj = Self::tensor_to_pyobject(output_cell);
1428            self.env.insert(result, Ok(pyobj));
1429        }
1430
1431        comm.group(
1432            cx,
1433            messages,
1434            self.cuda_stream()
1435                .expect("tried to send_tensor on non-cuda stream")
1436                .clone(),
1437        )
1438        .await?;
1439        Ok(())
1440    }
1441
1442    async fn send_value(
1443        &mut self,
1444        cx: &Context<Self>,
1445        seq: Seq,
1446        worker_actor_id: reference::ActorId,
1447        mutates: Vec<Ref>,
1448        function: Option<ResolvableFunction>,
1449        args_kwargs: ArgsKwargs,
1450        device_meshes: HashMap<Ref, DeviceMesh>,
1451    ) -> Result<()> {
1452        if self.respond_with_python_message {
1453            return self
1454                .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1455                .await;
1456        }
1457
1458        let result = if let Some(function) = function {
1459            // If a function was provided, use that to resolve the value.
1460            tokio::task::block_in_place(|| {
1461                self.call_python_fn_pytree(
1462                    cx,
1463                    function,
1464                    args_kwargs,
1465                    &mutates,
1466                    device_meshes,
1467                    HashMap::new(),
1468                )
1469            })
1470        } else {
1471            // If there's no function provided, there should be exactly one arg
1472            // and no kwargs.
1473            Python::attach(|py| {
1474                let (args, kwargs) = args_kwargs
1475                    .to_python(py)
1476                    .map_err(|e| CallFunctionError::Error(e.into()))?;
1477                match (args.len(), kwargs.len()) {
1478                    (1, 0) => {
1479                        let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1480                        arg.extract::<PyTree<Py<PyAny>>>()
1481                            .map_err(SerializablePyErr::from_fn(py))?
1482                            .try_into_map(|obj| {
1483                                let bound_obj = obj.bind(py);
1484                                if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1485                                    self.ref_to_pyobject(&ref_)
1486                                } else {
1487                                    Ok(obj)
1488                                }
1489                            })
1490                    }
1491                    _ => Err(CallFunctionError::TooManyArgsForValue(
1492                        format!("args with {} elements", args.len()),
1493                        format!("kwargs with {} elements", kwargs.len()),
1494                    )),
1495                }
1496            })
1497        };
1498
1499        let value = match result {
1500            Ok(pyobject) => Ok(pyobject),
1501            Err(err) => {
1502                let err = self.report_seq_error(cx, seq, err).await?;
1503                for ref_ in mutates {
1504                    self.env.insert(ref_, Err(err.clone()));
1505                }
1506                Err(WorkerError {
1507                    backtrace: format!("{:?}", err),
1508                    worker_actor_id,
1509                })
1510            }
1511        };
1512
1513        // Actually send the value.
1514        // NOTE: respond_with_python_message is always true, so serialization is not needed
1515        // The controller will receive the value through send_value_python_message instead
1516        let result = match value {
1517            Ok(_value) => {
1518                // This code path is never executed since respond_with_python_message is true
1519                unreachable!(
1520                    "send_value should return early when respond_with_python_message is true"
1521                )
1522            }
1523            Err(e) => Err(e),
1524        };
1525        self.controller_actor.fetch_result(cx, seq, result).await?;
1526
1527        Ok(())
1528    }
1529
1530    async fn send_result_of_actor_call(
1531        &mut self,
1532        cx: &Context<Self>,
1533        params: ActorCallParams,
1534    ) -> anyhow::Result<()> {
1535        let seq = params.seq;
1536        let mutates = params.mutates.clone();
1537        self.try_define(cx, seq, vec![], &mutates, async |self| {
1538            let value = self.call_actor(cx, params).await?;
1539            let result =
1540                Python::attach(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1541            let result = wirevalue::Any::serialize(&result).unwrap();
1542            self.controller_actor
1543                .fetch_result(cx, seq, Ok(result))
1544                .await?;
1545            Ok(vec![])
1546        })
1547        .await
1548    }
1549
1550    async fn call_actor_method(
1551        &mut self,
1552        cx: &Context<Self>,
1553        params: ActorMethodParams,
1554    ) -> anyhow::Result<()> {
1555        let seq = params.call.seq;
1556        let mutates = params.call.mutates.clone();
1557        self.try_define(cx, seq, params.results, &mutates, async |self| {
1558            let result = self.call_actor(cx, params.call).await?;
1559            let result = Python::attach(|py| {
1560                PyTree::<Py<PyAny>>::extract_bound(&result.into_bound(py))
1561                    .map_err(SerializablePyErr::from_fn(py))
1562            })?;
1563            Ok(result.into_leaves())
1564        })
1565        .await
1566    }
1567
1568    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1569        if self.active_recording.is_some() {
1570            bail!("different recording already active");
1571        }
1572        match self.recordings.entry(recording) {
1573            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1574            Entry::Vacant(entry) => entry.insert(Recording::new()),
1575        };
1576        self.active_recording = Some(RecordingState::Defining {
1577            recording,
1578            defined_borrows: HashSet::new(),
1579        });
1580        Ok(())
1581    }
1582
1583    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1584        match self.active_recording {
1585            Some(RecordingState::Defining {
1586                recording: active_recording,
1587                ref defined_borrows,
1588            }) if active_recording == recording => {
1589                ensure!(
1590                    defined_borrows.is_empty(),
1591                    "all borrows created within recording must be dropped within recording"
1592                );
1593                self.active_recording = None;
1594            }
1595            _ => bail!("cannot finalize recording that isn't active"),
1596        }
1597        Ok(())
1598    }
1599
1600    async fn recording_formal(
1601        &mut self,
1602        _cx: &Context<Self>,
1603        result: Ref,
1604        argument_index: usize,
1605    ) -> Result<()> {
1606        match self.get_defining_recording() {
1607            Some((recording, _)) => {
1608                recording.messages.push(StreamMessage::RecordingFormal {
1609                    result,
1610                    argument_index,
1611                });
1612            }
1613            None => bail!("recording_formal called outside of recording"),
1614        };
1615        Ok(())
1616    }
1617
1618    async fn recording_result(
1619        &mut self,
1620        _cx: &Context<Self>,
1621        result: Ref,
1622        output_index: usize,
1623    ) -> Result<()> {
1624        match self.get_defining_recording() {
1625            Some((recording, _)) => {
1626                recording.messages.push(StreamMessage::RecordingResult {
1627                    result,
1628                    output_index,
1629                });
1630            }
1631            None => bail!("recording_result called outside of recording"),
1632        };
1633        Ok(())
1634    }
1635
1636    async fn call_recording(
1637        &mut self,
1638        cx: &Context<Self>,
1639        seq: Seq,
1640        recording: Ref,
1641        results: Vec<Ref>,
1642        actuals: Vec<Ref>,
1643    ) -> Result<()> {
1644        if self.active_recording.is_some() {
1645            bail!("cannot call recording while another recording is active");
1646        }
1647
1648        let messages = match self.recordings.get(&recording) {
1649            Some(recording) => recording
1650                .messages
1651                .iter()
1652                .map(|message| message.clone_for_recording())
1653                .collect::<Vec<_>>(),
1654            None => bail!("recording {:?} not found", recording),
1655        };
1656
1657        self.active_recording = Some(RecordingState::Running);
1658
1659        // Global error for all messages in the recording. The first time a message
1660        // fails in the recording, we set the error. We then need to propagate this
1661        // error to all of the refs mutated by the entire recording, as well as the
1662        // result refs.
1663        let mut error: Option<Arc<SeqError>> = None;
1664        // The set of all refs defined by this recording (excluding "results"),
1665        // which we need to ensure are deleted when the recording is done executing.
1666        let mut all_defined_refs = HashSet::new();
1667        // The set of all refs mutated by this recording. If there is an error with
1668        // any message, all of these refs need to have the correct error set.
1669        let mut all_mutated_refs = HashSet::new();
1670        // Map from the result ref of a RecordingFormal message to the associated
1671        // actual ref from "actuals". We need to track this in order to properly
1672        // handle recordings that mutate refs contained in "actuals" -- every
1673        // message in the recording that interacts with the recording inputs will
1674        // interact with the formal ref rather than the actual ref.
1675        let mut formal_to_actual_refs = HashMap::new();
1676        // clear any pre-existing error messages before recording started
1677        self.last_seq_error = None;
1678        for message in messages.into_iter() {
1679            let defined_refs = message.get_defined_refs();
1680            all_defined_refs.extend(defined_refs.clone());
1681
1682            let mutated_refs_with_formals = message.get_mutated_refs();
1683            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1684                match formal_to_actual_refs.get(ref_) {
1685                    Some(actual_ref) => Some(*actual_ref),
1686                    None => {
1687                        if all_defined_refs.contains(ref_) {
1688                            None
1689                        } else {
1690                            Some(*ref_)
1691                        }
1692                    }
1693                }
1694            }));
1695
1696            match message {
1697                StreamMessage::RecordingFormal {
1698                    result: formal_ref,
1699                    argument_index,
1700                } => match actuals.get(argument_index) {
1701                    None => bail!("recording_formal called with too few arguments"),
1702                    Some(actual_ref) => {
1703                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1704                        self.define_ref(formal_ref, *actual_ref)?;
1705                    }
1706                },
1707                StreamMessage::RecordingResult {
1708                    result: result_ref,
1709                    output_index,
1710                } => match results.get(output_index) {
1711                    None => bail!("recording_result called with too few results"),
1712                    Some(actual_result_ref) => {
1713                        self.define_ref(*actual_result_ref, result_ref)?;
1714                    }
1715                },
1716                StreamMessage::DeleteRefs(ref refs) => {
1717                    for ref_ in refs {
1718                        all_defined_refs.remove(ref_);
1719                    }
1720                    StreamMessageHandler::handle(self, cx, message).await?;
1721                }
1722                StreamMessage::CallFunction { .. } if error.is_some() => {
1723                    // CallFunction is expensive. If the recording already failed, then
1724                    // just update the necessary refs with the error. Most of the other
1725                    // message types need to run regardless because there are other actors
1726                    // that expect the call to happen (e.g., all of the borrow messages,
1727                    // pipe send/recv, send_tensor, reduce, etc.).
1728                    let error = error.clone().unwrap();
1729                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1730                        self.env.insert(*ref_, Err(error.clone()));
1731                    }
1732                }
1733                StreamMessage::BorrowLastUse { ref result, .. } => {
1734                    all_defined_refs.remove(result);
1735                    StreamMessageHandler::handle(self, cx, message).await?;
1736                }
1737                StreamMessage::Reduce {
1738                    local_tensor,
1739                    ref out,
1740                    ..
1741                } => {
1742                    // Reduce doesn't propagate errors to the result ref, so we need
1743                    // to check for existing errors on the input tensors and set the
1744                    // recording's error if necessary.
1745                    if error.is_none() {
1746                        let inputs_to_check = [Some(local_tensor), out.clone()]
1747                            .iter()
1748                            .filter_map(|r| *r)
1749                            .collect::<Vec<_>>();
1750                        error = self.get_first_error(inputs_to_check.as_slice())?;
1751                    }
1752                    StreamMessageHandler::handle(self, cx, message).await?;
1753                }
1754                StreamMessage::SendTensor {
1755                    ref tensor,
1756                    ref to_rank,
1757                    ..
1758                } => {
1759                    // If this rank is sending a tensor (e.g., to_rank has a value),
1760                    // we need to check for existing errors on the input tensor, because
1761                    // the error is only propagated to the result ref when this rank
1762                    // is also receiving a tensor.
1763                    if to_rank.is_some() && error.is_none() {
1764                        error = self.get_first_error(&[*tensor])?;
1765                    }
1766                    StreamMessageHandler::handle(self, cx, message).await?;
1767                }
1768                _ => {
1769                    StreamMessageHandler::handle(self, cx, message).await?;
1770                }
1771            };
1772
1773            // It's not entirely trivial to determine whether a message "failed" or not.
1774            // For example, the CallFunction message can return Ok(..) if there is an error
1775            // in the underlying function call. But in that case, we would still want to
1776            // consider the recording call as "failed". Unlike in python, where we can just
1777            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1778            // we clear before handling each recording message. If we see it is set, the
1779            // we know the recording has faild.
1780            match (&error, self.last_seq_error.take()) {
1781                (None, Some(seq_err)) => {
1782                    // Report failure to the controller.
1783                    self.controller_actor
1784                        .remote_function_failed(
1785                            cx,
1786                            seq,
1787                            WorkerError {
1788                                backtrace: format!("recording failed: {}", &seq_err),
1789                                worker_actor_id: cx.self_id().clone(),
1790                            },
1791                        )
1792                        .await?;
1793                    error = Some(seq_err)
1794                }
1795                _ => {}
1796            }
1797            // Continue processing the remaining stream messages regardless of error.
1798            // We need to do this partially for error propagation, but also because
1799            // certain messages (like borrows and reductions) need to run regardless
1800            // in order to prevent deadlocks.
1801        }
1802
1803        // Delete the formal refs and some subset of the RecordingResult refs. The
1804        // controller should have generated DeleteRefs messages for all other refs
1805        // defined by the recording.
1806        StreamMessageHandler::handle(
1807            self,
1808            cx,
1809            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1810        )
1811        .await?;
1812
1813        // Any refs mutated by the recording and all results should have the same error
1814        // (the original error that caused the recording to fail).
1815        if error.is_some() {
1816            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1817                self.env.insert(*ref_, Err(error.clone().unwrap()));
1818            }
1819        }
1820
1821        self.active_recording = None;
1822        Ok(())
1823    }
1824
1825    async fn set_ref_unit_tests_only(
1826        &mut self,
1827        _cx: &Context<Self>,
1828        reference: Ref,
1829        value: WireValue,
1830    ) -> Result<()> {
1831        let pyobj =
1832            Python::attach(|py| -> PyResult<Py<PyAny>> { Ok(value.into_pyobject(py)?.unbind()) })?;
1833        self.env.insert(reference, Ok(pyobj));
1834        Ok(())
1835    }
1836
1837    async fn set_tensor_ref_unit_tests_only(
1838        &mut self,
1839        _cx: &Context<Self>,
1840        reference: Ref,
1841        tensor_result: TensorCellResult,
1842    ) -> Result<()> {
1843        match tensor_result {
1844            Ok(tensor_cell) => {
1845                let pyobj = Self::tensor_to_pyobject(tensor_cell);
1846                self.env.insert(reference, Ok(pyobj));
1847            }
1848            Err(err) => {
1849                self.env.insert(reference, Err(err));
1850            }
1851        }
1852        Ok(())
1853    }
1854
1855    async fn get_ref_unit_tests_only(
1856        &mut self,
1857        _cx: &Context<Self>,
1858        reference: Ref,
1859    ) -> Result<Option<Result<WireValue, String>>> {
1860        use pyo3::types::PyBool;
1861        use pyo3::types::PyFloat;
1862        use pyo3::types::PyInt;
1863        use pyo3::types::PyList;
1864        use pyo3::types::PyNone;
1865        use pyo3::types::PyString;
1866        /// For testing only, doesn't support Tensor or TensorList.
1867        fn pyobject_to_wire(
1868            value: Result<Py<PyAny>, Arc<SeqError>>,
1869        ) -> Result<WireValue, Arc<SeqError>> {
1870            let pyobj = value?;
1871            Python::attach(|py| {
1872                let bound = pyobj.bind(py);
1873                // Check bool before int since Python's bool is a subclass of int
1874                if bound.is_instance_of::<PyBool>() {
1875                    Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1876                } else if bound.is_instance_of::<PyInt>() {
1877                    Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1878                } else if bound.is_instance_of::<PyList>() {
1879                    if let Ok(val) = bound.extract::<Vec<i64>>() {
1880                        Ok(WireValue::IntList(val))
1881                    } else {
1882                        Ok(WireValue::String(format!(
1883                            "unsupported list type: {:?}",
1884                            bound
1885                        )))
1886                    }
1887                } else if bound.is_instance_of::<PyFloat>() {
1888                    Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1889                } else if bound.is_instance_of::<PyString>() {
1890                    Ok(WireValue::String(bound.extract::<String>().unwrap()))
1891                } else if bound.is_instance_of::<PyNone>() {
1892                    Ok(WireValue::None(()))
1893                } else {
1894                    Ok(WireValue::String(format!(
1895                        "unsupported pyobject type: {:?}",
1896                        bound
1897                    )))
1898                }
1899            })
1900        }
1901        Ok(self.env.get(&reference).map(|pyobj| {
1902            pyobject_to_wire(Python::attach(|_py| pyobj.clone())).map_err(|err| err.to_string())
1903        }))
1904    }
1905
1906    async fn get_tensor_ref_unit_tests_only(
1907        &mut self,
1908        _cx: &Context<Self>,
1909        reference: Ref,
1910    ) -> Result<Option<TensorCellResult>> {
1911        match self.env.get(&reference) {
1912            Some(Ok(pyobj)) => Python::attach(|py| match Self::pyobject_to_tensor(py, pyobj) {
1913                Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1914                Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1915            }),
1916            Some(Err(err)) => Ok(Some(Err(err.clone()))),
1917            None => Ok(None),
1918        }
1919    }
1920}
1921
1922#[cfg(test)]
1923mod tests {
1924    use hyperactor::actor::ActorStatus;
1925    use hyperactor::context;
1926    use hyperactor::supervision::ActorSupervisionEvent;
1927    use monarch_messages::controller::ControllerMessage;
1928    use monarch_messages::worker::StreamCreationMode;
1929    use monarch_types::PickledPyObject;
1930    use pyo3::IntoPyObjectExt;
1931    use timed_test::async_timed_test;
1932    use tokio::sync::watch;
1933    use torch_sys_cuda::nccl::UniqueId;
1934    use torch_sys2::factory_float_tensor;
1935    use torch_sys2::testing::allclose;
1936
1937    use super::*;
1938    use crate::comm::CommParams;
1939    use crate::test_util;
1940
1941    #[allow(dead_code)]
1942    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1943        Arc::new(SeqError {
1944            seq: 0.into(),
1945            error: err,
1946        })
1947    }
1948
1949    struct TestSetup {
1950        proc: Proc,
1951        stream_actor: ActorHandle<StreamActor>,
1952        client: Instance<()>,
1953        // Unused, but necessary, because proc needs a supervision
1954        // port -- otherwise an actor failure will cause a crash.
1955        #[allow(dead_code)]
1956        supervision_rx: PortReceiver<ActorSupervisionEvent>,
1957        #[allow(dead_code)]
1958        controller_rx: PortReceiver<ControllerMessage>,
1959        #[allow(dead_code)]
1960        controller_actor: reference::ActorRef<ControllerActor>,
1961        next_ref: Ref,
1962    }
1963
1964    impl TestSetup {
1965        async fn new() -> Result<Self> {
1966            Self::new_with_world_size(1).await
1967        }
1968
1969        async fn new_with_world_size(world_size: usize) -> Result<Self> {
1970            test_util::test_setup()?;
1971
1972            let proc = Proc::local();
1973            let (_, controller_actor, controller_rx) =
1974                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1975            let (client, _handle) = proc.instance("client")?;
1976            let (supervision_tx, supervision_rx) = client.open_port();
1977            proc.set_supervision_coordinator(supervision_tx)?;
1978            let stream_actor = proc.spawn(
1979                "stream",
1980                StreamActor::new(StreamParams {
1981                    world_size,
1982                    rank: 0,
1983                    creation_mode: StreamCreationMode::UseDefaultStream,
1984                    id: 0.into(),
1985                    device: Some(CudaDevice::new(0.into())),
1986                    controller_actor: controller_actor.clone(),
1987                    respond_with_python_message: false,
1988                }),
1989            )?;
1990
1991            Ok(Self {
1992                proc,
1993                stream_actor,
1994                client,
1995                supervision_rx,
1996                controller_rx,
1997                controller_actor,
1998                next_ref: 0.into(),
1999            })
2000        }
2001
2002        fn next_ref(&mut self) -> Ref {
2003            let ref_ = self.next_ref;
2004            self.next_ref = Ref {
2005                id: self.next_ref.id + 1,
2006            };
2007            ref_
2008        }
2009
2010        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2011            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2012            self.stream_actor
2013                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2014                .await
2015        }
2016
2017        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2018            let actual = self
2019                .stream_actor
2020                .get_tensor_ref_unit_tests_only(&self.client, reference)
2021                .await
2022                .unwrap()
2023                .unwrap()
2024                .unwrap();
2025
2026            // rustfmt-ignore
2027            allclose(
2028                &factory_float_tensor(data, "cpu".parse().unwrap()),
2029                &actual.borrow(),
2030            )
2031            .unwrap()
2032        }
2033
2034        #[allow(dead_code)]
2035        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2036            let result_error = self
2037                .stream_actor
2038                .get_tensor_ref_unit_tests_only(&self.client, reference)
2039                .await
2040                .unwrap()
2041                .unwrap()
2042                .unwrap_err();
2043
2044            assert!(Arc::ptr_eq(&result_error, &error));
2045        }
2046    }
2047
2048    async fn assert_actor_failed_with_msg(
2049        status_rx: &mut watch::Receiver<ActorStatus>,
2050        expected_msg: String,
2051    ) {
2052        status_rx
2053            .wait_for(|s| matches!(s, ActorStatus::Failed(_)))
2054            .await
2055            .unwrap();
2056        let status = status_rx.borrow().clone();
2057        if let ActorStatus::Failed(msg) = status {
2058            assert!(msg.to_string().contains(&expected_msg));
2059        } else {
2060            panic!("expected ActorStatus::Failed, got {:?}", status);
2061        }
2062    }
2063
2064    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2065        for ref_ in refs {
2066            assert!(
2067                test_setup
2068                    .stream_actor
2069                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2070                    .await
2071                    .unwrap()
2072                    .is_none()
2073            );
2074        }
2075    }
2076
2077    #[allow(dead_code)]
2078    async fn fetch_result(
2079        cx: &impl context::Actor,
2080        stream_actor: ActorHandle<StreamActor>,
2081        seq: Seq,
2082        reference: Ref,
2083    ) {
2084        let ref_to_send = Python::attach(|py| {
2085            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2086        });
2087
2088        stream_actor
2089            .send_value(
2090                cx,
2091                seq,
2092                stream_actor.actor_id().clone(),
2093                Vec::new(),
2094                None,
2095                ArgsKwargs::from_wire_values(
2096                    vec![WireValue::PyObject(ref_to_send)],
2097                    HashMap::new(),
2098                )
2099                .unwrap(),
2100                HashMap::new(),
2101            )
2102            .await
2103            .unwrap()
2104    }
2105
2106    #[allow(dead_code)]
2107    async fn check_fetch_result_error(
2108        cx: &impl context::Actor,
2109        stream_actor: ActorHandle<StreamActor>,
2110        seq: Seq,
2111        reference: Ref,
2112        controller_rx: &mut PortReceiver<ControllerMessage>,
2113        expected_backtrace: &str,
2114    ) {
2115        fetch_result(cx, stream_actor, seq, reference).await;
2116
2117        let controller_msg = controller_rx.recv().await.unwrap();
2118        match controller_msg {
2119            ControllerMessage::FetchResult {
2120                seq: actual_seq,
2121                value: Err(err),
2122            } => {
2123                assert_eq!(actual_seq, seq);
2124                assert!(
2125                    err.backtrace.contains(expected_backtrace),
2126                    "backtrace did not contain {:?}: {:?}",
2127                    expected_backtrace,
2128                    err.backtrace
2129                );
2130            }
2131            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2132        };
2133    }
2134
2135    #[allow(dead_code)]
2136    async fn check_fetch_result_value(
2137        cx: &impl context::Actor,
2138        stream_actor: ActorHandle<StreamActor>,
2139        seq: Seq,
2140        reference: Ref,
2141        controller_rx: &mut PortReceiver<ControllerMessage>,
2142    ) {
2143        fetch_result(cx, stream_actor, seq, reference).await;
2144
2145        let controller_msg = controller_rx.recv().await.unwrap();
2146        match controller_msg {
2147            ControllerMessage::FetchResult {
2148                value: Ok(_),
2149                seq: actual_seq,
2150            } => assert_eq!(seq, actual_seq),
2151            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2152        };
2153    }
2154
2155    #[async_timed_test(timeout_secs = 60)]
2156    async fn test_define_recording_other_recording_active() -> Result<()> {
2157        let test_setup = TestSetup::new().await?;
2158        test_setup
2159            .stream_actor
2160            .define_recording(&test_setup.client, 0.into())
2161            .await?;
2162        test_setup
2163            .stream_actor
2164            .define_recording(&test_setup.client, 1.into())
2165            .await?;
2166        assert_actor_failed_with_msg(
2167            &mut test_setup.stream_actor.status(),
2168            "different recording already active".into(),
2169        )
2170        .await;
2171        Ok(())
2172    }
2173
2174    #[async_timed_test(timeout_secs = 60)]
2175    async fn test_define_recording_already_defined() -> Result<()> {
2176        let test_setup = TestSetup::new().await?;
2177        test_setup
2178            .stream_actor
2179            .define_recording(&test_setup.client, 0.into())
2180            .await?;
2181        test_setup
2182            .stream_actor
2183            .finalize_recording(&test_setup.client, 0.into())
2184            .await?;
2185        test_setup
2186            .stream_actor
2187            .define_recording(&test_setup.client, 0.into())
2188            .await?;
2189        assert_actor_failed_with_msg(
2190            &mut test_setup.stream_actor.status(),
2191            "already defined".into(),
2192        )
2193        .await;
2194        Ok(())
2195    }
2196
2197    #[async_timed_test(timeout_secs = 60)]
2198    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2199        let test_setup = TestSetup::new().await?;
2200        test_setup
2201            .stream_actor
2202            .define_recording(&test_setup.client, 0.into())
2203            .await?;
2204        test_setup
2205            .stream_actor
2206            .finalize_recording(&test_setup.client, 1.into())
2207            .await?;
2208        assert_actor_failed_with_msg(
2209            &mut test_setup.stream_actor.status(),
2210            "cannot finalize recording that isn't active".into(),
2211        )
2212        .await;
2213        Ok(())
2214    }
2215
2216    #[async_timed_test(timeout_secs = 60)]
2217    async fn test_recording_formal_outside_recording() -> Result<()> {
2218        let test_setup = TestSetup::new().await?;
2219        test_setup
2220            .stream_actor
2221            .recording_formal(&test_setup.client, 0.into(), 0)
2222            .await?;
2223        assert_actor_failed_with_msg(
2224            &mut test_setup.stream_actor.status(),
2225            "recording_formal called outside of recording".into(),
2226        )
2227        .await;
2228        Ok(())
2229    }
2230
2231    #[async_timed_test(timeout_secs = 60)]
2232    async fn test_recording_result_outside_recording() -> Result<()> {
2233        let test_setup = TestSetup::new().await?;
2234        test_setup
2235            .stream_actor
2236            .recording_result(&test_setup.client, 0.into(), 0)
2237            .await?;
2238        assert_actor_failed_with_msg(
2239            &mut test_setup.stream_actor.status(),
2240            "recording_result called outside of recording".into(),
2241        )
2242        .await;
2243        Ok(())
2244    }
2245
2246    #[async_timed_test(timeout_secs = 60)]
2247    async fn test_call_recording_other_recording_active() -> Result<()> {
2248        let test_setup = TestSetup::new().await?;
2249        test_setup
2250            .stream_actor
2251            .define_recording(&test_setup.client, 0.into())
2252            .await?;
2253        test_setup
2254            .stream_actor
2255            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2256            .await?;
2257        assert_actor_failed_with_msg(
2258            &mut test_setup.stream_actor.status(),
2259            "cannot call recording while another recording is active".into(),
2260        )
2261        .await;
2262        Ok(())
2263    }
2264
2265    #[async_timed_test(timeout_secs = 60)]
2266    async fn test_call_recording_not_found() -> Result<()> {
2267        let test_setup = TestSetup::new().await?;
2268        test_setup
2269            .stream_actor
2270            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2271            .await?;
2272        assert_actor_failed_with_msg(&mut test_setup.stream_actor.status(), "not found".into())
2273            .await;
2274        Ok(())
2275    }
2276
2277    #[async_timed_test(timeout_secs = 60)]
2278    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2279        let test_setup = TestSetup::new().await?;
2280
2281        test_setup
2282            .stream_actor
2283            .define_recording(&test_setup.client, 0.into())
2284            .await?;
2285
2286        test_setup
2287            .stream_actor
2288            .recording_formal(&test_setup.client, 1.into(), 0)
2289            .await?;
2290
2291        test_setup
2292            .stream_actor
2293            .finalize_recording(&test_setup.client, 0.into())
2294            .await?;
2295
2296        test_setup
2297            .stream_actor
2298            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2299            .await?;
2300
2301        assert_actor_failed_with_msg(
2302            &mut test_setup.stream_actor.status(),
2303            "recording_formal called with too few arguments".into(),
2304        )
2305        .await;
2306        Ok(())
2307    }
2308
2309    #[async_timed_test(timeout_secs = 60)]
2310    async fn test_recording_result_too_few_results() -> Result<()> {
2311        let test_setup = TestSetup::new().await?;
2312
2313        test_setup
2314            .stream_actor
2315            .define_recording(&test_setup.client, 0.into())
2316            .await?;
2317
2318        test_setup
2319            .stream_actor
2320            .recording_result(&test_setup.client, 1.into(), 0)
2321            .await?;
2322
2323        test_setup
2324            .stream_actor
2325            .finalize_recording(&test_setup.client, 0.into())
2326            .await?;
2327
2328        test_setup
2329            .stream_actor
2330            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2331            .await?;
2332
2333        assert_actor_failed_with_msg(
2334            &mut test_setup.stream_actor.status(),
2335            "recording_result called with too few results".into(),
2336        )
2337        .await;
2338        Ok(())
2339    }
2340
2341    #[async_timed_test(timeout_secs = 60)]
2342    async fn test_basic_call_recording() -> Result<()> {
2343        let mut test_setup = TestSetup::new().await?;
2344
2345        // Define a recording equivalent to:
2346        // def f(x, y):
2347        //   return y, x
2348        test_setup
2349            .stream_actor
2350            .define_recording(&test_setup.client, 0.into())
2351            .await?;
2352
2353        let formal0_ref = 1.into();
2354        let formal0_index = 1;
2355        test_setup
2356            .stream_actor
2357            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2358            .await?;
2359
2360        let formal1_ref = 2.into();
2361        let formal1_index = 0;
2362        test_setup
2363            .stream_actor
2364            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2365            .await?;
2366
2367        let result0_ref = formal0_ref;
2368        let result0_index = 0;
2369        test_setup
2370            .stream_actor
2371            .recording_result(&test_setup.client, result0_ref, result0_index)
2372            .await?;
2373
2374        let result1_ref = formal1_ref;
2375        let result1_index = 1;
2376        test_setup
2377            .stream_actor
2378            .recording_result(&test_setup.client, result1_ref, result1_index)
2379            .await?;
2380
2381        test_setup
2382            .stream_actor
2383            .finalize_recording(&test_setup.client, 0.into())
2384            .await?;
2385
2386        let actual0_ref = 3.into();
2387        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2388
2389        let actual1_ref = 4.into();
2390        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2391
2392        // Call the recording with valid tensors for the actual inputs,
2393        // and store the results in refs 5 and 6.
2394        let actual_result0_ref = 5.into();
2395        let actual_result1_ref = 6.into();
2396        test_setup
2397            .stream_actor
2398            .call_recording(
2399                &test_setup.client,
2400                0.into(),
2401                0.into(),
2402                vec![actual_result0_ref, actual_result1_ref],
2403                vec![actual0_ref, actual1_ref],
2404            )
2405            .await?;
2406
2407        // Ensure the results are correct.
2408        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2409        assert!(
2410            test_setup
2411                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2412                .await
2413        );
2414
2415        // Ensure the temporary refs associated with the formals/results have
2416        // been deleted.
2417        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2418        Ok(())
2419    }
2420
2421    #[async_timed_test(timeout_secs = 60)]
2422    async fn test_request_status_in_recording() -> Result<()> {
2423        let test_setup = TestSetup::new().await?;
2424        test_setup
2425            .stream_actor
2426            .define_recording(&test_setup.client, 0.into())
2427            .await?;
2428        test_setup
2429            .stream_actor
2430            .request_status(&test_setup.client)
2431            .await
2432            .expect_err("request_status should have failed");
2433        assert_actor_failed_with_msg(
2434            &mut test_setup.stream_actor.status(),
2435            "request_status not allowed in recording".into(),
2436        )
2437        .await;
2438        Ok(())
2439    }
2440
2441    #[async_timed_test(timeout_secs = 60)]
2442    async fn test_init_comm_in_recording() -> Result<()> {
2443        let test_setup = TestSetup::new().await?;
2444        test_setup
2445            .stream_actor
2446            .define_recording(&test_setup.client, 0.into())
2447            .await?;
2448
2449        let dummy_comm = test_setup.proc.spawn(
2450            "comm",
2451            NcclCommActor::new(CommParams::New {
2452                device: CudaDevice::new(0.into()),
2453                unique_id: UniqueId::new()?,
2454                world_size: 1,
2455                rank: 0,
2456            })
2457            .await
2458            .unwrap(),
2459        )?;
2460
2461        test_setup
2462            .stream_actor
2463            .init_comm(&test_setup.client, dummy_comm)
2464            .await?;
2465        assert_actor_failed_with_msg(
2466            &mut test_setup.stream_actor.status(),
2467            "init_comm not allowed in recording".into(),
2468        )
2469        .await;
2470        Ok(())
2471    }
2472
2473    #[async_timed_test(timeout_secs = 60)]
2474    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2475        let mut test_setup = TestSetup::new().await?;
2476        test_setup
2477            .stream_actor
2478            .define_recording(&test_setup.client, 0.into())
2479            .await?;
2480
2481        let borrow_id = 1;
2482        let tensor_ref = test_setup.next_ref();
2483        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2484
2485        test_setup
2486            .stream_actor
2487            .borrow_create(
2488                &test_setup.client,
2489                borrow_id,
2490                tensor_ref,
2491                first_use_sender.clone(),
2492            )
2493            .await?;
2494
2495        test_setup
2496            .stream_actor
2497            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2498            .await?;
2499
2500        assert_actor_failed_with_msg(
2501            &mut test_setup.stream_actor.status(),
2502            "duplicate borrow create in recording".into(),
2503        )
2504        .await;
2505
2506        Ok(())
2507    }
2508
2509    #[async_timed_test(timeout_secs = 60)]
2510    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2511        let test_setup = TestSetup::new().await?;
2512        test_setup
2513            .stream_actor
2514            .define_recording(&test_setup.client, 0.into())
2515            .await?;
2516
2517        let borrow_id = 1;
2518        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2519
2520        test_setup
2521            .stream_actor
2522            .borrow_drop(
2523                &test_setup.client,
2524                borrow_id,
2525                Arc::new(Mutex::new(last_use_receiver)),
2526            )
2527            .await?;
2528
2529        assert_actor_failed_with_msg(
2530            &mut test_setup.stream_actor.status(),
2531            "borrow drop for borrow not defined in recording".into(),
2532        )
2533        .await;
2534
2535        Ok(())
2536    }
2537
2538    #[async_timed_test(timeout_secs = 60)]
2539    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2540        let mut test_setup = TestSetup::new().await?;
2541        test_setup
2542            .stream_actor
2543            .define_recording(&test_setup.client, 0.into())
2544            .await?;
2545
2546        let borrow_id = 1;
2547        let tensor_ref = test_setup.next_ref();
2548        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2549
2550        test_setup
2551            .stream_actor
2552            .borrow_create(
2553                &test_setup.client,
2554                borrow_id,
2555                tensor_ref,
2556                first_use_sender.clone(),
2557            )
2558            .await?;
2559
2560        // Attempt to finalize the recording without dropping the borrow
2561        test_setup
2562            .stream_actor
2563            .finalize_recording(&test_setup.client, 0.into())
2564            .await?;
2565
2566        assert_actor_failed_with_msg(
2567            &mut test_setup.stream_actor.status(),
2568            "all borrows created within recording must be dropped within recording".into(),
2569        )
2570        .await;
2571
2572        Ok(())
2573    }
2574}