monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::Mailbox;
36use hyperactor::mailbox::OncePortHandle;
37use hyperactor::mailbox::PortReceiver;
38use hyperactor::proc::Proc;
39use monarch_hyperactor::actor::PythonMessage;
40use monarch_hyperactor::actor::PythonMessageKind;
41use monarch_hyperactor::local_state_broker::BrokerId;
42use monarch_hyperactor::local_state_broker::LocalState;
43use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::CallFunctionError;
50use monarch_messages::worker::CallFunctionParams;
51use monarch_messages::worker::SeqError;
52use monarch_messages::worker::StreamRef;
53use monarch_types::PyTree;
54use monarch_types::SerializablePyErr;
55use monarch_types::TryIntoPyObjectUnsafe;
56use pyo3::prelude::*;
57use pyo3::types::PyTuple;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys::BorrowType;
62use torch_sys::CudaDevice;
63use torch_sys::MultiBorrow;
64use torch_sys::RValue;
65use torch_sys::TensorCell;
66use torch_sys::deep_clone;
67use torch_sys::factory_empty;
68use torch_sys::factory_zeros;
69use torch_sys_cuda::cuda::Event;
70use torch_sys_cuda::cuda::Stream;
71use tracing_subscriber::fmt::Subscriber;
72
73use crate::ControllerActor;
74use crate::DeviceMesh;
75use crate::Factory;
76use crate::Reduction;
77use crate::Ref;
78use crate::ResolvableFunction;
79use crate::StreamCreationMode;
80use crate::WireValue;
81use crate::comm::CommBackend;
82use crate::comm::CommMessage;
83use crate::comm::CommMessageClient;
84use crate::comm::NcclCommActor;
85use crate::pipe::PipeMessage;
86
87pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
88
89// These thread locals are accessed by the python runtime for debugging sessions.
90thread_local! {
91    pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
92    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
93    pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
94}
95
96fn pickle_python_result(
97    py: Python<'_>,
98    result: Bound<'_, PyAny>,
99    worker_actor_id: ActorId,
100) -> Result<PythonMessage, anyhow::Error> {
101    let pickle = py
102        .import("monarch._src.actor.actor_mesh")
103        .unwrap()
104        .getattr("_pickle")
105        .unwrap();
106    let data: Vec<u8> = pickle
107        .call1((result,))
108        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
109        .extract()
110        .unwrap();
111    Ok(PythonMessage::new_from_buf(
112        PythonMessageKind::Result {
113            rank: Some(worker_actor_id.rank()),
114        },
115        data,
116    ))
117}
118
119#[derive(Debug)]
120struct Recording {
121    messages: Vec<StreamMessage>,
122}
123
124impl Recording {
125    fn new() -> Self {
126        Self {
127            messages: Vec::new(),
128        }
129    }
130}
131
132#[derive(Debug, PartialEq)]
133enum RecordingState {
134    Defining {
135        recording: Ref,
136        // Set of borrow ids used to track proper borrow usage inside
137        // a recording.
138        defined_borrows: HashSet<u64>,
139    },
140    Running,
141}
142
143/// Messages handled by the stream. Generally these are stream-local versions of
144/// [`crate::WorkerMessage`].
145#[derive(Handler, HandleClient, Debug, Named)]
146pub enum StreamMessage {
147    CallFunction(
148        CallFunctionParams,
149        HashMap<Ref, DeviceMesh>,
150        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
151    ),
152
153    BorrowCreate {
154        /// Id for the borrow.
155        borrow: u64,
156        /// Tensor to borrow.
157        tensor: Ref,
158        /// Port for sending the first use CUDA event + borrowed tensor to
159        /// the borrower.
160        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
161    },
162
163    BorrowFirstUse {
164        /// Id for the borrow.
165        borrow: u64,
166        /// Ref for storing the borrowed tensor.
167        result: Ref,
168        /// Port for receiving the first use CUDA event + borrowed tensor from
169        /// the provider stream.
170        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
171    },
172
173    BorrowLastUse {
174        /// Id for the borrow.
175        borrow: u64,
176        /// Ref for the borrowed tensor.
177        result: Ref,
178        /// Port for sending the last use CUDA event and borrowed tensor.
179        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
180    },
181
182    BorrowDrop {
183        borrow: u64,
184        /// Port for receiving the last use CUDA event and borrowed tensor.
185        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
186    },
187
188    DeleteRefs(Vec<Ref>),
189
190    RequestStatus(#[reply] OncePortHandle<()>),
191
192    InitComm(ActorHandle<NcclCommActor>),
193
194    Reduce {
195        comm: Arc<ActorHandle<NcclCommActor>>,
196        dim_size: i64,
197        result: Ref,
198        local_tensor: Ref,
199        factory: Factory,
200        reduction: Reduction,
201        scatter: bool,
202        in_place: bool,
203        out: Option<Ref>,
204    },
205
206    SendTensor {
207        result: Ref,
208        from_rank: Option<usize>,
209        to_rank: Option<usize>,
210        tensor: Ref,
211        factory: Factory,
212        comm: Arc<ActorHandle<NcclCommActor>>,
213    },
214
215    SendValue {
216        seq: Seq,
217        worker_actor_id: ActorId,
218        mutates: Vec<Ref>,
219        function: Option<ResolvableFunction>,
220        args: Vec<WireValue>,
221        kwargs: HashMap<String, WireValue>,
222        device_meshes: HashMap<Ref, DeviceMesh>,
223        pipe: Option<PortHandle<PipeMessage>>,
224    },
225
226    SetValue {
227        seq: Seq,
228        results: Vec<Option<Ref>>,
229        pipe: PortHandle<PipeMessage>,
230    },
231
232    DefineRecording {
233        recording: Ref,
234    },
235
236    FinalizeRecording {
237        recording: Ref,
238    },
239
240    CallRecording {
241        seq: Seq,
242        recording: Ref,
243        results: Vec<Ref>,
244        actuals: Vec<Ref>,
245    },
246
247    RecordingFormal {
248        result: Ref,
249        argument_index: usize,
250    },
251
252    RecordingResult {
253        result: Ref,
254        output_index: usize,
255    },
256
257    SetRefUnitTestsOnly(Ref, WireValue),
258
259    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
260
261    GetRefUnitTestsOnly(
262        Ref, // value
263        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
264    ),
265
266    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
267
268    SendResultOfActorCall(ActorId, ActorCallParams),
269    CallActorMethod(ActorMethodParams),
270}
271
272impl StreamMessage {
273    fn clone_for_recording(&self) -> Self {
274        match self {
275            StreamMessage::RecordingFormal {
276                result,
277                argument_index,
278            } => StreamMessage::RecordingFormal {
279                result: *result,
280                argument_index: *argument_index,
281            },
282            StreamMessage::RecordingResult {
283                result,
284                output_index,
285            } => StreamMessage::RecordingResult {
286                result: *result,
287                output_index: *output_index,
288            },
289            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
290            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
291                StreamMessage::CallFunction(
292                    params.clone(),
293                    device_meshes.clone(),
294                    remote_process_groups.clone(),
295                )
296            }
297            StreamMessage::BorrowCreate {
298                borrow,
299                tensor,
300                first_use_sender,
301            } => StreamMessage::BorrowCreate {
302                borrow: *borrow,
303                tensor: *tensor,
304                first_use_sender: first_use_sender.clone(),
305            },
306            StreamMessage::BorrowFirstUse {
307                borrow,
308                result,
309                first_use_receiver,
310            } => StreamMessage::BorrowFirstUse {
311                borrow: *borrow,
312                result: *result,
313                first_use_receiver: first_use_receiver.clone(),
314            },
315            StreamMessage::BorrowLastUse {
316                borrow,
317                result,
318                last_use_sender,
319            } => StreamMessage::BorrowLastUse {
320                borrow: *borrow,
321                result: *result,
322                last_use_sender: last_use_sender.clone(),
323            },
324            StreamMessage::BorrowDrop {
325                borrow,
326                last_use_receiver,
327            } => StreamMessage::BorrowDrop {
328                borrow: *borrow,
329                last_use_receiver: last_use_receiver.clone(),
330            },
331            StreamMessage::Reduce {
332                comm,
333                dim_size,
334                result,
335                local_tensor,
336                factory,
337                reduction,
338                scatter,
339                in_place,
340                out,
341            } => StreamMessage::Reduce {
342                comm: comm.clone(),
343                dim_size: *dim_size,
344                result: *result,
345                local_tensor: *local_tensor,
346                factory: factory.clone(),
347                reduction: reduction.clone(),
348                scatter: *scatter,
349                in_place: *in_place,
350                out: out.clone(),
351            },
352            StreamMessage::SendTensor {
353                result,
354                from_rank,
355                to_rank,
356                tensor,
357                factory,
358                comm,
359            } => StreamMessage::SendTensor {
360                result: *result,
361                from_rank: *from_rank,
362                to_rank: *to_rank,
363                tensor: *tensor,
364                factory: factory.clone(),
365                comm: comm.clone(),
366            },
367            StreamMessage::SetValue { seq, results, pipe } => StreamMessage::SetValue {
368                seq: seq.clone(),
369                results: results.clone(),
370                pipe: pipe.clone(),
371            },
372            other => panic!(
373                "StreamMessage variant not supported in recording: {:?}",
374                other
375            ),
376        }
377    }
378
379    // Get the set of refs that this message defines.
380    fn get_defined_refs(&self) -> HashSet<Ref> {
381        match self {
382            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
383            StreamMessage::CallFunction(params, ..) => {
384                params.results.iter().filter_map(|&ref_| ref_).collect()
385            }
386            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
387            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
388            StreamMessage::SendTensor {
389                result, from_rank, ..
390            } => {
391                if from_rank.is_some() {
392                    HashSet::from([*result])
393                } else {
394                    HashSet::new()
395                }
396            }
397            StreamMessage::SetValue { results, .. } => {
398                results.iter().filter_map(|&ref_| ref_).collect()
399            }
400            // TODO(slurye): Add SendValue eventually.
401            _ => HashSet::new(),
402        }
403    }
404
405    // Get the set of refs that this message mutates.
406    fn get_mutated_refs(&self) -> HashSet<Ref> {
407        match self {
408            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
409            StreamMessage::Reduce {
410                out,
411                in_place,
412                local_tensor,
413                ..
414            } => {
415                if *in_place {
416                    HashSet::from([*local_tensor])
417                } else if let Some(out) = out {
418                    HashSet::from([*out])
419                } else {
420                    HashSet::new()
421                }
422            }
423            // TODO(slurye): Add SendValue eventually.
424            _ => HashSet::new(),
425        }
426    }
427}
428
429/// A stream represents a linear sequence of execution. Operations on different
430/// streams can execute concurrently.
431///
432/// For CUDA operators, streams will invoke the corresponding stream management
433/// APIs to perform synchronization.
434///
435/// For CPU operators, streams will just execute synchronously on their own OS
436/// thread.
437#[derive(Debug)]
438pub struct StreamActor {
439    world_size: usize,
440    rank: usize,
441    /// Mapping of refs in the controller environment to TensorIndex in this
442    /// stream's local environment.
443    // TODO(agallagher): Use `ValueError` as the error type.
444    env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
445    /// How to create the stream.
446    creation_mode: StreamCreationMode,
447    /// CUDA stream that this actor will enqueue operations on. None if "device"
448    /// is not a CUDA device.
449    /// NOTE: We lazily create the stream, so that we do it from the dedicated
450    /// Stream OS thread as, otherwise, we see deadlocks when done from
451    /// unexpected threads.
452    cuda_stream: OnceLock<Option<Stream>>,
453    /// Device this stream should be scheduled on.
454    device: Option<CudaDevice>,
455    /// Communicator for this stream. Optional as we lazily initialize it.
456    comm: Option<ActorHandle<NcclCommActor>>,
457    /// Actor ref of the controller that created this stream.
458    controller_actor: ActorRef<ControllerActor>,
459    remote_process_groups: HashMap<Ref, PyObject>,
460    recordings: HashMap<Ref, Recording>,
461    active_recording: Option<RecordingState>,
462    respond_with_python_message: bool,
463    last_seq_error: Option<Arc<SeqError>>,
464}
465
466/// Parameters for creating a [`Stream`].
467#[derive(Debug, Clone)]
468pub struct StreamParams {
469    pub world_size: usize,
470    pub rank: usize,
471    /// Controls how the underlying CUDA stream is created.
472    pub creation_mode: StreamCreationMode,
473    /// Id of this stream in the worker actor's stream table.
474    pub id: StreamRef,
475    /// Device this stream should be scheduled on. If none, don't do stream
476    /// synchronization.
477    pub device: Option<CudaDevice>,
478    /// Actor ref of the controller that created this stream.
479    pub controller_actor: ActorRef<ControllerActor>,
480    pub respond_with_python_message: bool,
481}
482
483#[async_trait]
484impl Actor for StreamActor {
485    type Params = StreamParams;
486    async fn new(
487        StreamParams {
488            world_size,
489            rank,
490            id: _,
491            device,
492            controller_actor,
493            creation_mode,
494            respond_with_python_message,
495        }: Self::Params,
496    ) -> Result<Self> {
497        Ok(Self {
498            world_size,
499            rank,
500            env: HashMap::new(),
501            creation_mode,
502            cuda_stream: OnceLock::new(),
503            device,
504            comm: None,
505            controller_actor,
506            remote_process_groups: HashMap::new(),
507            recordings: HashMap::new(),
508            active_recording: None,
509            respond_with_python_message,
510            last_seq_error: None,
511        })
512    }
513
514    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
515        // These thread locals are exposed via python functions, so we need to set them in the
516        // same thread that python will run in. That means we need to initialize them here in
517        // StreamActor::init instead of in StreamActor::new.
518        CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
519            controller_actor_ref.set(self.controller_actor.clone()).ok()
520        });
521        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
522        ROOT_ACTOR_ID.with(|root_actor_id| {
523            root_actor_id
524                .set(ActorId::root(
525                    cx.self_id().proc_id().clone(),
526                    cx.self_id().name().to_string(),
527                ))
528                .ok()
529        });
530        // Set the current stream for this actor thread.
531        if let Some(stream) = self.cuda_stream() {
532            Stream::set_current_stream(stream);
533        }
534        Ok(())
535    }
536
537    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
538    /// dedicated OS thread. This is because:
539    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
540    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
541    ///     threads is problematic.
542    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
543    where
544        F: Future + Send + 'static,
545        F::Output: Send + 'static,
546    {
547        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
548        // It is important that we spawn a standalone thread for the work here,
549        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
550        // This is because the worker stream may call uninterruptible FFI code
551        // that can deadlock (CUDA, NCCL).
552        // If we use a tokio-managed blocking thread, then runtime teardown will
553        // try to wait for tasks on that thread to reach an await point, and
554        // hang forever.
555        let builder = std::thread::Builder::new().name("worker-stream".to_string());
556        let _thread_handle = builder.spawn(move || {
557            // Spawn a new thread with a single-threaded tokio runtime to run the
558            // actor loop.  We avoid the current-threaded runtime, so that we can
559            // use `block_in_place` for nested async-to-sync-to-async flows.
560            let rt = tokio::runtime::Builder::new_multi_thread()
561                .worker_threads(1)
562                .enable_io()
563                .build()
564                .unwrap();
565            let result = rt.block_on(async {
566                tokio::task::block_in_place(|| {
567                    // Allow e.g. destructing py objects on this thread, which
568                    // can happen at shutdown when the a stream actors env map
569                    // for rvalues is dropped (e.g. P1673311499).
570                    // https://github.com/PyO3/pyo3/discussions/3499
571                    Python::with_gil(|py| {
572                        py.allow_threads(|| {
573                            let result = Handle::current().block_on(future);
574                            if join_tx.send(result).is_err() {
575                                panic!("could not send join result")
576                            }
577                        })
578                    })
579                })
580            });
581            rt.shutdown_timeout(Duration::from_weeks(1));
582            result
583        });
584
585        // In order to bridge the synchronous join handle with the async world,
586        // smuggle the result through a channel.
587        tokio::spawn(async move { join_rx.await.unwrap() })
588    }
589}
590
591/// The arguments we accept as inputs to Python function calls.
592#[derive(Debug)]
593enum PyArg<'a> {
594    RValue(RValue),
595    DeviceMesh(&'a DeviceMesh),
596    PyObject(PyObject),
597}
598
599/// Serialize into a `PyObject`.
600impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
601    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
602        match self {
603            // SAFETY: This inherits the unsafety of `rvalue_to_ivalue` (see comment
604            // above).
605            PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
606            PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
607            PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
608        }
609    }
610}
611
612impl StreamActor {
613    fn cuda_stream(&self) -> Option<&Stream> {
614        self.cuda_stream
615            .get_or_init(|| {
616                self.device.map(|device| match self.creation_mode {
617                    StreamCreationMode::UseDefaultStream => {
618                        Stream::get_current_stream_on_device(device)
619                    }
620                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
621                })
622            })
623            .as_ref()
624    }
625
626    fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
627        let rvalue = self
628            .env
629            .get(ref_)
630            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
631        match rvalue {
632            Ok(val) => Ok(val.clone()),
633            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
634        }
635    }
636
637    fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
638        let ret = match value {
639            WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
640            // TODO: We might want to support GenericList / GenericDict etc.
641            WireValue::RefList(val) => {
642                let mut ret = Vec::with_capacity(val.len());
643                for v in val {
644                    match self.ref_to_rvalue(&v) {
645                        Ok(RValue::Tensor(t)) => ret.push(t),
646                        Err(err) => {
647                            return Err(err);
648                        }
649                        Ok(val) => {
650                            return Err(CallFunctionError::UnsupportedArgType(
651                                "wire_to_rvalue".into(),
652                                format!("RefList([{:?}])", val),
653                            ));
654                        }
655                    }
656                }
657                RValue::TensorList(ret)
658            }
659            WireValue::Int(val) => RValue::Int(val),
660            WireValue::IntList(val) => RValue::IntList(val),
661            WireValue::Double(val) => RValue::Double(val),
662            WireValue::Bool(val) => RValue::Bool(val),
663            WireValue::String(val) => RValue::String(val),
664            WireValue::Device(val) => RValue::Device(val),
665            WireValue::Layout(val) => RValue::Layout(val),
666            WireValue::ScalarType(val) => RValue::ScalarType(val),
667            WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
668            WireValue::PyObject(val) => RValue::PyObject(val),
669            WireValue::None(()) => RValue::None,
670            WireValue::IValue(val) => RValue::Opaque(val.into()),
671        };
672        Ok(ret)
673    }
674
675    async fn report_seq_error(
676        &mut self,
677        cx: &Context<'_, Self>,
678        seq: Seq,
679        error: CallFunctionError,
680    ) -> Result<Arc<SeqError>, anyhow::Error> {
681        match error {
682            CallFunctionError::DependentError(root) => Ok(root),
683            CallFunctionError::Error(e) => {
684                if self.active_recording.is_none() {
685                    let worker_error = WorkerError {
686                        backtrace: format!("{e}"),
687                        worker_actor_id: cx.self_id().clone(),
688                    };
689                    tracing::info!("Propagating remote function error to client: {worker_error}");
690                    self.controller_actor
691                        .remote_function_failed(cx, seq, worker_error)
692                        .await?
693                }
694                let err = Arc::new(SeqError { seq, error: e });
695                self.last_seq_error = Some(err.clone());
696                Ok(err)
697            }
698        }
699    }
700
701    async fn try_define<F>(
702        &mut self,
703        cx: &Context<'_, Self>,
704        seq: Seq,
705        result_refs: Vec<Option<Ref>>,
706        mutates: &Vec<Ref>,
707        f: F,
708    ) -> Result<()>
709    where
710        F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
711    {
712        let actual_results = f(self).await;
713        // Check if the expected number of returns is correct, otherwise convert
714        // into an error.
715        let op_results = actual_results.and_then(|actual_results| {
716            if result_refs.len() == actual_results.len() {
717                Ok(actual_results
718                    .into_iter()
719                    .zip(result_refs.iter())
720                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
721                    .collect::<Vec<(Ref, RValue)>>())
722            } else {
723                Err(CallFunctionError::UnexpectedNumberOfReturns(
724                    result_refs.len(),
725                    actual_results.len(),
726                ))
727            }
728        });
729
730        // Propagate the results (either the actual values or an error) to the
731        // right entries in the global env mapping.
732        match op_results {
733            Ok(op_results) => {
734                for (ref_, rvalue) in op_results.into_iter() {
735                    let prev = self.env.insert(ref_, Ok(rvalue));
736                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
737                }
738            }
739            Err(err) => {
740                let err = self.report_seq_error(cx, seq, err).await?;
741                for ref_ in result_refs {
742                    match ref_ {
743                        Some(ref_) => {
744                            let prev = self.env.insert(ref_, Err(err.clone()));
745                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
746                        }
747                        None => {}
748                    }
749                }
750                for ref_ in mutates {
751                    self.env.insert(*ref_, Err(err.clone()));
752                }
753            }
754        }
755        Ok(())
756    }
757
758    fn call_torch_op(
759        &self,
760        op: String,
761        overload: String,
762        args: Vec<WireValue>,
763        kwargs: HashMap<String, WireValue>,
764    ) -> Result<Vec<RValue>, CallFunctionError> {
765        let args = args
766            .into_iter()
767            .map(|arg| self.wire_to_rvalue(arg))
768            .collect::<Result<Vec<_>, _>>()?;
769        let kwargs = kwargs
770            .into_iter()
771            .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
772            .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
773
774        let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
775
776        // Handle the case where the op returns nothing and convert it to a list of None.
777        // This is to ensure handle results does not error out as the client will call
778        // such a function with expected results of size 1.
779        Ok(if results.is_empty() {
780            vec![RValue::None]
781        } else {
782            results
783        })
784    }
785
786    fn call_python_fn<'py>(
787        &mut self,
788        py: Python<'py>,
789        cx: &Context<Self>,
790        function: Option<ResolvableFunction>,
791        args: Vec<WireValue>,
792        kwargs: HashMap<String, WireValue>,
793        mutates: &[Ref],
794        device_meshes: HashMap<Ref, DeviceMesh>,
795        remote_process_groups: HashMap<
796            Ref,
797            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
798        >,
799    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
800        let function = function
801            .map(|function| {
802                function.resolve(py).map_err(|e| {
803                    CallFunctionError::InvalidRemoteFunction(format!(
804                        "failed to resolve function {}: {}",
805                        function,
806                        SerializablePyErr::from(py, &e)
807                    ))
808                })
809            })
810            .transpose()?;
811
812        let remote_process_groups = remote_process_groups
813            .into_iter()
814            .map(|(gref, (mesh, dims, comm))| {
815                let group = match self.remote_process_groups.entry(gref) {
816                    Entry::Occupied(ent) => ent.get().clone_ref(py),
817                    Entry::Vacant(ent) => {
818                        // We need to run `init_process_group` before any
819                        // remote process groups can get created.
820                        torch_sys::backend::ensure_init_process_group(
821                            py,
822                            self.world_size,
823                            self.rank,
824                        )?;
825
826                        // Create a backend object to wrap the comm and use
827                        // it to create a new torch group.
828                        let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
829                        let group_size = ranks.len();
830                        let backend = CommBackend::new(
831                            comm,
832                            Mailbox::new_detached(cx.self_id().clone()),
833                            self.rank,
834                            group_size,
835                            self.world_size,
836                        );
837                        ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
838                            .clone_ref(py)
839                    }
840                };
841                PyResult::Ok((gref, group))
842            })
843            .collect::<Result<HashMap<_, _>, _>>()
844            .map_err(SerializablePyErr::from_fn(py))?;
845
846        // SAFETY: We will be making an unchecked clone of each tensor to pass to to
847        // C++, so we need to hold a borrow of each input tensor for the duration of
848        // this function.
849        let mut multiborrow = MultiBorrow::new();
850
851        let resolve = |val: WireValue| {
852            val.into_py_object()
853                .map_err(|e| {
854                    CallFunctionError::UnsupportedArgType(
855                        format!("{:?}", function),
856                        format!("{:?}", e),
857                    )
858                })?
859                .unpickle(py)
860                .map_err(SerializablePyErr::from_fn(py))?
861                .extract::<PyTree<PyObject>>()
862                .map_err(SerializablePyErr::from_fn(py))?
863                .try_into_map(|obj| {
864                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
865                        if let Some(mesh) = device_meshes.get(&ref_) {
866                            PyArg::DeviceMesh(mesh)
867                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
868                            PyArg::PyObject(pg.clone_ref(py))
869                        } else {
870                            let rval = self.ref_to_rvalue(&ref_)?;
871                            PyArg::RValue(rval)
872                        }
873                    } else {
874                        PyArg::PyObject(obj)
875                    })
876                })
877        };
878
879        // Resolve refs
880        let py_args: Vec<PyTree<PyArg>> = args
881            .into_iter()
882            .map(resolve)
883            .collect::<Result<_, CallFunctionError>>()?;
884        let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
885            .into_iter()
886            .map(|(k, object)| Ok((k, resolve(object)?)))
887            .collect::<Result<_, CallFunctionError>>()?;
888
889        // Add a shared-borrow for each rvalue reference.
890        py_args
891            .iter()
892            .chain(py_kwargs.values())
893            .flat_map(|o| o.iter())
894            .for_each(|arg| {
895                if let PyArg::RValue(rval) = arg {
896                    multiborrow.add(rval, BorrowType::Shared);
897                }
898            });
899
900        // Add mutable borrows for params we're mutating.
901        let mutates: Vec<_> = mutates
902            .iter()
903            .map(|r| self.ref_to_rvalue(r))
904            .collect::<Result<_, CallFunctionError>>()?;
905        mutates
906            .iter()
907            .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
908
909        // Execute the borrow.
910        let _borrow = multiborrow.borrow()?;
911
912        // Call function.
913        // Use custom subscriber to route Worker messages to stdout.
914        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
915        let result: Bound<'_, PyAny> =
916            tracing::subscriber::with_default(scoped_subscriber, || {
917                // SAFETY: The borrows above guard the unchecked clones done by
918                // `rvalue_to_ivalue`. This may result in multiple mutable
919                // references to tensor data, but the Python side is responsible
920                // for making sure that is safe
921                // TODO(agallagher): The args/kwargs conversion traits generate
922                // the appropriate types here, but they get casted to `PyAny`.
923                // It'd be nice to make `TryToPyObjectUnsafe` take a template
924                // arg for the converted py object to avoid this downcast.
925                let args = unsafe { py_args.try_to_object_unsafe(py) }
926                    .map_err(SerializablePyErr::from_fn(py))?;
927                // SAFETY: above
928                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
929                    .map_err(SerializablePyErr::from_fn(py))?;
930
931                if let Some(function) = function {
932                    function
933                        .call(args, Some(kwargs))
934                        .map_err(SerializablePyErr::from_fn(py))
935                } else {
936                    Ok(args.get_item(0).unwrap())
937                }
938            })?;
939        Ok(result)
940    }
941
942    fn call_python_fn_pytree(
943        &mut self,
944        cx: &hyperactor::Context<Self>,
945        function: ResolvableFunction,
946        args: Vec<WireValue>,
947        kwargs: HashMap<String, WireValue>,
948        mutates: &[Ref],
949        device_meshes: HashMap<Ref, DeviceMesh>,
950        remote_process_groups: HashMap<
951            Ref,
952            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
953        >,
954    ) -> Result<PyTree<RValue>, CallFunctionError> {
955        Python::with_gil(|py| {
956            let result = self.call_python_fn(
957                py,
958                cx,
959                Some(function),
960                args,
961                kwargs,
962                mutates,
963                device_meshes,
964                remote_process_groups,
965            )?;
966            Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
967        })
968    }
969    /// Retrieve `ref_` or create a fake value with the provided factory if it
970    /// is an error. We use this for collective calls, where even if there was
971    /// an upstream failure, we still have participate in the collective to
972    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
973    /// value here of the correct shape; the controller will have been notified
974    /// of the upstream failure and will know to ignore everything dependent on
975    /// it.
976    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
977        let rvalue = self
978            .env
979            .get(&ref_)
980            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
981
982        match rvalue {
983            Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
984            Err(_) => {
985                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
986                Ok(TensorCell::new(t))
987            }
988        }
989    }
990
991    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
992        self.active_recording
993            .as_mut()
994            .and_then(|state| match state {
995                RecordingState::Defining {
996                    recording,
997                    defined_borrows,
998                } => {
999                    match self.recordings.get_mut(recording) {
1000                        Some(recording) => Some((recording, defined_borrows)),
1001                        // Panic, because this would be a logic error in the program.
1002                        None => panic!("recording not found: {:?}", recording),
1003                    }
1004                }
1005                RecordingState::Running => None,
1006            })
1007    }
1008
1009    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
1010        for ref_ in refs {
1011            let rvalue_or_err = self
1012                .env
1013                .get(ref_)
1014                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1015            if let Err(err) = rvalue_or_err {
1016                return Ok(Some(err.clone()));
1017            }
1018        }
1019        Ok(None)
1020    }
1021    async fn send_value_python_message(
1022        &mut self,
1023        cx: &hyperactor::Context<'_, Self>,
1024        seq: Seq,
1025        worker_actor_id: ActorId,
1026        mutates: Vec<Ref>,
1027        function: Option<ResolvableFunction>,
1028        args: Vec<WireValue>,
1029        kwargs: HashMap<String, WireValue>,
1030        device_meshes: HashMap<Ref, DeviceMesh>,
1031    ) -> Result<()> {
1032        self.try_define(cx, seq, vec![], &vec![], async |self_| {
1033            let python_message =
1034                Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1035                    let python_result = tokio::task::block_in_place(|| {
1036                        self_.call_python_fn(
1037                            py,
1038                            cx,
1039                            function,
1040                            args,
1041                            kwargs,
1042                            &mutates,
1043                            device_meshes,
1044                            HashMap::new(),
1045                        )
1046                    })?;
1047                    pickle_python_result(py, python_result, worker_actor_id)
1048                        .map_err(CallFunctionError::Error)
1049                })?;
1050            let ser = Serialized::serialize(&python_message).unwrap();
1051            self_
1052                .controller_actor
1053                .fetch_result(cx, seq, Ok(ser))
1054                .await?;
1055            Ok(vec![])
1056        })
1057        .await
1058    }
1059    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1060        let rvalue = self
1061            .env
1062            .get(&src)
1063            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1064        self.env.insert(dest, rvalue.clone());
1065        Ok(())
1066    }
1067    async fn call_actor(
1068        &mut self,
1069        cx: &Context<'_, Self>,
1070        params: ActorCallParams,
1071    ) -> Result<PyObject, CallFunctionError> {
1072        let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1073            params
1074                .local_state
1075                .into_iter()
1076                .map(|elem| {
1077                    // SAFETY: python is gonna make unsafe copies of this stuff anyway
1078                    unsafe {
1079                        let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1080                        Ok(x)
1081                    }
1082                })
1083                .collect()
1084        });
1085
1086        let (send, recv) = cx.open_once_port();
1087        let state = LocalState {
1088            response_port: send,
1089            state: local_state?,
1090        };
1091        let x: u64 = params.seq.into();
1092        let message = LocalStateBrokerMessage::Set(x as usize, state);
1093
1094        let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1095        broker
1096            .send(message)
1097            .map_err(|e| CallFunctionError::Error(e.into()))?;
1098        let result = recv
1099            .recv()
1100            .await
1101            .map_err(|e| CallFunctionError::Error(e.into()))?;
1102
1103        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1104    }
1105}
1106
1107#[async_trait]
1108#[forward(StreamMessage)]
1109impl StreamMessageHandler for StreamActor {
1110    async fn call_function(
1111        &mut self,
1112        cx: &Context<Self>,
1113        params: CallFunctionParams,
1114        device_meshes: HashMap<Ref, DeviceMesh>,
1115        remote_process_groups: HashMap<
1116            Ref,
1117            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1118        >,
1119    ) -> Result<()> {
1120        if let Some((recording, _)) = self.get_defining_recording() {
1121            recording.messages.push(StreamMessage::CallFunction(
1122                params,
1123                device_meshes,
1124                remote_process_groups,
1125            ));
1126            return Ok(());
1127        }
1128
1129        params.function.panic_if_requested();
1130        self.try_define(
1131            cx,
1132            params.seq,
1133            params.results,
1134            &params.mutates,
1135            async |self| {
1136                tokio::task::block_in_place(|| match params.function.as_torch_op() {
1137                    Some((op, overload)) => {
1138                        self.call_torch_op(op, overload, params.args, params.kwargs)
1139                    }
1140                    _ => self
1141                        .call_python_fn_pytree(
1142                            cx,
1143                            params.function,
1144                            params.args,
1145                            params.kwargs,
1146                            &params.mutates,
1147                            device_meshes,
1148                            remote_process_groups,
1149                        )
1150                        .map(|results| results.into_leaves()),
1151                })
1152            },
1153        )
1154        .await?;
1155        Ok(())
1156    }
1157
1158    async fn borrow_create(
1159        &mut self,
1160        _cx: &Context<Self>,
1161        borrow: u64,
1162        tensor: Ref,
1163        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1164    ) -> Result<()> {
1165        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1166            recording.messages.push(StreamMessage::BorrowCreate {
1167                borrow,
1168                tensor,
1169                first_use_sender,
1170            });
1171            ensure!(
1172                defined_borrows.insert(borrow),
1173                "duplicate borrow create in recording"
1174            );
1175            return Ok(());
1176        }
1177
1178        let rvalue_result = self
1179            .env
1180            .get(&tensor)
1181            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1182
1183        let result = match rvalue_result {
1184            Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1185            Err(e) => Err(e.clone()),
1186        };
1187
1188        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1189        first_use_sender.send((event, result)).map_err(|err| {
1190            anyhow!(
1191                "failed sending first use event for borrow {:?}: {:?}",
1192                borrow,
1193                err
1194            )
1195        })
1196    }
1197
1198    async fn borrow_first_use(
1199        &mut self,
1200        _cx: &Context<Self>,
1201        borrow: u64,
1202        result: Ref,
1203        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1204    ) -> Result<()> {
1205        if let Some((recording, _)) = self.get_defining_recording() {
1206            recording.messages.push(StreamMessage::BorrowFirstUse {
1207                borrow,
1208                result,
1209                first_use_receiver: first_use_receiver.clone(),
1210            });
1211            return Ok(());
1212        }
1213
1214        let (first_use_event, cell) =
1215            first_use_receiver
1216                .lock()
1217                .await
1218                .recv()
1219                .await
1220                .map_err(|err| {
1221                    anyhow!(
1222                        "failed receiving first use event for borrow {:?}: {:?}",
1223                        borrow,
1224                        err
1225                    )
1226                })?;
1227
1228        if let Some(stream) = self.cuda_stream() {
1229            stream.wait_event(
1230                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1231            );
1232        }
1233        match cell {
1234            Ok(cell) => {
1235                self.env.insert(result, Ok(cell.into()));
1236            }
1237            Err(err) => {
1238                self.env.insert(result, Err(err.clone()));
1239            }
1240        }
1241        Ok(())
1242    }
1243
1244    async fn borrow_last_use(
1245        &mut self,
1246        _cx: &Context<Self>,
1247        borrow: u64,
1248        result: Ref,
1249        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1250    ) -> Result<()> {
1251        if let Some((recording, _)) = self.get_defining_recording() {
1252            recording.messages.push(StreamMessage::BorrowLastUse {
1253                borrow,
1254                result,
1255                last_use_sender,
1256            });
1257            return Ok(());
1258        }
1259
1260        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1261        let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1262            "Invalid reference for borrow_last_use: {result:#?}"
1263        ))?;
1264        let tensor = match rvalue_or_err {
1265            Ok(RValue::Tensor(t)) => Ok(t),
1266            Err(e) => Err(e),
1267            _ => bail!("invalid rvalue type for borrow_last_use"),
1268        };
1269
1270        last_use_sender.send((event, tensor)).map_err(|err| {
1271            anyhow!(
1272                "failed sending last use event for borrow {:?}: {:?}",
1273                borrow,
1274                err
1275            )
1276        })
1277    }
1278
1279    async fn borrow_drop(
1280        &mut self,
1281        _cx: &Context<Self>,
1282        borrow: u64,
1283        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1284    ) -> Result<()> {
1285        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1286            recording.messages.push(StreamMessage::BorrowDrop {
1287                borrow,
1288                last_use_receiver: last_use_receiver.clone(),
1289            });
1290            ensure!(
1291                defined_borrows.remove(&borrow),
1292                "borrow drop for borrow not defined in recording"
1293            );
1294            return Ok(());
1295        }
1296
1297        // The borrowed cell isn't used directly, but we still want to receive it here
1298        // so that the underlying tensor isn't dropped until after we synchronize the
1299        // CUDA streams.
1300        let (last_use_event, _cell) =
1301            last_use_receiver.lock().await.recv().await.map_err(|err| {
1302                anyhow!(
1303                    "failed receiving last use event for borrow {:?}: {:?}",
1304                    borrow,
1305                    err
1306                )
1307            })?;
1308
1309        if let Some(stream) = self.cuda_stream() {
1310            stream.wait_event(
1311                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1312            );
1313        }
1314        // let the cell drop.
1315        Ok(())
1316    }
1317
1318    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1319        if let Some((recording, _)) = self.get_defining_recording() {
1320            recording.messages.push(StreamMessage::DeleteRefs(refs));
1321            return Ok(());
1322        }
1323
1324        for ref_ in refs.iter() {
1325            self.env.remove(ref_);
1326        }
1327        Ok(())
1328    }
1329
1330    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1331        if self.get_defining_recording().is_some() {
1332            bail!("request_status not allowed in recording");
1333        }
1334
1335        Ok(())
1336    }
1337
1338    async fn init_comm(
1339        &mut self,
1340        _cx: &Context<Self>,
1341        comm: ActorHandle<NcclCommActor>,
1342    ) -> Result<()> {
1343        if self.get_defining_recording().is_some() {
1344            bail!("init_comm not allowed in recording");
1345        }
1346
1347        self.comm = Some(comm);
1348        Ok(())
1349    }
1350
1351    async fn reduce(
1352        &mut self,
1353        cx: &Context<Self>,
1354        comm: Arc<ActorHandle<NcclCommActor>>,
1355        dim_size: i64,
1356        result: Ref,
1357        local_tensor: Ref,
1358        factory: Factory,
1359        reduction: Reduction,
1360        scatter: bool,
1361        in_place: bool,
1362        out: Option<Ref>,
1363    ) -> Result<()> {
1364        if let Some((recording, _)) = self.get_defining_recording() {
1365            recording.messages.push(StreamMessage::Reduce {
1366                comm,
1367                dim_size,
1368                result,
1369                local_tensor,
1370                factory,
1371                reduction,
1372                scatter,
1373                in_place,
1374                out,
1375            });
1376            return Ok(());
1377        }
1378
1379        let stream = self
1380            .cuda_stream()
1381            .expect("reductions not yet supported for non-CUDA workers")
1382            .clone();
1383        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1384        let out_cell = out
1385            .map(|out| self.get_or_fake_on_err(out, &factory))
1386            .transpose()?;
1387        let output_cell = match reduction {
1388            Reduction::Stack => {
1389                if scatter {
1390                    let output_cell = if in_place {
1391                        input_cell.clone()
1392                    } else {
1393                        out_cell.unwrap_or({
1394                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1395                            let cloned = deep_clone(&borrow);
1396                            TensorCell::new(cloned)
1397                        })
1398                    };
1399                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1400                        .await?;
1401                    output_cell
1402                } else {
1403                    ensure!(
1404                        !in_place,
1405                        "in-place, non-scatter not supported for stack reduce"
1406                    );
1407
1408                    let output_cell = out_cell.unwrap_or({
1409                        // In Python, this would be [dim_size, *factory.sizes]
1410                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1411                        let output =
1412                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1413                        TensorCell::new(output)
1414                    });
1415
1416                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1417                        .await?;
1418                    output_cell
1419                }
1420            }
1421            Reduction::ReduceOp(op) => {
1422                if scatter {
1423                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1424
1425                    let output_cell = out_cell.unwrap_or({
1426                        let output = factory_empty(
1427                            &factory.size[1..],
1428                            factory.dtype,
1429                            factory.layout,
1430                            factory.device,
1431                        );
1432                        TensorCell::new(output)
1433                    });
1434                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1435                        .await?;
1436                    output_cell
1437                } else {
1438                    let output_cell = if in_place {
1439                        input_cell.clone()
1440                    } else {
1441                        out_cell.map_or(
1442                            {
1443                                let borrow =
1444                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1445                                let cloned = deep_clone(&borrow);
1446                                Ok(TensorCell::new(cloned))
1447                            },
1448                            |out_cell| -> Result<_, anyhow::Error> {
1449                                let mut out_borrow =
1450                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1451                                let in_borrow =
1452                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1453                                out_borrow.copy_(&in_borrow);
1454                                drop(out_borrow);
1455                                Ok(out_cell)
1456                            },
1457                        )?
1458                    };
1459
1460                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1461                    output_cell
1462                }
1463            }
1464        };
1465
1466        self.env.insert(result, Ok(output_cell.into()));
1467        Ok(())
1468    }
1469
1470    async fn send_tensor(
1471        &mut self,
1472        cx: &Context<Self>,
1473        result: Ref,
1474        from_rank: Option<usize>,
1475        to_rank: Option<usize>,
1476        tensor: Ref,
1477        factory: Factory,
1478        comm: Arc<ActorHandle<NcclCommActor>>,
1479    ) -> Result<()> {
1480        if let Some((recording, _)) = self.get_defining_recording() {
1481            recording.messages.push(StreamMessage::SendTensor {
1482                result,
1483                from_rank,
1484                to_rank,
1485                tensor,
1486                factory,
1487                comm,
1488            });
1489            return Ok(());
1490        }
1491
1492        if to_rank.is_none() && from_rank.is_none() {
1493            bail!("tried to send tensor without a to/from rank");
1494        }
1495
1496        // Value is local, so we do not have to actually send it.
1497        if from_rank == to_rank {
1498            let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1499                .env
1500                .get(&tensor)
1501                .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1502            let output_cell = match input_cell {
1503                Ok(RValue::Tensor(input_cell)) => {
1504                    // We create a defensive copy here to prevent mutations on
1505                    // the input tensor from affecting output tensor.
1506                    // Should we copy if input ref == output ref?
1507                    // Should we support copy-on-write to avoid unnecessary copy?
1508                    let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1509                    let cloned = deep_clone(&borrow);
1510                    Ok(RValue::Tensor(TensorCell::new(cloned)))
1511                }
1512                Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1513                Err(err) => Err(err.clone()),
1514            };
1515            self.env.insert(result, output_cell);
1516            return Ok(());
1517        }
1518
1519        let mut messages = Vec::new();
1520
1521        if let Some(to_rank) = to_rank {
1522            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1523            messages.push(CommMessage::Send(
1524                input_cell,
1525                to_rank.try_into().unwrap(),
1526                self.cuda_stream()
1527                    .expect("tried to send_tensor on non-cuda stream")
1528                    .clone(),
1529                cx.open_once_port().0,
1530            ));
1531        }
1532
1533        if let Some(from_rank) = from_rank {
1534            let output_cell = TensorCell::new(factory_empty(
1535                &factory.size,
1536                factory.dtype,
1537                factory.layout,
1538                factory.device,
1539            ));
1540            messages.push(CommMessage::Recv(
1541                output_cell.clone(),
1542                from_rank.try_into().unwrap(),
1543                self.cuda_stream()
1544                    .expect("tried to send_tensor on non-cuda stream")
1545                    .clone(),
1546                cx.open_once_port().0,
1547            ));
1548            self.env.insert(result, Ok(output_cell.into()));
1549        }
1550
1551        comm.group(
1552            cx,
1553            messages,
1554            self.cuda_stream()
1555                .expect("tried to send_tensor on non-cuda stream")
1556                .clone(),
1557        )
1558        .await?;
1559        Ok(())
1560    }
1561
1562    async fn send_value(
1563        &mut self,
1564        cx: &Context<Self>,
1565        seq: Seq,
1566        worker_actor_id: ActorId,
1567        mutates: Vec<Ref>,
1568        function: Option<ResolvableFunction>,
1569        args: Vec<WireValue>,
1570        kwargs: HashMap<String, WireValue>,
1571        device_meshes: HashMap<Ref, DeviceMesh>,
1572        pipe: Option<PortHandle<PipeMessage>>,
1573    ) -> Result<()> {
1574        if self.respond_with_python_message && pipe.is_none() {
1575            return self
1576                .send_value_python_message(
1577                    cx,
1578                    seq,
1579                    worker_actor_id,
1580                    mutates,
1581                    function,
1582                    args,
1583                    kwargs,
1584                    device_meshes,
1585                )
1586                .await;
1587        }
1588        let result = if let Some(function) = function {
1589            // If a function was provided, use that to resolve the value.
1590            match function.as_torch_op() {
1591                Some((op, overload)) => {
1592                    self.call_torch_op(op, overload, args, kwargs)
1593                        .map(|rvalues| {
1594                            if rvalues.len() == 1 {
1595                                Ok(rvalues[0].clone().into())
1596                            } else {
1597                                // TODO: Replace with native pytrees when possible
1598                                Python::with_gil(|py| {
1599                                    Ok((|| {
1600                                        let py_rvalues = rvalues
1601                                            .into_iter()
1602                                            // SAFETY: This inherits the unsafety of `try_to_object_unsafe`.
1603                                            .map(|rvalue| unsafe {
1604                                                rvalue.try_to_object_unsafe(py)
1605                                            })
1606                                            .collect::<Result<Vec<_>, _>>()?;
1607                                        PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1608                                    })()
1609                                    .map_err(SerializablePyErr::from_fn(py))?)
1610                                })
1611                            }
1612                        })?
1613                }
1614                // Use block-in-place to allow nested callbacks to re-enter the
1615                // runtime to run async code.
1616                _ => tokio::task::block_in_place(|| {
1617                    self.call_python_fn_pytree(
1618                        cx,
1619                        function,
1620                        args,
1621                        kwargs,
1622                        &mutates,
1623                        device_meshes,
1624                        HashMap::new(),
1625                    )
1626                }),
1627            }
1628        } else {
1629            // If there's no function provided, there should be exactly one arg
1630            // and no kwargs.
1631            match (args.len(), kwargs.len()) {
1632                (1, 0) => Python::with_gil(|py| {
1633                    let arg = args[0]
1634                        .as_py_object()
1635                        .ok_or_else(|| {
1636                            CallFunctionError::UnsupportedArgType(
1637                                "send_value".to_string(),
1638                                "expected a PyObject as the first arg".to_string(),
1639                            )
1640                        })?
1641                        .unpickle(py)
1642                        .map_err(SerializablePyErr::from_fn(py))?;
1643                    arg.extract::<PyTree<PyObject>>()
1644                        .map_err(SerializablePyErr::from_fn(py))?
1645                        .try_into_map(|obj| {
1646                            let bound_obj = obj.bind(py);
1647                            if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1648                                self.ref_to_rvalue(&ref_)
1649                            } else {
1650                                Ok(bound_obj
1651                                    .extract::<RValue>()
1652                                    .map_err(SerializablePyErr::from_fn(py))?)
1653                            }
1654                        })
1655                }),
1656                _ => Err(CallFunctionError::TooManyArgsForValue(
1657                    format!("{:?}", args),
1658                    format!("{:?}", kwargs),
1659                )),
1660            }
1661        };
1662
1663        let value = match result {
1664            Ok(rvalue) => {
1665                // When returning a tensor, we copy out to decouple from the GPU,
1666                // as the worker will either serialize and send this to the controller
1667                // or to a pipe and we see hangs if it tries to pull from the GPU
1668                // in its thread.
1669                Ok(rvalue.into_map(|rval| match rval {
1670                    RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1671                    RValue::TensorList(tensors) => RValue::TensorList(
1672                        tensors
1673                            .into_iter()
1674                            .map(|tensor| tensor.try_cpu().unwrap())
1675                            .collect(),
1676                    ),
1677                    rval => rval,
1678                }))
1679            }
1680            Err(err) => {
1681                let err = self.report_seq_error(cx, seq, err).await?;
1682                for ref_ in mutates {
1683                    self.env.insert(ref_, Err(err.clone()));
1684                }
1685                Err(WorkerError {
1686                    backtrace: format!("{:?}", err),
1687                    worker_actor_id,
1688                })
1689            }
1690        };
1691
1692        // Actually send the value.
1693        if let Some(pipe) = pipe {
1694            pipe.send(PipeMessage::SendValue(value))?;
1695        } else {
1696            let result = match value {
1697                Ok(value) => Ok(Serialized::serialize_anon(&value).map_err(anyhow::Error::from)?),
1698                Err(e) => Err(e),
1699            };
1700            self.controller_actor.fetch_result(cx, seq, result).await?;
1701        }
1702
1703        Ok(())
1704    }
1705
1706    async fn send_result_of_actor_call(
1707        &mut self,
1708        cx: &Context<Self>,
1709        worker_actor_id: ActorId,
1710        params: ActorCallParams,
1711    ) -> anyhow::Result<()> {
1712        let seq = params.seq;
1713        let mutates = params.mutates.clone();
1714        self.try_define(cx, seq, vec![], &mutates, async |self| {
1715            let value = self.call_actor(cx, params).await?;
1716            let result = Python::with_gil(|py| {
1717                pickle_python_result(py, value.into_bound(py), worker_actor_id)
1718            })?;
1719            let result = Serialized::serialize(&result).unwrap();
1720            self.controller_actor
1721                .fetch_result(cx, seq, Ok(result))
1722                .await?;
1723            Ok(vec![])
1724        })
1725        .await
1726    }
1727
1728    async fn call_actor_method(
1729        &mut self,
1730        cx: &Context<Self>,
1731        params: ActorMethodParams,
1732    ) -> anyhow::Result<()> {
1733        let seq = params.call.seq;
1734        let mutates = params.call.mutates.clone();
1735        self.try_define(cx, seq, params.results, &mutates, async |self| {
1736            let result = self.call_actor(cx, params.call).await?;
1737            let result = Python::with_gil(|py| {
1738                PyTree::<RValue>::extract_bound(&result.into_bound(py))
1739                    .map_err(SerializablePyErr::from_fn(py))
1740            })?;
1741            Ok(result.into_leaves())
1742        })
1743        .await
1744    }
1745
1746    async fn set_value(
1747        &mut self,
1748        cx: &Context<Self>,
1749        seq: Seq,
1750        results: Vec<Option<Ref>>,
1751        pipe: PortHandle<PipeMessage>,
1752    ) -> Result<()> {
1753        if let Some((recording, _)) = self.get_defining_recording() {
1754            recording
1755                .messages
1756                .push(StreamMessage::SetValue { seq, results, pipe });
1757            return Ok(());
1758        }
1759
1760        self.try_define(cx, seq, results, &vec![], async |self| {
1761            let (tx, rx) = cx.open_once_port();
1762            pipe.send(PipeMessage::RecvValue(tx))
1763                .map_err(anyhow::Error::from)
1764                .map_err(CallFunctionError::from)?;
1765            let value = rx.recv().await.map_err(anyhow::Error::from)?;
1766            Ok(value.into_leaves())
1767        })
1768        .await
1769    }
1770
1771    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1772        if self.active_recording.is_some() {
1773            bail!("different recording already active");
1774        }
1775        match self.recordings.entry(recording) {
1776            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1777            Entry::Vacant(entry) => entry.insert(Recording::new()),
1778        };
1779        self.active_recording = Some(RecordingState::Defining {
1780            recording,
1781            defined_borrows: HashSet::new(),
1782        });
1783        Ok(())
1784    }
1785
1786    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1787        match self.active_recording {
1788            Some(RecordingState::Defining {
1789                recording: active_recording,
1790                ref defined_borrows,
1791            }) if active_recording == recording => {
1792                ensure!(
1793                    defined_borrows.is_empty(),
1794                    "all borrows created within recording must be dropped within recording"
1795                );
1796                self.active_recording = None;
1797            }
1798            _ => bail!("cannot finalize recording that isn't active"),
1799        }
1800        Ok(())
1801    }
1802
1803    async fn recording_formal(
1804        &mut self,
1805        _cx: &Context<Self>,
1806        result: Ref,
1807        argument_index: usize,
1808    ) -> Result<()> {
1809        match self.get_defining_recording() {
1810            Some((recording, _)) => {
1811                recording.messages.push(StreamMessage::RecordingFormal {
1812                    result,
1813                    argument_index,
1814                });
1815            }
1816            None => bail!("recording_formal called outside of recording"),
1817        };
1818        Ok(())
1819    }
1820
1821    async fn recording_result(
1822        &mut self,
1823        _cx: &Context<Self>,
1824        result: Ref,
1825        output_index: usize,
1826    ) -> Result<()> {
1827        match self.get_defining_recording() {
1828            Some((recording, _)) => {
1829                recording.messages.push(StreamMessage::RecordingResult {
1830                    result,
1831                    output_index,
1832                });
1833            }
1834            None => bail!("recording_result called outside of recording"),
1835        };
1836        Ok(())
1837    }
1838
1839    async fn call_recording(
1840        &mut self,
1841        cx: &Context<Self>,
1842        seq: Seq,
1843        recording: Ref,
1844        results: Vec<Ref>,
1845        actuals: Vec<Ref>,
1846    ) -> Result<()> {
1847        if self.active_recording.is_some() {
1848            bail!("cannot call recording while another recording is active");
1849        }
1850
1851        let messages = match self.recordings.get(&recording) {
1852            Some(recording) => recording
1853                .messages
1854                .iter()
1855                .map(|message| message.clone_for_recording())
1856                .collect::<Vec<_>>(),
1857            None => bail!("recording {:?} not found", recording),
1858        };
1859
1860        self.active_recording = Some(RecordingState::Running);
1861
1862        // Global error for all messages in the recording. The first time a message
1863        // fails in the recording, we set the error. We then need to propagate this
1864        // error to all of the refs mutated by the entire recording, as well as the
1865        // result refs.
1866        let mut error: Option<Arc<SeqError>> = None;
1867        // The set of all refs defined by this recording (excluding "results"),
1868        // which we need to ensure are deleted when the recording is done executing.
1869        let mut all_defined_refs = HashSet::new();
1870        // The set of all refs mutated by this recording. If there is an error with
1871        // any message, all of these refs need to have the correct error set.
1872        let mut all_mutated_refs = HashSet::new();
1873        // Map from the result ref of a RecordingFormal message to the associated
1874        // actual ref from "actuals". We need to track this in order to properly
1875        // handle recordings that mutate refs contained in "actuals" -- every
1876        // message in the recording that interacts with the recording inputs will
1877        // interact with the formal ref rather than the actual ref.
1878        let mut formal_to_actual_refs = HashMap::new();
1879        // clear any pre-existing error messages before recording started
1880        self.last_seq_error = None;
1881        for message in messages.into_iter() {
1882            let defined_refs = message.get_defined_refs();
1883            all_defined_refs.extend(defined_refs.clone());
1884
1885            let mutated_refs_with_formals = message.get_mutated_refs();
1886            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1887                match formal_to_actual_refs.get(ref_) {
1888                    Some(actual_ref) => Some(*actual_ref),
1889                    None => {
1890                        if all_defined_refs.contains(ref_) {
1891                            None
1892                        } else {
1893                            Some(*ref_)
1894                        }
1895                    }
1896                }
1897            }));
1898
1899            match message {
1900                StreamMessage::RecordingFormal {
1901                    result: formal_ref,
1902                    argument_index,
1903                } => match actuals.get(argument_index) {
1904                    None => bail!("recording_formal called with too few arguments"),
1905                    Some(actual_ref) => {
1906                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1907                        self.define_ref(formal_ref, *actual_ref)?;
1908                    }
1909                },
1910                StreamMessage::RecordingResult {
1911                    result: result_ref,
1912                    output_index,
1913                } => match results.get(output_index) {
1914                    None => bail!("recording_result called with too few results"),
1915                    Some(actual_result_ref) => {
1916                        self.define_ref(*actual_result_ref, result_ref)?;
1917                    }
1918                },
1919                StreamMessage::DeleteRefs(ref refs) => {
1920                    for ref_ in refs {
1921                        all_defined_refs.remove(ref_);
1922                    }
1923                    StreamMessageHandler::handle(self, cx, message).await?;
1924                }
1925                StreamMessage::CallFunction { .. } if error.is_some() => {
1926                    // CallFunction is expensive. If the recording already failed, then
1927                    // just update the necessary refs with the error. Most of the other
1928                    // message types need to run regardless because there are other actors
1929                    // that expect the call to happen (e.g., all of the borrow messages,
1930                    // pipe send/recv, send_tensor, reduce, etc.).
1931                    let error = error.clone().unwrap();
1932                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1933                        self.env.insert(*ref_, Err(error.clone()));
1934                    }
1935                }
1936                StreamMessage::BorrowLastUse { ref result, .. } => {
1937                    all_defined_refs.remove(result);
1938                    StreamMessageHandler::handle(self, cx, message).await?;
1939                }
1940                StreamMessage::Reduce {
1941                    local_tensor,
1942                    ref out,
1943                    ..
1944                } => {
1945                    // Reduce doesn't propagate errors to the result ref, so we need
1946                    // to check for existing errors on the input tensors and set the
1947                    // recording's error if necessary.
1948                    if error.is_none() {
1949                        let inputs_to_check = [Some(local_tensor), out.clone()]
1950                            .iter()
1951                            .filter_map(|r| *r)
1952                            .collect::<Vec<_>>();
1953                        error = self.get_first_error(inputs_to_check.as_slice())?;
1954                    }
1955                    StreamMessageHandler::handle(self, cx, message).await?;
1956                }
1957                StreamMessage::SendTensor {
1958                    ref tensor,
1959                    ref to_rank,
1960                    ..
1961                } => {
1962                    // If this rank is sending a tensor (e.g., to_rank has a value),
1963                    // we need to check for existing errors on the input tensor, because
1964                    // the error is only propagated to the result ref when this rank
1965                    // is also receiving a tensor.
1966                    if to_rank.is_some() && error.is_none() {
1967                        error = self.get_first_error(&[*tensor])?;
1968                    }
1969                    StreamMessageHandler::handle(self, cx, message).await?;
1970                }
1971                _ => {
1972                    StreamMessageHandler::handle(self, cx, message).await?;
1973                }
1974            };
1975
1976            // It's not entirely trivial to determine whether a message "failed" or not.
1977            // For example, the CallFunction message can return Ok(..) if there is an error
1978            // in the underlying function call. But in that case, we would still want to
1979            // consider the recording call as "failed". Unlike in python, where we can just
1980            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1981            // we clear before handling each recording message. If we see it is set, the
1982            // we know the recording has faild.
1983            match (&error, self.last_seq_error.take()) {
1984                (None, Some(seq_err)) => {
1985                    // Report failure to the controller.
1986                    self.controller_actor
1987                        .remote_function_failed(
1988                            cx,
1989                            seq,
1990                            WorkerError {
1991                                backtrace: format!("recording failed: {}", &seq_err),
1992                                worker_actor_id: cx.self_id().clone(),
1993                            },
1994                        )
1995                        .await?;
1996                    error = Some(seq_err)
1997                }
1998                _ => {}
1999            }
2000            // Continue processing the remaining stream messages regardless of error.
2001            // We need to do this partially for error propagation, but also because
2002            // certain messages (like borrows and reductions) need to run regardless
2003            // in order to prevent deadlocks.
2004        }
2005
2006        // Delete the formal refs and some subset of the RecordingResult refs. The
2007        // controller should have generated DeleteRefs messages for all other refs
2008        // defined by the recording.
2009        StreamMessageHandler::handle(
2010            self,
2011            cx,
2012            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
2013        )
2014        .await?;
2015
2016        // Any refs mutated by the recording and all results should have the same error
2017        // (the original error that caused the recording to fail).
2018        if error.is_some() {
2019            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
2020                self.env.insert(*ref_, Err(error.clone().unwrap()));
2021            }
2022        }
2023
2024        self.active_recording = None;
2025        Ok(())
2026    }
2027
2028    async fn set_ref_unit_tests_only(
2029        &mut self,
2030        _cx: &Context<Self>,
2031        reference: Ref,
2032        value: WireValue,
2033    ) -> Result<()> {
2034        self.env
2035            .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
2036        Ok(())
2037    }
2038
2039    async fn set_tensor_ref_unit_tests_only(
2040        &mut self,
2041        _cx: &Context<Self>,
2042        reference: Ref,
2043        tensor_result: TensorCellResult,
2044    ) -> Result<()> {
2045        match tensor_result {
2046            Ok(tensor_cell) => {
2047                self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
2048            }
2049            Err(err) => {
2050                self.env.insert(reference, Err(err));
2051            }
2052        }
2053        Ok(())
2054    }
2055
2056    async fn get_ref_unit_tests_only(
2057        &mut self,
2058        _cx: &Context<Self>,
2059        reference: Ref,
2060    ) -> Result<Option<Result<WireValue, String>>> {
2061        /// For testing only, doesn't support Tensor or TensorList.
2062        fn rvalue_to_wire(
2063            value: Result<RValue, Arc<SeqError>>,
2064        ) -> Result<WireValue, Arc<SeqError>> {
2065            Ok(match value? {
2066                RValue::Int(val) => WireValue::Int(val),
2067                RValue::IntList(val) => WireValue::IntList(val),
2068                RValue::Double(val) => WireValue::Double(val),
2069                RValue::Bool(val) => WireValue::Bool(val),
2070                RValue::String(val) => WireValue::String(val),
2071                RValue::Layout(val) => WireValue::Layout(val),
2072                RValue::Device(val) => WireValue::Device(val),
2073                RValue::ScalarType(val) => WireValue::ScalarType(val),
2074                RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2075                RValue::None => WireValue::None(()),
2076                other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2077            })
2078        }
2079        Ok(self
2080            .env
2081            .get(&reference)
2082            .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2083    }
2084
2085    async fn get_tensor_ref_unit_tests_only(
2086        &mut self,
2087        _cx: &Context<Self>,
2088        reference: Ref,
2089    ) -> Result<Option<TensorCellResult>> {
2090        match self.env.get(&reference) {
2091            Some(Ok(rvalue)) => match rvalue {
2092                RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2093                other => bail!("expected tensor, got {:?}", other),
2094            },
2095            Some(Err(err)) => Ok(Some(Err(err.clone()))),
2096            None => Ok(None),
2097        }
2098    }
2099}
2100
2101#[cfg(test)]
2102mod tests {
2103    use hyperactor::actor::ActorStatus;
2104    use hyperactor::cap;
2105    use hyperactor::supervision::ActorSupervisionEvent;
2106    use monarch_messages::controller::ControllerMessage;
2107    use monarch_messages::worker::StreamCreationMode;
2108    use monarch_types::PickledPyObject;
2109    use pyo3::IntoPyObjectExt;
2110    use timed_test::async_timed_test;
2111    use torch_sys::factory_float_tensor;
2112    use torch_sys::testing::allclose;
2113    use torch_sys_cuda::nccl::UniqueId;
2114
2115    use super::*;
2116    use crate::comm::CommParams;
2117    use crate::test_util;
2118
2119    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2120        Arc::new(SeqError {
2121            seq: 0.into(),
2122            error: err,
2123        })
2124    }
2125
2126    struct TestSetup {
2127        proc: Proc,
2128        stream_actor: ActorHandle<StreamActor>,
2129        client: Mailbox,
2130        // Unused, but necessary, because proc needs a supervision
2131        // port -- otherwise an actor failure will cause a crash.
2132        #[allow(dead_code)]
2133        supervision_rx: PortReceiver<ActorSupervisionEvent>,
2134        controller_rx: PortReceiver<ControllerMessage>,
2135        controller_actor: ActorRef<ControllerActor>,
2136        next_ref: Ref,
2137    }
2138
2139    impl TestSetup {
2140        async fn new() -> Result<Self> {
2141            Self::new_with_world_size(1).await
2142        }
2143
2144        async fn new_with_world_size(world_size: usize) -> Result<Self> {
2145            test_util::test_setup()?;
2146
2147            let proc = Proc::local();
2148            let (_, controller_actor, controller_rx) =
2149                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2150            let client = proc.attach("client")?;
2151            let (supervision_tx, supervision_rx) = client.open_port();
2152            proc.set_supervision_coordinator(supervision_tx)?;
2153            let stream_actor = proc
2154                .spawn::<StreamActor>(
2155                    "stream",
2156                    StreamParams {
2157                        world_size,
2158                        rank: 0,
2159                        creation_mode: StreamCreationMode::UseDefaultStream,
2160                        id: 0.into(),
2161                        device: Some(CudaDevice::new(0.into())),
2162                        controller_actor: controller_actor.clone(),
2163                        respond_with_python_message: false,
2164                    },
2165                )
2166                .await?;
2167
2168            Ok(Self {
2169                proc,
2170                stream_actor,
2171                client,
2172                supervision_rx,
2173                controller_rx,
2174                controller_actor,
2175                next_ref: 0.into(),
2176            })
2177        }
2178
2179        fn next_ref(&mut self) -> Ref {
2180            let ref_ = self.next_ref;
2181            self.next_ref = Ref {
2182                id: self.next_ref.id + 1,
2183            };
2184            ref_
2185        }
2186
2187        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2188            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2189            self.stream_actor
2190                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2191                .await
2192        }
2193
2194        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2195            let actual = self
2196                .stream_actor
2197                .get_tensor_ref_unit_tests_only(&self.client, reference)
2198                .await
2199                .unwrap()
2200                .unwrap()
2201                .unwrap();
2202            let x = allclose(
2203                &factory_float_tensor(data, "cpu".try_into().unwrap()),
2204                &actual.borrow(),
2205            )
2206            .unwrap();
2207            x
2208        }
2209
2210        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2211            let result_error = self
2212                .stream_actor
2213                .get_tensor_ref_unit_tests_only(&self.client, reference)
2214                .await
2215                .unwrap()
2216                .unwrap()
2217                .unwrap_err();
2218
2219            assert!(Arc::ptr_eq(&result_error, &error));
2220        }
2221    }
2222
2223    async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2224        loop {
2225            let status = proc
2226                .ledger_snapshot()
2227                .roots
2228                .get(actor_id)
2229                .unwrap()
2230                .status
2231                .clone();
2232            if let ActorStatus::Failed(msg) = status {
2233                assert!(msg.contains(&expected_msg));
2234                break;
2235            } else {
2236                tokio::task::yield_now().await;
2237            }
2238        }
2239    }
2240
2241    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2242        for ref_ in refs {
2243            assert!(
2244                test_setup
2245                    .stream_actor
2246                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2247                    .await
2248                    .unwrap()
2249                    .is_none()
2250            );
2251        }
2252    }
2253
2254    async fn fetch_result(
2255        caps: &impl cap::CanSend,
2256        stream_actor: ActorHandle<StreamActor>,
2257        seq: Seq,
2258        reference: Ref,
2259    ) {
2260        let ref_to_send = Python::with_gil(|py| {
2261            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2262        });
2263
2264        stream_actor
2265            .send_value(
2266                caps,
2267                seq,
2268                stream_actor.actor_id().clone(),
2269                Vec::new(),
2270                None,
2271                vec![WireValue::PyObject(ref_to_send)],
2272                HashMap::new(),
2273                HashMap::new(),
2274                None,
2275            )
2276            .await
2277            .unwrap()
2278    }
2279
2280    async fn check_fetch_result_error(
2281        caps: &impl cap::CanSend,
2282        stream_actor: ActorHandle<StreamActor>,
2283        seq: Seq,
2284        reference: Ref,
2285        controller_rx: &mut PortReceiver<ControllerMessage>,
2286        expected_backtrace: &str,
2287    ) {
2288        fetch_result(caps, stream_actor, seq, reference).await;
2289
2290        let controller_msg = controller_rx.recv().await.unwrap();
2291        match controller_msg {
2292            ControllerMessage::FetchResult {
2293                seq: actual_seq,
2294                value: Err(err),
2295            } => {
2296                assert_eq!(actual_seq, seq);
2297                assert!(
2298                    err.backtrace.contains(expected_backtrace),
2299                    "backtrace did not contain {:?}: {:?}",
2300                    expected_backtrace,
2301                    err.backtrace
2302                );
2303            }
2304            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2305        };
2306    }
2307
2308    async fn check_fetch_result_value(
2309        caps: &impl cap::CanSend,
2310        stream_actor: ActorHandle<StreamActor>,
2311        seq: Seq,
2312        reference: Ref,
2313        controller_rx: &mut PortReceiver<ControllerMessage>,
2314    ) {
2315        fetch_result(caps, stream_actor, seq, reference).await;
2316
2317        let controller_msg = controller_rx.recv().await.unwrap();
2318        match controller_msg {
2319            ControllerMessage::FetchResult {
2320                value: Ok(_),
2321                seq: actual_seq,
2322            } => assert_eq!(seq, actual_seq),
2323            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2324        };
2325    }
2326
2327    #[async_timed_test(timeout_secs = 60)]
2328    async fn test_define_recording_other_recording_active() -> Result<()> {
2329        let test_setup = TestSetup::new().await?;
2330        test_setup
2331            .stream_actor
2332            .define_recording(&test_setup.client, 0.into())
2333            .await?;
2334        test_setup
2335            .stream_actor
2336            .define_recording(&test_setup.client, 1.into())
2337            .await?;
2338        assert_actor_failed_with_msg(
2339            &test_setup.proc,
2340            test_setup.stream_actor.actor_id(),
2341            "different recording already active".into(),
2342        )
2343        .await;
2344        Ok(())
2345    }
2346
2347    #[async_timed_test(timeout_secs = 60)]
2348    async fn test_define_recording_already_defined() -> Result<()> {
2349        let test_setup = TestSetup::new().await?;
2350        test_setup
2351            .stream_actor
2352            .define_recording(&test_setup.client, 0.into())
2353            .await?;
2354        test_setup
2355            .stream_actor
2356            .finalize_recording(&test_setup.client, 0.into())
2357            .await?;
2358        test_setup
2359            .stream_actor
2360            .define_recording(&test_setup.client, 0.into())
2361            .await?;
2362        assert_actor_failed_with_msg(
2363            &test_setup.proc,
2364            test_setup.stream_actor.actor_id(),
2365            "already defined".into(),
2366        )
2367        .await;
2368        Ok(())
2369    }
2370
2371    #[async_timed_test(timeout_secs = 60)]
2372    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2373        let test_setup = TestSetup::new().await?;
2374        test_setup
2375            .stream_actor
2376            .define_recording(&test_setup.client, 0.into())
2377            .await?;
2378        test_setup
2379            .stream_actor
2380            .finalize_recording(&test_setup.client, 1.into())
2381            .await?;
2382        assert_actor_failed_with_msg(
2383            &test_setup.proc,
2384            test_setup.stream_actor.actor_id(),
2385            "cannot finalize recording that isn't active".into(),
2386        )
2387        .await;
2388        Ok(())
2389    }
2390
2391    #[async_timed_test(timeout_secs = 60)]
2392    async fn test_recording_formal_outside_recording() -> Result<()> {
2393        let test_setup = TestSetup::new().await?;
2394        test_setup
2395            .stream_actor
2396            .recording_formal(&test_setup.client, 0.into(), 0)
2397            .await?;
2398        assert_actor_failed_with_msg(
2399            &test_setup.proc,
2400            test_setup.stream_actor.actor_id(),
2401            "recording_formal called outside of recording".into(),
2402        )
2403        .await;
2404        Ok(())
2405    }
2406
2407    #[async_timed_test(timeout_secs = 60)]
2408    async fn test_recording_result_outside_recording() -> Result<()> {
2409        let test_setup = TestSetup::new().await?;
2410        test_setup
2411            .stream_actor
2412            .recording_result(&test_setup.client, 0.into(), 0)
2413            .await?;
2414        assert_actor_failed_with_msg(
2415            &test_setup.proc,
2416            test_setup.stream_actor.actor_id(),
2417            "recording_result called outside of recording".into(),
2418        )
2419        .await;
2420        Ok(())
2421    }
2422
2423    #[async_timed_test(timeout_secs = 60)]
2424    async fn test_call_recording_other_recording_active() -> Result<()> {
2425        let test_setup = TestSetup::new().await?;
2426        test_setup
2427            .stream_actor
2428            .define_recording(&test_setup.client, 0.into())
2429            .await?;
2430        test_setup
2431            .stream_actor
2432            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2433            .await?;
2434        assert_actor_failed_with_msg(
2435            &test_setup.proc,
2436            test_setup.stream_actor.actor_id(),
2437            "cannot call recording while another recording is active".into(),
2438        )
2439        .await;
2440        Ok(())
2441    }
2442
2443    #[async_timed_test(timeout_secs = 60)]
2444    async fn test_call_recording_not_found() -> Result<()> {
2445        let test_setup = TestSetup::new().await?;
2446        test_setup
2447            .stream_actor
2448            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2449            .await?;
2450        assert_actor_failed_with_msg(
2451            &test_setup.proc,
2452            test_setup.stream_actor.actor_id(),
2453            "not found".into(),
2454        )
2455        .await;
2456        Ok(())
2457    }
2458
2459    #[async_timed_test(timeout_secs = 60)]
2460    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2461        let test_setup = TestSetup::new().await?;
2462
2463        test_setup
2464            .stream_actor
2465            .define_recording(&test_setup.client, 0.into())
2466            .await?;
2467
2468        test_setup
2469            .stream_actor
2470            .recording_formal(&test_setup.client, 1.into(), 0)
2471            .await?;
2472
2473        test_setup
2474            .stream_actor
2475            .finalize_recording(&test_setup.client, 0.into())
2476            .await?;
2477
2478        test_setup
2479            .stream_actor
2480            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2481            .await?;
2482
2483        assert_actor_failed_with_msg(
2484            &test_setup.proc,
2485            test_setup.stream_actor.actor_id(),
2486            "recording_formal called with too few arguments".into(),
2487        )
2488        .await;
2489        Ok(())
2490    }
2491
2492    #[async_timed_test(timeout_secs = 60)]
2493    async fn test_recording_result_too_few_results() -> Result<()> {
2494        let test_setup = TestSetup::new().await?;
2495
2496        test_setup
2497            .stream_actor
2498            .define_recording(&test_setup.client, 0.into())
2499            .await?;
2500
2501        test_setup
2502            .stream_actor
2503            .recording_result(&test_setup.client, 1.into(), 0)
2504            .await?;
2505
2506        test_setup
2507            .stream_actor
2508            .finalize_recording(&test_setup.client, 0.into())
2509            .await?;
2510
2511        test_setup
2512            .stream_actor
2513            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2514            .await?;
2515
2516        assert_actor_failed_with_msg(
2517            &test_setup.proc,
2518            test_setup.stream_actor.actor_id(),
2519            "recording_result called with too few results".into(),
2520        )
2521        .await;
2522        Ok(())
2523    }
2524
2525    #[async_timed_test(timeout_secs = 60)]
2526    async fn test_basic_call_recording() -> Result<()> {
2527        let mut test_setup = TestSetup::new().await?;
2528
2529        // Define a recording equivalent to:
2530        // def f(x, y):
2531        //   return y, x
2532        test_setup
2533            .stream_actor
2534            .define_recording(&test_setup.client, 0.into())
2535            .await?;
2536
2537        let formal0_ref = 1.into();
2538        let formal0_index = 1;
2539        test_setup
2540            .stream_actor
2541            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2542            .await?;
2543
2544        let formal1_ref = 2.into();
2545        let formal1_index = 0;
2546        test_setup
2547            .stream_actor
2548            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2549            .await?;
2550
2551        let result0_ref = formal0_ref;
2552        let result0_index = 0;
2553        test_setup
2554            .stream_actor
2555            .recording_result(&test_setup.client, result0_ref, result0_index)
2556            .await?;
2557
2558        let result1_ref = formal1_ref;
2559        let result1_index = 1;
2560        test_setup
2561            .stream_actor
2562            .recording_result(&test_setup.client, result1_ref, result1_index)
2563            .await?;
2564
2565        test_setup
2566            .stream_actor
2567            .finalize_recording(&test_setup.client, 0.into())
2568            .await?;
2569
2570        let actual0_ref = 3.into();
2571        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2572
2573        let actual1_ref = 4.into();
2574        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2575
2576        // Call the recording with valid tensors for the actual inputs,
2577        // and store the results in refs 5 and 6.
2578        let actual_result0_ref = 5.into();
2579        let actual_result1_ref = 6.into();
2580        test_setup
2581            .stream_actor
2582            .call_recording(
2583                &test_setup.client,
2584                0.into(),
2585                0.into(),
2586                vec![actual_result0_ref, actual_result1_ref],
2587                vec![actual0_ref, actual1_ref],
2588            )
2589            .await?;
2590
2591        // Ensure the results are correct.
2592        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2593        assert!(
2594            test_setup
2595                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2596                .await
2597        );
2598
2599        // Ensure the temporary refs associated with the formals/results have
2600        // been deleted.
2601        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2602        Ok(())
2603    }
2604
2605    #[async_timed_test(timeout_secs = 60)]
2606    async fn test_request_status_in_recording() -> Result<()> {
2607        let test_setup = TestSetup::new().await?;
2608        test_setup
2609            .stream_actor
2610            .define_recording(&test_setup.client, 0.into())
2611            .await?;
2612        test_setup
2613            .stream_actor
2614            .request_status(&test_setup.client)
2615            .await
2616            .expect_err("request_status should have failed");
2617        assert_actor_failed_with_msg(
2618            &test_setup.proc,
2619            test_setup.stream_actor.actor_id(),
2620            "request_status not allowed in recording".into(),
2621        )
2622        .await;
2623        Ok(())
2624    }
2625
2626    #[async_timed_test(timeout_secs = 60)]
2627    async fn test_init_comm_in_recording() -> Result<()> {
2628        let test_setup = TestSetup::new().await?;
2629        test_setup
2630            .stream_actor
2631            .define_recording(&test_setup.client, 0.into())
2632            .await?;
2633
2634        let dummy_comm = test_setup
2635            .proc
2636            .spawn::<NcclCommActor>(
2637                "comm",
2638                CommParams::New {
2639                    device: CudaDevice::new(0.into()),
2640                    unique_id: UniqueId::new()?,
2641                    world_size: 1,
2642                    rank: 0,
2643                },
2644            )
2645            .await?;
2646
2647        test_setup
2648            .stream_actor
2649            .init_comm(&test_setup.client, dummy_comm)
2650            .await?;
2651        assert_actor_failed_with_msg(
2652            &test_setup.proc,
2653            test_setup.stream_actor.actor_id(),
2654            "init_comm not allowed in recording".into(),
2655        )
2656        .await;
2657        Ok(())
2658    }
2659
2660    #[async_timed_test(timeout_secs = 60)]
2661    async fn test_call_function_in_recording() -> Result<()> {
2662        let mut test_setup = TestSetup::new().await?;
2663
2664        // Define a recording equivalent to:
2665        // def f(x, y):
2666        //   w = x + y
2667        //   nonlocal z
2668        //   z.add_(1.0)
2669        //   return w + z
2670        test_setup
2671            .stream_actor
2672            .define_recording(&test_setup.client, 0.into())
2673            .await?;
2674
2675        let formal0_ref = test_setup.next_ref();
2676        let formal0_index = 0;
2677        test_setup
2678            .stream_actor
2679            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2680            .await?;
2681
2682        let formal1_ref = test_setup.next_ref();
2683        let formal1_index = 1;
2684        test_setup
2685            .stream_actor
2686            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2687            .await?;
2688
2689        let captured_ref = test_setup.next_ref();
2690        let result_captured_ref = test_setup.next_ref();
2691        let add_one_function =
2692            ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2693        let add_tensors_function =
2694            ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2695
2696        let add_result_ref_0 = test_setup.next_ref();
2697        test_setup
2698            .stream_actor
2699            .call_function(
2700                &test_setup.client,
2701                CallFunctionParams {
2702                    seq: 100.into(),
2703                    function: add_tensors_function.clone(),
2704                    args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2705                    kwargs: HashMap::new(),
2706                    results: vec![Some(add_result_ref_0)],
2707                    mutates: vec![],
2708                    stream: 0.into(),
2709                    remote_process_groups: Vec::new(),
2710                },
2711                HashMap::new(),
2712                HashMap::new(),
2713            )
2714            .await?;
2715
2716        test_setup
2717            .stream_actor
2718            .call_function(
2719                &test_setup.client,
2720                CallFunctionParams {
2721                    seq: 101.into(),
2722                    function: add_one_function,
2723                    args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2724                    kwargs: HashMap::new(),
2725                    results: vec![Some(result_captured_ref)],
2726                    mutates: vec![captured_ref],
2727                    stream: 0.into(),
2728                    remote_process_groups: Vec::new(),
2729                },
2730                HashMap::new(),
2731                HashMap::new(),
2732            )
2733            .await?;
2734
2735        let add_result_ref_1 = test_setup.next_ref();
2736        test_setup
2737            .stream_actor
2738            .call_function(
2739                &test_setup.client,
2740                CallFunctionParams {
2741                    seq: 102.into(),
2742                    function: add_tensors_function,
2743                    args: vec![
2744                        WireValue::Ref(add_result_ref_0),
2745                        WireValue::Ref(captured_ref),
2746                    ],
2747                    kwargs: HashMap::new(),
2748                    results: vec![Some(add_result_ref_1)],
2749                    mutates: vec![],
2750                    stream: 0.into(),
2751                    remote_process_groups: Vec::new(),
2752                },
2753                HashMap::new(),
2754                HashMap::new(),
2755            )
2756            .await?;
2757
2758        test_setup
2759            .stream_actor
2760            .recording_result(&test_setup.client, add_result_ref_1, 0)
2761            .await?;
2762
2763        test_setup
2764            .stream_actor
2765            .delete_refs(
2766                &test_setup.client,
2767                vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2768            )
2769            .await?;
2770
2771        test_setup
2772            .stream_actor
2773            .finalize_recording(&test_setup.client, 0.into())
2774            .await?;
2775
2776        let actual0_ref = test_setup.next_ref();
2777        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2778
2779        let actual1_ref = test_setup.next_ref();
2780        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2781
2782        test_setup
2783            .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2784            .await?;
2785
2786        let actual_result_ref = test_setup.next_ref();
2787        test_setup
2788            .stream_actor
2789            .call_recording(
2790                &test_setup.client,
2791                0.into(),
2792                0.into(),
2793                vec![actual_result_ref],
2794                vec![actual0_ref, actual1_ref],
2795            )
2796            .await?;
2797
2798        assert!(
2799            test_setup
2800                .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2801                .await
2802        );
2803
2804        // Set actual1_tensor to a bad shape which will cause the recording to fail.
2805        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2806
2807        let actual_result_ref = test_setup.next_ref();
2808        test_setup
2809            .stream_actor
2810            .call_recording(
2811                &test_setup.client,
2812                1.into(),
2813                0.into(),
2814                vec![actual_result_ref],
2815                vec![actual0_ref, actual1_ref],
2816            )
2817            .await?;
2818
2819        // Both inputs should still be valid.
2820        for ref_ in [actual0_ref, actual1_ref] {
2821            let _ = test_setup
2822                .stream_actor
2823                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2824                .await?
2825                .unwrap()
2826                .unwrap();
2827        }
2828
2829        for ref_ in [captured_ref, actual_result_ref] {
2830            let result_error = test_setup
2831                .stream_actor
2832                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2833                .await?
2834                .unwrap()
2835                .unwrap_err();
2836            // Check that the error contains the expected strings
2837            let error_str = result_error.to_string();
2838            assert!(
2839                error_str.contains("torch operator error"),
2840                "Error should contain 'torch operator failed': {}",
2841                error_str
2842            );
2843        }
2844
2845        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2846        match controller_msg {
2847            ControllerMessage::RemoteFunctionFailed { seq, error } => {
2848                assert_eq!(seq, 1.into());
2849                assert!(
2850                    error.backtrace.contains("torch operator error"),
2851                    "Unexpected WorkerError: {:?}",
2852                    error
2853                );
2854            }
2855            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2856        };
2857
2858        // Reset input tensor to a valid shape.
2859        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2860
2861        // captured_tensor should still have an error, so calling
2862        // the recording should set DependentErrors and not report
2863        // anything to the controller.
2864        let actual_result_ref = test_setup.next_ref();
2865        test_setup
2866            .stream_actor
2867            .call_recording(
2868                &test_setup.client,
2869                2.into(),
2870                0.into(),
2871                vec![actual_result_ref],
2872                vec![actual0_ref, actual1_ref],
2873            )
2874            .await?;
2875
2876        // Both inputs should still be valid.
2877        for ref_ in [actual0_ref, actual1_ref] {
2878            let _ = test_setup
2879                .stream_actor
2880                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2881                .await?
2882                .unwrap()
2883                .unwrap();
2884        }
2885
2886        for ref_ in [captured_ref, actual_result_ref] {
2887            let result_error = test_setup
2888                .stream_actor
2889                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2890                .await?
2891                .unwrap()
2892                .unwrap_err();
2893            // Check that the error contains the expected strings
2894            let error_str = result_error.to_string();
2895            assert!(
2896                error_str.contains("torch operator error"),
2897                "Error should contain input error: {}",
2898                error_str
2899            );
2900        }
2901
2902        // This tests that the DependentError was never reported to the controller.
2903        // If it were reported to the controller, the next message would match
2904        // RemoteFunctionFailed instead of FetchResult.
2905        check_fetch_result_error(
2906            &test_setup.client,
2907            test_setup.stream_actor.clone(),
2908            3.into(),
2909            captured_ref,
2910            &mut test_setup.controller_rx,
2911            "torch operator error",
2912        )
2913        .await;
2914
2915        Ok(())
2916    }
2917
2918    #[async_timed_test(timeout_secs = 60)]
2919    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2920        let mut test_setup = TestSetup::new().await?;
2921        test_setup
2922            .stream_actor
2923            .define_recording(&test_setup.client, 0.into())
2924            .await?;
2925
2926        let borrow_id = 1;
2927        let tensor_ref = test_setup.next_ref();
2928        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2929
2930        test_setup
2931            .stream_actor
2932            .borrow_create(
2933                &test_setup.client,
2934                borrow_id,
2935                tensor_ref,
2936                first_use_sender.clone(),
2937            )
2938            .await?;
2939
2940        test_setup
2941            .stream_actor
2942            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2943            .await?;
2944
2945        assert_actor_failed_with_msg(
2946            &test_setup.proc,
2947            test_setup.stream_actor.actor_id(),
2948            "duplicate borrow create in recording".into(),
2949        )
2950        .await;
2951
2952        Ok(())
2953    }
2954
2955    #[async_timed_test(timeout_secs = 60)]
2956    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2957        let test_setup = TestSetup::new().await?;
2958        test_setup
2959            .stream_actor
2960            .define_recording(&test_setup.client, 0.into())
2961            .await?;
2962
2963        let borrow_id = 1;
2964        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2965
2966        test_setup
2967            .stream_actor
2968            .borrow_drop(
2969                &test_setup.client,
2970                borrow_id,
2971                Arc::new(Mutex::new(last_use_receiver)),
2972            )
2973            .await?;
2974
2975        assert_actor_failed_with_msg(
2976            &test_setup.proc,
2977            test_setup.stream_actor.actor_id(),
2978            "borrow drop for borrow not defined in recording".into(),
2979        )
2980        .await;
2981
2982        Ok(())
2983    }
2984
2985    #[async_timed_test(timeout_secs = 60)]
2986    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2987        let mut test_setup = TestSetup::new().await?;
2988        test_setup
2989            .stream_actor
2990            .define_recording(&test_setup.client, 0.into())
2991            .await?;
2992
2993        let borrow_id = 1;
2994        let tensor_ref = test_setup.next_ref();
2995        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2996
2997        test_setup
2998            .stream_actor
2999            .borrow_create(
3000                &test_setup.client,
3001                borrow_id,
3002                tensor_ref,
3003                first_use_sender.clone(),
3004            )
3005            .await?;
3006
3007        // Attempt to finalize the recording without dropping the borrow
3008        test_setup
3009            .stream_actor
3010            .finalize_recording(&test_setup.client, 0.into())
3011            .await?;
3012
3013        assert_actor_failed_with_msg(
3014            &test_setup.proc,
3015            test_setup.stream_actor.actor_id(),
3016            "all borrows created within recording must be dropped within recording".into(),
3017        )
3018        .await;
3019
3020        Ok(())
3021    }
3022
3023    #[async_timed_test(timeout_secs = 60)]
3024    async fn test_borrow_in_recording() -> Result<()> {
3025        let mut test_setup = TestSetup::new().await?;
3026
3027        let borrower_stream = test_setup
3028            .proc
3029            .spawn::<StreamActor>(
3030                "stream1",
3031                StreamParams {
3032                    world_size: 1,
3033                    rank: 0,
3034                    creation_mode: StreamCreationMode::CreateNewStream,
3035                    id: 1.into(),
3036                    device: Some(CudaDevice::new(0.into())),
3037                    controller_actor: test_setup.controller_actor.clone(),
3038                    respond_with_python_message: false,
3039                },
3040            )
3041            .await?;
3042
3043        let lender_stream = test_setup.stream_actor.clone();
3044
3045        let borrow_id = 1;
3046        let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
3047        let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
3048
3049        // Stream 1: Define a recording that creates a borrow and drops it.
3050        lender_stream
3051            .define_recording(&test_setup.client, 0.into())
3052            .await?;
3053
3054        let formal_ref = test_setup.next_ref();
3055        lender_stream
3056            .recording_formal(&test_setup.client, formal_ref, 0)
3057            .await?;
3058
3059        lender_stream
3060            .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3061            .await?;
3062
3063        lender_stream
3064            .borrow_drop(
3065                &test_setup.client,
3066                borrow_id,
3067                Arc::new(Mutex::new(last_use_receiver)),
3068            )
3069            .await?;
3070
3071        lender_stream
3072            .finalize_recording(&test_setup.client, 0.into())
3073            .await?;
3074
3075        let borrower_tensor_ref = test_setup.next_ref();
3076        let borrower_tensor = TensorCell::new(factory_float_tensor(
3077            &[1.0, 2.0, 3.0],
3078            "cuda".try_into().unwrap(),
3079        ));
3080
3081        borrower_stream
3082            .set_tensor_ref_unit_tests_only(
3083                &test_setup.client,
3084                borrower_tensor_ref,
3085                Ok(borrower_tensor.clone()),
3086            )
3087            .await?;
3088
3089        // Stream 2: Define a recording that uses the borrow from Stream 1.
3090        borrower_stream
3091            .define_recording(&test_setup.client, 0.into())
3092            .await?;
3093
3094        let borrowed_ref = test_setup.next_ref();
3095
3096        borrower_stream
3097            .borrow_first_use(
3098                &test_setup.client,
3099                borrow_id,
3100                borrowed_ref,
3101                Arc::new(Mutex::new(first_use_receiver)),
3102            )
3103            .await?;
3104
3105        let result_ref = test_setup.next_ref();
3106        borrower_stream
3107            .call_function(
3108                &test_setup.client,
3109                CallFunctionParams {
3110                    seq: 100.into(),
3111                    function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3112                    args: vec![
3113                        WireValue::Ref(borrowed_ref),
3114                        WireValue::Ref(borrower_tensor_ref),
3115                    ],
3116                    kwargs: HashMap::new(),
3117                    results: vec![Some(result_ref)],
3118                    mutates: vec![],
3119                    stream: 1.into(),
3120                    remote_process_groups: Vec::new(),
3121                },
3122                HashMap::new(),
3123                HashMap::new(),
3124            )
3125            .await?;
3126
3127        borrower_stream
3128            .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3129            .await?;
3130
3131        borrower_stream
3132            .recording_result(&test_setup.client, result_ref, 0)
3133            .await?;
3134
3135        borrower_stream
3136            .finalize_recording(&test_setup.client, 0.into())
3137            .await?;
3138
3139        // Set up a tensor in the lender stream and call the recording.
3140        let input_tensor_ref = test_setup.next_ref();
3141        test_setup
3142            .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3143            .await?;
3144
3145        let result_tensor_ref = test_setup.next_ref();
3146
3147        let lender_future = lender_stream.call_recording(
3148            &test_setup.client,
3149            0.into(),
3150            0.into(),
3151            vec![],
3152            vec![input_tensor_ref],
3153        );
3154
3155        let borrower_future = borrower_stream.call_recording(
3156            &test_setup.client,
3157            0.into(),
3158            0.into(),
3159            vec![result_tensor_ref],
3160            vec![],
3161        );
3162
3163        tokio::try_join!(lender_future, borrower_future)?;
3164
3165        let result_tensor = borrower_stream
3166            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3167            .await?
3168            .unwrap()
3169            .unwrap();
3170
3171        let expected_tensor = TensorCell::new(factory_float_tensor(
3172            &[5.0, 7.0, 9.0],
3173            "cpu".try_into().unwrap(),
3174        ));
3175        assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3176
3177        // Set borrower_tensor to a tensor with only 2 elements to cause a failure.
3178        let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3179            &[1.0, 2.0],
3180            "cuda".try_into().unwrap(),
3181        ));
3182        borrower_stream
3183            .set_tensor_ref_unit_tests_only(
3184                &test_setup.client,
3185                borrower_tensor_ref,
3186                Ok(invalid_borrower_tensor.clone()),
3187            )
3188            .await?;
3189
3190        // Call the recording again.
3191        let lender_future = lender_stream.call_recording(
3192            &test_setup.client,
3193            1.into(),
3194            0.into(),
3195            vec![],
3196            vec![input_tensor_ref],
3197        );
3198
3199        let borrower_future = borrower_stream.call_recording(
3200            &test_setup.client,
3201            1.into(),
3202            0.into(),
3203            vec![result_tensor_ref],
3204            vec![],
3205        );
3206
3207        tokio::try_join!(lender_future, borrower_future)?;
3208
3209        // Check that the borrower_stream reports the error to the controller.
3210        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3211        match controller_msg {
3212            ControllerMessage::RemoteFunctionFailed { seq, error } => {
3213                assert_eq!(seq, 1.into());
3214                assert!(
3215                    error.backtrace.contains("recording failed"),
3216                    "Unexpected WorkerError: {:?}",
3217                    error
3218                );
3219                assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3220            }
3221            _ => panic!("Unexpected controller message: {:?}", controller_msg),
3222        };
3223
3224        // Check that no error was reported from the lender stream
3225        check_fetch_result_value(
3226            &test_setup.client,
3227            lender_stream.clone(),
3228            2.into(),
3229            input_tensor_ref,
3230            &mut test_setup.controller_rx,
3231        )
3232        .await;
3233
3234        // Set the recording's input tensor to an error.
3235        let input_error = fake_seq_error(anyhow!("input error"));
3236        lender_stream
3237            .set_tensor_ref_unit_tests_only(
3238                &test_setup.client,
3239                input_tensor_ref,
3240                Err(input_error.clone()),
3241            )
3242            .await?;
3243
3244        let lender_future = lender_stream.call_recording(
3245            &test_setup.client,
3246            3.into(),
3247            0.into(),
3248            vec![],
3249            vec![input_tensor_ref],
3250        );
3251
3252        let borrower_future = borrower_stream.call_recording(
3253            &test_setup.client,
3254            3.into(),
3255            0.into(),
3256            vec![result_tensor_ref],
3257            vec![],
3258        );
3259
3260        tokio::try_join!(lender_future, borrower_future)?;
3261
3262        // Verify that borrower_stream sets a CallFunctionError::DependentError on result_tensor_ref.
3263        let result_error = borrower_stream
3264            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3265            .await?
3266            .unwrap()
3267            .unwrap_err();
3268
3269        // Check that the error contains the expected strings
3270        let error_str = result_error.to_string();
3271        assert!(
3272            error_str.contains("input error"),
3273            "Error should contain input error: {}",
3274            error_str
3275        );
3276
3277        // Since we're checking for pointer equality in the original code, we need to ensure
3278        // the error is propagated correctly. We can check that the original error message is contained.
3279        let input_error_str = input_error.to_string();
3280        assert!(
3281            error_str.contains(&input_error_str),
3282            "Error should contain the original error: {}",
3283            error_str
3284        );
3285
3286        // Verify that neither stream sends a failure message to the controller.
3287        check_fetch_result_error(
3288            &test_setup.client,
3289            lender_stream,
3290            4.into(),
3291            input_tensor_ref,
3292            &mut test_setup.controller_rx,
3293            "input error",
3294        )
3295        .await;
3296
3297        // Verify that neither stream sends a failure message to the controller.
3298        check_fetch_result_error(
3299            &test_setup.client,
3300            borrower_stream,
3301            5.into(),
3302            result_tensor_ref,
3303            &mut test_setup.controller_rx,
3304            "input error",
3305        )
3306        .await;
3307
3308        Ok(())
3309    }
3310
3311    #[async_timed_test(timeout_secs = 60)]
3312    async fn test_reduce_in_recording() -> Result<()> {
3313        let mut test_setup = TestSetup::new().await?;
3314        let recording_ref = test_setup.next_ref();
3315
3316        let comm = Arc::new(
3317            test_setup
3318                .proc
3319                .spawn::<NcclCommActor>(
3320                    "comm",
3321                    CommParams::New {
3322                        device: CudaDevice::new(0.into()),
3323                        unique_id: UniqueId::new()?,
3324                        world_size: 1,
3325                        rank: 0,
3326                    },
3327                )
3328                .await?,
3329        );
3330
3331        let factory = Factory {
3332            size: vec![3],
3333            dtype: torch_sys::ScalarType::Float,
3334            layout: torch_sys::Layout::Strided,
3335            device: "cuda".try_into().unwrap(),
3336        };
3337
3338        let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3339
3340        test_setup
3341            .stream_actor
3342            .define_recording(&test_setup.client, recording_ref)
3343            .await?;
3344
3345        let formal_tensor_ref_0 = test_setup.next_ref();
3346        let formal_tensor_ref_1 = test_setup.next_ref();
3347        let formal_tensor_ref_2 = test_setup.next_ref();
3348
3349        test_setup
3350            .stream_actor
3351            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3352            .await?;
3353        test_setup
3354            .stream_actor
3355            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3356            .await?;
3357        test_setup
3358            .stream_actor
3359            .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3360            .await?;
3361
3362        let intermediate_tensor_ref_0 = test_setup.next_ref();
3363
3364        // Handle case with in_place = true.
3365        test_setup
3366            .stream_actor
3367            .reduce(
3368                &test_setup.client,
3369                comm.clone(),
3370                1,
3371                intermediate_tensor_ref_0,
3372                formal_tensor_ref_0,
3373                factory.clone(),
3374                reduction.clone(),
3375                false,
3376                true,
3377                None,
3378            )
3379            .await?;
3380
3381        // Handle case with in_place = false and out = None.
3382        let intermediate_tensor_ref_1 = test_setup.next_ref();
3383        test_setup
3384            .stream_actor
3385            .reduce(
3386                &test_setup.client,
3387                comm.clone(),
3388                1,
3389                intermediate_tensor_ref_1,
3390                formal_tensor_ref_1,
3391                factory.clone(),
3392                reduction.clone(),
3393                false,
3394                false,
3395                None,
3396            )
3397            .await?;
3398
3399        let intermediate_tensor_ref_2 = test_setup.next_ref();
3400
3401        // Third reduce call with out = formal_tensor_ref_2
3402        test_setup
3403            .stream_actor
3404            .reduce(
3405                &test_setup.client,
3406                comm.clone(),
3407                1,
3408                intermediate_tensor_ref_2,
3409                intermediate_tensor_ref_1,
3410                factory.clone(),
3411                reduction.clone(),
3412                false,
3413                false,
3414                Some(formal_tensor_ref_2),
3415            )
3416            .await?;
3417
3418        test_setup
3419            .stream_actor
3420            .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3421            .await?;
3422
3423        test_setup
3424            .stream_actor
3425            .finalize_recording(&test_setup.client, recording_ref)
3426            .await?;
3427
3428        let input_tensor_ref_0 = test_setup.next_ref();
3429        let input_tensor_ref_1 = test_setup.next_ref();
3430        let input_tensor_ref_2 = test_setup.next_ref();
3431
3432        test_setup
3433            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3434            .await?;
3435
3436        test_setup
3437            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3438            .await?;
3439
3440        test_setup
3441            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3442            .await?;
3443
3444        let output_ref = test_setup.next_ref();
3445
3446        test_setup
3447            .stream_actor
3448            .call_recording(
3449                &test_setup.client,
3450                0.into(),
3451                recording_ref,
3452                vec![output_ref],
3453                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3454            )
3455            .await?;
3456
3457        // Validate that input_tensor_ref_0 is unchanged.
3458        assert!(
3459            test_setup
3460                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3461                .await
3462        );
3463        // All the other inputs/outputs should be equal to input 1
3464        for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3465            assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3466        }
3467
3468        // Set an error on input 0
3469        let input_error = fake_seq_error(anyhow!("input error"));
3470        test_setup
3471            .stream_actor
3472            .set_tensor_ref_unit_tests_only(
3473                &test_setup.client,
3474                input_tensor_ref_0,
3475                Err(input_error.clone()),
3476            )
3477            .await?;
3478
3479        test_setup
3480            .stream_actor
3481            .call_recording(
3482                &test_setup.client,
3483                1.into(),
3484                recording_ref,
3485                vec![output_ref],
3486                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3487            )
3488            .await?;
3489
3490        // Verify that input_tensor_ref_0, input_tensor_ref_2, and output_ref have a dependent error.
3491        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3492            test_setup
3493                .validate_dependent_error(ref_, input_error.clone())
3494                .await;
3495        }
3496
3497        // Verify that input_tensor_ref_1 is untouched.
3498        assert!(
3499            test_setup
3500                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3501                .await
3502        );
3503
3504        // Verify that no failure was reported to the controller.
3505        check_fetch_result_value(
3506            &test_setup.client,
3507            test_setup.stream_actor.clone(),
3508            2.into(),
3509            input_tensor_ref_1,
3510            &mut test_setup.controller_rx,
3511        )
3512        .await;
3513
3514        // Reset input tensors 0 and 2 to their original values
3515        test_setup
3516            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3517            .await?;
3518        test_setup
3519            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3520            .await?;
3521
3522        // Set an error on input tensor 1
3523        test_setup
3524            .stream_actor
3525            .set_tensor_ref_unit_tests_only(
3526                &test_setup.client,
3527                input_tensor_ref_1,
3528                Err(input_error.clone()),
3529            )
3530            .await?;
3531
3532        test_setup
3533            .stream_actor
3534            .call_recording(
3535                &test_setup.client,
3536                3.into(),
3537                recording_ref,
3538                vec![output_ref],
3539                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3540            )
3541            .await?;
3542
3543        // Validate that the mutated inputs and the output have a dependent error containing
3544        // the input error
3545        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3546            test_setup
3547                .validate_dependent_error(ref_, input_error.clone())
3548                .await;
3549        }
3550
3551        // Validate that no error was reported to the controller
3552        check_fetch_result_error(
3553            &test_setup.client,
3554            test_setup.stream_actor.clone(),
3555            4.into(),
3556            input_tensor_ref_1,
3557            &mut test_setup.controller_rx,
3558            "input error",
3559        )
3560        .await;
3561
3562        // Reset input tensors 0 and 1 to their original values
3563        test_setup
3564            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3565            .await?;
3566        test_setup
3567            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3568            .await?;
3569
3570        // Set an error on input tensor 2
3571        test_setup
3572            .stream_actor
3573            .set_tensor_ref_unit_tests_only(
3574                &test_setup.client,
3575                input_tensor_ref_2,
3576                Err(input_error.clone()),
3577            )
3578            .await?;
3579
3580        test_setup
3581            .stream_actor
3582            .call_recording(
3583                &test_setup.client,
3584                5.into(),
3585                recording_ref,
3586                vec![output_ref],
3587                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3588            )
3589            .await?;
3590
3591        // Validate that input tensor 1 has its original values
3592        assert!(
3593            test_setup
3594                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3595                .await
3596        );
3597
3598        // Validate that the mutated inputs and the output have a dependent error containing
3599        // the input error
3600        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3601            test_setup
3602                .validate_dependent_error(ref_, input_error.clone())
3603                .await;
3604        }
3605
3606        // Validate that no error was reported to the controller
3607        check_fetch_result_value(
3608            &test_setup.client,
3609            test_setup.stream_actor.clone(),
3610            6.into(),
3611            input_tensor_ref_1,
3612            &mut test_setup.controller_rx,
3613        )
3614        .await;
3615
3616        Ok(())
3617    }
3618
3619    #[async_timed_test(timeout_secs = 60)]
3620    async fn test_send_tensor_in_recording() -> Result<()> {
3621        let mut test_setup = TestSetup::new_with_world_size(2).await?;
3622        let recording_ref = test_setup.next_ref();
3623
3624        let unique_id = UniqueId::new()?;
3625        let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3626            "comm0",
3627            CommParams::New {
3628                device: CudaDevice::new(0.into()),
3629                unique_id: unique_id.clone(),
3630                world_size: 2,
3631                rank: 0,
3632            },
3633        );
3634        let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3635            "comm1",
3636            CommParams::New {
3637                device: CudaDevice::new(1.into()),
3638                unique_id,
3639                world_size: 2,
3640                rank: 1,
3641            },
3642        );
3643        let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3644        let comm0 = Arc::new(comm0);
3645        let comm1 = Arc::new(comm1);
3646
3647        let factory = Factory {
3648            size: vec![3],
3649            dtype: torch_sys::ScalarType::Float,
3650            layout: torch_sys::Layout::Strided,
3651            device: "cuda".try_into().unwrap(),
3652        };
3653
3654        let send_stream = test_setup.stream_actor.clone();
3655        let recv_stream = test_setup
3656            .proc
3657            .spawn::<StreamActor>(
3658                "recv_stream",
3659                StreamParams {
3660                    world_size: 2,
3661                    rank: 1,
3662                    creation_mode: StreamCreationMode::CreateNewStream,
3663                    id: 1.into(),
3664                    device: Some(CudaDevice::new(1.into())),
3665                    controller_actor: test_setup.controller_actor.clone(),
3666                    respond_with_python_message: false,
3667                },
3668            )
3669            .await?;
3670
3671        send_stream
3672            .define_recording(&test_setup.client, recording_ref)
3673            .await?;
3674        recv_stream
3675            .define_recording(&test_setup.client, recording_ref)
3676            .await?;
3677
3678        let formal_tensor_ref_0 = test_setup.next_ref();
3679        let formal_tensor_ref_1 = test_setup.next_ref();
3680
3681        send_stream
3682            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3683            .await?;
3684        send_stream
3685            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3686            .await?;
3687
3688        let _ref = test_setup.next_ref();
3689        send_stream
3690            .send_tensor(
3691                &test_setup.client,
3692                _ref,
3693                None,
3694                Some(1),
3695                formal_tensor_ref_0,
3696                factory.clone(),
3697                comm0.clone(),
3698            )
3699            .await?;
3700
3701        let result_ref_0 = test_setup.next_ref();
3702        let _ref = test_setup.next_ref();
3703        recv_stream
3704            .send_tensor(
3705                &test_setup.client,
3706                result_ref_0,
3707                Some(0),
3708                None,
3709                _ref,
3710                factory.clone(),
3711                comm1,
3712            )
3713            .await?;
3714
3715        let result_ref_1 = test_setup.next_ref();
3716        send_stream
3717            .send_tensor(
3718                &test_setup.client,
3719                result_ref_1,
3720                Some(0),
3721                Some(0),
3722                formal_tensor_ref_1,
3723                factory.clone(),
3724                comm0,
3725            )
3726            .await?;
3727
3728        send_stream
3729            .recording_result(&test_setup.client, result_ref_1, 0)
3730            .await?;
3731        recv_stream
3732            .recording_result(&test_setup.client, result_ref_0, 0)
3733            .await?;
3734
3735        send_stream
3736            .finalize_recording(&test_setup.client, recording_ref)
3737            .await?;
3738        recv_stream
3739            .finalize_recording(&test_setup.client, recording_ref)
3740            .await?;
3741
3742        let input_tensor_ref_0 = test_setup.next_ref();
3743        let input_tensor_ref_1 = test_setup.next_ref();
3744        test_setup
3745            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3746            .await?;
3747        test_setup
3748            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3749            .await?;
3750
3751        let actual_result_ref_0 = test_setup.next_ref();
3752        let actual_result_ref_1 = test_setup.next_ref();
3753        let send_fut = send_stream.call_recording(
3754            &test_setup.client,
3755            0.into(),
3756            recording_ref,
3757            vec![actual_result_ref_1],
3758            vec![input_tensor_ref_0, input_tensor_ref_1],
3759        );
3760        let recv_fut = recv_stream.call_recording(
3761            &test_setup.client,
3762            0.into(),
3763            recording_ref,
3764            vec![actual_result_ref_0],
3765            vec![],
3766        );
3767        tokio::try_join!(send_fut, recv_fut)?;
3768
3769        assert!(
3770            test_setup
3771                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3772                .await
3773        );
3774        assert!(
3775            test_setup
3776                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3777                .await
3778        );
3779        assert!(
3780            test_setup
3781                .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3782                .await
3783        );
3784
3785        let actual_result_0 = recv_stream
3786            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3787            .await
3788            .unwrap()
3789            .unwrap()
3790            .unwrap();
3791        assert!(allclose(
3792            &actual_result_0.borrow(),
3793            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3794        )?);
3795
3796        // Validate that failure wasn't reported to controller.
3797        check_fetch_result_value(
3798            &test_setup.client,
3799            send_stream.clone(),
3800            1.into(),
3801            actual_result_ref_1,
3802            &mut test_setup.controller_rx,
3803        )
3804        .await;
3805        check_fetch_result_value(
3806            &test_setup.client,
3807            recv_stream.clone(),
3808            2.into(),
3809            actual_result_ref_0,
3810            &mut test_setup.controller_rx,
3811        )
3812        .await;
3813
3814        let input_error = fake_seq_error(anyhow!("input error"));
3815        send_stream
3816            .set_tensor_ref_unit_tests_only(
3817                &test_setup.client,
3818                input_tensor_ref_0,
3819                Err(input_error.clone()),
3820            )
3821            .await?;
3822
3823        let send_fut = send_stream.call_recording(
3824            &test_setup.client,
3825            3.into(),
3826            recording_ref,
3827            vec![actual_result_ref_1],
3828            vec![input_tensor_ref_0, input_tensor_ref_1],
3829        );
3830        let recv_fut = recv_stream.call_recording(
3831            &test_setup.client,
3832            3.into(),
3833            recording_ref,
3834            vec![actual_result_ref_0],
3835            vec![],
3836        );
3837        tokio::try_join!(send_fut, recv_fut)?;
3838
3839        // The result on recv_stream should have a value, but it will be garbage.
3840        let _ = recv_stream
3841            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3842            .await
3843            .unwrap()
3844            .unwrap()
3845            .unwrap();
3846
3847        test_setup
3848            .validate_dependent_error(actual_result_ref_1, input_error.clone())
3849            .await;
3850
3851        // Input 1 should be untouched.
3852        assert!(
3853            test_setup
3854                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3855                .await
3856        );
3857
3858        // Validate that failure wasn't reported to controller.
3859        check_fetch_result_error(
3860            &test_setup.client,
3861            send_stream.clone(),
3862            4.into(),
3863            actual_result_ref_1,
3864            &mut test_setup.controller_rx,
3865            "input error",
3866        )
3867        .await;
3868        check_fetch_result_value(
3869            &test_setup.client,
3870            recv_stream.clone(),
3871            5.into(),
3872            actual_result_ref_0,
3873            &mut test_setup.controller_rx,
3874        )
3875        .await;
3876
3877        test_setup
3878            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3879            .await?;
3880        send_stream
3881            .set_tensor_ref_unit_tests_only(
3882                &test_setup.client,
3883                input_tensor_ref_1,
3884                Err(input_error.clone()),
3885            )
3886            .await?;
3887
3888        let send_fut = send_stream.call_recording(
3889            &test_setup.client,
3890            6.into(),
3891            recording_ref,
3892            vec![actual_result_ref_1],
3893            vec![input_tensor_ref_0, input_tensor_ref_1],
3894        );
3895        let recv_fut = recv_stream.call_recording(
3896            &test_setup.client,
3897            6.into(),
3898            recording_ref,
3899            vec![actual_result_ref_0],
3900            vec![],
3901        );
3902        tokio::try_join!(send_fut, recv_fut)?;
3903
3904        let actual_result_0 = recv_stream
3905            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3906            .await
3907            .unwrap()
3908            .unwrap()
3909            .unwrap();
3910        assert!(allclose(
3911            &actual_result_0.borrow(),
3912            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3913        )?);
3914
3915        assert!(
3916            test_setup
3917                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3918                .await
3919        );
3920
3921        test_setup
3922            .validate_dependent_error(actual_result_ref_1, input_error)
3923            .await;
3924
3925        // Validate that failure wasn't reported to controller.
3926        check_fetch_result_error(
3927            &test_setup.client,
3928            send_stream.clone(),
3929            7.into(),
3930            actual_result_ref_1,
3931            &mut test_setup.controller_rx,
3932            "input error",
3933        )
3934        .await;
3935        check_fetch_result_value(
3936            &test_setup.client,
3937            recv_stream.clone(),
3938            8.into(),
3939            actual_result_ref_0,
3940            &mut test_setup.controller_rx,
3941        )
3942        .await;
3943
3944        Ok(())
3945    }
3946
3947    #[async_timed_test(timeout_secs = 60)]
3948    async fn test_set_value_in_recording_valid_pipe() -> Result<()> {
3949        let mut test_setup = TestSetup::new().await?;
3950
3951        let (pipe_tx, mut pipe_rx) = test_setup.client.open_port();
3952
3953        let recording_ref = test_setup.next_ref();
3954        test_setup
3955            .stream_actor
3956            .define_recording(&test_setup.client, recording_ref)
3957            .await?;
3958
3959        let result_ref_0 = test_setup.next_ref();
3960
3961        test_setup
3962            .stream_actor
3963            .set_value(
3964                &test_setup.client,
3965                0.into(),
3966                vec![Some(result_ref_0)],
3967                pipe_tx,
3968            )
3969            .await?;
3970
3971        test_setup
3972            .stream_actor
3973            .recording_result(&test_setup.client, result_ref_0, 0)
3974            .await?;
3975
3976        test_setup
3977            .stream_actor
3978            .finalize_recording(&test_setup.client, recording_ref)
3979            .await?;
3980
3981        let real_result_ref = test_setup.next_ref();
3982        let recording_fut = test_setup.stream_actor.call_recording(
3983            &test_setup.client,
3984            0.into(),
3985            recording_ref,
3986            vec![real_result_ref],
3987            vec![],
3988        );
3989
3990        let pipe_fut = async {
3991            let msg = pipe_rx.recv().await.unwrap();
3992            match msg {
3993                PipeMessage::RecvValue(tx) => {
3994                    tx.send(PyTree::from(RValue::Tensor(TensorCell::new(
3995                        factory_float_tensor(&[1.0, 2.0, 3.0], "cuda".try_into().unwrap()),
3996                    ))))
3997                    .unwrap();
3998                }
3999                _ => panic!("Unexpected message"),
4000            }
4001            Ok(())
4002        };
4003
4004        tokio::try_join!(recording_fut, pipe_fut)?;
4005
4006        assert!(test_setup.allclose(real_result_ref, &[1.0, 2.0, 3.0]).await);
4007
4008        // This will cause the next call to set_value to fail.
4009        drop(pipe_rx);
4010
4011        let real_result_ref = test_setup.next_ref();
4012        test_setup
4013            .stream_actor
4014            .call_recording(
4015                &test_setup.client,
4016                1.into(),
4017                recording_ref,
4018                vec![real_result_ref],
4019                vec![],
4020            )
4021            .await?;
4022
4023        let real_result_err = test_setup
4024            .stream_actor
4025            .get_tensor_ref_unit_tests_only(&test_setup.client, real_result_ref)
4026            .await?
4027            .unwrap()
4028            .unwrap_err();
4029        // Check that the error contains the expected string
4030        let error_str = real_result_err.to_string();
4031        assert!(
4032            error_str.contains("send error"),
4033            "Error should contain 'send error': {}",
4034            error_str
4035        );
4036
4037        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
4038        match controller_msg {
4039            ControllerMessage::RemoteFunctionFailed { seq, error } => {
4040                assert_eq!(seq, 1.into());
4041                assert!(
4042                    error.backtrace.contains("send error"),
4043                    "Unexpected WorkerError: {:?}",
4044                    error
4045                );
4046            }
4047            _ => panic!("Unexpected controller message: {:?}", controller_msg),
4048        };
4049
4050        Ok(())
4051    }
4052}