monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::OncePortHandle;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::proc::Proc;
38use monarch_hyperactor::actor::PythonMessage;
39use monarch_hyperactor::actor::PythonMessageKind;
40use monarch_hyperactor::buffers::FrozenBuffer;
41use monarch_hyperactor::local_state_broker::BrokerId;
42use monarch_hyperactor::local_state_broker::LocalState;
43use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::CallFunctionError;
50use monarch_messages::worker::CallFunctionParams;
51use monarch_messages::worker::SeqError;
52use monarch_messages::worker::StreamRef;
53use monarch_types::PyTree;
54use monarch_types::SerializablePyErr;
55use monarch_types::TryIntoPyObjectUnsafe;
56use pyo3::prelude::*;
57use pyo3::types::PyTuple;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys::BorrowType;
62use torch_sys::CudaDevice;
63use torch_sys::MultiBorrow;
64use torch_sys::RValue;
65use torch_sys::TensorCell;
66use torch_sys::deep_clone;
67use torch_sys::factory_empty;
68use torch_sys::factory_zeros;
69use torch_sys_cuda::cuda::Event;
70use torch_sys_cuda::cuda::Stream;
71use tracing_subscriber::fmt::Subscriber;
72
73use crate::ControllerActor;
74use crate::DeviceMesh;
75use crate::Factory;
76use crate::Reduction;
77use crate::Ref;
78use crate::ResolvableFunction;
79use crate::StreamCreationMode;
80use crate::WireValue;
81use crate::comm::CommBackend;
82use crate::comm::CommMessage;
83use crate::comm::CommMessageClient;
84use crate::comm::NcclCommActor;
85
86pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
87
88// These thread locals are accessed by the python runtime for debugging sessions.
89thread_local! {
90    pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
91    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
92    pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
93}
94
95fn pickle_python_result(
96    py: Python<'_>,
97    result: Bound<'_, PyAny>,
98    worker_rank: usize,
99) -> Result<PythonMessage, anyhow::Error> {
100    let pickle = py
101        .import("monarch._src.actor.actor_mesh")
102        .unwrap()
103        .getattr("_pickle")
104        .unwrap();
105    let data: FrozenBuffer = pickle
106        .call1((result,))
107        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
108        .extract()
109        .unwrap();
110    Ok(PythonMessage::new_from_buf(
111        PythonMessageKind::Result {
112            rank: Some(worker_rank),
113        },
114        data.inner,
115    ))
116}
117
118#[derive(Debug)]
119struct Recording {
120    messages: Vec<StreamMessage>,
121}
122
123impl Recording {
124    fn new() -> Self {
125        Self {
126            messages: Vec::new(),
127        }
128    }
129}
130
131#[derive(Debug, PartialEq)]
132enum RecordingState {
133    Defining {
134        recording: Ref,
135        // Set of borrow ids used to track proper borrow usage inside
136        // a recording.
137        defined_borrows: HashSet<u64>,
138    },
139    Running,
140}
141
142/// Messages handled by the stream. Generally these are stream-local versions of
143/// [`crate::WorkerMessage`].
144#[derive(Handler, HandleClient, Debug, Named)]
145#[named(register = false)]
146pub enum StreamMessage {
147    CallFunction(
148        CallFunctionParams,
149        HashMap<Ref, DeviceMesh>,
150        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
151    ),
152
153    BorrowCreate {
154        /// Id for the borrow.
155        borrow: u64,
156        /// Tensor to borrow.
157        tensor: Ref,
158        /// Port for sending the first use CUDA event + borrowed tensor to
159        /// the borrower.
160        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
161    },
162
163    BorrowFirstUse {
164        /// Id for the borrow.
165        borrow: u64,
166        /// Ref for storing the borrowed tensor.
167        result: Ref,
168        /// Port for receiving the first use CUDA event + borrowed tensor from
169        /// the provider stream.
170        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
171    },
172
173    BorrowLastUse {
174        /// Id for the borrow.
175        borrow: u64,
176        /// Ref for the borrowed tensor.
177        result: Ref,
178        /// Port for sending the last use CUDA event and borrowed tensor.
179        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
180    },
181
182    BorrowDrop {
183        borrow: u64,
184        /// Port for receiving the last use CUDA event and borrowed tensor.
185        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
186    },
187
188    DeleteRefs(Vec<Ref>),
189
190    RequestStatus(#[reply] OncePortHandle<()>),
191
192    InitComm(ActorHandle<NcclCommActor>),
193
194    Reduce {
195        comm: Arc<ActorHandle<NcclCommActor>>,
196        dim_size: i64,
197        result: Ref,
198        local_tensor: Ref,
199        factory: Factory,
200        reduction: Reduction,
201        scatter: bool,
202        in_place: bool,
203        out: Option<Ref>,
204    },
205
206    SendTensor {
207        result: Ref,
208        from_rank: Option<usize>,
209        to_rank: Option<usize>,
210        tensor: Ref,
211        factory: Factory,
212        comm: Arc<ActorHandle<NcclCommActor>>,
213    },
214
215    SendValue {
216        seq: Seq,
217        worker_actor_id: ActorId,
218        mutates: Vec<Ref>,
219        function: Option<ResolvableFunction>,
220        args: Vec<WireValue>,
221        kwargs: HashMap<String, WireValue>,
222        device_meshes: HashMap<Ref, DeviceMesh>,
223    },
224
225    DefineRecording {
226        recording: Ref,
227    },
228
229    FinalizeRecording {
230        recording: Ref,
231    },
232
233    CallRecording {
234        seq: Seq,
235        recording: Ref,
236        results: Vec<Ref>,
237        actuals: Vec<Ref>,
238    },
239
240    RecordingFormal {
241        result: Ref,
242        argument_index: usize,
243    },
244
245    RecordingResult {
246        result: Ref,
247        output_index: usize,
248    },
249
250    SetRefUnitTestsOnly(Ref, WireValue),
251
252    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
253
254    GetRefUnitTestsOnly(
255        Ref, // value
256        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
257    ),
258
259    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
260
261    SendResultOfActorCall(ActorCallParams),
262    CallActorMethod(ActorMethodParams),
263}
264
265impl StreamMessage {
266    fn clone_for_recording(&self) -> Self {
267        match self {
268            StreamMessage::RecordingFormal {
269                result,
270                argument_index,
271            } => StreamMessage::RecordingFormal {
272                result: *result,
273                argument_index: *argument_index,
274            },
275            StreamMessage::RecordingResult {
276                result,
277                output_index,
278            } => StreamMessage::RecordingResult {
279                result: *result,
280                output_index: *output_index,
281            },
282            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
283            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
284                StreamMessage::CallFunction(
285                    params.clone(),
286                    device_meshes.clone(),
287                    remote_process_groups.clone(),
288                )
289            }
290            StreamMessage::BorrowCreate {
291                borrow,
292                tensor,
293                first_use_sender,
294            } => StreamMessage::BorrowCreate {
295                borrow: *borrow,
296                tensor: *tensor,
297                first_use_sender: first_use_sender.clone(),
298            },
299            StreamMessage::BorrowFirstUse {
300                borrow,
301                result,
302                first_use_receiver,
303            } => StreamMessage::BorrowFirstUse {
304                borrow: *borrow,
305                result: *result,
306                first_use_receiver: first_use_receiver.clone(),
307            },
308            StreamMessage::BorrowLastUse {
309                borrow,
310                result,
311                last_use_sender,
312            } => StreamMessage::BorrowLastUse {
313                borrow: *borrow,
314                result: *result,
315                last_use_sender: last_use_sender.clone(),
316            },
317            StreamMessage::BorrowDrop {
318                borrow,
319                last_use_receiver,
320            } => StreamMessage::BorrowDrop {
321                borrow: *borrow,
322                last_use_receiver: last_use_receiver.clone(),
323            },
324            StreamMessage::Reduce {
325                comm,
326                dim_size,
327                result,
328                local_tensor,
329                factory,
330                reduction,
331                scatter,
332                in_place,
333                out,
334            } => StreamMessage::Reduce {
335                comm: comm.clone(),
336                dim_size: *dim_size,
337                result: *result,
338                local_tensor: *local_tensor,
339                factory: factory.clone(),
340                reduction: reduction.clone(),
341                scatter: *scatter,
342                in_place: *in_place,
343                out: out.clone(),
344            },
345            StreamMessage::SendTensor {
346                result,
347                from_rank,
348                to_rank,
349                tensor,
350                factory,
351                comm,
352            } => StreamMessage::SendTensor {
353                result: *result,
354                from_rank: *from_rank,
355                to_rank: *to_rank,
356                tensor: *tensor,
357                factory: factory.clone(),
358                comm: comm.clone(),
359            },
360            other => panic!(
361                "StreamMessage variant not supported in recording: {:?}",
362                other
363            ),
364        }
365    }
366
367    // Get the set of refs that this message defines.
368    fn get_defined_refs(&self) -> HashSet<Ref> {
369        match self {
370            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
371            StreamMessage::CallFunction(params, ..) => {
372                params.results.iter().filter_map(|&ref_| ref_).collect()
373            }
374            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
375            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
376            StreamMessage::SendTensor {
377                result, from_rank, ..
378            } => {
379                if from_rank.is_some() {
380                    HashSet::from([*result])
381                } else {
382                    HashSet::new()
383                }
384            }
385            // TODO(slurye): Add SendValue eventually.
386            _ => HashSet::new(),
387        }
388    }
389
390    // Get the set of refs that this message mutates.
391    fn get_mutated_refs(&self) -> HashSet<Ref> {
392        match self {
393            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
394            StreamMessage::Reduce {
395                out,
396                in_place,
397                local_tensor,
398                ..
399            } => {
400                if *in_place {
401                    HashSet::from([*local_tensor])
402                } else if let Some(out) = out {
403                    HashSet::from([*out])
404                } else {
405                    HashSet::new()
406                }
407            }
408            // TODO(slurye): Add SendValue eventually.
409            _ => HashSet::new(),
410        }
411    }
412}
413
414/// A stream represents a linear sequence of execution. Operations on different
415/// streams can execute concurrently.
416///
417/// For CUDA operators, streams will invoke the corresponding stream management
418/// APIs to perform synchronization.
419///
420/// For CPU operators, streams will just execute synchronously on their own OS
421/// thread.
422#[derive(Debug)]
423pub struct StreamActor {
424    world_size: usize,
425    rank: usize,
426    /// Mapping of refs in the controller environment to TensorIndex in this
427    /// stream's local environment.
428    // TODO(agallagher): Use `ValueError` as the error type.
429    env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
430    /// How to create the stream.
431    creation_mode: StreamCreationMode,
432    /// CUDA stream that this actor will enqueue operations on. None if "device"
433    /// is not a CUDA device.
434    /// NOTE: We lazily create the stream, so that we do it from the dedicated
435    /// Stream OS thread as, otherwise, we see deadlocks when done from
436    /// unexpected threads.
437    cuda_stream: OnceLock<Option<Stream>>,
438    /// Device this stream should be scheduled on.
439    device: Option<CudaDevice>,
440    /// Communicator for this stream. Optional as we lazily initialize it.
441    comm: Option<ActorHandle<NcclCommActor>>,
442    /// Actor ref of the controller that created this stream.
443    controller_actor: ActorRef<ControllerActor>,
444    remote_process_groups: HashMap<Ref, PyObject>,
445    recordings: HashMap<Ref, Recording>,
446    active_recording: Option<RecordingState>,
447    respond_with_python_message: bool,
448    last_seq_error: Option<Arc<SeqError>>,
449}
450
451/// Parameters for creating a [`Stream`].
452#[derive(Debug, Clone)]
453pub struct StreamParams {
454    pub world_size: usize,
455    pub rank: usize,
456    /// Controls how the underlying CUDA stream is created.
457    pub creation_mode: StreamCreationMode,
458    /// Id of this stream in the worker actor's stream table.
459    pub id: StreamRef,
460    /// Device this stream should be scheduled on. If none, don't do stream
461    /// synchronization.
462    pub device: Option<CudaDevice>,
463    /// Actor ref of the controller that created this stream.
464    pub controller_actor: ActorRef<ControllerActor>,
465    pub respond_with_python_message: bool,
466}
467
468#[async_trait]
469impl Actor for StreamActor {
470    type Params = StreamParams;
471    async fn new(
472        StreamParams {
473            world_size,
474            rank,
475            id: _,
476            device,
477            controller_actor,
478            creation_mode,
479            respond_with_python_message,
480        }: Self::Params,
481    ) -> Result<Self> {
482        Ok(Self {
483            world_size,
484            rank,
485            env: HashMap::new(),
486            creation_mode,
487            cuda_stream: OnceLock::new(),
488            device,
489            comm: None,
490            controller_actor,
491            remote_process_groups: HashMap::new(),
492            recordings: HashMap::new(),
493            active_recording: None,
494            respond_with_python_message,
495            last_seq_error: None,
496        })
497    }
498
499    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
500        // These thread locals are exposed via python functions, so we need to set them in the
501        // same thread that python will run in. That means we need to initialize them here in
502        // StreamActor::init instead of in StreamActor::new.
503        CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
504            controller_actor_ref.set(self.controller_actor.clone()).ok()
505        });
506        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
507        ROOT_ACTOR_ID.with(|root_actor_id| {
508            root_actor_id
509                .set(ActorId::root(
510                    cx.self_id().proc_id().clone(),
511                    cx.self_id().name().to_string(),
512                ))
513                .ok()
514        });
515        // Set the current stream for this actor thread.
516        if let Some(stream) = self.cuda_stream() {
517            Stream::set_current_stream(stream);
518        }
519        Ok(())
520    }
521
522    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
523    /// dedicated OS thread. This is because:
524    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
525    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
526    ///     threads is problematic.
527    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
528    where
529        F: Future + Send + 'static,
530        F::Output: Send + 'static,
531    {
532        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
533        // It is important that we spawn a standalone thread for the work here,
534        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
535        // This is because the worker stream may call uninterruptible FFI code
536        // that can deadlock (CUDA, NCCL).
537        // If we use a tokio-managed blocking thread, then runtime teardown will
538        // try to wait for tasks on that thread to reach an await point, and
539        // hang forever.
540        let builder = std::thread::Builder::new().name("worker-stream".to_string());
541        let _thread_handle = builder.spawn(move || {
542            // Spawn a new thread with a single-threaded tokio runtime to run the
543            // actor loop.  We avoid the current-threaded runtime, so that we can
544            // use `block_in_place` for nested async-to-sync-to-async flows.
545            let rt = tokio::runtime::Builder::new_multi_thread()
546                .worker_threads(1)
547                .enable_all()
548                .build()
549                .unwrap();
550            let result = rt.block_on(async {
551                tokio::task::block_in_place(|| {
552                    // Allow e.g. destructing py objects on this thread, which
553                    // can happen at shutdown when the a stream actors env map
554                    // for rvalues is dropped (e.g. P1673311499).
555                    // https://github.com/PyO3/pyo3/discussions/3499
556                    Python::with_gil(|py| {
557                        py.allow_threads(|| {
558                            let result = Handle::current().block_on(future);
559                            if join_tx.send(result).is_err() {
560                                panic!("could not send join result")
561                            }
562                        })
563                    })
564                })
565            });
566            rt.shutdown_timeout(Duration::from_weeks(1));
567            result
568        });
569
570        // In order to bridge the synchronous join handle with the async world,
571        // smuggle the result through a channel.
572        tokio::spawn(async move { join_rx.await.unwrap() })
573    }
574}
575
576/// The arguments we accept as inputs to Python function calls.
577#[derive(Debug)]
578enum PyArg<'a> {
579    RValue(RValue),
580    DeviceMesh(&'a DeviceMesh),
581    PyObject(PyObject),
582}
583
584/// Serialize into a `PyObject`.
585impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
586    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
587        match self {
588            // SAFETY: This inherits the unsafety of `rvalue_to_ivalue` (see comment
589            // above).
590            PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
591            PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
592            PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
593        }
594    }
595}
596
597impl StreamActor {
598    fn cuda_stream(&self) -> Option<&Stream> {
599        self.cuda_stream
600            .get_or_init(|| {
601                self.device.map(|device| match self.creation_mode {
602                    StreamCreationMode::UseDefaultStream => {
603                        Stream::get_current_stream_on_device(device)
604                    }
605                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
606                })
607            })
608            .as_ref()
609    }
610
611    fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
612        let rvalue = self
613            .env
614            .get(ref_)
615            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
616        match rvalue {
617            Ok(val) => Ok(val.clone()),
618            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
619        }
620    }
621
622    fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
623        let ret = match value {
624            WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
625            // TODO: We might want to support GenericList / GenericDict etc.
626            WireValue::RefList(val) => {
627                let mut ret = Vec::with_capacity(val.len());
628                for v in val {
629                    match self.ref_to_rvalue(&v) {
630                        Ok(RValue::Tensor(t)) => ret.push(t),
631                        Err(err) => {
632                            return Err(err);
633                        }
634                        Ok(val) => {
635                            return Err(CallFunctionError::UnsupportedArgType(
636                                "wire_to_rvalue".into(),
637                                format!("RefList([{:?}])", val),
638                            ));
639                        }
640                    }
641                }
642                RValue::TensorList(ret)
643            }
644            WireValue::Int(val) => RValue::Int(val),
645            WireValue::IntList(val) => RValue::IntList(val),
646            WireValue::Double(val) => RValue::Double(val),
647            WireValue::Bool(val) => RValue::Bool(val),
648            WireValue::String(val) => RValue::String(val),
649            WireValue::Device(val) => RValue::Device(val),
650            WireValue::Layout(val) => RValue::Layout(val),
651            WireValue::ScalarType(val) => RValue::ScalarType(val),
652            WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
653            WireValue::PyObject(val) => RValue::PyObject(val),
654            WireValue::None(()) => RValue::None,
655            WireValue::IValue(val) => RValue::Opaque(val.into()),
656        };
657        Ok(ret)
658    }
659
660    async fn report_seq_error(
661        &mut self,
662        cx: &Context<'_, Self>,
663        seq: Seq,
664        error: CallFunctionError,
665    ) -> Result<Arc<SeqError>, anyhow::Error> {
666        match error {
667            CallFunctionError::DependentError(root) => Ok(root),
668            CallFunctionError::Error(e) => {
669                if self.active_recording.is_none() {
670                    let worker_error = WorkerError {
671                        backtrace: format!("{e}"),
672                        worker_actor_id: cx.self_id().clone(),
673                    };
674                    tracing::info!("Propagating remote function error to client: {worker_error}");
675                    self.controller_actor
676                        .remote_function_failed(cx, seq, worker_error)
677                        .await?
678                }
679                let err = Arc::new(SeqError { seq, error: e });
680                self.last_seq_error = Some(err.clone());
681                Ok(err)
682            }
683        }
684    }
685
686    async fn try_define<F>(
687        &mut self,
688        cx: &Context<'_, Self>,
689        seq: Seq,
690        result_refs: Vec<Option<Ref>>,
691        mutates: &Vec<Ref>,
692        f: F,
693    ) -> Result<()>
694    where
695        F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
696    {
697        let actual_results = f(self).await;
698        // Check if the expected number of returns is correct, otherwise convert
699        // into an error.
700        let op_results = actual_results.and_then(|actual_results| {
701            if result_refs.len() == actual_results.len() {
702                Ok(actual_results
703                    .into_iter()
704                    .zip(result_refs.iter())
705                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
706                    .collect::<Vec<(Ref, RValue)>>())
707            } else {
708                Err(CallFunctionError::UnexpectedNumberOfReturns(
709                    result_refs.len(),
710                    actual_results.len(),
711                ))
712            }
713        });
714
715        // Propagate the results (either the actual values or an error) to the
716        // right entries in the global env mapping.
717        match op_results {
718            Ok(op_results) => {
719                for (ref_, rvalue) in op_results.into_iter() {
720                    let prev = self.env.insert(ref_, Ok(rvalue));
721                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
722                }
723            }
724            Err(err) => {
725                let err = self.report_seq_error(cx, seq, err).await?;
726                for ref_ in result_refs {
727                    match ref_ {
728                        Some(ref_) => {
729                            let prev = self.env.insert(ref_, Err(err.clone()));
730                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
731                        }
732                        None => {}
733                    }
734                }
735                for ref_ in mutates {
736                    self.env.insert(*ref_, Err(err.clone()));
737                }
738            }
739        }
740        Ok(())
741    }
742
743    fn call_torch_op(
744        &self,
745        op: String,
746        overload: String,
747        args: Vec<WireValue>,
748        kwargs: HashMap<String, WireValue>,
749    ) -> Result<Vec<RValue>, CallFunctionError> {
750        let args = args
751            .into_iter()
752            .map(|arg| self.wire_to_rvalue(arg))
753            .collect::<Result<Vec<_>, _>>()?;
754        let kwargs = kwargs
755            .into_iter()
756            .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
757            .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
758
759        let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
760
761        // Handle the case where the op returns nothing and convert it to a list of None.
762        // This is to ensure handle results does not error out as the client will call
763        // such a function with expected results of size 1.
764        Ok(if results.is_empty() {
765            vec![RValue::None]
766        } else {
767            results
768        })
769    }
770
771    fn call_python_fn<'py>(
772        &mut self,
773        py: Python<'py>,
774        cx: &Context<Self>,
775        function: Option<ResolvableFunction>,
776        args: Vec<WireValue>,
777        kwargs: HashMap<String, WireValue>,
778        mutates: &[Ref],
779        device_meshes: HashMap<Ref, DeviceMesh>,
780        remote_process_groups: HashMap<
781            Ref,
782            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
783        >,
784    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
785        let function = function
786            .map(|function| {
787                function.resolve(py).map_err(|e| {
788                    CallFunctionError::InvalidRemoteFunction(format!(
789                        "failed to resolve function {}: {}",
790                        function,
791                        SerializablePyErr::from(py, &e)
792                    ))
793                })
794            })
795            .transpose()?;
796
797        let remote_process_groups = remote_process_groups
798            .into_iter()
799            .map(|(gref, (mesh, dims, comm))| {
800                let group = match self.remote_process_groups.entry(gref) {
801                    Entry::Occupied(ent) => ent.get().clone_ref(py),
802                    Entry::Vacant(ent) => {
803                        // We need to run `init_process_group` before any
804                        // remote process groups can get created.
805                        torch_sys::backend::ensure_init_process_group(
806                            py,
807                            self.world_size,
808                            self.rank,
809                        )?;
810
811                        // Create a backend object to wrap the comm and use
812                        // it to create a new torch group.
813                        let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
814                        let group_size = ranks.len();
815                        let (child_instance, _) = cx.child()?;
816                        let backend = CommBackend::new(
817                            child_instance,
818                            comm,
819                            self.rank,
820                            group_size,
821                            self.world_size,
822                        );
823                        ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
824                            .clone_ref(py)
825                    }
826                };
827                PyResult::Ok((gref, group))
828            })
829            .collect::<Result<HashMap<_, _>, _>>()
830            .map_err(SerializablePyErr::from_fn(py))?;
831
832        // SAFETY: We will be making an unchecked clone of each tensor to pass to to
833        // C++, so we need to hold a borrow of each input tensor for the duration of
834        // this function.
835        let mut multiborrow = MultiBorrow::new();
836
837        let resolve = |val: WireValue| {
838            val.into_py_object()
839                .map_err(|e| {
840                    CallFunctionError::UnsupportedArgType(
841                        format!("{:?}", function),
842                        format!("{:?}", e),
843                    )
844                })?
845                .unpickle(py)
846                .map_err(SerializablePyErr::from_fn(py))?
847                .extract::<PyTree<PyObject>>()
848                .map_err(SerializablePyErr::from_fn(py))?
849                .try_into_map(|obj| {
850                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
851                        if let Some(mesh) = device_meshes.get(&ref_) {
852                            PyArg::DeviceMesh(mesh)
853                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
854                            PyArg::PyObject(pg.clone_ref(py))
855                        } else {
856                            let rval = self.ref_to_rvalue(&ref_)?;
857                            PyArg::RValue(rval)
858                        }
859                    } else {
860                        PyArg::PyObject(obj)
861                    })
862                })
863        };
864
865        // Resolve refs
866        let py_args: Vec<PyTree<PyArg>> = args
867            .into_iter()
868            .map(resolve)
869            .collect::<Result<_, CallFunctionError>>()?;
870        let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
871            .into_iter()
872            .map(|(k, object)| Ok((k, resolve(object)?)))
873            .collect::<Result<_, CallFunctionError>>()?;
874
875        // Add a shared-borrow for each rvalue reference.
876        py_args
877            .iter()
878            .chain(py_kwargs.values())
879            .flat_map(|o| o.iter())
880            .for_each(|arg| {
881                if let PyArg::RValue(rval) = arg {
882                    multiborrow.add(rval, BorrowType::Shared);
883                }
884            });
885
886        // Add mutable borrows for params we're mutating.
887        let mutates: Vec<_> = mutates
888            .iter()
889            .map(|r| self.ref_to_rvalue(r))
890            .collect::<Result<_, CallFunctionError>>()?;
891        mutates
892            .iter()
893            .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
894
895        // Execute the borrow.
896        let _borrow = multiborrow.borrow()?;
897
898        // Call function.
899        // Use custom subscriber to route Worker messages to stdout.
900        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
901        let result: Bound<'_, PyAny> =
902            tracing::subscriber::with_default(scoped_subscriber, || {
903                // SAFETY: The borrows above guard the unchecked clones done by
904                // `rvalue_to_ivalue`. This may result in multiple mutable
905                // references to tensor data, but the Python side is responsible
906                // for making sure that is safe
907                // TODO(agallagher): The args/kwargs conversion traits generate
908                // the appropriate types here, but they get casted to `PyAny`.
909                // It'd be nice to make `TryToPyObjectUnsafe` take a template
910                // arg for the converted py object to avoid this downcast.
911                let args = unsafe { py_args.try_to_object_unsafe(py) }
912                    .map_err(SerializablePyErr::from_fn(py))?;
913                // SAFETY: above
914                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
915                    .map_err(SerializablePyErr::from_fn(py))?;
916
917                if let Some(function) = function {
918                    function
919                        .call(args, Some(kwargs))
920                        .map_err(SerializablePyErr::from_fn(py))
921                } else {
922                    Ok(args.get_item(0).unwrap())
923                }
924            })?;
925        Ok(result)
926    }
927
928    fn call_python_fn_pytree(
929        &mut self,
930        cx: &hyperactor::Context<Self>,
931        function: ResolvableFunction,
932        args: Vec<WireValue>,
933        kwargs: HashMap<String, WireValue>,
934        mutates: &[Ref],
935        device_meshes: HashMap<Ref, DeviceMesh>,
936        remote_process_groups: HashMap<
937            Ref,
938            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
939        >,
940    ) -> Result<PyTree<RValue>, CallFunctionError> {
941        Python::with_gil(|py| {
942            let result = self.call_python_fn(
943                py,
944                cx,
945                Some(function),
946                args,
947                kwargs,
948                mutates,
949                device_meshes,
950                remote_process_groups,
951            )?;
952            Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
953        })
954    }
955    /// Retrieve `ref_` or create a fake value with the provided factory if it
956    /// is an error. We use this for collective calls, where even if there was
957    /// an upstream failure, we still have participate in the collective to
958    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
959    /// value here of the correct shape; the controller will have been notified
960    /// of the upstream failure and will know to ignore everything dependent on
961    /// it.
962    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
963        let rvalue = self
964            .env
965            .get(&ref_)
966            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
967
968        match rvalue {
969            Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
970            Err(_) => {
971                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
972                Ok(TensorCell::new(t))
973            }
974        }
975    }
976
977    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
978        self.active_recording
979            .as_mut()
980            .and_then(|state| match state {
981                RecordingState::Defining {
982                    recording,
983                    defined_borrows,
984                } => {
985                    match self.recordings.get_mut(recording) {
986                        Some(recording) => Some((recording, defined_borrows)),
987                        // Panic, because this would be a logic error in the program.
988                        None => panic!("recording not found: {:?}", recording),
989                    }
990                }
991                RecordingState::Running => None,
992            })
993    }
994
995    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
996        for ref_ in refs {
997            let rvalue_or_err = self
998                .env
999                .get(ref_)
1000                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1001            if let Err(err) = rvalue_or_err {
1002                return Ok(Some(err.clone()));
1003            }
1004        }
1005        Ok(None)
1006    }
1007    async fn send_value_python_message(
1008        &mut self,
1009        cx: &hyperactor::Context<'_, Self>,
1010        seq: Seq,
1011        mutates: Vec<Ref>,
1012        function: Option<ResolvableFunction>,
1013        args: Vec<WireValue>,
1014        kwargs: HashMap<String, WireValue>,
1015        device_meshes: HashMap<Ref, DeviceMesh>,
1016    ) -> Result<()> {
1017        let rank = self.rank;
1018        self.try_define(cx, seq, vec![], &vec![], async |self_| {
1019            let python_message =
1020                Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1021                    let python_result = tokio::task::block_in_place(|| {
1022                        self_.call_python_fn(
1023                            py,
1024                            cx,
1025                            function,
1026                            args,
1027                            kwargs,
1028                            &mutates,
1029                            device_meshes,
1030                            HashMap::new(),
1031                        )
1032                    })?;
1033                    pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
1034                })?;
1035            let ser = Serialized::serialize(&python_message).unwrap();
1036            self_
1037                .controller_actor
1038                .fetch_result(cx, seq, Ok(ser))
1039                .await?;
1040            Ok(vec![])
1041        })
1042        .await
1043    }
1044    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1045        let rvalue = self
1046            .env
1047            .get(&src)
1048            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1049        self.env.insert(dest, rvalue.clone());
1050        Ok(())
1051    }
1052    async fn call_actor(
1053        &mut self,
1054        cx: &Context<'_, Self>,
1055        params: ActorCallParams,
1056    ) -> Result<PyObject, CallFunctionError> {
1057        let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1058            params
1059                .local_state
1060                .into_iter()
1061                .map(|elem| {
1062                    // SAFETY: python is gonna make unsafe copies of this stuff anyway
1063                    unsafe {
1064                        let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1065                        Ok(x)
1066                    }
1067                })
1068                .collect()
1069        });
1070
1071        let (send, recv) = cx.open_once_port();
1072        let state = LocalState {
1073            response_port: send,
1074            state: local_state?,
1075        };
1076        let x: u64 = params.seq.into();
1077        let message = LocalStateBrokerMessage::Set(x as usize, state);
1078
1079        let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1080        broker
1081            .send(message)
1082            .map_err(|e| CallFunctionError::Error(e.into()))?;
1083        let result = recv
1084            .recv()
1085            .await
1086            .map_err(|e| CallFunctionError::Error(e.into()))?;
1087
1088        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1089    }
1090}
1091
1092#[async_trait]
1093#[forward(StreamMessage)]
1094impl StreamMessageHandler for StreamActor {
1095    async fn call_function(
1096        &mut self,
1097        cx: &Context<Self>,
1098        params: CallFunctionParams,
1099        device_meshes: HashMap<Ref, DeviceMesh>,
1100        remote_process_groups: HashMap<
1101            Ref,
1102            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1103        >,
1104    ) -> Result<()> {
1105        if let Some((recording, _)) = self.get_defining_recording() {
1106            recording.messages.push(StreamMessage::CallFunction(
1107                params,
1108                device_meshes,
1109                remote_process_groups,
1110            ));
1111            return Ok(());
1112        }
1113
1114        params.function.panic_if_requested();
1115        self.try_define(
1116            cx,
1117            params.seq,
1118            params.results,
1119            &params.mutates,
1120            async |self| {
1121                tokio::task::block_in_place(|| match params.function.as_torch_op() {
1122                    Some((op, overload)) => {
1123                        self.call_torch_op(op, overload, params.args, params.kwargs)
1124                    }
1125                    _ => self
1126                        .call_python_fn_pytree(
1127                            cx,
1128                            params.function,
1129                            params.args,
1130                            params.kwargs,
1131                            &params.mutates,
1132                            device_meshes,
1133                            remote_process_groups,
1134                        )
1135                        .map(|results| results.into_leaves()),
1136                })
1137            },
1138        )
1139        .await?;
1140        Ok(())
1141    }
1142
1143    async fn borrow_create(
1144        &mut self,
1145        _cx: &Context<Self>,
1146        borrow: u64,
1147        tensor: Ref,
1148        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1149    ) -> Result<()> {
1150        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1151            recording.messages.push(StreamMessage::BorrowCreate {
1152                borrow,
1153                tensor,
1154                first_use_sender,
1155            });
1156            ensure!(
1157                defined_borrows.insert(borrow),
1158                "duplicate borrow create in recording"
1159            );
1160            return Ok(());
1161        }
1162
1163        let rvalue_result = self
1164            .env
1165            .get(&tensor)
1166            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1167
1168        let result = match rvalue_result {
1169            Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1170            Err(e) => Err(e.clone()),
1171        };
1172
1173        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1174        first_use_sender.send((event, result)).map_err(|err| {
1175            anyhow!(
1176                "failed sending first use event for borrow {:?}: {:?}",
1177                borrow,
1178                err
1179            )
1180        })
1181    }
1182
1183    async fn borrow_first_use(
1184        &mut self,
1185        _cx: &Context<Self>,
1186        borrow: u64,
1187        result: Ref,
1188        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1189    ) -> Result<()> {
1190        if let Some((recording, _)) = self.get_defining_recording() {
1191            recording.messages.push(StreamMessage::BorrowFirstUse {
1192                borrow,
1193                result,
1194                first_use_receiver: first_use_receiver.clone(),
1195            });
1196            return Ok(());
1197        }
1198
1199        let (first_use_event, cell) =
1200            first_use_receiver
1201                .lock()
1202                .await
1203                .recv()
1204                .await
1205                .map_err(|err| {
1206                    anyhow!(
1207                        "failed receiving first use event for borrow {:?}: {:?}",
1208                        borrow,
1209                        err
1210                    )
1211                })?;
1212
1213        if let Some(stream) = self.cuda_stream() {
1214            stream.wait_event(
1215                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1216            );
1217        }
1218        match cell {
1219            Ok(cell) => {
1220                self.env.insert(result, Ok(cell.into()));
1221            }
1222            Err(err) => {
1223                self.env.insert(result, Err(err.clone()));
1224            }
1225        }
1226        Ok(())
1227    }
1228
1229    async fn borrow_last_use(
1230        &mut self,
1231        _cx: &Context<Self>,
1232        borrow: u64,
1233        result: Ref,
1234        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1235    ) -> Result<()> {
1236        if let Some((recording, _)) = self.get_defining_recording() {
1237            recording.messages.push(StreamMessage::BorrowLastUse {
1238                borrow,
1239                result,
1240                last_use_sender,
1241            });
1242            return Ok(());
1243        }
1244
1245        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1246        let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1247            "Invalid reference for borrow_last_use: {result:#?}"
1248        ))?;
1249        let tensor = match rvalue_or_err {
1250            Ok(RValue::Tensor(t)) => Ok(t),
1251            Err(e) => Err(e),
1252            _ => bail!("invalid rvalue type for borrow_last_use"),
1253        };
1254
1255        last_use_sender.send((event, tensor)).map_err(|err| {
1256            anyhow!(
1257                "failed sending last use event for borrow {:?}: {:?}",
1258                borrow,
1259                err
1260            )
1261        })
1262    }
1263
1264    async fn borrow_drop(
1265        &mut self,
1266        _cx: &Context<Self>,
1267        borrow: u64,
1268        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1269    ) -> Result<()> {
1270        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1271            recording.messages.push(StreamMessage::BorrowDrop {
1272                borrow,
1273                last_use_receiver: last_use_receiver.clone(),
1274            });
1275            ensure!(
1276                defined_borrows.remove(&borrow),
1277                "borrow drop for borrow not defined in recording"
1278            );
1279            return Ok(());
1280        }
1281
1282        // The borrowed cell isn't used directly, but we still want to receive it here
1283        // so that the underlying tensor isn't dropped until after we synchronize the
1284        // CUDA streams.
1285        let (last_use_event, _cell) =
1286            last_use_receiver.lock().await.recv().await.map_err(|err| {
1287                anyhow!(
1288                    "failed receiving last use event for borrow {:?}: {:?}",
1289                    borrow,
1290                    err
1291                )
1292            })?;
1293
1294        if let Some(stream) = self.cuda_stream() {
1295            stream.wait_event(
1296                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1297            );
1298        }
1299        // let the cell drop.
1300        Ok(())
1301    }
1302
1303    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1304        if let Some((recording, _)) = self.get_defining_recording() {
1305            recording.messages.push(StreamMessage::DeleteRefs(refs));
1306            return Ok(());
1307        }
1308
1309        for ref_ in refs.iter() {
1310            self.env.remove(ref_);
1311        }
1312        Ok(())
1313    }
1314
1315    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1316        if self.get_defining_recording().is_some() {
1317            bail!("request_status not allowed in recording");
1318        }
1319
1320        Ok(())
1321    }
1322
1323    async fn init_comm(
1324        &mut self,
1325        _cx: &Context<Self>,
1326        comm: ActorHandle<NcclCommActor>,
1327    ) -> Result<()> {
1328        if self.get_defining_recording().is_some() {
1329            bail!("init_comm not allowed in recording");
1330        }
1331
1332        self.comm = Some(comm);
1333        Ok(())
1334    }
1335
1336    async fn reduce(
1337        &mut self,
1338        cx: &Context<Self>,
1339        comm: Arc<ActorHandle<NcclCommActor>>,
1340        dim_size: i64,
1341        result: Ref,
1342        local_tensor: Ref,
1343        factory: Factory,
1344        reduction: Reduction,
1345        scatter: bool,
1346        in_place: bool,
1347        out: Option<Ref>,
1348    ) -> Result<()> {
1349        if let Some((recording, _)) = self.get_defining_recording() {
1350            recording.messages.push(StreamMessage::Reduce {
1351                comm,
1352                dim_size,
1353                result,
1354                local_tensor,
1355                factory,
1356                reduction,
1357                scatter,
1358                in_place,
1359                out,
1360            });
1361            return Ok(());
1362        }
1363
1364        let stream = self
1365            .cuda_stream()
1366            .expect("reductions not yet supported for non-CUDA workers")
1367            .clone();
1368        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1369        let out_cell = out
1370            .map(|out| self.get_or_fake_on_err(out, &factory))
1371            .transpose()?;
1372        let output_cell = match reduction {
1373            Reduction::Stack => {
1374                if scatter {
1375                    let output_cell = if in_place {
1376                        input_cell.clone()
1377                    } else {
1378                        out_cell.unwrap_or({
1379                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1380                            let cloned = deep_clone(&borrow);
1381                            TensorCell::new(cloned)
1382                        })
1383                    };
1384                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1385                        .await?;
1386                    output_cell
1387                } else {
1388                    ensure!(
1389                        !in_place,
1390                        "in-place, non-scatter not supported for stack reduce"
1391                    );
1392
1393                    let output_cell = out_cell.unwrap_or({
1394                        // In Python, this would be [dim_size, *factory.sizes]
1395                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1396                        let output =
1397                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1398                        TensorCell::new(output)
1399                    });
1400
1401                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1402                        .await?;
1403                    output_cell
1404                }
1405            }
1406            Reduction::ReduceOp(op) => {
1407                if scatter {
1408                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1409
1410                    let output_cell = out_cell.unwrap_or({
1411                        let output = factory_empty(
1412                            &factory.size[1..],
1413                            factory.dtype,
1414                            factory.layout,
1415                            factory.device,
1416                        );
1417                        TensorCell::new(output)
1418                    });
1419                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1420                        .await?;
1421                    output_cell
1422                } else {
1423                    let output_cell = if in_place {
1424                        input_cell.clone()
1425                    } else {
1426                        out_cell.map_or(
1427                            {
1428                                let borrow =
1429                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1430                                let cloned = deep_clone(&borrow);
1431                                Ok(TensorCell::new(cloned))
1432                            },
1433                            |out_cell| -> Result<_, anyhow::Error> {
1434                                let mut out_borrow =
1435                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1436                                let in_borrow =
1437                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1438                                out_borrow.copy_(&in_borrow);
1439                                drop(out_borrow);
1440                                Ok(out_cell)
1441                            },
1442                        )?
1443                    };
1444
1445                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1446                    output_cell
1447                }
1448            }
1449        };
1450
1451        self.env.insert(result, Ok(output_cell.into()));
1452        Ok(())
1453    }
1454
1455    async fn send_tensor(
1456        &mut self,
1457        cx: &Context<Self>,
1458        result: Ref,
1459        from_rank: Option<usize>,
1460        to_rank: Option<usize>,
1461        tensor: Ref,
1462        factory: Factory,
1463        comm: Arc<ActorHandle<NcclCommActor>>,
1464    ) -> Result<()> {
1465        if let Some((recording, _)) = self.get_defining_recording() {
1466            recording.messages.push(StreamMessage::SendTensor {
1467                result,
1468                from_rank,
1469                to_rank,
1470                tensor,
1471                factory,
1472                comm,
1473            });
1474            return Ok(());
1475        }
1476
1477        if to_rank.is_none() && from_rank.is_none() {
1478            bail!("tried to send tensor without a to/from rank");
1479        }
1480
1481        // Value is local, so we do not have to actually send it.
1482        if from_rank == to_rank {
1483            let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1484                .env
1485                .get(&tensor)
1486                .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1487            let output_cell = match input_cell {
1488                Ok(RValue::Tensor(input_cell)) => {
1489                    // We create a defensive copy here to prevent mutations on
1490                    // the input tensor from affecting output tensor.
1491                    // Should we copy if input ref == output ref?
1492                    // Should we support copy-on-write to avoid unnecessary copy?
1493                    let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1494                    let cloned = deep_clone(&borrow);
1495                    Ok(RValue::Tensor(TensorCell::new(cloned)))
1496                }
1497                Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1498                Err(err) => Err(err.clone()),
1499            };
1500            self.env.insert(result, output_cell);
1501            return Ok(());
1502        }
1503
1504        let mut messages = Vec::new();
1505
1506        if let Some(to_rank) = to_rank {
1507            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1508            messages.push(CommMessage::Send(
1509                input_cell,
1510                to_rank.try_into().unwrap(),
1511                self.cuda_stream()
1512                    .expect("tried to send_tensor on non-cuda stream")
1513                    .clone(),
1514                cx.open_once_port().0,
1515            ));
1516        }
1517
1518        if let Some(from_rank) = from_rank {
1519            let output_cell = TensorCell::new(factory_empty(
1520                &factory.size,
1521                factory.dtype,
1522                factory.layout,
1523                factory.device,
1524            ));
1525            messages.push(CommMessage::Recv(
1526                output_cell.clone(),
1527                from_rank.try_into().unwrap(),
1528                self.cuda_stream()
1529                    .expect("tried to send_tensor on non-cuda stream")
1530                    .clone(),
1531                cx.open_once_port().0,
1532            ));
1533            self.env.insert(result, Ok(output_cell.into()));
1534        }
1535
1536        comm.group(
1537            cx,
1538            messages,
1539            self.cuda_stream()
1540                .expect("tried to send_tensor on non-cuda stream")
1541                .clone(),
1542        )
1543        .await?;
1544        Ok(())
1545    }
1546
1547    async fn send_value(
1548        &mut self,
1549        cx: &Context<Self>,
1550        seq: Seq,
1551        worker_actor_id: ActorId,
1552        mutates: Vec<Ref>,
1553        function: Option<ResolvableFunction>,
1554        args: Vec<WireValue>,
1555        kwargs: HashMap<String, WireValue>,
1556        device_meshes: HashMap<Ref, DeviceMesh>,
1557    ) -> Result<()> {
1558        if self.respond_with_python_message {
1559            return self
1560                .send_value_python_message(cx, seq, mutates, function, args, kwargs, device_meshes)
1561                .await;
1562        }
1563        let result = if let Some(function) = function {
1564            // If a function was provided, use that to resolve the value.
1565            match function.as_torch_op() {
1566                Some((op, overload)) => {
1567                    self.call_torch_op(op, overload, args, kwargs)
1568                        .map(|rvalues| {
1569                            if rvalues.len() == 1 {
1570                                Ok(rvalues[0].clone().into())
1571                            } else {
1572                                // TODO: Replace with native pytrees when possible
1573                                Python::with_gil(|py| {
1574                                    Ok((|| {
1575                                        let py_rvalues = rvalues
1576                                            .into_iter()
1577                                            // SAFETY: This inherits the unsafety of `try_to_object_unsafe`.
1578                                            .map(|rvalue| unsafe {
1579                                                rvalue.try_to_object_unsafe(py)
1580                                            })
1581                                            .collect::<Result<Vec<_>, _>>()?;
1582                                        PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1583                                    })()
1584                                    .map_err(SerializablePyErr::from_fn(py))?)
1585                                })
1586                            }
1587                        })?
1588                }
1589                // Use block-in-place to allow nested callbacks to re-enter the
1590                // runtime to run async code.
1591                _ => tokio::task::block_in_place(|| {
1592                    self.call_python_fn_pytree(
1593                        cx,
1594                        function,
1595                        args,
1596                        kwargs,
1597                        &mutates,
1598                        device_meshes,
1599                        HashMap::new(),
1600                    )
1601                }),
1602            }
1603        } else {
1604            // If there's no function provided, there should be exactly one arg
1605            // and no kwargs.
1606            match (args.len(), kwargs.len()) {
1607                (1, 0) => Python::with_gil(|py| {
1608                    let arg = args[0]
1609                        .as_py_object()
1610                        .ok_or_else(|| {
1611                            CallFunctionError::UnsupportedArgType(
1612                                "send_value".to_string(),
1613                                "expected a PyObject as the first arg".to_string(),
1614                            )
1615                        })?
1616                        .unpickle(py)
1617                        .map_err(SerializablePyErr::from_fn(py))?;
1618                    arg.extract::<PyTree<PyObject>>()
1619                        .map_err(SerializablePyErr::from_fn(py))?
1620                        .try_into_map(|obj| {
1621                            let bound_obj = obj.bind(py);
1622                            if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1623                                self.ref_to_rvalue(&ref_)
1624                            } else {
1625                                Ok(bound_obj
1626                                    .extract::<RValue>()
1627                                    .map_err(SerializablePyErr::from_fn(py))?)
1628                            }
1629                        })
1630                }),
1631                _ => Err(CallFunctionError::TooManyArgsForValue(
1632                    format!("{:?}", args),
1633                    format!("{:?}", kwargs),
1634                )),
1635            }
1636        };
1637
1638        let value = match result {
1639            Ok(rvalue) => {
1640                // When returning a tensor, we copy out to decouple from the GPU,
1641                // as the worker will either serialize and send this to the controller
1642                // or to a pipe and we see hangs if it tries to pull from the GPU
1643                // in its thread.
1644                Ok(rvalue.into_map(|rval| match rval {
1645                    RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1646                    RValue::TensorList(tensors) => RValue::TensorList(
1647                        tensors
1648                            .into_iter()
1649                            .map(|tensor| tensor.try_cpu().unwrap())
1650                            .collect(),
1651                    ),
1652                    rval => rval,
1653                }))
1654            }
1655            Err(err) => {
1656                let err = self.report_seq_error(cx, seq, err).await?;
1657                for ref_ in mutates {
1658                    self.env.insert(ref_, Err(err.clone()));
1659                }
1660                Err(WorkerError {
1661                    backtrace: format!("{:?}", err),
1662                    worker_actor_id,
1663                })
1664            }
1665        };
1666
1667        // Actually send the value.
1668        let result = match value {
1669            Ok(value) => Ok(Serialized::serialize(&value).map_err(anyhow::Error::from)?),
1670            Err(e) => Err(e),
1671        };
1672        self.controller_actor.fetch_result(cx, seq, result).await?;
1673
1674        Ok(())
1675    }
1676
1677    async fn send_result_of_actor_call(
1678        &mut self,
1679        cx: &Context<Self>,
1680        params: ActorCallParams,
1681    ) -> anyhow::Result<()> {
1682        let seq = params.seq;
1683        let mutates = params.mutates.clone();
1684        self.try_define(cx, seq, vec![], &mutates, async |self| {
1685            let value = self.call_actor(cx, params).await?;
1686            let result =
1687                Python::with_gil(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1688            let result = Serialized::serialize(&result).unwrap();
1689            self.controller_actor
1690                .fetch_result(cx, seq, Ok(result))
1691                .await?;
1692            Ok(vec![])
1693        })
1694        .await
1695    }
1696
1697    async fn call_actor_method(
1698        &mut self,
1699        cx: &Context<Self>,
1700        params: ActorMethodParams,
1701    ) -> anyhow::Result<()> {
1702        let seq = params.call.seq;
1703        let mutates = params.call.mutates.clone();
1704        self.try_define(cx, seq, params.results, &mutates, async |self| {
1705            let result = self.call_actor(cx, params.call).await?;
1706            let result = Python::with_gil(|py| {
1707                PyTree::<RValue>::extract_bound(&result.into_bound(py))
1708                    .map_err(SerializablePyErr::from_fn(py))
1709            })?;
1710            Ok(result.into_leaves())
1711        })
1712        .await
1713    }
1714
1715    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1716        if self.active_recording.is_some() {
1717            bail!("different recording already active");
1718        }
1719        match self.recordings.entry(recording) {
1720            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1721            Entry::Vacant(entry) => entry.insert(Recording::new()),
1722        };
1723        self.active_recording = Some(RecordingState::Defining {
1724            recording,
1725            defined_borrows: HashSet::new(),
1726        });
1727        Ok(())
1728    }
1729
1730    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1731        match self.active_recording {
1732            Some(RecordingState::Defining {
1733                recording: active_recording,
1734                ref defined_borrows,
1735            }) if active_recording == recording => {
1736                ensure!(
1737                    defined_borrows.is_empty(),
1738                    "all borrows created within recording must be dropped within recording"
1739                );
1740                self.active_recording = None;
1741            }
1742            _ => bail!("cannot finalize recording that isn't active"),
1743        }
1744        Ok(())
1745    }
1746
1747    async fn recording_formal(
1748        &mut self,
1749        _cx: &Context<Self>,
1750        result: Ref,
1751        argument_index: usize,
1752    ) -> Result<()> {
1753        match self.get_defining_recording() {
1754            Some((recording, _)) => {
1755                recording.messages.push(StreamMessage::RecordingFormal {
1756                    result,
1757                    argument_index,
1758                });
1759            }
1760            None => bail!("recording_formal called outside of recording"),
1761        };
1762        Ok(())
1763    }
1764
1765    async fn recording_result(
1766        &mut self,
1767        _cx: &Context<Self>,
1768        result: Ref,
1769        output_index: usize,
1770    ) -> Result<()> {
1771        match self.get_defining_recording() {
1772            Some((recording, _)) => {
1773                recording.messages.push(StreamMessage::RecordingResult {
1774                    result,
1775                    output_index,
1776                });
1777            }
1778            None => bail!("recording_result called outside of recording"),
1779        };
1780        Ok(())
1781    }
1782
1783    async fn call_recording(
1784        &mut self,
1785        cx: &Context<Self>,
1786        seq: Seq,
1787        recording: Ref,
1788        results: Vec<Ref>,
1789        actuals: Vec<Ref>,
1790    ) -> Result<()> {
1791        if self.active_recording.is_some() {
1792            bail!("cannot call recording while another recording is active");
1793        }
1794
1795        let messages = match self.recordings.get(&recording) {
1796            Some(recording) => recording
1797                .messages
1798                .iter()
1799                .map(|message| message.clone_for_recording())
1800                .collect::<Vec<_>>(),
1801            None => bail!("recording {:?} not found", recording),
1802        };
1803
1804        self.active_recording = Some(RecordingState::Running);
1805
1806        // Global error for all messages in the recording. The first time a message
1807        // fails in the recording, we set the error. We then need to propagate this
1808        // error to all of the refs mutated by the entire recording, as well as the
1809        // result refs.
1810        let mut error: Option<Arc<SeqError>> = None;
1811        // The set of all refs defined by this recording (excluding "results"),
1812        // which we need to ensure are deleted when the recording is done executing.
1813        let mut all_defined_refs = HashSet::new();
1814        // The set of all refs mutated by this recording. If there is an error with
1815        // any message, all of these refs need to have the correct error set.
1816        let mut all_mutated_refs = HashSet::new();
1817        // Map from the result ref of a RecordingFormal message to the associated
1818        // actual ref from "actuals". We need to track this in order to properly
1819        // handle recordings that mutate refs contained in "actuals" -- every
1820        // message in the recording that interacts with the recording inputs will
1821        // interact with the formal ref rather than the actual ref.
1822        let mut formal_to_actual_refs = HashMap::new();
1823        // clear any pre-existing error messages before recording started
1824        self.last_seq_error = None;
1825        for message in messages.into_iter() {
1826            let defined_refs = message.get_defined_refs();
1827            all_defined_refs.extend(defined_refs.clone());
1828
1829            let mutated_refs_with_formals = message.get_mutated_refs();
1830            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1831                match formal_to_actual_refs.get(ref_) {
1832                    Some(actual_ref) => Some(*actual_ref),
1833                    None => {
1834                        if all_defined_refs.contains(ref_) {
1835                            None
1836                        } else {
1837                            Some(*ref_)
1838                        }
1839                    }
1840                }
1841            }));
1842
1843            match message {
1844                StreamMessage::RecordingFormal {
1845                    result: formal_ref,
1846                    argument_index,
1847                } => match actuals.get(argument_index) {
1848                    None => bail!("recording_formal called with too few arguments"),
1849                    Some(actual_ref) => {
1850                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1851                        self.define_ref(formal_ref, *actual_ref)?;
1852                    }
1853                },
1854                StreamMessage::RecordingResult {
1855                    result: result_ref,
1856                    output_index,
1857                } => match results.get(output_index) {
1858                    None => bail!("recording_result called with too few results"),
1859                    Some(actual_result_ref) => {
1860                        self.define_ref(*actual_result_ref, result_ref)?;
1861                    }
1862                },
1863                StreamMessage::DeleteRefs(ref refs) => {
1864                    for ref_ in refs {
1865                        all_defined_refs.remove(ref_);
1866                    }
1867                    StreamMessageHandler::handle(self, cx, message).await?;
1868                }
1869                StreamMessage::CallFunction { .. } if error.is_some() => {
1870                    // CallFunction is expensive. If the recording already failed, then
1871                    // just update the necessary refs with the error. Most of the other
1872                    // message types need to run regardless because there are other actors
1873                    // that expect the call to happen (e.g., all of the borrow messages,
1874                    // pipe send/recv, send_tensor, reduce, etc.).
1875                    let error = error.clone().unwrap();
1876                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1877                        self.env.insert(*ref_, Err(error.clone()));
1878                    }
1879                }
1880                StreamMessage::BorrowLastUse { ref result, .. } => {
1881                    all_defined_refs.remove(result);
1882                    StreamMessageHandler::handle(self, cx, message).await?;
1883                }
1884                StreamMessage::Reduce {
1885                    local_tensor,
1886                    ref out,
1887                    ..
1888                } => {
1889                    // Reduce doesn't propagate errors to the result ref, so we need
1890                    // to check for existing errors on the input tensors and set the
1891                    // recording's error if necessary.
1892                    if error.is_none() {
1893                        let inputs_to_check = [Some(local_tensor), out.clone()]
1894                            .iter()
1895                            .filter_map(|r| *r)
1896                            .collect::<Vec<_>>();
1897                        error = self.get_first_error(inputs_to_check.as_slice())?;
1898                    }
1899                    StreamMessageHandler::handle(self, cx, message).await?;
1900                }
1901                StreamMessage::SendTensor {
1902                    ref tensor,
1903                    ref to_rank,
1904                    ..
1905                } => {
1906                    // If this rank is sending a tensor (e.g., to_rank has a value),
1907                    // we need to check for existing errors on the input tensor, because
1908                    // the error is only propagated to the result ref when this rank
1909                    // is also receiving a tensor.
1910                    if to_rank.is_some() && error.is_none() {
1911                        error = self.get_first_error(&[*tensor])?;
1912                    }
1913                    StreamMessageHandler::handle(self, cx, message).await?;
1914                }
1915                _ => {
1916                    StreamMessageHandler::handle(self, cx, message).await?;
1917                }
1918            };
1919
1920            // It's not entirely trivial to determine whether a message "failed" or not.
1921            // For example, the CallFunction message can return Ok(..) if there is an error
1922            // in the underlying function call. But in that case, we would still want to
1923            // consider the recording call as "failed". Unlike in python, where we can just
1924            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1925            // we clear before handling each recording message. If we see it is set, the
1926            // we know the recording has faild.
1927            match (&error, self.last_seq_error.take()) {
1928                (None, Some(seq_err)) => {
1929                    // Report failure to the controller.
1930                    self.controller_actor
1931                        .remote_function_failed(
1932                            cx,
1933                            seq,
1934                            WorkerError {
1935                                backtrace: format!("recording failed: {}", &seq_err),
1936                                worker_actor_id: cx.self_id().clone(),
1937                            },
1938                        )
1939                        .await?;
1940                    error = Some(seq_err)
1941                }
1942                _ => {}
1943            }
1944            // Continue processing the remaining stream messages regardless of error.
1945            // We need to do this partially for error propagation, but also because
1946            // certain messages (like borrows and reductions) need to run regardless
1947            // in order to prevent deadlocks.
1948        }
1949
1950        // Delete the formal refs and some subset of the RecordingResult refs. The
1951        // controller should have generated DeleteRefs messages for all other refs
1952        // defined by the recording.
1953        StreamMessageHandler::handle(
1954            self,
1955            cx,
1956            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1957        )
1958        .await?;
1959
1960        // Any refs mutated by the recording and all results should have the same error
1961        // (the original error that caused the recording to fail).
1962        if error.is_some() {
1963            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1964                self.env.insert(*ref_, Err(error.clone().unwrap()));
1965            }
1966        }
1967
1968        self.active_recording = None;
1969        Ok(())
1970    }
1971
1972    async fn set_ref_unit_tests_only(
1973        &mut self,
1974        _cx: &Context<Self>,
1975        reference: Ref,
1976        value: WireValue,
1977    ) -> Result<()> {
1978        self.env
1979            .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
1980        Ok(())
1981    }
1982
1983    async fn set_tensor_ref_unit_tests_only(
1984        &mut self,
1985        _cx: &Context<Self>,
1986        reference: Ref,
1987        tensor_result: TensorCellResult,
1988    ) -> Result<()> {
1989        match tensor_result {
1990            Ok(tensor_cell) => {
1991                self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
1992            }
1993            Err(err) => {
1994                self.env.insert(reference, Err(err));
1995            }
1996        }
1997        Ok(())
1998    }
1999
2000    async fn get_ref_unit_tests_only(
2001        &mut self,
2002        _cx: &Context<Self>,
2003        reference: Ref,
2004    ) -> Result<Option<Result<WireValue, String>>> {
2005        /// For testing only, doesn't support Tensor or TensorList.
2006        fn rvalue_to_wire(
2007            value: Result<RValue, Arc<SeqError>>,
2008        ) -> Result<WireValue, Arc<SeqError>> {
2009            Ok(match value? {
2010                RValue::Int(val) => WireValue::Int(val),
2011                RValue::IntList(val) => WireValue::IntList(val),
2012                RValue::Double(val) => WireValue::Double(val),
2013                RValue::Bool(val) => WireValue::Bool(val),
2014                RValue::String(val) => WireValue::String(val),
2015                RValue::Layout(val) => WireValue::Layout(val),
2016                RValue::Device(val) => WireValue::Device(val),
2017                RValue::ScalarType(val) => WireValue::ScalarType(val),
2018                RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2019                RValue::None => WireValue::None(()),
2020                other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2021            })
2022        }
2023        Ok(self
2024            .env
2025            .get(&reference)
2026            .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2027    }
2028
2029    async fn get_tensor_ref_unit_tests_only(
2030        &mut self,
2031        _cx: &Context<Self>,
2032        reference: Ref,
2033    ) -> Result<Option<TensorCellResult>> {
2034        match self.env.get(&reference) {
2035            Some(Ok(rvalue)) => match rvalue {
2036                RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2037                other => bail!("expected tensor, got {:?}", other),
2038            },
2039            Some(Err(err)) => Ok(Some(Err(err.clone()))),
2040            None => Ok(None),
2041        }
2042    }
2043}
2044
2045#[cfg(test)]
2046mod tests {
2047    use hyperactor::actor::ActorStatus;
2048    use hyperactor::context;
2049    use hyperactor::supervision::ActorSupervisionEvent;
2050    use monarch_messages::controller::ControllerMessage;
2051    use monarch_messages::worker::StreamCreationMode;
2052    use monarch_types::PickledPyObject;
2053    use pyo3::IntoPyObjectExt;
2054    use timed_test::async_timed_test;
2055    use torch_sys::factory_float_tensor;
2056    use torch_sys::testing::allclose;
2057    use torch_sys_cuda::nccl::UniqueId;
2058
2059    use super::*;
2060    use crate::comm::CommParams;
2061    use crate::test_util;
2062
2063    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2064        Arc::new(SeqError {
2065            seq: 0.into(),
2066            error: err,
2067        })
2068    }
2069
2070    struct TestSetup {
2071        proc: Proc,
2072        stream_actor: ActorHandle<StreamActor>,
2073        client: Instance<()>,
2074        // Unused, but necessary, because proc needs a supervision
2075        // port -- otherwise an actor failure will cause a crash.
2076        #[allow(dead_code)]
2077        supervision_rx: PortReceiver<ActorSupervisionEvent>,
2078        controller_rx: PortReceiver<ControllerMessage>,
2079        controller_actor: ActorRef<ControllerActor>,
2080        next_ref: Ref,
2081    }
2082
2083    impl TestSetup {
2084        async fn new() -> Result<Self> {
2085            Self::new_with_world_size(1).await
2086        }
2087
2088        async fn new_with_world_size(world_size: usize) -> Result<Self> {
2089            test_util::test_setup()?;
2090
2091            let proc = Proc::local();
2092            let (_, controller_actor, controller_rx) =
2093                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2094            let (client, _handle) = proc.instance("client")?;
2095            let (supervision_tx, supervision_rx) = client.open_port();
2096            proc.set_supervision_coordinator(supervision_tx)?;
2097            let stream_actor = proc
2098                .spawn::<StreamActor>(
2099                    "stream",
2100                    StreamParams {
2101                        world_size,
2102                        rank: 0,
2103                        creation_mode: StreamCreationMode::UseDefaultStream,
2104                        id: 0.into(),
2105                        device: Some(CudaDevice::new(0.into())),
2106                        controller_actor: controller_actor.clone(),
2107                        respond_with_python_message: false,
2108                    },
2109                )
2110                .await?;
2111
2112            Ok(Self {
2113                proc,
2114                stream_actor,
2115                client,
2116                supervision_rx,
2117                controller_rx,
2118                controller_actor,
2119                next_ref: 0.into(),
2120            })
2121        }
2122
2123        fn next_ref(&mut self) -> Ref {
2124            let ref_ = self.next_ref;
2125            self.next_ref = Ref {
2126                id: self.next_ref.id + 1,
2127            };
2128            ref_
2129        }
2130
2131        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2132            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2133            self.stream_actor
2134                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2135                .await
2136        }
2137
2138        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2139            let actual = self
2140                .stream_actor
2141                .get_tensor_ref_unit_tests_only(&self.client, reference)
2142                .await
2143                .unwrap()
2144                .unwrap()
2145                .unwrap();
2146
2147            let result = allclose(
2148                &factory_float_tensor(data, "cpu".try_into().unwrap()),
2149                &actual.borrow(),
2150            )
2151            .unwrap();
2152            // rustfmt-ignore
2153            result
2154        }
2155
2156        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2157            let result_error = self
2158                .stream_actor
2159                .get_tensor_ref_unit_tests_only(&self.client, reference)
2160                .await
2161                .unwrap()
2162                .unwrap()
2163                .unwrap_err();
2164
2165            assert!(Arc::ptr_eq(&result_error, &error));
2166        }
2167    }
2168
2169    async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2170        loop {
2171            let status = proc
2172                .ledger_snapshot()
2173                .roots
2174                .get(actor_id)
2175                .unwrap()
2176                .status
2177                .clone();
2178            if let ActorStatus::Failed(msg) = status {
2179                assert!(msg.to_string().contains(&expected_msg));
2180                break;
2181            } else {
2182                tokio::task::yield_now().await;
2183            }
2184        }
2185    }
2186
2187    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2188        for ref_ in refs {
2189            assert!(
2190                test_setup
2191                    .stream_actor
2192                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2193                    .await
2194                    .unwrap()
2195                    .is_none()
2196            );
2197        }
2198    }
2199
2200    async fn fetch_result(
2201        cx: &impl context::Actor,
2202        stream_actor: ActorHandle<StreamActor>,
2203        seq: Seq,
2204        reference: Ref,
2205    ) {
2206        let ref_to_send = Python::with_gil(|py| {
2207            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2208        });
2209
2210        stream_actor
2211            .send_value(
2212                cx,
2213                seq,
2214                stream_actor.actor_id().clone(),
2215                Vec::new(),
2216                None,
2217                vec![WireValue::PyObject(ref_to_send)],
2218                HashMap::new(),
2219                HashMap::new(),
2220            )
2221            .await
2222            .unwrap()
2223    }
2224
2225    async fn check_fetch_result_error(
2226        cx: &impl context::Actor,
2227        stream_actor: ActorHandle<StreamActor>,
2228        seq: Seq,
2229        reference: Ref,
2230        controller_rx: &mut PortReceiver<ControllerMessage>,
2231        expected_backtrace: &str,
2232    ) {
2233        fetch_result(cx, stream_actor, seq, reference).await;
2234
2235        let controller_msg = controller_rx.recv().await.unwrap();
2236        match controller_msg {
2237            ControllerMessage::FetchResult {
2238                seq: actual_seq,
2239                value: Err(err),
2240            } => {
2241                assert_eq!(actual_seq, seq);
2242                assert!(
2243                    err.backtrace.contains(expected_backtrace),
2244                    "backtrace did not contain {:?}: {:?}",
2245                    expected_backtrace,
2246                    err.backtrace
2247                );
2248            }
2249            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2250        };
2251    }
2252
2253    async fn check_fetch_result_value(
2254        cx: &impl context::Actor,
2255        stream_actor: ActorHandle<StreamActor>,
2256        seq: Seq,
2257        reference: Ref,
2258        controller_rx: &mut PortReceiver<ControllerMessage>,
2259    ) {
2260        fetch_result(cx, stream_actor, seq, reference).await;
2261
2262        let controller_msg = controller_rx.recv().await.unwrap();
2263        match controller_msg {
2264            ControllerMessage::FetchResult {
2265                value: Ok(_),
2266                seq: actual_seq,
2267            } => assert_eq!(seq, actual_seq),
2268            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2269        };
2270    }
2271
2272    #[async_timed_test(timeout_secs = 60)]
2273    async fn test_define_recording_other_recording_active() -> Result<()> {
2274        let test_setup = TestSetup::new().await?;
2275        test_setup
2276            .stream_actor
2277            .define_recording(&test_setup.client, 0.into())
2278            .await?;
2279        test_setup
2280            .stream_actor
2281            .define_recording(&test_setup.client, 1.into())
2282            .await?;
2283        assert_actor_failed_with_msg(
2284            &test_setup.proc,
2285            test_setup.stream_actor.actor_id(),
2286            "different recording already active".into(),
2287        )
2288        .await;
2289        Ok(())
2290    }
2291
2292    #[async_timed_test(timeout_secs = 60)]
2293    async fn test_define_recording_already_defined() -> Result<()> {
2294        let test_setup = TestSetup::new().await?;
2295        test_setup
2296            .stream_actor
2297            .define_recording(&test_setup.client, 0.into())
2298            .await?;
2299        test_setup
2300            .stream_actor
2301            .finalize_recording(&test_setup.client, 0.into())
2302            .await?;
2303        test_setup
2304            .stream_actor
2305            .define_recording(&test_setup.client, 0.into())
2306            .await?;
2307        assert_actor_failed_with_msg(
2308            &test_setup.proc,
2309            test_setup.stream_actor.actor_id(),
2310            "already defined".into(),
2311        )
2312        .await;
2313        Ok(())
2314    }
2315
2316    #[async_timed_test(timeout_secs = 60)]
2317    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2318        let test_setup = TestSetup::new().await?;
2319        test_setup
2320            .stream_actor
2321            .define_recording(&test_setup.client, 0.into())
2322            .await?;
2323        test_setup
2324            .stream_actor
2325            .finalize_recording(&test_setup.client, 1.into())
2326            .await?;
2327        assert_actor_failed_with_msg(
2328            &test_setup.proc,
2329            test_setup.stream_actor.actor_id(),
2330            "cannot finalize recording that isn't active".into(),
2331        )
2332        .await;
2333        Ok(())
2334    }
2335
2336    #[async_timed_test(timeout_secs = 60)]
2337    async fn test_recording_formal_outside_recording() -> Result<()> {
2338        let test_setup = TestSetup::new().await?;
2339        test_setup
2340            .stream_actor
2341            .recording_formal(&test_setup.client, 0.into(), 0)
2342            .await?;
2343        assert_actor_failed_with_msg(
2344            &test_setup.proc,
2345            test_setup.stream_actor.actor_id(),
2346            "recording_formal called outside of recording".into(),
2347        )
2348        .await;
2349        Ok(())
2350    }
2351
2352    #[async_timed_test(timeout_secs = 60)]
2353    async fn test_recording_result_outside_recording() -> Result<()> {
2354        let test_setup = TestSetup::new().await?;
2355        test_setup
2356            .stream_actor
2357            .recording_result(&test_setup.client, 0.into(), 0)
2358            .await?;
2359        assert_actor_failed_with_msg(
2360            &test_setup.proc,
2361            test_setup.stream_actor.actor_id(),
2362            "recording_result called outside of recording".into(),
2363        )
2364        .await;
2365        Ok(())
2366    }
2367
2368    #[async_timed_test(timeout_secs = 60)]
2369    async fn test_call_recording_other_recording_active() -> Result<()> {
2370        let test_setup = TestSetup::new().await?;
2371        test_setup
2372            .stream_actor
2373            .define_recording(&test_setup.client, 0.into())
2374            .await?;
2375        test_setup
2376            .stream_actor
2377            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2378            .await?;
2379        assert_actor_failed_with_msg(
2380            &test_setup.proc,
2381            test_setup.stream_actor.actor_id(),
2382            "cannot call recording while another recording is active".into(),
2383        )
2384        .await;
2385        Ok(())
2386    }
2387
2388    #[async_timed_test(timeout_secs = 60)]
2389    async fn test_call_recording_not_found() -> Result<()> {
2390        let test_setup = TestSetup::new().await?;
2391        test_setup
2392            .stream_actor
2393            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2394            .await?;
2395        assert_actor_failed_with_msg(
2396            &test_setup.proc,
2397            test_setup.stream_actor.actor_id(),
2398            "not found".into(),
2399        )
2400        .await;
2401        Ok(())
2402    }
2403
2404    #[async_timed_test(timeout_secs = 60)]
2405    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2406        let test_setup = TestSetup::new().await?;
2407
2408        test_setup
2409            .stream_actor
2410            .define_recording(&test_setup.client, 0.into())
2411            .await?;
2412
2413        test_setup
2414            .stream_actor
2415            .recording_formal(&test_setup.client, 1.into(), 0)
2416            .await?;
2417
2418        test_setup
2419            .stream_actor
2420            .finalize_recording(&test_setup.client, 0.into())
2421            .await?;
2422
2423        test_setup
2424            .stream_actor
2425            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2426            .await?;
2427
2428        assert_actor_failed_with_msg(
2429            &test_setup.proc,
2430            test_setup.stream_actor.actor_id(),
2431            "recording_formal called with too few arguments".into(),
2432        )
2433        .await;
2434        Ok(())
2435    }
2436
2437    #[async_timed_test(timeout_secs = 60)]
2438    async fn test_recording_result_too_few_results() -> Result<()> {
2439        let test_setup = TestSetup::new().await?;
2440
2441        test_setup
2442            .stream_actor
2443            .define_recording(&test_setup.client, 0.into())
2444            .await?;
2445
2446        test_setup
2447            .stream_actor
2448            .recording_result(&test_setup.client, 1.into(), 0)
2449            .await?;
2450
2451        test_setup
2452            .stream_actor
2453            .finalize_recording(&test_setup.client, 0.into())
2454            .await?;
2455
2456        test_setup
2457            .stream_actor
2458            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2459            .await?;
2460
2461        assert_actor_failed_with_msg(
2462            &test_setup.proc,
2463            test_setup.stream_actor.actor_id(),
2464            "recording_result called with too few results".into(),
2465        )
2466        .await;
2467        Ok(())
2468    }
2469
2470    #[async_timed_test(timeout_secs = 60)]
2471    async fn test_basic_call_recording() -> Result<()> {
2472        let mut test_setup = TestSetup::new().await?;
2473
2474        // Define a recording equivalent to:
2475        // def f(x, y):
2476        //   return y, x
2477        test_setup
2478            .stream_actor
2479            .define_recording(&test_setup.client, 0.into())
2480            .await?;
2481
2482        let formal0_ref = 1.into();
2483        let formal0_index = 1;
2484        test_setup
2485            .stream_actor
2486            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2487            .await?;
2488
2489        let formal1_ref = 2.into();
2490        let formal1_index = 0;
2491        test_setup
2492            .stream_actor
2493            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2494            .await?;
2495
2496        let result0_ref = formal0_ref;
2497        let result0_index = 0;
2498        test_setup
2499            .stream_actor
2500            .recording_result(&test_setup.client, result0_ref, result0_index)
2501            .await?;
2502
2503        let result1_ref = formal1_ref;
2504        let result1_index = 1;
2505        test_setup
2506            .stream_actor
2507            .recording_result(&test_setup.client, result1_ref, result1_index)
2508            .await?;
2509
2510        test_setup
2511            .stream_actor
2512            .finalize_recording(&test_setup.client, 0.into())
2513            .await?;
2514
2515        let actual0_ref = 3.into();
2516        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2517
2518        let actual1_ref = 4.into();
2519        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2520
2521        // Call the recording with valid tensors for the actual inputs,
2522        // and store the results in refs 5 and 6.
2523        let actual_result0_ref = 5.into();
2524        let actual_result1_ref = 6.into();
2525        test_setup
2526            .stream_actor
2527            .call_recording(
2528                &test_setup.client,
2529                0.into(),
2530                0.into(),
2531                vec![actual_result0_ref, actual_result1_ref],
2532                vec![actual0_ref, actual1_ref],
2533            )
2534            .await?;
2535
2536        // Ensure the results are correct.
2537        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2538        assert!(
2539            test_setup
2540                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2541                .await
2542        );
2543
2544        // Ensure the temporary refs associated with the formals/results have
2545        // been deleted.
2546        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2547        Ok(())
2548    }
2549
2550    #[async_timed_test(timeout_secs = 60)]
2551    async fn test_request_status_in_recording() -> Result<()> {
2552        let test_setup = TestSetup::new().await?;
2553        test_setup
2554            .stream_actor
2555            .define_recording(&test_setup.client, 0.into())
2556            .await?;
2557        test_setup
2558            .stream_actor
2559            .request_status(&test_setup.client)
2560            .await
2561            .expect_err("request_status should have failed");
2562        assert_actor_failed_with_msg(
2563            &test_setup.proc,
2564            test_setup.stream_actor.actor_id(),
2565            "request_status not allowed in recording".into(),
2566        )
2567        .await;
2568        Ok(())
2569    }
2570
2571    #[async_timed_test(timeout_secs = 60)]
2572    async fn test_init_comm_in_recording() -> Result<()> {
2573        let test_setup = TestSetup::new().await?;
2574        test_setup
2575            .stream_actor
2576            .define_recording(&test_setup.client, 0.into())
2577            .await?;
2578
2579        let dummy_comm = test_setup
2580            .proc
2581            .spawn::<NcclCommActor>(
2582                "comm",
2583                CommParams::New {
2584                    device: CudaDevice::new(0.into()),
2585                    unique_id: UniqueId::new()?,
2586                    world_size: 1,
2587                    rank: 0,
2588                },
2589            )
2590            .await?;
2591
2592        test_setup
2593            .stream_actor
2594            .init_comm(&test_setup.client, dummy_comm)
2595            .await?;
2596        assert_actor_failed_with_msg(
2597            &test_setup.proc,
2598            test_setup.stream_actor.actor_id(),
2599            "init_comm not allowed in recording".into(),
2600        )
2601        .await;
2602        Ok(())
2603    }
2604
2605    #[async_timed_test(timeout_secs = 60)]
2606    async fn test_call_function_in_recording() -> Result<()> {
2607        let mut test_setup = TestSetup::new().await?;
2608
2609        // Define a recording equivalent to:
2610        // def f(x, y):
2611        //   w = x + y
2612        //   nonlocal z
2613        //   z.add_(1.0)
2614        //   return w + z
2615        test_setup
2616            .stream_actor
2617            .define_recording(&test_setup.client, 0.into())
2618            .await?;
2619
2620        let formal0_ref = test_setup.next_ref();
2621        let formal0_index = 0;
2622        test_setup
2623            .stream_actor
2624            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2625            .await?;
2626
2627        let formal1_ref = test_setup.next_ref();
2628        let formal1_index = 1;
2629        test_setup
2630            .stream_actor
2631            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2632            .await?;
2633
2634        let captured_ref = test_setup.next_ref();
2635        let result_captured_ref = test_setup.next_ref();
2636        let add_one_function =
2637            ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2638        let add_tensors_function =
2639            ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2640
2641        let add_result_ref_0 = test_setup.next_ref();
2642        test_setup
2643            .stream_actor
2644            .call_function(
2645                &test_setup.client,
2646                CallFunctionParams {
2647                    seq: 100.into(),
2648                    function: add_tensors_function.clone(),
2649                    args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2650                    kwargs: HashMap::new(),
2651                    results: vec![Some(add_result_ref_0)],
2652                    mutates: vec![],
2653                    stream: 0.into(),
2654                    remote_process_groups: Vec::new(),
2655                },
2656                HashMap::new(),
2657                HashMap::new(),
2658            )
2659            .await?;
2660
2661        test_setup
2662            .stream_actor
2663            .call_function(
2664                &test_setup.client,
2665                CallFunctionParams {
2666                    seq: 101.into(),
2667                    function: add_one_function,
2668                    args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2669                    kwargs: HashMap::new(),
2670                    results: vec![Some(result_captured_ref)],
2671                    mutates: vec![captured_ref],
2672                    stream: 0.into(),
2673                    remote_process_groups: Vec::new(),
2674                },
2675                HashMap::new(),
2676                HashMap::new(),
2677            )
2678            .await?;
2679
2680        let add_result_ref_1 = test_setup.next_ref();
2681        test_setup
2682            .stream_actor
2683            .call_function(
2684                &test_setup.client,
2685                CallFunctionParams {
2686                    seq: 102.into(),
2687                    function: add_tensors_function,
2688                    args: vec![
2689                        WireValue::Ref(add_result_ref_0),
2690                        WireValue::Ref(captured_ref),
2691                    ],
2692                    kwargs: HashMap::new(),
2693                    results: vec![Some(add_result_ref_1)],
2694                    mutates: vec![],
2695                    stream: 0.into(),
2696                    remote_process_groups: Vec::new(),
2697                },
2698                HashMap::new(),
2699                HashMap::new(),
2700            )
2701            .await?;
2702
2703        test_setup
2704            .stream_actor
2705            .recording_result(&test_setup.client, add_result_ref_1, 0)
2706            .await?;
2707
2708        test_setup
2709            .stream_actor
2710            .delete_refs(
2711                &test_setup.client,
2712                vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2713            )
2714            .await?;
2715
2716        test_setup
2717            .stream_actor
2718            .finalize_recording(&test_setup.client, 0.into())
2719            .await?;
2720
2721        let actual0_ref = test_setup.next_ref();
2722        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2723
2724        let actual1_ref = test_setup.next_ref();
2725        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2726
2727        test_setup
2728            .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2729            .await?;
2730
2731        let actual_result_ref = test_setup.next_ref();
2732        test_setup
2733            .stream_actor
2734            .call_recording(
2735                &test_setup.client,
2736                0.into(),
2737                0.into(),
2738                vec![actual_result_ref],
2739                vec![actual0_ref, actual1_ref],
2740            )
2741            .await?;
2742
2743        assert!(
2744            test_setup
2745                .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2746                .await
2747        );
2748
2749        // Set actual1_tensor to a bad shape which will cause the recording to fail.
2750        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2751
2752        let actual_result_ref = test_setup.next_ref();
2753        test_setup
2754            .stream_actor
2755            .call_recording(
2756                &test_setup.client,
2757                1.into(),
2758                0.into(),
2759                vec![actual_result_ref],
2760                vec![actual0_ref, actual1_ref],
2761            )
2762            .await?;
2763
2764        // Both inputs should still be valid.
2765        for ref_ in [actual0_ref, actual1_ref] {
2766            let _ = test_setup
2767                .stream_actor
2768                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2769                .await?
2770                .unwrap()
2771                .unwrap();
2772        }
2773
2774        for ref_ in [captured_ref, actual_result_ref] {
2775            let result_error = test_setup
2776                .stream_actor
2777                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2778                .await?
2779                .unwrap()
2780                .unwrap_err();
2781            // Check that the error contains the expected strings
2782            let error_str = result_error.to_string();
2783            assert!(
2784                error_str.contains("torch operator error"),
2785                "Error should contain 'torch operator failed': {}",
2786                error_str
2787            );
2788        }
2789
2790        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2791        match controller_msg {
2792            ControllerMessage::RemoteFunctionFailed { seq, error } => {
2793                assert_eq!(seq, 1.into());
2794                assert!(
2795                    error.backtrace.contains("torch operator error"),
2796                    "Unexpected WorkerError: {:?}",
2797                    error
2798                );
2799            }
2800            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2801        };
2802
2803        // Reset input tensor to a valid shape.
2804        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2805
2806        // captured_tensor should still have an error, so calling
2807        // the recording should set DependentErrors and not report
2808        // anything to the controller.
2809        let actual_result_ref = test_setup.next_ref();
2810        test_setup
2811            .stream_actor
2812            .call_recording(
2813                &test_setup.client,
2814                2.into(),
2815                0.into(),
2816                vec![actual_result_ref],
2817                vec![actual0_ref, actual1_ref],
2818            )
2819            .await?;
2820
2821        // Both inputs should still be valid.
2822        for ref_ in [actual0_ref, actual1_ref] {
2823            let _ = test_setup
2824                .stream_actor
2825                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2826                .await?
2827                .unwrap()
2828                .unwrap();
2829        }
2830
2831        for ref_ in [captured_ref, actual_result_ref] {
2832            let result_error = test_setup
2833                .stream_actor
2834                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2835                .await?
2836                .unwrap()
2837                .unwrap_err();
2838            // Check that the error contains the expected strings
2839            let error_str = result_error.to_string();
2840            assert!(
2841                error_str.contains("torch operator error"),
2842                "Error should contain input error: {}",
2843                error_str
2844            );
2845        }
2846
2847        // This tests that the DependentError was never reported to the controller.
2848        // If it were reported to the controller, the next message would match
2849        // RemoteFunctionFailed instead of FetchResult.
2850        check_fetch_result_error(
2851            &test_setup.client,
2852            test_setup.stream_actor.clone(),
2853            3.into(),
2854            captured_ref,
2855            &mut test_setup.controller_rx,
2856            "torch operator error",
2857        )
2858        .await;
2859
2860        Ok(())
2861    }
2862
2863    #[async_timed_test(timeout_secs = 60)]
2864    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2865        let mut test_setup = TestSetup::new().await?;
2866        test_setup
2867            .stream_actor
2868            .define_recording(&test_setup.client, 0.into())
2869            .await?;
2870
2871        let borrow_id = 1;
2872        let tensor_ref = test_setup.next_ref();
2873        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2874
2875        test_setup
2876            .stream_actor
2877            .borrow_create(
2878                &test_setup.client,
2879                borrow_id,
2880                tensor_ref,
2881                first_use_sender.clone(),
2882            )
2883            .await?;
2884
2885        test_setup
2886            .stream_actor
2887            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2888            .await?;
2889
2890        assert_actor_failed_with_msg(
2891            &test_setup.proc,
2892            test_setup.stream_actor.actor_id(),
2893            "duplicate borrow create in recording".into(),
2894        )
2895        .await;
2896
2897        Ok(())
2898    }
2899
2900    #[async_timed_test(timeout_secs = 60)]
2901    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2902        let test_setup = TestSetup::new().await?;
2903        test_setup
2904            .stream_actor
2905            .define_recording(&test_setup.client, 0.into())
2906            .await?;
2907
2908        let borrow_id = 1;
2909        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2910
2911        test_setup
2912            .stream_actor
2913            .borrow_drop(
2914                &test_setup.client,
2915                borrow_id,
2916                Arc::new(Mutex::new(last_use_receiver)),
2917            )
2918            .await?;
2919
2920        assert_actor_failed_with_msg(
2921            &test_setup.proc,
2922            test_setup.stream_actor.actor_id(),
2923            "borrow drop for borrow not defined in recording".into(),
2924        )
2925        .await;
2926
2927        Ok(())
2928    }
2929
2930    #[async_timed_test(timeout_secs = 60)]
2931    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2932        let mut test_setup = TestSetup::new().await?;
2933        test_setup
2934            .stream_actor
2935            .define_recording(&test_setup.client, 0.into())
2936            .await?;
2937
2938        let borrow_id = 1;
2939        let tensor_ref = test_setup.next_ref();
2940        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2941
2942        test_setup
2943            .stream_actor
2944            .borrow_create(
2945                &test_setup.client,
2946                borrow_id,
2947                tensor_ref,
2948                first_use_sender.clone(),
2949            )
2950            .await?;
2951
2952        // Attempt to finalize the recording without dropping the borrow
2953        test_setup
2954            .stream_actor
2955            .finalize_recording(&test_setup.client, 0.into())
2956            .await?;
2957
2958        assert_actor_failed_with_msg(
2959            &test_setup.proc,
2960            test_setup.stream_actor.actor_id(),
2961            "all borrows created within recording must be dropped within recording".into(),
2962        )
2963        .await;
2964
2965        Ok(())
2966    }
2967
2968    #[async_timed_test(timeout_secs = 60)]
2969    async fn test_borrow_in_recording() -> Result<()> {
2970        let mut test_setup = TestSetup::new().await?;
2971
2972        let borrower_stream = test_setup
2973            .proc
2974            .spawn::<StreamActor>(
2975                "stream1",
2976                StreamParams {
2977                    world_size: 1,
2978                    rank: 0,
2979                    creation_mode: StreamCreationMode::CreateNewStream,
2980                    id: 1.into(),
2981                    device: Some(CudaDevice::new(0.into())),
2982                    controller_actor: test_setup.controller_actor.clone(),
2983                    respond_with_python_message: false,
2984                },
2985            )
2986            .await?;
2987
2988        let lender_stream = test_setup.stream_actor.clone();
2989
2990        let borrow_id = 1;
2991        let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
2992        let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
2993
2994        // Stream 1: Define a recording that creates a borrow and drops it.
2995        lender_stream
2996            .define_recording(&test_setup.client, 0.into())
2997            .await?;
2998
2999        let formal_ref = test_setup.next_ref();
3000        lender_stream
3001            .recording_formal(&test_setup.client, formal_ref, 0)
3002            .await?;
3003
3004        lender_stream
3005            .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3006            .await?;
3007
3008        lender_stream
3009            .borrow_drop(
3010                &test_setup.client,
3011                borrow_id,
3012                Arc::new(Mutex::new(last_use_receiver)),
3013            )
3014            .await?;
3015
3016        lender_stream
3017            .finalize_recording(&test_setup.client, 0.into())
3018            .await?;
3019
3020        let borrower_tensor_ref = test_setup.next_ref();
3021        let borrower_tensor = TensorCell::new(factory_float_tensor(
3022            &[1.0, 2.0, 3.0],
3023            "cuda".try_into().unwrap(),
3024        ));
3025
3026        borrower_stream
3027            .set_tensor_ref_unit_tests_only(
3028                &test_setup.client,
3029                borrower_tensor_ref,
3030                Ok(borrower_tensor.clone()),
3031            )
3032            .await?;
3033
3034        // Stream 2: Define a recording that uses the borrow from Stream 1.
3035        borrower_stream
3036            .define_recording(&test_setup.client, 0.into())
3037            .await?;
3038
3039        let borrowed_ref = test_setup.next_ref();
3040
3041        borrower_stream
3042            .borrow_first_use(
3043                &test_setup.client,
3044                borrow_id,
3045                borrowed_ref,
3046                Arc::new(Mutex::new(first_use_receiver)),
3047            )
3048            .await?;
3049
3050        let result_ref = test_setup.next_ref();
3051        borrower_stream
3052            .call_function(
3053                &test_setup.client,
3054                CallFunctionParams {
3055                    seq: 100.into(),
3056                    function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3057                    args: vec![
3058                        WireValue::Ref(borrowed_ref),
3059                        WireValue::Ref(borrower_tensor_ref),
3060                    ],
3061                    kwargs: HashMap::new(),
3062                    results: vec![Some(result_ref)],
3063                    mutates: vec![],
3064                    stream: 1.into(),
3065                    remote_process_groups: Vec::new(),
3066                },
3067                HashMap::new(),
3068                HashMap::new(),
3069            )
3070            .await?;
3071
3072        borrower_stream
3073            .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3074            .await?;
3075
3076        borrower_stream
3077            .recording_result(&test_setup.client, result_ref, 0)
3078            .await?;
3079
3080        borrower_stream
3081            .finalize_recording(&test_setup.client, 0.into())
3082            .await?;
3083
3084        // Set up a tensor in the lender stream and call the recording.
3085        let input_tensor_ref = test_setup.next_ref();
3086        test_setup
3087            .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3088            .await?;
3089
3090        let result_tensor_ref = test_setup.next_ref();
3091
3092        let lender_future = lender_stream.call_recording(
3093            &test_setup.client,
3094            0.into(),
3095            0.into(),
3096            vec![],
3097            vec![input_tensor_ref],
3098        );
3099
3100        let borrower_future = borrower_stream.call_recording(
3101            &test_setup.client,
3102            0.into(),
3103            0.into(),
3104            vec![result_tensor_ref],
3105            vec![],
3106        );
3107
3108        tokio::try_join!(lender_future, borrower_future)?;
3109
3110        let result_tensor = borrower_stream
3111            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3112            .await?
3113            .unwrap()
3114            .unwrap();
3115
3116        let expected_tensor = TensorCell::new(factory_float_tensor(
3117            &[5.0, 7.0, 9.0],
3118            "cpu".try_into().unwrap(),
3119        ));
3120        assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3121
3122        // Set borrower_tensor to a tensor with only 2 elements to cause a failure.
3123        let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3124            &[1.0, 2.0],
3125            "cuda".try_into().unwrap(),
3126        ));
3127        borrower_stream
3128            .set_tensor_ref_unit_tests_only(
3129                &test_setup.client,
3130                borrower_tensor_ref,
3131                Ok(invalid_borrower_tensor.clone()),
3132            )
3133            .await?;
3134
3135        // Call the recording again.
3136        let lender_future = lender_stream.call_recording(
3137            &test_setup.client,
3138            1.into(),
3139            0.into(),
3140            vec![],
3141            vec![input_tensor_ref],
3142        );
3143
3144        let borrower_future = borrower_stream.call_recording(
3145            &test_setup.client,
3146            1.into(),
3147            0.into(),
3148            vec![result_tensor_ref],
3149            vec![],
3150        );
3151
3152        tokio::try_join!(lender_future, borrower_future)?;
3153
3154        // Check that the borrower_stream reports the error to the controller.
3155        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3156        match controller_msg {
3157            ControllerMessage::RemoteFunctionFailed { seq, error } => {
3158                assert_eq!(seq, 1.into());
3159                assert!(
3160                    error.backtrace.contains("recording failed"),
3161                    "Unexpected WorkerError: {:?}",
3162                    error
3163                );
3164                assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3165            }
3166            _ => panic!("Unexpected controller message: {:?}", controller_msg),
3167        };
3168
3169        // Check that no error was reported from the lender stream
3170        check_fetch_result_value(
3171            &test_setup.client,
3172            lender_stream.clone(),
3173            2.into(),
3174            input_tensor_ref,
3175            &mut test_setup.controller_rx,
3176        )
3177        .await;
3178
3179        // Set the recording's input tensor to an error.
3180        let input_error = fake_seq_error(anyhow!("input error"));
3181        lender_stream
3182            .set_tensor_ref_unit_tests_only(
3183                &test_setup.client,
3184                input_tensor_ref,
3185                Err(input_error.clone()),
3186            )
3187            .await?;
3188
3189        let lender_future = lender_stream.call_recording(
3190            &test_setup.client,
3191            3.into(),
3192            0.into(),
3193            vec![],
3194            vec![input_tensor_ref],
3195        );
3196
3197        let borrower_future = borrower_stream.call_recording(
3198            &test_setup.client,
3199            3.into(),
3200            0.into(),
3201            vec![result_tensor_ref],
3202            vec![],
3203        );
3204
3205        tokio::try_join!(lender_future, borrower_future)?;
3206
3207        // Verify that borrower_stream sets a CallFunctionError::DependentError on result_tensor_ref.
3208        let result_error = borrower_stream
3209            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3210            .await?
3211            .unwrap()
3212            .unwrap_err();
3213
3214        // Check that the error contains the expected strings
3215        let error_str = result_error.to_string();
3216        assert!(
3217            error_str.contains("input error"),
3218            "Error should contain input error: {}",
3219            error_str
3220        );
3221
3222        // Since we're checking for pointer equality in the original code, we need to ensure
3223        // the error is propagated correctly. We can check that the original error message is contained.
3224        let input_error_str = input_error.to_string();
3225        assert!(
3226            error_str.contains(&input_error_str),
3227            "Error should contain the original error: {}",
3228            error_str
3229        );
3230
3231        // Verify that neither stream sends a failure message to the controller.
3232        check_fetch_result_error(
3233            &test_setup.client,
3234            lender_stream,
3235            4.into(),
3236            input_tensor_ref,
3237            &mut test_setup.controller_rx,
3238            "input error",
3239        )
3240        .await;
3241
3242        // Verify that neither stream sends a failure message to the controller.
3243        check_fetch_result_error(
3244            &test_setup.client,
3245            borrower_stream,
3246            5.into(),
3247            result_tensor_ref,
3248            &mut test_setup.controller_rx,
3249            "input error",
3250        )
3251        .await;
3252
3253        Ok(())
3254    }
3255
3256    #[async_timed_test(timeout_secs = 60)]
3257    async fn test_reduce_in_recording() -> Result<()> {
3258        let mut test_setup = TestSetup::new().await?;
3259        let recording_ref = test_setup.next_ref();
3260
3261        let comm = Arc::new(
3262            test_setup
3263                .proc
3264                .spawn::<NcclCommActor>(
3265                    "comm",
3266                    CommParams::New {
3267                        device: CudaDevice::new(0.into()),
3268                        unique_id: UniqueId::new()?,
3269                        world_size: 1,
3270                        rank: 0,
3271                    },
3272                )
3273                .await?,
3274        );
3275
3276        let factory = Factory {
3277            size: vec![3],
3278            dtype: torch_sys::ScalarType::Float,
3279            layout: torch_sys::Layout::Strided,
3280            device: "cuda".try_into().unwrap(),
3281        };
3282
3283        let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3284
3285        test_setup
3286            .stream_actor
3287            .define_recording(&test_setup.client, recording_ref)
3288            .await?;
3289
3290        let formal_tensor_ref_0 = test_setup.next_ref();
3291        let formal_tensor_ref_1 = test_setup.next_ref();
3292        let formal_tensor_ref_2 = test_setup.next_ref();
3293
3294        test_setup
3295            .stream_actor
3296            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3297            .await?;
3298        test_setup
3299            .stream_actor
3300            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3301            .await?;
3302        test_setup
3303            .stream_actor
3304            .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3305            .await?;
3306
3307        let intermediate_tensor_ref_0 = test_setup.next_ref();
3308
3309        // Handle case with in_place = true.
3310        test_setup
3311            .stream_actor
3312            .reduce(
3313                &test_setup.client,
3314                comm.clone(),
3315                1,
3316                intermediate_tensor_ref_0,
3317                formal_tensor_ref_0,
3318                factory.clone(),
3319                reduction.clone(),
3320                false,
3321                true,
3322                None,
3323            )
3324            .await?;
3325
3326        // Handle case with in_place = false and out = None.
3327        let intermediate_tensor_ref_1 = test_setup.next_ref();
3328        test_setup
3329            .stream_actor
3330            .reduce(
3331                &test_setup.client,
3332                comm.clone(),
3333                1,
3334                intermediate_tensor_ref_1,
3335                formal_tensor_ref_1,
3336                factory.clone(),
3337                reduction.clone(),
3338                false,
3339                false,
3340                None,
3341            )
3342            .await?;
3343
3344        let intermediate_tensor_ref_2 = test_setup.next_ref();
3345
3346        // Third reduce call with out = formal_tensor_ref_2
3347        test_setup
3348            .stream_actor
3349            .reduce(
3350                &test_setup.client,
3351                comm.clone(),
3352                1,
3353                intermediate_tensor_ref_2,
3354                intermediate_tensor_ref_1,
3355                factory.clone(),
3356                reduction.clone(),
3357                false,
3358                false,
3359                Some(formal_tensor_ref_2),
3360            )
3361            .await?;
3362
3363        test_setup
3364            .stream_actor
3365            .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3366            .await?;
3367
3368        test_setup
3369            .stream_actor
3370            .finalize_recording(&test_setup.client, recording_ref)
3371            .await?;
3372
3373        let input_tensor_ref_0 = test_setup.next_ref();
3374        let input_tensor_ref_1 = test_setup.next_ref();
3375        let input_tensor_ref_2 = test_setup.next_ref();
3376
3377        test_setup
3378            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3379            .await?;
3380
3381        test_setup
3382            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3383            .await?;
3384
3385        test_setup
3386            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3387            .await?;
3388
3389        let output_ref = test_setup.next_ref();
3390
3391        test_setup
3392            .stream_actor
3393            .call_recording(
3394                &test_setup.client,
3395                0.into(),
3396                recording_ref,
3397                vec![output_ref],
3398                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3399            )
3400            .await?;
3401
3402        // Validate that input_tensor_ref_0 is unchanged.
3403        assert!(
3404            test_setup
3405                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3406                .await
3407        );
3408        // All the other inputs/outputs should be equal to input 1
3409        for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3410            assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3411        }
3412
3413        // Set an error on input 0
3414        let input_error = fake_seq_error(anyhow!("input error"));
3415        test_setup
3416            .stream_actor
3417            .set_tensor_ref_unit_tests_only(
3418                &test_setup.client,
3419                input_tensor_ref_0,
3420                Err(input_error.clone()),
3421            )
3422            .await?;
3423
3424        test_setup
3425            .stream_actor
3426            .call_recording(
3427                &test_setup.client,
3428                1.into(),
3429                recording_ref,
3430                vec![output_ref],
3431                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3432            )
3433            .await?;
3434
3435        // Verify that input_tensor_ref_0, input_tensor_ref_2, and output_ref have a dependent error.
3436        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3437            test_setup
3438                .validate_dependent_error(ref_, input_error.clone())
3439                .await;
3440        }
3441
3442        // Verify that input_tensor_ref_1 is untouched.
3443        assert!(
3444            test_setup
3445                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3446                .await
3447        );
3448
3449        // Verify that no failure was reported to the controller.
3450        check_fetch_result_value(
3451            &test_setup.client,
3452            test_setup.stream_actor.clone(),
3453            2.into(),
3454            input_tensor_ref_1,
3455            &mut test_setup.controller_rx,
3456        )
3457        .await;
3458
3459        // Reset input tensors 0 and 2 to their original values
3460        test_setup
3461            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3462            .await?;
3463        test_setup
3464            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3465            .await?;
3466
3467        // Set an error on input tensor 1
3468        test_setup
3469            .stream_actor
3470            .set_tensor_ref_unit_tests_only(
3471                &test_setup.client,
3472                input_tensor_ref_1,
3473                Err(input_error.clone()),
3474            )
3475            .await?;
3476
3477        test_setup
3478            .stream_actor
3479            .call_recording(
3480                &test_setup.client,
3481                3.into(),
3482                recording_ref,
3483                vec![output_ref],
3484                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3485            )
3486            .await?;
3487
3488        // Validate that the mutated inputs and the output have a dependent error containing
3489        // the input error
3490        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3491            test_setup
3492                .validate_dependent_error(ref_, input_error.clone())
3493                .await;
3494        }
3495
3496        // Validate that no error was reported to the controller
3497        check_fetch_result_error(
3498            &test_setup.client,
3499            test_setup.stream_actor.clone(),
3500            4.into(),
3501            input_tensor_ref_1,
3502            &mut test_setup.controller_rx,
3503            "input error",
3504        )
3505        .await;
3506
3507        // Reset input tensors 0 and 1 to their original values
3508        test_setup
3509            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3510            .await?;
3511        test_setup
3512            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3513            .await?;
3514
3515        // Set an error on input tensor 2
3516        test_setup
3517            .stream_actor
3518            .set_tensor_ref_unit_tests_only(
3519                &test_setup.client,
3520                input_tensor_ref_2,
3521                Err(input_error.clone()),
3522            )
3523            .await?;
3524
3525        test_setup
3526            .stream_actor
3527            .call_recording(
3528                &test_setup.client,
3529                5.into(),
3530                recording_ref,
3531                vec![output_ref],
3532                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3533            )
3534            .await?;
3535
3536        // Validate that input tensor 1 has its original values
3537        assert!(
3538            test_setup
3539                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3540                .await
3541        );
3542
3543        // Validate that the mutated inputs and the output have a dependent error containing
3544        // the input error
3545        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3546            test_setup
3547                .validate_dependent_error(ref_, input_error.clone())
3548                .await;
3549        }
3550
3551        // Validate that no error was reported to the controller
3552        check_fetch_result_value(
3553            &test_setup.client,
3554            test_setup.stream_actor.clone(),
3555            6.into(),
3556            input_tensor_ref_1,
3557            &mut test_setup.controller_rx,
3558        )
3559        .await;
3560
3561        Ok(())
3562    }
3563
3564    #[async_timed_test(timeout_secs = 60)]
3565    async fn test_send_tensor_in_recording() -> Result<()> {
3566        let mut test_setup = TestSetup::new_with_world_size(2).await?;
3567        let recording_ref = test_setup.next_ref();
3568
3569        let unique_id = UniqueId::new()?;
3570        let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3571            "comm0",
3572            CommParams::New {
3573                device: CudaDevice::new(0.into()),
3574                unique_id: unique_id.clone(),
3575                world_size: 2,
3576                rank: 0,
3577            },
3578        );
3579        let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3580            "comm1",
3581            CommParams::New {
3582                device: CudaDevice::new(1.into()),
3583                unique_id,
3584                world_size: 2,
3585                rank: 1,
3586            },
3587        );
3588        let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3589        let comm0 = Arc::new(comm0);
3590        let comm1 = Arc::new(comm1);
3591
3592        let factory = Factory {
3593            size: vec![3],
3594            dtype: torch_sys::ScalarType::Float,
3595            layout: torch_sys::Layout::Strided,
3596            device: "cuda".try_into().unwrap(),
3597        };
3598
3599        let send_stream = test_setup.stream_actor.clone();
3600        let recv_stream = test_setup
3601            .proc
3602            .spawn::<StreamActor>(
3603                "recv_stream",
3604                StreamParams {
3605                    world_size: 2,
3606                    rank: 1,
3607                    creation_mode: StreamCreationMode::CreateNewStream,
3608                    id: 1.into(),
3609                    device: Some(CudaDevice::new(1.into())),
3610                    controller_actor: test_setup.controller_actor.clone(),
3611                    respond_with_python_message: false,
3612                },
3613            )
3614            .await?;
3615
3616        send_stream
3617            .define_recording(&test_setup.client, recording_ref)
3618            .await?;
3619        recv_stream
3620            .define_recording(&test_setup.client, recording_ref)
3621            .await?;
3622
3623        let formal_tensor_ref_0 = test_setup.next_ref();
3624        let formal_tensor_ref_1 = test_setup.next_ref();
3625
3626        send_stream
3627            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3628            .await?;
3629        send_stream
3630            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3631            .await?;
3632
3633        let _ref = test_setup.next_ref();
3634        send_stream
3635            .send_tensor(
3636                &test_setup.client,
3637                _ref,
3638                None,
3639                Some(1),
3640                formal_tensor_ref_0,
3641                factory.clone(),
3642                comm0.clone(),
3643            )
3644            .await?;
3645
3646        let result_ref_0 = test_setup.next_ref();
3647        let _ref = test_setup.next_ref();
3648        recv_stream
3649            .send_tensor(
3650                &test_setup.client,
3651                result_ref_0,
3652                Some(0),
3653                None,
3654                _ref,
3655                factory.clone(),
3656                comm1,
3657            )
3658            .await?;
3659
3660        let result_ref_1 = test_setup.next_ref();
3661        send_stream
3662            .send_tensor(
3663                &test_setup.client,
3664                result_ref_1,
3665                Some(0),
3666                Some(0),
3667                formal_tensor_ref_1,
3668                factory.clone(),
3669                comm0,
3670            )
3671            .await?;
3672
3673        send_stream
3674            .recording_result(&test_setup.client, result_ref_1, 0)
3675            .await?;
3676        recv_stream
3677            .recording_result(&test_setup.client, result_ref_0, 0)
3678            .await?;
3679
3680        send_stream
3681            .finalize_recording(&test_setup.client, recording_ref)
3682            .await?;
3683        recv_stream
3684            .finalize_recording(&test_setup.client, recording_ref)
3685            .await?;
3686
3687        let input_tensor_ref_0 = test_setup.next_ref();
3688        let input_tensor_ref_1 = test_setup.next_ref();
3689        test_setup
3690            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3691            .await?;
3692        test_setup
3693            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3694            .await?;
3695
3696        let actual_result_ref_0 = test_setup.next_ref();
3697        let actual_result_ref_1 = test_setup.next_ref();
3698        let send_fut = send_stream.call_recording(
3699            &test_setup.client,
3700            0.into(),
3701            recording_ref,
3702            vec![actual_result_ref_1],
3703            vec![input_tensor_ref_0, input_tensor_ref_1],
3704        );
3705        let recv_fut = recv_stream.call_recording(
3706            &test_setup.client,
3707            0.into(),
3708            recording_ref,
3709            vec![actual_result_ref_0],
3710            vec![],
3711        );
3712        tokio::try_join!(send_fut, recv_fut)?;
3713
3714        assert!(
3715            test_setup
3716                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3717                .await
3718        );
3719        assert!(
3720            test_setup
3721                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3722                .await
3723        );
3724        assert!(
3725            test_setup
3726                .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3727                .await
3728        );
3729
3730        let actual_result_0 = recv_stream
3731            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3732            .await
3733            .unwrap()
3734            .unwrap()
3735            .unwrap();
3736        assert!(allclose(
3737            &actual_result_0.borrow(),
3738            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3739        )?);
3740
3741        // Validate that failure wasn't reported to controller.
3742        check_fetch_result_value(
3743            &test_setup.client,
3744            send_stream.clone(),
3745            1.into(),
3746            actual_result_ref_1,
3747            &mut test_setup.controller_rx,
3748        )
3749        .await;
3750        check_fetch_result_value(
3751            &test_setup.client,
3752            recv_stream.clone(),
3753            2.into(),
3754            actual_result_ref_0,
3755            &mut test_setup.controller_rx,
3756        )
3757        .await;
3758
3759        let input_error = fake_seq_error(anyhow!("input error"));
3760        send_stream
3761            .set_tensor_ref_unit_tests_only(
3762                &test_setup.client,
3763                input_tensor_ref_0,
3764                Err(input_error.clone()),
3765            )
3766            .await?;
3767
3768        let send_fut = send_stream.call_recording(
3769            &test_setup.client,
3770            3.into(),
3771            recording_ref,
3772            vec![actual_result_ref_1],
3773            vec![input_tensor_ref_0, input_tensor_ref_1],
3774        );
3775        let recv_fut = recv_stream.call_recording(
3776            &test_setup.client,
3777            3.into(),
3778            recording_ref,
3779            vec![actual_result_ref_0],
3780            vec![],
3781        );
3782        tokio::try_join!(send_fut, recv_fut)?;
3783
3784        // The result on recv_stream should have a value, but it will be garbage.
3785        let _ = recv_stream
3786            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3787            .await
3788            .unwrap()
3789            .unwrap()
3790            .unwrap();
3791
3792        test_setup
3793            .validate_dependent_error(actual_result_ref_1, input_error.clone())
3794            .await;
3795
3796        // Input 1 should be untouched.
3797        assert!(
3798            test_setup
3799                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3800                .await
3801        );
3802
3803        // Validate that failure wasn't reported to controller.
3804        check_fetch_result_error(
3805            &test_setup.client,
3806            send_stream.clone(),
3807            4.into(),
3808            actual_result_ref_1,
3809            &mut test_setup.controller_rx,
3810            "input error",
3811        )
3812        .await;
3813        check_fetch_result_value(
3814            &test_setup.client,
3815            recv_stream.clone(),
3816            5.into(),
3817            actual_result_ref_0,
3818            &mut test_setup.controller_rx,
3819        )
3820        .await;
3821
3822        test_setup
3823            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3824            .await?;
3825        send_stream
3826            .set_tensor_ref_unit_tests_only(
3827                &test_setup.client,
3828                input_tensor_ref_1,
3829                Err(input_error.clone()),
3830            )
3831            .await?;
3832
3833        let send_fut = send_stream.call_recording(
3834            &test_setup.client,
3835            6.into(),
3836            recording_ref,
3837            vec![actual_result_ref_1],
3838            vec![input_tensor_ref_0, input_tensor_ref_1],
3839        );
3840        let recv_fut = recv_stream.call_recording(
3841            &test_setup.client,
3842            6.into(),
3843            recording_ref,
3844            vec![actual_result_ref_0],
3845            vec![],
3846        );
3847        tokio::try_join!(send_fut, recv_fut)?;
3848
3849        let actual_result_0 = recv_stream
3850            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3851            .await
3852            .unwrap()
3853            .unwrap()
3854            .unwrap();
3855        assert!(allclose(
3856            &actual_result_0.borrow(),
3857            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3858        )?);
3859
3860        assert!(
3861            test_setup
3862                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3863                .await
3864        );
3865
3866        test_setup
3867            .validate_dependent_error(actual_result_ref_1, input_error)
3868            .await;
3869
3870        // Validate that failure wasn't reported to controller.
3871        check_fetch_result_error(
3872            &test_setup.client,
3873            send_stream.clone(),
3874            7.into(),
3875            actual_result_ref_1,
3876            &mut test_setup.controller_rx,
3877            "input error",
3878        )
3879        .await;
3880        check_fetch_result_value(
3881            &test_setup.client,
3882            recv_stream.clone(),
3883            8.into(),
3884            actual_result_ref_0,
3885            &mut test_setup.controller_rx,
3886        )
3887        .await;
3888
3889        Ok(())
3890    }
3891}