monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::data::Serialized;
34use hyperactor::forward;
35use hyperactor::mailbox::Mailbox;
36use hyperactor::mailbox::OncePortHandle;
37use hyperactor::mailbox::PortReceiver;
38use hyperactor::proc::Proc;
39use monarch_hyperactor::actor::PythonMessage;
40use monarch_hyperactor::actor::PythonMessageKind;
41use monarch_hyperactor::buffers::FrozenBuffer;
42use monarch_hyperactor::local_state_broker::BrokerId;
43use monarch_hyperactor::local_state_broker::LocalState;
44use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
45use monarch_messages::controller::ControllerMessageClient;
46use monarch_messages::controller::Seq;
47use monarch_messages::controller::WorkerError;
48use monarch_messages::worker::ActorCallParams;
49use monarch_messages::worker::ActorMethodParams;
50use monarch_messages::worker::CallFunctionError;
51use monarch_messages::worker::CallFunctionParams;
52use monarch_messages::worker::SeqError;
53use monarch_messages::worker::StreamRef;
54use monarch_types::PyTree;
55use monarch_types::SerializablePyErr;
56use monarch_types::TryIntoPyObjectUnsafe;
57use pyo3::prelude::*;
58use pyo3::types::PyTuple;
59use tokio::runtime::Handle;
60use tokio::sync::Mutex;
61use tokio::task::JoinHandle;
62use torch_sys::BorrowType;
63use torch_sys::CudaDevice;
64use torch_sys::MultiBorrow;
65use torch_sys::RValue;
66use torch_sys::TensorCell;
67use torch_sys::deep_clone;
68use torch_sys::factory_empty;
69use torch_sys::factory_zeros;
70use torch_sys_cuda::cuda::Event;
71use torch_sys_cuda::cuda::Stream;
72use tracing_subscriber::fmt::Subscriber;
73
74use crate::ControllerActor;
75use crate::DeviceMesh;
76use crate::Factory;
77use crate::Reduction;
78use crate::Ref;
79use crate::ResolvableFunction;
80use crate::StreamCreationMode;
81use crate::WireValue;
82use crate::comm::CommBackend;
83use crate::comm::CommMessage;
84use crate::comm::CommMessageClient;
85use crate::comm::NcclCommActor;
86use crate::pipe::PipeMessage;
87
88pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
89
90// These thread locals are accessed by the python runtime for debugging sessions.
91thread_local! {
92    pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
93    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
94    pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
95}
96
97fn pickle_python_result(
98    py: Python<'_>,
99    result: Bound<'_, PyAny>,
100    worker_actor_id: ActorId,
101) -> Result<PythonMessage, anyhow::Error> {
102    let pickle = py
103        .import("monarch._src.actor.actor_mesh")
104        .unwrap()
105        .getattr("_pickle")
106        .unwrap();
107    let data: FrozenBuffer = pickle
108        .call1((result,))
109        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
110        .extract()
111        .unwrap();
112    Ok(PythonMessage::new_from_buf(
113        PythonMessageKind::Result {
114            rank: Some(worker_actor_id.rank()),
115        },
116        data.inner,
117    ))
118}
119
120#[derive(Debug)]
121struct Recording {
122    messages: Vec<StreamMessage>,
123}
124
125impl Recording {
126    fn new() -> Self {
127        Self {
128            messages: Vec::new(),
129        }
130    }
131}
132
133#[derive(Debug, PartialEq)]
134enum RecordingState {
135    Defining {
136        recording: Ref,
137        // Set of borrow ids used to track proper borrow usage inside
138        // a recording.
139        defined_borrows: HashSet<u64>,
140    },
141    Running,
142}
143
144/// Messages handled by the stream. Generally these are stream-local versions of
145/// [`crate::WorkerMessage`].
146#[derive(Handler, HandleClient, Debug, Named)]
147#[named(register = false)]
148pub enum StreamMessage {
149    CallFunction(
150        CallFunctionParams,
151        HashMap<Ref, DeviceMesh>,
152        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
153    ),
154
155    BorrowCreate {
156        /// Id for the borrow.
157        borrow: u64,
158        /// Tensor to borrow.
159        tensor: Ref,
160        /// Port for sending the first use CUDA event + borrowed tensor to
161        /// the borrower.
162        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
163    },
164
165    BorrowFirstUse {
166        /// Id for the borrow.
167        borrow: u64,
168        /// Ref for storing the borrowed tensor.
169        result: Ref,
170        /// Port for receiving the first use CUDA event + borrowed tensor from
171        /// the provider stream.
172        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
173    },
174
175    BorrowLastUse {
176        /// Id for the borrow.
177        borrow: u64,
178        /// Ref for the borrowed tensor.
179        result: Ref,
180        /// Port for sending the last use CUDA event and borrowed tensor.
181        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
182    },
183
184    BorrowDrop {
185        borrow: u64,
186        /// Port for receiving the last use CUDA event and borrowed tensor.
187        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
188    },
189
190    DeleteRefs(Vec<Ref>),
191
192    RequestStatus(#[reply] OncePortHandle<()>),
193
194    InitComm(ActorHandle<NcclCommActor>),
195
196    Reduce {
197        comm: Arc<ActorHandle<NcclCommActor>>,
198        dim_size: i64,
199        result: Ref,
200        local_tensor: Ref,
201        factory: Factory,
202        reduction: Reduction,
203        scatter: bool,
204        in_place: bool,
205        out: Option<Ref>,
206    },
207
208    SendTensor {
209        result: Ref,
210        from_rank: Option<usize>,
211        to_rank: Option<usize>,
212        tensor: Ref,
213        factory: Factory,
214        comm: Arc<ActorHandle<NcclCommActor>>,
215    },
216
217    SendValue {
218        seq: Seq,
219        worker_actor_id: ActorId,
220        mutates: Vec<Ref>,
221        function: Option<ResolvableFunction>,
222        args: Vec<WireValue>,
223        kwargs: HashMap<String, WireValue>,
224        device_meshes: HashMap<Ref, DeviceMesh>,
225        pipe: Option<PortHandle<PipeMessage>>,
226    },
227
228    SetValue {
229        seq: Seq,
230        results: Vec<Option<Ref>>,
231        pipe: PortHandle<PipeMessage>,
232    },
233
234    DefineRecording {
235        recording: Ref,
236    },
237
238    FinalizeRecording {
239        recording: Ref,
240    },
241
242    CallRecording {
243        seq: Seq,
244        recording: Ref,
245        results: Vec<Ref>,
246        actuals: Vec<Ref>,
247    },
248
249    RecordingFormal {
250        result: Ref,
251        argument_index: usize,
252    },
253
254    RecordingResult {
255        result: Ref,
256        output_index: usize,
257    },
258
259    SetRefUnitTestsOnly(Ref, WireValue),
260
261    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
262
263    GetRefUnitTestsOnly(
264        Ref, // value
265        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
266    ),
267
268    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
269
270    SendResultOfActorCall(ActorId, ActorCallParams),
271    CallActorMethod(ActorMethodParams),
272}
273
274impl StreamMessage {
275    fn clone_for_recording(&self) -> Self {
276        match self {
277            StreamMessage::RecordingFormal {
278                result,
279                argument_index,
280            } => StreamMessage::RecordingFormal {
281                result: *result,
282                argument_index: *argument_index,
283            },
284            StreamMessage::RecordingResult {
285                result,
286                output_index,
287            } => StreamMessage::RecordingResult {
288                result: *result,
289                output_index: *output_index,
290            },
291            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
292            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
293                StreamMessage::CallFunction(
294                    params.clone(),
295                    device_meshes.clone(),
296                    remote_process_groups.clone(),
297                )
298            }
299            StreamMessage::BorrowCreate {
300                borrow,
301                tensor,
302                first_use_sender,
303            } => StreamMessage::BorrowCreate {
304                borrow: *borrow,
305                tensor: *tensor,
306                first_use_sender: first_use_sender.clone(),
307            },
308            StreamMessage::BorrowFirstUse {
309                borrow,
310                result,
311                first_use_receiver,
312            } => StreamMessage::BorrowFirstUse {
313                borrow: *borrow,
314                result: *result,
315                first_use_receiver: first_use_receiver.clone(),
316            },
317            StreamMessage::BorrowLastUse {
318                borrow,
319                result,
320                last_use_sender,
321            } => StreamMessage::BorrowLastUse {
322                borrow: *borrow,
323                result: *result,
324                last_use_sender: last_use_sender.clone(),
325            },
326            StreamMessage::BorrowDrop {
327                borrow,
328                last_use_receiver,
329            } => StreamMessage::BorrowDrop {
330                borrow: *borrow,
331                last_use_receiver: last_use_receiver.clone(),
332            },
333            StreamMessage::Reduce {
334                comm,
335                dim_size,
336                result,
337                local_tensor,
338                factory,
339                reduction,
340                scatter,
341                in_place,
342                out,
343            } => StreamMessage::Reduce {
344                comm: comm.clone(),
345                dim_size: *dim_size,
346                result: *result,
347                local_tensor: *local_tensor,
348                factory: factory.clone(),
349                reduction: reduction.clone(),
350                scatter: *scatter,
351                in_place: *in_place,
352                out: out.clone(),
353            },
354            StreamMessage::SendTensor {
355                result,
356                from_rank,
357                to_rank,
358                tensor,
359                factory,
360                comm,
361            } => StreamMessage::SendTensor {
362                result: *result,
363                from_rank: *from_rank,
364                to_rank: *to_rank,
365                tensor: *tensor,
366                factory: factory.clone(),
367                comm: comm.clone(),
368            },
369            StreamMessage::SetValue { seq, results, pipe } => StreamMessage::SetValue {
370                seq: seq.clone(),
371                results: results.clone(),
372                pipe: pipe.clone(),
373            },
374            other => panic!(
375                "StreamMessage variant not supported in recording: {:?}",
376                other
377            ),
378        }
379    }
380
381    // Get the set of refs that this message defines.
382    fn get_defined_refs(&self) -> HashSet<Ref> {
383        match self {
384            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
385            StreamMessage::CallFunction(params, ..) => {
386                params.results.iter().filter_map(|&ref_| ref_).collect()
387            }
388            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
389            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
390            StreamMessage::SendTensor {
391                result, from_rank, ..
392            } => {
393                if from_rank.is_some() {
394                    HashSet::from([*result])
395                } else {
396                    HashSet::new()
397                }
398            }
399            StreamMessage::SetValue { results, .. } => {
400                results.iter().filter_map(|&ref_| ref_).collect()
401            }
402            // TODO(slurye): Add SendValue eventually.
403            _ => HashSet::new(),
404        }
405    }
406
407    // Get the set of refs that this message mutates.
408    fn get_mutated_refs(&self) -> HashSet<Ref> {
409        match self {
410            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
411            StreamMessage::Reduce {
412                out,
413                in_place,
414                local_tensor,
415                ..
416            } => {
417                if *in_place {
418                    HashSet::from([*local_tensor])
419                } else if let Some(out) = out {
420                    HashSet::from([*out])
421                } else {
422                    HashSet::new()
423                }
424            }
425            // TODO(slurye): Add SendValue eventually.
426            _ => HashSet::new(),
427        }
428    }
429}
430
431/// A stream represents a linear sequence of execution. Operations on different
432/// streams can execute concurrently.
433///
434/// For CUDA operators, streams will invoke the corresponding stream management
435/// APIs to perform synchronization.
436///
437/// For CPU operators, streams will just execute synchronously on their own OS
438/// thread.
439#[derive(Debug)]
440pub struct StreamActor {
441    world_size: usize,
442    rank: usize,
443    /// Mapping of refs in the controller environment to TensorIndex in this
444    /// stream's local environment.
445    // TODO(agallagher): Use `ValueError` as the error type.
446    env: HashMap<Ref, Result<RValue, Arc<SeqError>>>,
447    /// How to create the stream.
448    creation_mode: StreamCreationMode,
449    /// CUDA stream that this actor will enqueue operations on. None if "device"
450    /// is not a CUDA device.
451    /// NOTE: We lazily create the stream, so that we do it from the dedicated
452    /// Stream OS thread as, otherwise, we see deadlocks when done from
453    /// unexpected threads.
454    cuda_stream: OnceLock<Option<Stream>>,
455    /// Device this stream should be scheduled on.
456    device: Option<CudaDevice>,
457    /// Communicator for this stream. Optional as we lazily initialize it.
458    comm: Option<ActorHandle<NcclCommActor>>,
459    /// Actor ref of the controller that created this stream.
460    controller_actor: ActorRef<ControllerActor>,
461    remote_process_groups: HashMap<Ref, PyObject>,
462    recordings: HashMap<Ref, Recording>,
463    active_recording: Option<RecordingState>,
464    respond_with_python_message: bool,
465    last_seq_error: Option<Arc<SeqError>>,
466}
467
468/// Parameters for creating a [`Stream`].
469#[derive(Debug, Clone)]
470pub struct StreamParams {
471    pub world_size: usize,
472    pub rank: usize,
473    /// Controls how the underlying CUDA stream is created.
474    pub creation_mode: StreamCreationMode,
475    /// Id of this stream in the worker actor's stream table.
476    pub id: StreamRef,
477    /// Device this stream should be scheduled on. If none, don't do stream
478    /// synchronization.
479    pub device: Option<CudaDevice>,
480    /// Actor ref of the controller that created this stream.
481    pub controller_actor: ActorRef<ControllerActor>,
482    pub respond_with_python_message: bool,
483}
484
485#[async_trait]
486impl Actor for StreamActor {
487    type Params = StreamParams;
488    async fn new(
489        StreamParams {
490            world_size,
491            rank,
492            id: _,
493            device,
494            controller_actor,
495            creation_mode,
496            respond_with_python_message,
497        }: Self::Params,
498    ) -> Result<Self> {
499        Ok(Self {
500            world_size,
501            rank,
502            env: HashMap::new(),
503            creation_mode,
504            cuda_stream: OnceLock::new(),
505            device,
506            comm: None,
507            controller_actor,
508            remote_process_groups: HashMap::new(),
509            recordings: HashMap::new(),
510            active_recording: None,
511            respond_with_python_message,
512            last_seq_error: None,
513        })
514    }
515
516    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
517        // These thread locals are exposed via python functions, so we need to set them in the
518        // same thread that python will run in. That means we need to initialize them here in
519        // StreamActor::init instead of in StreamActor::new.
520        CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
521            controller_actor_ref.set(self.controller_actor.clone()).ok()
522        });
523        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
524        ROOT_ACTOR_ID.with(|root_actor_id| {
525            root_actor_id
526                .set(ActorId::root(
527                    cx.self_id().proc_id().clone(),
528                    cx.self_id().name().to_string(),
529                ))
530                .ok()
531        });
532        // Set the current stream for this actor thread.
533        if let Some(stream) = self.cuda_stream() {
534            Stream::set_current_stream(stream);
535        }
536        Ok(())
537    }
538
539    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
540    /// dedicated OS thread. This is because:
541    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
542    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
543    ///     threads is problematic.
544    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
545    where
546        F: Future + Send + 'static,
547        F::Output: Send + 'static,
548    {
549        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
550        // It is important that we spawn a standalone thread for the work here,
551        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
552        // This is because the worker stream may call uninterruptible FFI code
553        // that can deadlock (CUDA, NCCL).
554        // If we use a tokio-managed blocking thread, then runtime teardown will
555        // try to wait for tasks on that thread to reach an await point, and
556        // hang forever.
557        let builder = std::thread::Builder::new().name("worker-stream".to_string());
558        let _thread_handle = builder.spawn(move || {
559            // Spawn a new thread with a single-threaded tokio runtime to run the
560            // actor loop.  We avoid the current-threaded runtime, so that we can
561            // use `block_in_place` for nested async-to-sync-to-async flows.
562            let rt = tokio::runtime::Builder::new_multi_thread()
563                .worker_threads(1)
564                .enable_io()
565                .build()
566                .unwrap();
567            let result = rt.block_on(async {
568                tokio::task::block_in_place(|| {
569                    // Allow e.g. destructing py objects on this thread, which
570                    // can happen at shutdown when the a stream actors env map
571                    // for rvalues is dropped (e.g. P1673311499).
572                    // https://github.com/PyO3/pyo3/discussions/3499
573                    Python::with_gil(|py| {
574                        py.allow_threads(|| {
575                            let result = Handle::current().block_on(future);
576                            if join_tx.send(result).is_err() {
577                                panic!("could not send join result")
578                            }
579                        })
580                    })
581                })
582            });
583            rt.shutdown_timeout(Duration::from_weeks(1));
584            result
585        });
586
587        // In order to bridge the synchronous join handle with the async world,
588        // smuggle the result through a channel.
589        tokio::spawn(async move { join_rx.await.unwrap() })
590    }
591}
592
593/// The arguments we accept as inputs to Python function calls.
594#[derive(Debug)]
595enum PyArg<'a> {
596    RValue(RValue),
597    DeviceMesh(&'a DeviceMesh),
598    PyObject(PyObject),
599}
600
601/// Serialize into a `PyObject`.
602impl<'a, 'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg<'a> {
603    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
604        match self {
605            // SAFETY: This inherits the unsafety of `rvalue_to_ivalue` (see comment
606            // above).
607            PyArg::RValue(rval) => unsafe { rval.try_to_object_unsafe(py) },
608            PyArg::DeviceMesh(mesh) => Ok(Py::new(py, (*mesh).clone())?.into_bound(py).into_any()),
609            PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
610        }
611    }
612}
613
614impl StreamActor {
615    fn cuda_stream(&self) -> Option<&Stream> {
616        self.cuda_stream
617            .get_or_init(|| {
618                self.device.map(|device| match self.creation_mode {
619                    StreamCreationMode::UseDefaultStream => {
620                        Stream::get_current_stream_on_device(device)
621                    }
622                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
623                })
624            })
625            .as_ref()
626    }
627
628    fn ref_to_rvalue(&self, ref_: &Ref) -> Result<RValue, CallFunctionError> {
629        let rvalue = self
630            .env
631            .get(ref_)
632            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
633        match rvalue {
634            Ok(val) => Ok(val.clone()),
635            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
636        }
637    }
638
639    fn wire_to_rvalue(&self, value: WireValue) -> Result<RValue, CallFunctionError> {
640        let ret = match value {
641            WireValue::Ref(val) => self.ref_to_rvalue(&val)?,
642            // TODO: We might want to support GenericList / GenericDict etc.
643            WireValue::RefList(val) => {
644                let mut ret = Vec::with_capacity(val.len());
645                for v in val {
646                    match self.ref_to_rvalue(&v) {
647                        Ok(RValue::Tensor(t)) => ret.push(t),
648                        Err(err) => {
649                            return Err(err);
650                        }
651                        Ok(val) => {
652                            return Err(CallFunctionError::UnsupportedArgType(
653                                "wire_to_rvalue".into(),
654                                format!("RefList([{:?}])", val),
655                            ));
656                        }
657                    }
658                }
659                RValue::TensorList(ret)
660            }
661            WireValue::Int(val) => RValue::Int(val),
662            WireValue::IntList(val) => RValue::IntList(val),
663            WireValue::Double(val) => RValue::Double(val),
664            WireValue::Bool(val) => RValue::Bool(val),
665            WireValue::String(val) => RValue::String(val),
666            WireValue::Device(val) => RValue::Device(val),
667            WireValue::Layout(val) => RValue::Layout(val),
668            WireValue::ScalarType(val) => RValue::ScalarType(val),
669            WireValue::MemoryFormat(val) => RValue::MemoryFormat(val),
670            WireValue::PyObject(val) => RValue::PyObject(val),
671            WireValue::None(()) => RValue::None,
672            WireValue::IValue(val) => RValue::Opaque(val.into()),
673        };
674        Ok(ret)
675    }
676
677    async fn report_seq_error(
678        &mut self,
679        cx: &Context<'_, Self>,
680        seq: Seq,
681        error: CallFunctionError,
682    ) -> Result<Arc<SeqError>, anyhow::Error> {
683        match error {
684            CallFunctionError::DependentError(root) => Ok(root),
685            CallFunctionError::Error(e) => {
686                if self.active_recording.is_none() {
687                    let worker_error = WorkerError {
688                        backtrace: format!("{e}"),
689                        worker_actor_id: cx.self_id().clone(),
690                    };
691                    tracing::info!("Propagating remote function error to client: {worker_error}");
692                    self.controller_actor
693                        .remote_function_failed(cx, seq, worker_error)
694                        .await?
695                }
696                let err = Arc::new(SeqError { seq, error: e });
697                self.last_seq_error = Some(err.clone());
698                Ok(err)
699            }
700        }
701    }
702
703    async fn try_define<F>(
704        &mut self,
705        cx: &Context<'_, Self>,
706        seq: Seq,
707        result_refs: Vec<Option<Ref>>,
708        mutates: &Vec<Ref>,
709        f: F,
710    ) -> Result<()>
711    where
712        F: AsyncFnOnce(&mut Self) -> Result<Vec<RValue>, CallFunctionError>,
713    {
714        let actual_results = f(self).await;
715        // Check if the expected number of returns is correct, otherwise convert
716        // into an error.
717        let op_results = actual_results.and_then(|actual_results| {
718            if result_refs.len() == actual_results.len() {
719                Ok(actual_results
720                    .into_iter()
721                    .zip(result_refs.iter())
722                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
723                    .collect::<Vec<(Ref, RValue)>>())
724            } else {
725                Err(CallFunctionError::UnexpectedNumberOfReturns(
726                    result_refs.len(),
727                    actual_results.len(),
728                ))
729            }
730        });
731
732        // Propagate the results (either the actual values or an error) to the
733        // right entries in the global env mapping.
734        match op_results {
735            Ok(op_results) => {
736                for (ref_, rvalue) in op_results.into_iter() {
737                    let prev = self.env.insert(ref_, Ok(rvalue));
738                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
739                }
740            }
741            Err(err) => {
742                let err = self.report_seq_error(cx, seq, err).await?;
743                for ref_ in result_refs {
744                    match ref_ {
745                        Some(ref_) => {
746                            let prev = self.env.insert(ref_, Err(err.clone()));
747                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
748                        }
749                        None => {}
750                    }
751                }
752                for ref_ in mutates {
753                    self.env.insert(*ref_, Err(err.clone()));
754                }
755            }
756        }
757        Ok(())
758    }
759
760    fn call_torch_op(
761        &self,
762        op: String,
763        overload: String,
764        args: Vec<WireValue>,
765        kwargs: HashMap<String, WireValue>,
766    ) -> Result<Vec<RValue>, CallFunctionError> {
767        let args = args
768            .into_iter()
769            .map(|arg| self.wire_to_rvalue(arg))
770            .collect::<Result<Vec<_>, _>>()?;
771        let kwargs = kwargs
772            .into_iter()
773            .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
774            .collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
775
776        let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
777
778        // Handle the case where the op returns nothing and convert it to a list of None.
779        // This is to ensure handle results does not error out as the client will call
780        // such a function with expected results of size 1.
781        Ok(if results.is_empty() {
782            vec![RValue::None]
783        } else {
784            results
785        })
786    }
787
788    fn call_python_fn<'py>(
789        &mut self,
790        py: Python<'py>,
791        cx: &Context<Self>,
792        function: Option<ResolvableFunction>,
793        args: Vec<WireValue>,
794        kwargs: HashMap<String, WireValue>,
795        mutates: &[Ref],
796        device_meshes: HashMap<Ref, DeviceMesh>,
797        remote_process_groups: HashMap<
798            Ref,
799            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
800        >,
801    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
802        let function = function
803            .map(|function| {
804                function.resolve(py).map_err(|e| {
805                    CallFunctionError::InvalidRemoteFunction(format!(
806                        "failed to resolve function {}: {}",
807                        function,
808                        SerializablePyErr::from(py, &e)
809                    ))
810                })
811            })
812            .transpose()?;
813
814        let remote_process_groups = remote_process_groups
815            .into_iter()
816            .map(|(gref, (mesh, dims, comm))| {
817                let group = match self.remote_process_groups.entry(gref) {
818                    Entry::Occupied(ent) => ent.get().clone_ref(py),
819                    Entry::Vacant(ent) => {
820                        // We need to run `init_process_group` before any
821                        // remote process groups can get created.
822                        torch_sys::backend::ensure_init_process_group(
823                            py,
824                            self.world_size,
825                            self.rank,
826                        )?;
827
828                        // Create a backend object to wrap the comm and use
829                        // it to create a new torch group.
830                        let ranks = mesh.get_ranks_for_dim_slice(&dims)?;
831                        let group_size = ranks.len();
832                        let backend = CommBackend::new(
833                            comm,
834                            Mailbox::new_detached(cx.self_id().clone()),
835                            self.rank,
836                            group_size,
837                            self.world_size,
838                        );
839                        ent.insert(torch_sys::backend::new_group(py, ranks, backend)?.unbind())
840                            .clone_ref(py)
841                    }
842                };
843                PyResult::Ok((gref, group))
844            })
845            .collect::<Result<HashMap<_, _>, _>>()
846            .map_err(SerializablePyErr::from_fn(py))?;
847
848        // SAFETY: We will be making an unchecked clone of each tensor to pass to to
849        // C++, so we need to hold a borrow of each input tensor for the duration of
850        // this function.
851        let mut multiborrow = MultiBorrow::new();
852
853        let resolve = |val: WireValue| {
854            val.into_py_object()
855                .map_err(|e| {
856                    CallFunctionError::UnsupportedArgType(
857                        format!("{:?}", function),
858                        format!("{:?}", e),
859                    )
860                })?
861                .unpickle(py)
862                .map_err(SerializablePyErr::from_fn(py))?
863                .extract::<PyTree<PyObject>>()
864                .map_err(SerializablePyErr::from_fn(py))?
865                .try_into_map(|obj| {
866                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
867                        if let Some(mesh) = device_meshes.get(&ref_) {
868                            PyArg::DeviceMesh(mesh)
869                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
870                            PyArg::PyObject(pg.clone_ref(py))
871                        } else {
872                            let rval = self.ref_to_rvalue(&ref_)?;
873                            PyArg::RValue(rval)
874                        }
875                    } else {
876                        PyArg::PyObject(obj)
877                    })
878                })
879        };
880
881        // Resolve refs
882        let py_args: Vec<PyTree<PyArg>> = args
883            .into_iter()
884            .map(resolve)
885            .collect::<Result<_, CallFunctionError>>()?;
886        let py_kwargs: HashMap<_, PyTree<PyArg>> = kwargs
887            .into_iter()
888            .map(|(k, object)| Ok((k, resolve(object)?)))
889            .collect::<Result<_, CallFunctionError>>()?;
890
891        // Add a shared-borrow for each rvalue reference.
892        py_args
893            .iter()
894            .chain(py_kwargs.values())
895            .flat_map(|o| o.iter())
896            .for_each(|arg| {
897                if let PyArg::RValue(rval) = arg {
898                    multiborrow.add(rval, BorrowType::Shared);
899                }
900            });
901
902        // Add mutable borrows for params we're mutating.
903        let mutates: Vec<_> = mutates
904            .iter()
905            .map(|r| self.ref_to_rvalue(r))
906            .collect::<Result<_, CallFunctionError>>()?;
907        mutates
908            .iter()
909            .for_each(|rval| multiborrow.add(rval, BorrowType::Mutable));
910
911        // Execute the borrow.
912        let _borrow = multiborrow.borrow()?;
913
914        // Call function.
915        // Use custom subscriber to route Worker messages to stdout.
916        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
917        let result: Bound<'_, PyAny> =
918            tracing::subscriber::with_default(scoped_subscriber, || {
919                // SAFETY: The borrows above guard the unchecked clones done by
920                // `rvalue_to_ivalue`. This may result in multiple mutable
921                // references to tensor data, but the Python side is responsible
922                // for making sure that is safe
923                // TODO(agallagher): The args/kwargs conversion traits generate
924                // the appropriate types here, but they get casted to `PyAny`.
925                // It'd be nice to make `TryToPyObjectUnsafe` take a template
926                // arg for the converted py object to avoid this downcast.
927                let args = unsafe { py_args.try_to_object_unsafe(py) }
928                    .map_err(SerializablePyErr::from_fn(py))?;
929                // SAFETY: above
930                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
931                    .map_err(SerializablePyErr::from_fn(py))?;
932
933                if let Some(function) = function {
934                    function
935                        .call(args, Some(kwargs))
936                        .map_err(SerializablePyErr::from_fn(py))
937                } else {
938                    Ok(args.get_item(0).unwrap())
939                }
940            })?;
941        Ok(result)
942    }
943
944    fn call_python_fn_pytree(
945        &mut self,
946        cx: &hyperactor::Context<Self>,
947        function: ResolvableFunction,
948        args: Vec<WireValue>,
949        kwargs: HashMap<String, WireValue>,
950        mutates: &[Ref],
951        device_meshes: HashMap<Ref, DeviceMesh>,
952        remote_process_groups: HashMap<
953            Ref,
954            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
955        >,
956    ) -> Result<PyTree<RValue>, CallFunctionError> {
957        Python::with_gil(|py| {
958            let result = self.call_python_fn(
959                py,
960                cx,
961                Some(function),
962                args,
963                kwargs,
964                mutates,
965                device_meshes,
966                remote_process_groups,
967            )?;
968            Ok(PyTree::<RValue>::extract_bound(&result).map_err(SerializablePyErr::from_fn(py))?)
969        })
970    }
971    /// Retrieve `ref_` or create a fake value with the provided factory if it
972    /// is an error. We use this for collective calls, where even if there was
973    /// an upstream failure, we still have participate in the collective to
974    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
975    /// value here of the correct shape; the controller will have been notified
976    /// of the upstream failure and will know to ignore everything dependent on
977    /// it.
978    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
979        let rvalue = self
980            .env
981            .get(&ref_)
982            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
983
984        match rvalue {
985            Ok(val) => Ok(val.clone().try_into().map_err(|e| anyhow!("{}", e))?),
986            Err(_) => {
987                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
988                Ok(TensorCell::new(t))
989            }
990        }
991    }
992
993    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
994        self.active_recording
995            .as_mut()
996            .and_then(|state| match state {
997                RecordingState::Defining {
998                    recording,
999                    defined_borrows,
1000                } => {
1001                    match self.recordings.get_mut(recording) {
1002                        Some(recording) => Some((recording, defined_borrows)),
1003                        // Panic, because this would be a logic error in the program.
1004                        None => panic!("recording not found: {:?}", recording),
1005                    }
1006                }
1007                RecordingState::Running => None,
1008            })
1009    }
1010
1011    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
1012        for ref_ in refs {
1013            let rvalue_or_err = self
1014                .env
1015                .get(ref_)
1016                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
1017            if let Err(err) = rvalue_or_err {
1018                return Ok(Some(err.clone()));
1019            }
1020        }
1021        Ok(None)
1022    }
1023    async fn send_value_python_message(
1024        &mut self,
1025        cx: &hyperactor::Context<'_, Self>,
1026        seq: Seq,
1027        worker_actor_id: ActorId,
1028        mutates: Vec<Ref>,
1029        function: Option<ResolvableFunction>,
1030        args: Vec<WireValue>,
1031        kwargs: HashMap<String, WireValue>,
1032        device_meshes: HashMap<Ref, DeviceMesh>,
1033    ) -> Result<()> {
1034        self.try_define(cx, seq, vec![], &vec![], async |self_| {
1035            let python_message =
1036                Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
1037                    let python_result = tokio::task::block_in_place(|| {
1038                        self_.call_python_fn(
1039                            py,
1040                            cx,
1041                            function,
1042                            args,
1043                            kwargs,
1044                            &mutates,
1045                            device_meshes,
1046                            HashMap::new(),
1047                        )
1048                    })?;
1049                    pickle_python_result(py, python_result, worker_actor_id)
1050                        .map_err(CallFunctionError::Error)
1051                })?;
1052            let ser = Serialized::serialize(&python_message).unwrap();
1053            self_
1054                .controller_actor
1055                .fetch_result(cx, seq, Ok(ser))
1056                .await?;
1057            Ok(vec![])
1058        })
1059        .await
1060    }
1061    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
1062        let rvalue = self
1063            .env
1064            .get(&src)
1065            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
1066        self.env.insert(dest, rvalue.clone());
1067        Ok(())
1068    }
1069    async fn call_actor(
1070        &mut self,
1071        cx: &Context<'_, Self>,
1072        params: ActorCallParams,
1073    ) -> Result<PyObject, CallFunctionError> {
1074        let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1075            params
1076                .local_state
1077                .into_iter()
1078                .map(|elem| {
1079                    // SAFETY: python is gonna make unsafe copies of this stuff anyway
1080                    unsafe {
1081                        let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1082                        Ok(x)
1083                    }
1084                })
1085                .collect()
1086        });
1087
1088        let (send, recv) = cx.open_once_port();
1089        let state = LocalState {
1090            response_port: send,
1091            state: local_state?,
1092        };
1093        let x: u64 = params.seq.into();
1094        let message = LocalStateBrokerMessage::Set(x as usize, state);
1095
1096        let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1097        broker
1098            .send(message)
1099            .map_err(|e| CallFunctionError::Error(e.into()))?;
1100        let result = recv
1101            .recv()
1102            .await
1103            .map_err(|e| CallFunctionError::Error(e.into()))?;
1104
1105        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1106    }
1107}
1108
1109#[async_trait]
1110#[forward(StreamMessage)]
1111impl StreamMessageHandler for StreamActor {
1112    async fn call_function(
1113        &mut self,
1114        cx: &Context<Self>,
1115        params: CallFunctionParams,
1116        device_meshes: HashMap<Ref, DeviceMesh>,
1117        remote_process_groups: HashMap<
1118            Ref,
1119            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1120        >,
1121    ) -> Result<()> {
1122        if let Some((recording, _)) = self.get_defining_recording() {
1123            recording.messages.push(StreamMessage::CallFunction(
1124                params,
1125                device_meshes,
1126                remote_process_groups,
1127            ));
1128            return Ok(());
1129        }
1130
1131        params.function.panic_if_requested();
1132        self.try_define(
1133            cx,
1134            params.seq,
1135            params.results,
1136            &params.mutates,
1137            async |self| {
1138                tokio::task::block_in_place(|| match params.function.as_torch_op() {
1139                    Some((op, overload)) => {
1140                        self.call_torch_op(op, overload, params.args, params.kwargs)
1141                    }
1142                    _ => self
1143                        .call_python_fn_pytree(
1144                            cx,
1145                            params.function,
1146                            params.args,
1147                            params.kwargs,
1148                            &params.mutates,
1149                            device_meshes,
1150                            remote_process_groups,
1151                        )
1152                        .map(|results| results.into_leaves()),
1153                })
1154            },
1155        )
1156        .await?;
1157        Ok(())
1158    }
1159
1160    async fn borrow_create(
1161        &mut self,
1162        _cx: &Context<Self>,
1163        borrow: u64,
1164        tensor: Ref,
1165        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1166    ) -> Result<()> {
1167        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1168            recording.messages.push(StreamMessage::BorrowCreate {
1169                borrow,
1170                tensor,
1171                first_use_sender,
1172            });
1173            ensure!(
1174                defined_borrows.insert(borrow),
1175                "duplicate borrow create in recording"
1176            );
1177            return Ok(());
1178        }
1179
1180        let rvalue_result = self
1181            .env
1182            .get(&tensor)
1183            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1184
1185        let result = match rvalue_result {
1186            Ok(rvalue) => Ok(rvalue.clone().try_into().map_err(|e| anyhow!("{}", e))?),
1187            Err(e) => Err(e.clone()),
1188        };
1189
1190        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1191        first_use_sender.send((event, result)).map_err(|err| {
1192            anyhow!(
1193                "failed sending first use event for borrow {:?}: {:?}",
1194                borrow,
1195                err
1196            )
1197        })
1198    }
1199
1200    async fn borrow_first_use(
1201        &mut self,
1202        _cx: &Context<Self>,
1203        borrow: u64,
1204        result: Ref,
1205        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1206    ) -> Result<()> {
1207        if let Some((recording, _)) = self.get_defining_recording() {
1208            recording.messages.push(StreamMessage::BorrowFirstUse {
1209                borrow,
1210                result,
1211                first_use_receiver: first_use_receiver.clone(),
1212            });
1213            return Ok(());
1214        }
1215
1216        let (first_use_event, cell) =
1217            first_use_receiver
1218                .lock()
1219                .await
1220                .recv()
1221                .await
1222                .map_err(|err| {
1223                    anyhow!(
1224                        "failed receiving first use event for borrow {:?}: {:?}",
1225                        borrow,
1226                        err
1227                    )
1228                })?;
1229
1230        if let Some(stream) = self.cuda_stream() {
1231            stream.wait_event(
1232                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1233            );
1234        }
1235        match cell {
1236            Ok(cell) => {
1237                self.env.insert(result, Ok(cell.into()));
1238            }
1239            Err(err) => {
1240                self.env.insert(result, Err(err.clone()));
1241            }
1242        }
1243        Ok(())
1244    }
1245
1246    async fn borrow_last_use(
1247        &mut self,
1248        _cx: &Context<Self>,
1249        borrow: u64,
1250        result: Ref,
1251        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1252    ) -> Result<()> {
1253        if let Some((recording, _)) = self.get_defining_recording() {
1254            recording.messages.push(StreamMessage::BorrowLastUse {
1255                borrow,
1256                result,
1257                last_use_sender,
1258            });
1259            return Ok(());
1260        }
1261
1262        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1263        let rvalue_or_err = self.env.remove(&result).ok_or(anyhow!(
1264            "Invalid reference for borrow_last_use: {result:#?}"
1265        ))?;
1266        let tensor = match rvalue_or_err {
1267            Ok(RValue::Tensor(t)) => Ok(t),
1268            Err(e) => Err(e),
1269            _ => bail!("invalid rvalue type for borrow_last_use"),
1270        };
1271
1272        last_use_sender.send((event, tensor)).map_err(|err| {
1273            anyhow!(
1274                "failed sending last use event for borrow {:?}: {:?}",
1275                borrow,
1276                err
1277            )
1278        })
1279    }
1280
1281    async fn borrow_drop(
1282        &mut self,
1283        _cx: &Context<Self>,
1284        borrow: u64,
1285        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1286    ) -> Result<()> {
1287        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1288            recording.messages.push(StreamMessage::BorrowDrop {
1289                borrow,
1290                last_use_receiver: last_use_receiver.clone(),
1291            });
1292            ensure!(
1293                defined_borrows.remove(&borrow),
1294                "borrow drop for borrow not defined in recording"
1295            );
1296            return Ok(());
1297        }
1298
1299        // The borrowed cell isn't used directly, but we still want to receive it here
1300        // so that the underlying tensor isn't dropped until after we synchronize the
1301        // CUDA streams.
1302        let (last_use_event, _cell) =
1303            last_use_receiver.lock().await.recv().await.map_err(|err| {
1304                anyhow!(
1305                    "failed receiving last use event for borrow {:?}: {:?}",
1306                    borrow,
1307                    err
1308                )
1309            })?;
1310
1311        if let Some(stream) = self.cuda_stream() {
1312            stream.wait_event(
1313                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1314            );
1315        }
1316        // let the cell drop.
1317        Ok(())
1318    }
1319
1320    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1321        if let Some((recording, _)) = self.get_defining_recording() {
1322            recording.messages.push(StreamMessage::DeleteRefs(refs));
1323            return Ok(());
1324        }
1325
1326        for ref_ in refs.iter() {
1327            self.env.remove(ref_);
1328        }
1329        Ok(())
1330    }
1331
1332    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1333        if self.get_defining_recording().is_some() {
1334            bail!("request_status not allowed in recording");
1335        }
1336
1337        Ok(())
1338    }
1339
1340    async fn init_comm(
1341        &mut self,
1342        _cx: &Context<Self>,
1343        comm: ActorHandle<NcclCommActor>,
1344    ) -> Result<()> {
1345        if self.get_defining_recording().is_some() {
1346            bail!("init_comm not allowed in recording");
1347        }
1348
1349        self.comm = Some(comm);
1350        Ok(())
1351    }
1352
1353    async fn reduce(
1354        &mut self,
1355        cx: &Context<Self>,
1356        comm: Arc<ActorHandle<NcclCommActor>>,
1357        dim_size: i64,
1358        result: Ref,
1359        local_tensor: Ref,
1360        factory: Factory,
1361        reduction: Reduction,
1362        scatter: bool,
1363        in_place: bool,
1364        out: Option<Ref>,
1365    ) -> Result<()> {
1366        if let Some((recording, _)) = self.get_defining_recording() {
1367            recording.messages.push(StreamMessage::Reduce {
1368                comm,
1369                dim_size,
1370                result,
1371                local_tensor,
1372                factory,
1373                reduction,
1374                scatter,
1375                in_place,
1376                out,
1377            });
1378            return Ok(());
1379        }
1380
1381        let stream = self
1382            .cuda_stream()
1383            .expect("reductions not yet supported for non-CUDA workers")
1384            .clone();
1385        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1386        let out_cell = out
1387            .map(|out| self.get_or_fake_on_err(out, &factory))
1388            .transpose()?;
1389        let output_cell = match reduction {
1390            Reduction::Stack => {
1391                if scatter {
1392                    let output_cell = if in_place {
1393                        input_cell.clone()
1394                    } else {
1395                        out_cell.unwrap_or({
1396                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1397                            let cloned = deep_clone(&borrow);
1398                            TensorCell::new(cloned)
1399                        })
1400                    };
1401                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1402                        .await?;
1403                    output_cell
1404                } else {
1405                    ensure!(
1406                        !in_place,
1407                        "in-place, non-scatter not supported for stack reduce"
1408                    );
1409
1410                    let output_cell = out_cell.unwrap_or({
1411                        // In Python, this would be [dim_size, *factory.sizes]
1412                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1413                        let output =
1414                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1415                        TensorCell::new(output)
1416                    });
1417
1418                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1419                        .await?;
1420                    output_cell
1421                }
1422            }
1423            Reduction::ReduceOp(op) => {
1424                if scatter {
1425                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1426
1427                    let output_cell = out_cell.unwrap_or({
1428                        let output = factory_empty(
1429                            &factory.size[1..],
1430                            factory.dtype,
1431                            factory.layout,
1432                            factory.device,
1433                        );
1434                        TensorCell::new(output)
1435                    });
1436                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1437                        .await?;
1438                    output_cell
1439                } else {
1440                    let output_cell = if in_place {
1441                        input_cell.clone()
1442                    } else {
1443                        out_cell.map_or(
1444                            {
1445                                let borrow =
1446                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1447                                let cloned = deep_clone(&borrow);
1448                                Ok(TensorCell::new(cloned))
1449                            },
1450                            |out_cell| -> Result<_, anyhow::Error> {
1451                                let mut out_borrow =
1452                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1453                                let in_borrow =
1454                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1455                                out_borrow.copy_(&in_borrow);
1456                                drop(out_borrow);
1457                                Ok(out_cell)
1458                            },
1459                        )?
1460                    };
1461
1462                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1463                    output_cell
1464                }
1465            }
1466        };
1467
1468        self.env.insert(result, Ok(output_cell.into()));
1469        Ok(())
1470    }
1471
1472    async fn send_tensor(
1473        &mut self,
1474        cx: &Context<Self>,
1475        result: Ref,
1476        from_rank: Option<usize>,
1477        to_rank: Option<usize>,
1478        tensor: Ref,
1479        factory: Factory,
1480        comm: Arc<ActorHandle<NcclCommActor>>,
1481    ) -> Result<()> {
1482        if let Some((recording, _)) = self.get_defining_recording() {
1483            recording.messages.push(StreamMessage::SendTensor {
1484                result,
1485                from_rank,
1486                to_rank,
1487                tensor,
1488                factory,
1489                comm,
1490            });
1491            return Ok(());
1492        }
1493
1494        if to_rank.is_none() && from_rank.is_none() {
1495            bail!("tried to send tensor without a to/from rank");
1496        }
1497
1498        // Value is local, so we do not have to actually send it.
1499        if from_rank == to_rank {
1500            let input_cell: &std::result::Result<RValue, Arc<SeqError>> = self
1501                .env
1502                .get(&tensor)
1503                .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1504            let output_cell = match input_cell {
1505                Ok(RValue::Tensor(input_cell)) => {
1506                    // We create a defensive copy here to prevent mutations on
1507                    // the input tensor from affecting output tensor.
1508                    // Should we copy if input ref == output ref?
1509                    // Should we support copy-on-write to avoid unnecessary copy?
1510                    let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1511                    let cloned = deep_clone(&borrow);
1512                    Ok(RValue::Tensor(TensorCell::new(cloned)))
1513                }
1514                Ok(rval) => bail!("tensor ref is not a tensor: {:?}", rval),
1515                Err(err) => Err(err.clone()),
1516            };
1517            self.env.insert(result, output_cell);
1518            return Ok(());
1519        }
1520
1521        let mut messages = Vec::new();
1522
1523        if let Some(to_rank) = to_rank {
1524            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1525            messages.push(CommMessage::Send(
1526                input_cell,
1527                to_rank.try_into().unwrap(),
1528                self.cuda_stream()
1529                    .expect("tried to send_tensor on non-cuda stream")
1530                    .clone(),
1531                cx.open_once_port().0,
1532            ));
1533        }
1534
1535        if let Some(from_rank) = from_rank {
1536            let output_cell = TensorCell::new(factory_empty(
1537                &factory.size,
1538                factory.dtype,
1539                factory.layout,
1540                factory.device,
1541            ));
1542            messages.push(CommMessage::Recv(
1543                output_cell.clone(),
1544                from_rank.try_into().unwrap(),
1545                self.cuda_stream()
1546                    .expect("tried to send_tensor on non-cuda stream")
1547                    .clone(),
1548                cx.open_once_port().0,
1549            ));
1550            self.env.insert(result, Ok(output_cell.into()));
1551        }
1552
1553        comm.group(
1554            cx,
1555            messages,
1556            self.cuda_stream()
1557                .expect("tried to send_tensor on non-cuda stream")
1558                .clone(),
1559        )
1560        .await?;
1561        Ok(())
1562    }
1563
1564    async fn send_value(
1565        &mut self,
1566        cx: &Context<Self>,
1567        seq: Seq,
1568        worker_actor_id: ActorId,
1569        mutates: Vec<Ref>,
1570        function: Option<ResolvableFunction>,
1571        args: Vec<WireValue>,
1572        kwargs: HashMap<String, WireValue>,
1573        device_meshes: HashMap<Ref, DeviceMesh>,
1574        pipe: Option<PortHandle<PipeMessage>>,
1575    ) -> Result<()> {
1576        if self.respond_with_python_message && pipe.is_none() {
1577            return self
1578                .send_value_python_message(
1579                    cx,
1580                    seq,
1581                    worker_actor_id,
1582                    mutates,
1583                    function,
1584                    args,
1585                    kwargs,
1586                    device_meshes,
1587                )
1588                .await;
1589        }
1590        let result = if let Some(function) = function {
1591            // If a function was provided, use that to resolve the value.
1592            match function.as_torch_op() {
1593                Some((op, overload)) => {
1594                    self.call_torch_op(op, overload, args, kwargs)
1595                        .map(|rvalues| {
1596                            if rvalues.len() == 1 {
1597                                Ok(rvalues[0].clone().into())
1598                            } else {
1599                                // TODO: Replace with native pytrees when possible
1600                                Python::with_gil(|py| {
1601                                    Ok((|| {
1602                                        let py_rvalues = rvalues
1603                                            .into_iter()
1604                                            // SAFETY: This inherits the unsafety of `try_to_object_unsafe`.
1605                                            .map(|rvalue| unsafe {
1606                                                rvalue.try_to_object_unsafe(py)
1607                                            })
1608                                            .collect::<Result<Vec<_>, _>>()?;
1609                                        PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1610                                    })()
1611                                    .map_err(SerializablePyErr::from_fn(py))?)
1612                                })
1613                            }
1614                        })?
1615                }
1616                // Use block-in-place to allow nested callbacks to re-enter the
1617                // runtime to run async code.
1618                _ => tokio::task::block_in_place(|| {
1619                    self.call_python_fn_pytree(
1620                        cx,
1621                        function,
1622                        args,
1623                        kwargs,
1624                        &mutates,
1625                        device_meshes,
1626                        HashMap::new(),
1627                    )
1628                }),
1629            }
1630        } else {
1631            // If there's no function provided, there should be exactly one arg
1632            // and no kwargs.
1633            match (args.len(), kwargs.len()) {
1634                (1, 0) => Python::with_gil(|py| {
1635                    let arg = args[0]
1636                        .as_py_object()
1637                        .ok_or_else(|| {
1638                            CallFunctionError::UnsupportedArgType(
1639                                "send_value".to_string(),
1640                                "expected a PyObject as the first arg".to_string(),
1641                            )
1642                        })?
1643                        .unpickle(py)
1644                        .map_err(SerializablePyErr::from_fn(py))?;
1645                    arg.extract::<PyTree<PyObject>>()
1646                        .map_err(SerializablePyErr::from_fn(py))?
1647                        .try_into_map(|obj| {
1648                            let bound_obj = obj.bind(py);
1649                            if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1650                                self.ref_to_rvalue(&ref_)
1651                            } else {
1652                                Ok(bound_obj
1653                                    .extract::<RValue>()
1654                                    .map_err(SerializablePyErr::from_fn(py))?)
1655                            }
1656                        })
1657                }),
1658                _ => Err(CallFunctionError::TooManyArgsForValue(
1659                    format!("{:?}", args),
1660                    format!("{:?}", kwargs),
1661                )),
1662            }
1663        };
1664
1665        let value = match result {
1666            Ok(rvalue) => {
1667                // When returning a tensor, we copy out to decouple from the GPU,
1668                // as the worker will either serialize and send this to the controller
1669                // or to a pipe and we see hangs if it tries to pull from the GPU
1670                // in its thread.
1671                Ok(rvalue.into_map(|rval| match rval {
1672                    RValue::Tensor(tensor) => RValue::Tensor(tensor.try_cpu().unwrap()),
1673                    RValue::TensorList(tensors) => RValue::TensorList(
1674                        tensors
1675                            .into_iter()
1676                            .map(|tensor| tensor.try_cpu().unwrap())
1677                            .collect(),
1678                    ),
1679                    rval => rval,
1680                }))
1681            }
1682            Err(err) => {
1683                let err = self.report_seq_error(cx, seq, err).await?;
1684                for ref_ in mutates {
1685                    self.env.insert(ref_, Err(err.clone()));
1686                }
1687                Err(WorkerError {
1688                    backtrace: format!("{:?}", err),
1689                    worker_actor_id,
1690                })
1691            }
1692        };
1693
1694        // Actually send the value.
1695        if let Some(pipe) = pipe {
1696            pipe.send(PipeMessage::SendValue(value))?;
1697        } else {
1698            let result = match value {
1699                Ok(value) => Ok(Serialized::serialize(&value).map_err(anyhow::Error::from)?),
1700                Err(e) => Err(e),
1701            };
1702            self.controller_actor.fetch_result(cx, seq, result).await?;
1703        }
1704
1705        Ok(())
1706    }
1707
1708    async fn send_result_of_actor_call(
1709        &mut self,
1710        cx: &Context<Self>,
1711        worker_actor_id: ActorId,
1712        params: ActorCallParams,
1713    ) -> anyhow::Result<()> {
1714        let seq = params.seq;
1715        let mutates = params.mutates.clone();
1716        self.try_define(cx, seq, vec![], &mutates, async |self| {
1717            let value = self.call_actor(cx, params).await?;
1718            let result = Python::with_gil(|py| {
1719                pickle_python_result(py, value.into_bound(py), worker_actor_id)
1720            })?;
1721            let result = Serialized::serialize(&result).unwrap();
1722            self.controller_actor
1723                .fetch_result(cx, seq, Ok(result))
1724                .await?;
1725            Ok(vec![])
1726        })
1727        .await
1728    }
1729
1730    async fn call_actor_method(
1731        &mut self,
1732        cx: &Context<Self>,
1733        params: ActorMethodParams,
1734    ) -> anyhow::Result<()> {
1735        let seq = params.call.seq;
1736        let mutates = params.call.mutates.clone();
1737        self.try_define(cx, seq, params.results, &mutates, async |self| {
1738            let result = self.call_actor(cx, params.call).await?;
1739            let result = Python::with_gil(|py| {
1740                PyTree::<RValue>::extract_bound(&result.into_bound(py))
1741                    .map_err(SerializablePyErr::from_fn(py))
1742            })?;
1743            Ok(result.into_leaves())
1744        })
1745        .await
1746    }
1747
1748    async fn set_value(
1749        &mut self,
1750        cx: &Context<Self>,
1751        seq: Seq,
1752        results: Vec<Option<Ref>>,
1753        pipe: PortHandle<PipeMessage>,
1754    ) -> Result<()> {
1755        if let Some((recording, _)) = self.get_defining_recording() {
1756            recording
1757                .messages
1758                .push(StreamMessage::SetValue { seq, results, pipe });
1759            return Ok(());
1760        }
1761
1762        self.try_define(cx, seq, results, &vec![], async |self| {
1763            let (tx, rx) = cx.open_once_port();
1764            pipe.send(PipeMessage::RecvValue(tx))
1765                .map_err(anyhow::Error::from)
1766                .map_err(CallFunctionError::from)?;
1767            let value = rx.recv().await.map_err(anyhow::Error::from)?;
1768            Ok(value.into_leaves())
1769        })
1770        .await
1771    }
1772
1773    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1774        if self.active_recording.is_some() {
1775            bail!("different recording already active");
1776        }
1777        match self.recordings.entry(recording) {
1778            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1779            Entry::Vacant(entry) => entry.insert(Recording::new()),
1780        };
1781        self.active_recording = Some(RecordingState::Defining {
1782            recording,
1783            defined_borrows: HashSet::new(),
1784        });
1785        Ok(())
1786    }
1787
1788    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1789        match self.active_recording {
1790            Some(RecordingState::Defining {
1791                recording: active_recording,
1792                ref defined_borrows,
1793            }) if active_recording == recording => {
1794                ensure!(
1795                    defined_borrows.is_empty(),
1796                    "all borrows created within recording must be dropped within recording"
1797                );
1798                self.active_recording = None;
1799            }
1800            _ => bail!("cannot finalize recording that isn't active"),
1801        }
1802        Ok(())
1803    }
1804
1805    async fn recording_formal(
1806        &mut self,
1807        _cx: &Context<Self>,
1808        result: Ref,
1809        argument_index: usize,
1810    ) -> Result<()> {
1811        match self.get_defining_recording() {
1812            Some((recording, _)) => {
1813                recording.messages.push(StreamMessage::RecordingFormal {
1814                    result,
1815                    argument_index,
1816                });
1817            }
1818            None => bail!("recording_formal called outside of recording"),
1819        };
1820        Ok(())
1821    }
1822
1823    async fn recording_result(
1824        &mut self,
1825        _cx: &Context<Self>,
1826        result: Ref,
1827        output_index: usize,
1828    ) -> Result<()> {
1829        match self.get_defining_recording() {
1830            Some((recording, _)) => {
1831                recording.messages.push(StreamMessage::RecordingResult {
1832                    result,
1833                    output_index,
1834                });
1835            }
1836            None => bail!("recording_result called outside of recording"),
1837        };
1838        Ok(())
1839    }
1840
1841    async fn call_recording(
1842        &mut self,
1843        cx: &Context<Self>,
1844        seq: Seq,
1845        recording: Ref,
1846        results: Vec<Ref>,
1847        actuals: Vec<Ref>,
1848    ) -> Result<()> {
1849        if self.active_recording.is_some() {
1850            bail!("cannot call recording while another recording is active");
1851        }
1852
1853        let messages = match self.recordings.get(&recording) {
1854            Some(recording) => recording
1855                .messages
1856                .iter()
1857                .map(|message| message.clone_for_recording())
1858                .collect::<Vec<_>>(),
1859            None => bail!("recording {:?} not found", recording),
1860        };
1861
1862        self.active_recording = Some(RecordingState::Running);
1863
1864        // Global error for all messages in the recording. The first time a message
1865        // fails in the recording, we set the error. We then need to propagate this
1866        // error to all of the refs mutated by the entire recording, as well as the
1867        // result refs.
1868        let mut error: Option<Arc<SeqError>> = None;
1869        // The set of all refs defined by this recording (excluding "results"),
1870        // which we need to ensure are deleted when the recording is done executing.
1871        let mut all_defined_refs = HashSet::new();
1872        // The set of all refs mutated by this recording. If there is an error with
1873        // any message, all of these refs need to have the correct error set.
1874        let mut all_mutated_refs = HashSet::new();
1875        // Map from the result ref of a RecordingFormal message to the associated
1876        // actual ref from "actuals". We need to track this in order to properly
1877        // handle recordings that mutate refs contained in "actuals" -- every
1878        // message in the recording that interacts with the recording inputs will
1879        // interact with the formal ref rather than the actual ref.
1880        let mut formal_to_actual_refs = HashMap::new();
1881        // clear any pre-existing error messages before recording started
1882        self.last_seq_error = None;
1883        for message in messages.into_iter() {
1884            let defined_refs = message.get_defined_refs();
1885            all_defined_refs.extend(defined_refs.clone());
1886
1887            let mutated_refs_with_formals = message.get_mutated_refs();
1888            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1889                match formal_to_actual_refs.get(ref_) {
1890                    Some(actual_ref) => Some(*actual_ref),
1891                    None => {
1892                        if all_defined_refs.contains(ref_) {
1893                            None
1894                        } else {
1895                            Some(*ref_)
1896                        }
1897                    }
1898                }
1899            }));
1900
1901            match message {
1902                StreamMessage::RecordingFormal {
1903                    result: formal_ref,
1904                    argument_index,
1905                } => match actuals.get(argument_index) {
1906                    None => bail!("recording_formal called with too few arguments"),
1907                    Some(actual_ref) => {
1908                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1909                        self.define_ref(formal_ref, *actual_ref)?;
1910                    }
1911                },
1912                StreamMessage::RecordingResult {
1913                    result: result_ref,
1914                    output_index,
1915                } => match results.get(output_index) {
1916                    None => bail!("recording_result called with too few results"),
1917                    Some(actual_result_ref) => {
1918                        self.define_ref(*actual_result_ref, result_ref)?;
1919                    }
1920                },
1921                StreamMessage::DeleteRefs(ref refs) => {
1922                    for ref_ in refs {
1923                        all_defined_refs.remove(ref_);
1924                    }
1925                    StreamMessageHandler::handle(self, cx, message).await?;
1926                }
1927                StreamMessage::CallFunction { .. } if error.is_some() => {
1928                    // CallFunction is expensive. If the recording already failed, then
1929                    // just update the necessary refs with the error. Most of the other
1930                    // message types need to run regardless because there are other actors
1931                    // that expect the call to happen (e.g., all of the borrow messages,
1932                    // pipe send/recv, send_tensor, reduce, etc.).
1933                    let error = error.clone().unwrap();
1934                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1935                        self.env.insert(*ref_, Err(error.clone()));
1936                    }
1937                }
1938                StreamMessage::BorrowLastUse { ref result, .. } => {
1939                    all_defined_refs.remove(result);
1940                    StreamMessageHandler::handle(self, cx, message).await?;
1941                }
1942                StreamMessage::Reduce {
1943                    local_tensor,
1944                    ref out,
1945                    ..
1946                } => {
1947                    // Reduce doesn't propagate errors to the result ref, so we need
1948                    // to check for existing errors on the input tensors and set the
1949                    // recording's error if necessary.
1950                    if error.is_none() {
1951                        let inputs_to_check = [Some(local_tensor), out.clone()]
1952                            .iter()
1953                            .filter_map(|r| *r)
1954                            .collect::<Vec<_>>();
1955                        error = self.get_first_error(inputs_to_check.as_slice())?;
1956                    }
1957                    StreamMessageHandler::handle(self, cx, message).await?;
1958                }
1959                StreamMessage::SendTensor {
1960                    ref tensor,
1961                    ref to_rank,
1962                    ..
1963                } => {
1964                    // If this rank is sending a tensor (e.g., to_rank has a value),
1965                    // we need to check for existing errors on the input tensor, because
1966                    // the error is only propagated to the result ref when this rank
1967                    // is also receiving a tensor.
1968                    if to_rank.is_some() && error.is_none() {
1969                        error = self.get_first_error(&[*tensor])?;
1970                    }
1971                    StreamMessageHandler::handle(self, cx, message).await?;
1972                }
1973                _ => {
1974                    StreamMessageHandler::handle(self, cx, message).await?;
1975                }
1976            };
1977
1978            // It's not entirely trivial to determine whether a message "failed" or not.
1979            // For example, the CallFunction message can return Ok(..) if there is an error
1980            // in the underlying function call. But in that case, we would still want to
1981            // consider the recording call as "failed". Unlike in python, where we can just
1982            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1983            // we clear before handling each recording message. If we see it is set, the
1984            // we know the recording has faild.
1985            match (&error, self.last_seq_error.take()) {
1986                (None, Some(seq_err)) => {
1987                    // Report failure to the controller.
1988                    self.controller_actor
1989                        .remote_function_failed(
1990                            cx,
1991                            seq,
1992                            WorkerError {
1993                                backtrace: format!("recording failed: {}", &seq_err),
1994                                worker_actor_id: cx.self_id().clone(),
1995                            },
1996                        )
1997                        .await?;
1998                    error = Some(seq_err)
1999                }
2000                _ => {}
2001            }
2002            // Continue processing the remaining stream messages regardless of error.
2003            // We need to do this partially for error propagation, but also because
2004            // certain messages (like borrows and reductions) need to run regardless
2005            // in order to prevent deadlocks.
2006        }
2007
2008        // Delete the formal refs and some subset of the RecordingResult refs. The
2009        // controller should have generated DeleteRefs messages for all other refs
2010        // defined by the recording.
2011        StreamMessageHandler::handle(
2012            self,
2013            cx,
2014            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
2015        )
2016        .await?;
2017
2018        // Any refs mutated by the recording and all results should have the same error
2019        // (the original error that caused the recording to fail).
2020        if error.is_some() {
2021            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
2022                self.env.insert(*ref_, Err(error.clone().unwrap()));
2023            }
2024        }
2025
2026        self.active_recording = None;
2027        Ok(())
2028    }
2029
2030    async fn set_ref_unit_tests_only(
2031        &mut self,
2032        _cx: &Context<Self>,
2033        reference: Ref,
2034        value: WireValue,
2035    ) -> Result<()> {
2036        self.env
2037            .insert(reference, Ok(self.wire_to_rvalue(value).unwrap()));
2038        Ok(())
2039    }
2040
2041    async fn set_tensor_ref_unit_tests_only(
2042        &mut self,
2043        _cx: &Context<Self>,
2044        reference: Ref,
2045        tensor_result: TensorCellResult,
2046    ) -> Result<()> {
2047        match tensor_result {
2048            Ok(tensor_cell) => {
2049                self.env.insert(reference, Ok(RValue::Tensor(tensor_cell)));
2050            }
2051            Err(err) => {
2052                self.env.insert(reference, Err(err));
2053            }
2054        }
2055        Ok(())
2056    }
2057
2058    async fn get_ref_unit_tests_only(
2059        &mut self,
2060        _cx: &Context<Self>,
2061        reference: Ref,
2062    ) -> Result<Option<Result<WireValue, String>>> {
2063        /// For testing only, doesn't support Tensor or TensorList.
2064        fn rvalue_to_wire(
2065            value: Result<RValue, Arc<SeqError>>,
2066        ) -> Result<WireValue, Arc<SeqError>> {
2067            Ok(match value? {
2068                RValue::Int(val) => WireValue::Int(val),
2069                RValue::IntList(val) => WireValue::IntList(val),
2070                RValue::Double(val) => WireValue::Double(val),
2071                RValue::Bool(val) => WireValue::Bool(val),
2072                RValue::String(val) => WireValue::String(val),
2073                RValue::Layout(val) => WireValue::Layout(val),
2074                RValue::Device(val) => WireValue::Device(val),
2075                RValue::ScalarType(val) => WireValue::ScalarType(val),
2076                RValue::MemoryFormat(val) => WireValue::MemoryFormat(val),
2077                RValue::None => WireValue::None(()),
2078                other => WireValue::String(format!("unsupported rvalue type: {:?}", other)),
2079            })
2080        }
2081        Ok(self
2082            .env
2083            .get(&reference)
2084            .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string())))
2085    }
2086
2087    async fn get_tensor_ref_unit_tests_only(
2088        &mut self,
2089        _cx: &Context<Self>,
2090        reference: Ref,
2091    ) -> Result<Option<TensorCellResult>> {
2092        match self.env.get(&reference) {
2093            Some(Ok(rvalue)) => match rvalue {
2094                RValue::Tensor(tensor) => Ok(Some(Ok(tensor.clone().try_cpu().unwrap()))),
2095                other => bail!("expected tensor, got {:?}", other),
2096            },
2097            Some(Err(err)) => Ok(Some(Err(err.clone()))),
2098            None => Ok(None),
2099        }
2100    }
2101}
2102
2103#[cfg(test)]
2104mod tests {
2105    use hyperactor::actor::ActorStatus;
2106    use hyperactor::context;
2107    use hyperactor::supervision::ActorSupervisionEvent;
2108    use monarch_messages::controller::ControllerMessage;
2109    use monarch_messages::worker::StreamCreationMode;
2110    use monarch_types::PickledPyObject;
2111    use pyo3::IntoPyObjectExt;
2112    use timed_test::async_timed_test;
2113    use torch_sys::factory_float_tensor;
2114    use torch_sys::testing::allclose;
2115    use torch_sys_cuda::nccl::UniqueId;
2116
2117    use super::*;
2118    use crate::comm::CommParams;
2119    use crate::test_util;
2120
2121    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
2122        Arc::new(SeqError {
2123            seq: 0.into(),
2124            error: err,
2125        })
2126    }
2127
2128    struct TestSetup {
2129        proc: Proc,
2130        stream_actor: ActorHandle<StreamActor>,
2131        client: Instance<()>,
2132        // Unused, but necessary, because proc needs a supervision
2133        // port -- otherwise an actor failure will cause a crash.
2134        #[allow(dead_code)]
2135        supervision_rx: PortReceiver<ActorSupervisionEvent>,
2136        controller_rx: PortReceiver<ControllerMessage>,
2137        controller_actor: ActorRef<ControllerActor>,
2138        next_ref: Ref,
2139    }
2140
2141    impl TestSetup {
2142        async fn new() -> Result<Self> {
2143            Self::new_with_world_size(1).await
2144        }
2145
2146        async fn new_with_world_size(world_size: usize) -> Result<Self> {
2147            test_util::test_setup()?;
2148
2149            let proc = Proc::local();
2150            let (_, controller_actor, controller_rx) =
2151                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
2152            let (client, _handle) = proc.instance("client")?;
2153            let (supervision_tx, supervision_rx) = client.open_port();
2154            proc.set_supervision_coordinator(supervision_tx)?;
2155            let stream_actor = proc
2156                .spawn::<StreamActor>(
2157                    "stream",
2158                    StreamParams {
2159                        world_size,
2160                        rank: 0,
2161                        creation_mode: StreamCreationMode::UseDefaultStream,
2162                        id: 0.into(),
2163                        device: Some(CudaDevice::new(0.into())),
2164                        controller_actor: controller_actor.clone(),
2165                        respond_with_python_message: false,
2166                    },
2167                )
2168                .await?;
2169
2170            Ok(Self {
2171                proc,
2172                stream_actor,
2173                client,
2174                supervision_rx,
2175                controller_rx,
2176                controller_actor,
2177                next_ref: 0.into(),
2178            })
2179        }
2180
2181        fn next_ref(&mut self) -> Ref {
2182            let ref_ = self.next_ref;
2183            self.next_ref = Ref {
2184                id: self.next_ref.id + 1,
2185            };
2186            ref_
2187        }
2188
2189        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2190            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".try_into().unwrap()));
2191            self.stream_actor
2192                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2193                .await
2194        }
2195
2196        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2197            let actual = self
2198                .stream_actor
2199                .get_tensor_ref_unit_tests_only(&self.client, reference)
2200                .await
2201                .unwrap()
2202                .unwrap()
2203                .unwrap();
2204
2205            let result = allclose(
2206                &factory_float_tensor(data, "cpu".try_into().unwrap()),
2207                &actual.borrow(),
2208            )
2209            .unwrap();
2210            // rustfmt-ignore
2211            result
2212        }
2213
2214        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2215            let result_error = self
2216                .stream_actor
2217                .get_tensor_ref_unit_tests_only(&self.client, reference)
2218                .await
2219                .unwrap()
2220                .unwrap()
2221                .unwrap_err();
2222
2223            assert!(Arc::ptr_eq(&result_error, &error));
2224        }
2225    }
2226
2227    async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2228        loop {
2229            let status = proc
2230                .ledger_snapshot()
2231                .roots
2232                .get(actor_id)
2233                .unwrap()
2234                .status
2235                .clone();
2236            if let ActorStatus::Failed(msg) = status {
2237                assert!(msg.contains(&expected_msg));
2238                break;
2239            } else {
2240                tokio::task::yield_now().await;
2241            }
2242        }
2243    }
2244
2245    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2246        for ref_ in refs {
2247            assert!(
2248                test_setup
2249                    .stream_actor
2250                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2251                    .await
2252                    .unwrap()
2253                    .is_none()
2254            );
2255        }
2256    }
2257
2258    async fn fetch_result(
2259        cx: &impl context::Actor,
2260        stream_actor: ActorHandle<StreamActor>,
2261        seq: Seq,
2262        reference: Ref,
2263    ) {
2264        let ref_to_send = Python::with_gil(|py| {
2265            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2266        });
2267
2268        stream_actor
2269            .send_value(
2270                cx,
2271                seq,
2272                stream_actor.actor_id().clone(),
2273                Vec::new(),
2274                None,
2275                vec![WireValue::PyObject(ref_to_send)],
2276                HashMap::new(),
2277                HashMap::new(),
2278                None,
2279            )
2280            .await
2281            .unwrap()
2282    }
2283
2284    async fn check_fetch_result_error(
2285        cx: &impl context::Actor,
2286        stream_actor: ActorHandle<StreamActor>,
2287        seq: Seq,
2288        reference: Ref,
2289        controller_rx: &mut PortReceiver<ControllerMessage>,
2290        expected_backtrace: &str,
2291    ) {
2292        fetch_result(cx, stream_actor, seq, reference).await;
2293
2294        let controller_msg = controller_rx.recv().await.unwrap();
2295        match controller_msg {
2296            ControllerMessage::FetchResult {
2297                seq: actual_seq,
2298                value: Err(err),
2299            } => {
2300                assert_eq!(actual_seq, seq);
2301                assert!(
2302                    err.backtrace.contains(expected_backtrace),
2303                    "backtrace did not contain {:?}: {:?}",
2304                    expected_backtrace,
2305                    err.backtrace
2306                );
2307            }
2308            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2309        };
2310    }
2311
2312    async fn check_fetch_result_value(
2313        cx: &impl context::Actor,
2314        stream_actor: ActorHandle<StreamActor>,
2315        seq: Seq,
2316        reference: Ref,
2317        controller_rx: &mut PortReceiver<ControllerMessage>,
2318    ) {
2319        fetch_result(cx, stream_actor, seq, reference).await;
2320
2321        let controller_msg = controller_rx.recv().await.unwrap();
2322        match controller_msg {
2323            ControllerMessage::FetchResult {
2324                value: Ok(_),
2325                seq: actual_seq,
2326            } => assert_eq!(seq, actual_seq),
2327            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2328        };
2329    }
2330
2331    #[async_timed_test(timeout_secs = 60)]
2332    async fn test_define_recording_other_recording_active() -> Result<()> {
2333        let test_setup = TestSetup::new().await?;
2334        test_setup
2335            .stream_actor
2336            .define_recording(&test_setup.client, 0.into())
2337            .await?;
2338        test_setup
2339            .stream_actor
2340            .define_recording(&test_setup.client, 1.into())
2341            .await?;
2342        assert_actor_failed_with_msg(
2343            &test_setup.proc,
2344            test_setup.stream_actor.actor_id(),
2345            "different recording already active".into(),
2346        )
2347        .await;
2348        Ok(())
2349    }
2350
2351    #[async_timed_test(timeout_secs = 60)]
2352    async fn test_define_recording_already_defined() -> Result<()> {
2353        let test_setup = TestSetup::new().await?;
2354        test_setup
2355            .stream_actor
2356            .define_recording(&test_setup.client, 0.into())
2357            .await?;
2358        test_setup
2359            .stream_actor
2360            .finalize_recording(&test_setup.client, 0.into())
2361            .await?;
2362        test_setup
2363            .stream_actor
2364            .define_recording(&test_setup.client, 0.into())
2365            .await?;
2366        assert_actor_failed_with_msg(
2367            &test_setup.proc,
2368            test_setup.stream_actor.actor_id(),
2369            "already defined".into(),
2370        )
2371        .await;
2372        Ok(())
2373    }
2374
2375    #[async_timed_test(timeout_secs = 60)]
2376    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2377        let test_setup = TestSetup::new().await?;
2378        test_setup
2379            .stream_actor
2380            .define_recording(&test_setup.client, 0.into())
2381            .await?;
2382        test_setup
2383            .stream_actor
2384            .finalize_recording(&test_setup.client, 1.into())
2385            .await?;
2386        assert_actor_failed_with_msg(
2387            &test_setup.proc,
2388            test_setup.stream_actor.actor_id(),
2389            "cannot finalize recording that isn't active".into(),
2390        )
2391        .await;
2392        Ok(())
2393    }
2394
2395    #[async_timed_test(timeout_secs = 60)]
2396    async fn test_recording_formal_outside_recording() -> Result<()> {
2397        let test_setup = TestSetup::new().await?;
2398        test_setup
2399            .stream_actor
2400            .recording_formal(&test_setup.client, 0.into(), 0)
2401            .await?;
2402        assert_actor_failed_with_msg(
2403            &test_setup.proc,
2404            test_setup.stream_actor.actor_id(),
2405            "recording_formal called outside of recording".into(),
2406        )
2407        .await;
2408        Ok(())
2409    }
2410
2411    #[async_timed_test(timeout_secs = 60)]
2412    async fn test_recording_result_outside_recording() -> Result<()> {
2413        let test_setup = TestSetup::new().await?;
2414        test_setup
2415            .stream_actor
2416            .recording_result(&test_setup.client, 0.into(), 0)
2417            .await?;
2418        assert_actor_failed_with_msg(
2419            &test_setup.proc,
2420            test_setup.stream_actor.actor_id(),
2421            "recording_result called outside of recording".into(),
2422        )
2423        .await;
2424        Ok(())
2425    }
2426
2427    #[async_timed_test(timeout_secs = 60)]
2428    async fn test_call_recording_other_recording_active() -> Result<()> {
2429        let test_setup = TestSetup::new().await?;
2430        test_setup
2431            .stream_actor
2432            .define_recording(&test_setup.client, 0.into())
2433            .await?;
2434        test_setup
2435            .stream_actor
2436            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2437            .await?;
2438        assert_actor_failed_with_msg(
2439            &test_setup.proc,
2440            test_setup.stream_actor.actor_id(),
2441            "cannot call recording while another recording is active".into(),
2442        )
2443        .await;
2444        Ok(())
2445    }
2446
2447    #[async_timed_test(timeout_secs = 60)]
2448    async fn test_call_recording_not_found() -> Result<()> {
2449        let test_setup = TestSetup::new().await?;
2450        test_setup
2451            .stream_actor
2452            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2453            .await?;
2454        assert_actor_failed_with_msg(
2455            &test_setup.proc,
2456            test_setup.stream_actor.actor_id(),
2457            "not found".into(),
2458        )
2459        .await;
2460        Ok(())
2461    }
2462
2463    #[async_timed_test(timeout_secs = 60)]
2464    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2465        let test_setup = TestSetup::new().await?;
2466
2467        test_setup
2468            .stream_actor
2469            .define_recording(&test_setup.client, 0.into())
2470            .await?;
2471
2472        test_setup
2473            .stream_actor
2474            .recording_formal(&test_setup.client, 1.into(), 0)
2475            .await?;
2476
2477        test_setup
2478            .stream_actor
2479            .finalize_recording(&test_setup.client, 0.into())
2480            .await?;
2481
2482        test_setup
2483            .stream_actor
2484            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2485            .await?;
2486
2487        assert_actor_failed_with_msg(
2488            &test_setup.proc,
2489            test_setup.stream_actor.actor_id(),
2490            "recording_formal called with too few arguments".into(),
2491        )
2492        .await;
2493        Ok(())
2494    }
2495
2496    #[async_timed_test(timeout_secs = 60)]
2497    async fn test_recording_result_too_few_results() -> Result<()> {
2498        let test_setup = TestSetup::new().await?;
2499
2500        test_setup
2501            .stream_actor
2502            .define_recording(&test_setup.client, 0.into())
2503            .await?;
2504
2505        test_setup
2506            .stream_actor
2507            .recording_result(&test_setup.client, 1.into(), 0)
2508            .await?;
2509
2510        test_setup
2511            .stream_actor
2512            .finalize_recording(&test_setup.client, 0.into())
2513            .await?;
2514
2515        test_setup
2516            .stream_actor
2517            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2518            .await?;
2519
2520        assert_actor_failed_with_msg(
2521            &test_setup.proc,
2522            test_setup.stream_actor.actor_id(),
2523            "recording_result called with too few results".into(),
2524        )
2525        .await;
2526        Ok(())
2527    }
2528
2529    #[async_timed_test(timeout_secs = 60)]
2530    async fn test_basic_call_recording() -> Result<()> {
2531        let mut test_setup = TestSetup::new().await?;
2532
2533        // Define a recording equivalent to:
2534        // def f(x, y):
2535        //   return y, x
2536        test_setup
2537            .stream_actor
2538            .define_recording(&test_setup.client, 0.into())
2539            .await?;
2540
2541        let formal0_ref = 1.into();
2542        let formal0_index = 1;
2543        test_setup
2544            .stream_actor
2545            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2546            .await?;
2547
2548        let formal1_ref = 2.into();
2549        let formal1_index = 0;
2550        test_setup
2551            .stream_actor
2552            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2553            .await?;
2554
2555        let result0_ref = formal0_ref;
2556        let result0_index = 0;
2557        test_setup
2558            .stream_actor
2559            .recording_result(&test_setup.client, result0_ref, result0_index)
2560            .await?;
2561
2562        let result1_ref = formal1_ref;
2563        let result1_index = 1;
2564        test_setup
2565            .stream_actor
2566            .recording_result(&test_setup.client, result1_ref, result1_index)
2567            .await?;
2568
2569        test_setup
2570            .stream_actor
2571            .finalize_recording(&test_setup.client, 0.into())
2572            .await?;
2573
2574        let actual0_ref = 3.into();
2575        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2576
2577        let actual1_ref = 4.into();
2578        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2579
2580        // Call the recording with valid tensors for the actual inputs,
2581        // and store the results in refs 5 and 6.
2582        let actual_result0_ref = 5.into();
2583        let actual_result1_ref = 6.into();
2584        test_setup
2585            .stream_actor
2586            .call_recording(
2587                &test_setup.client,
2588                0.into(),
2589                0.into(),
2590                vec![actual_result0_ref, actual_result1_ref],
2591                vec![actual0_ref, actual1_ref],
2592            )
2593            .await?;
2594
2595        // Ensure the results are correct.
2596        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2597        assert!(
2598            test_setup
2599                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2600                .await
2601        );
2602
2603        // Ensure the temporary refs associated with the formals/results have
2604        // been deleted.
2605        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2606        Ok(())
2607    }
2608
2609    #[async_timed_test(timeout_secs = 60)]
2610    async fn test_request_status_in_recording() -> Result<()> {
2611        let test_setup = TestSetup::new().await?;
2612        test_setup
2613            .stream_actor
2614            .define_recording(&test_setup.client, 0.into())
2615            .await?;
2616        test_setup
2617            .stream_actor
2618            .request_status(&test_setup.client)
2619            .await
2620            .expect_err("request_status should have failed");
2621        assert_actor_failed_with_msg(
2622            &test_setup.proc,
2623            test_setup.stream_actor.actor_id(),
2624            "request_status not allowed in recording".into(),
2625        )
2626        .await;
2627        Ok(())
2628    }
2629
2630    #[async_timed_test(timeout_secs = 60)]
2631    async fn test_init_comm_in_recording() -> Result<()> {
2632        let test_setup = TestSetup::new().await?;
2633        test_setup
2634            .stream_actor
2635            .define_recording(&test_setup.client, 0.into())
2636            .await?;
2637
2638        let dummy_comm = test_setup
2639            .proc
2640            .spawn::<NcclCommActor>(
2641                "comm",
2642                CommParams::New {
2643                    device: CudaDevice::new(0.into()),
2644                    unique_id: UniqueId::new()?,
2645                    world_size: 1,
2646                    rank: 0,
2647                },
2648            )
2649            .await?;
2650
2651        test_setup
2652            .stream_actor
2653            .init_comm(&test_setup.client, dummy_comm)
2654            .await?;
2655        assert_actor_failed_with_msg(
2656            &test_setup.proc,
2657            test_setup.stream_actor.actor_id(),
2658            "init_comm not allowed in recording".into(),
2659        )
2660        .await;
2661        Ok(())
2662    }
2663
2664    #[async_timed_test(timeout_secs = 60)]
2665    async fn test_call_function_in_recording() -> Result<()> {
2666        let mut test_setup = TestSetup::new().await?;
2667
2668        // Define a recording equivalent to:
2669        // def f(x, y):
2670        //   w = x + y
2671        //   nonlocal z
2672        //   z.add_(1.0)
2673        //   return w + z
2674        test_setup
2675            .stream_actor
2676            .define_recording(&test_setup.client, 0.into())
2677            .await?;
2678
2679        let formal0_ref = test_setup.next_ref();
2680        let formal0_index = 0;
2681        test_setup
2682            .stream_actor
2683            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2684            .await?;
2685
2686        let formal1_ref = test_setup.next_ref();
2687        let formal1_index = 1;
2688        test_setup
2689            .stream_actor
2690            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2691            .await?;
2692
2693        let captured_ref = test_setup.next_ref();
2694        let result_captured_ref = test_setup.next_ref();
2695        let add_one_function =
2696            ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into());
2697        let add_tensors_function =
2698            ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into());
2699
2700        let add_result_ref_0 = test_setup.next_ref();
2701        test_setup
2702            .stream_actor
2703            .call_function(
2704                &test_setup.client,
2705                CallFunctionParams {
2706                    seq: 100.into(),
2707                    function: add_tensors_function.clone(),
2708                    args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)],
2709                    kwargs: HashMap::new(),
2710                    results: vec![Some(add_result_ref_0)],
2711                    mutates: vec![],
2712                    stream: 0.into(),
2713                    remote_process_groups: Vec::new(),
2714                },
2715                HashMap::new(),
2716                HashMap::new(),
2717            )
2718            .await?;
2719
2720        test_setup
2721            .stream_actor
2722            .call_function(
2723                &test_setup.client,
2724                CallFunctionParams {
2725                    seq: 101.into(),
2726                    function: add_one_function,
2727                    args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)],
2728                    kwargs: HashMap::new(),
2729                    results: vec![Some(result_captured_ref)],
2730                    mutates: vec![captured_ref],
2731                    stream: 0.into(),
2732                    remote_process_groups: Vec::new(),
2733                },
2734                HashMap::new(),
2735                HashMap::new(),
2736            )
2737            .await?;
2738
2739        let add_result_ref_1 = test_setup.next_ref();
2740        test_setup
2741            .stream_actor
2742            .call_function(
2743                &test_setup.client,
2744                CallFunctionParams {
2745                    seq: 102.into(),
2746                    function: add_tensors_function,
2747                    args: vec![
2748                        WireValue::Ref(add_result_ref_0),
2749                        WireValue::Ref(captured_ref),
2750                    ],
2751                    kwargs: HashMap::new(),
2752                    results: vec![Some(add_result_ref_1)],
2753                    mutates: vec![],
2754                    stream: 0.into(),
2755                    remote_process_groups: Vec::new(),
2756                },
2757                HashMap::new(),
2758                HashMap::new(),
2759            )
2760            .await?;
2761
2762        test_setup
2763            .stream_actor
2764            .recording_result(&test_setup.client, add_result_ref_1, 0)
2765            .await?;
2766
2767        test_setup
2768            .stream_actor
2769            .delete_refs(
2770                &test_setup.client,
2771                vec![add_result_ref_0, add_result_ref_1, result_captured_ref],
2772            )
2773            .await?;
2774
2775        test_setup
2776            .stream_actor
2777            .finalize_recording(&test_setup.client, 0.into())
2778            .await?;
2779
2780        let actual0_ref = test_setup.next_ref();
2781        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2782
2783        let actual1_ref = test_setup.next_ref();
2784        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2785
2786        test_setup
2787            .set_tensor(captured_ref, &[7.0, 8.0, 9.0])
2788            .await?;
2789
2790        let actual_result_ref = test_setup.next_ref();
2791        test_setup
2792            .stream_actor
2793            .call_recording(
2794                &test_setup.client,
2795                0.into(),
2796                0.into(),
2797                vec![actual_result_ref],
2798                vec![actual0_ref, actual1_ref],
2799            )
2800            .await?;
2801
2802        assert!(
2803            test_setup
2804                .allclose(actual_result_ref, &[13.0, 16.0, 19.0])
2805                .await
2806        );
2807
2808        // Set actual1_tensor to a bad shape which will cause the recording to fail.
2809        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2810
2811        let actual_result_ref = test_setup.next_ref();
2812        test_setup
2813            .stream_actor
2814            .call_recording(
2815                &test_setup.client,
2816                1.into(),
2817                0.into(),
2818                vec![actual_result_ref],
2819                vec![actual0_ref, actual1_ref],
2820            )
2821            .await?;
2822
2823        // Both inputs should still be valid.
2824        for ref_ in [actual0_ref, actual1_ref] {
2825            let _ = test_setup
2826                .stream_actor
2827                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2828                .await?
2829                .unwrap()
2830                .unwrap();
2831        }
2832
2833        for ref_ in [captured_ref, actual_result_ref] {
2834            let result_error = test_setup
2835                .stream_actor
2836                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2837                .await?
2838                .unwrap()
2839                .unwrap_err();
2840            // Check that the error contains the expected strings
2841            let error_str = result_error.to_string();
2842            assert!(
2843                error_str.contains("torch operator error"),
2844                "Error should contain 'torch operator failed': {}",
2845                error_str
2846            );
2847        }
2848
2849        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
2850        match controller_msg {
2851            ControllerMessage::RemoteFunctionFailed { seq, error } => {
2852                assert_eq!(seq, 1.into());
2853                assert!(
2854                    error.backtrace.contains("torch operator error"),
2855                    "Unexpected WorkerError: {:?}",
2856                    error
2857                );
2858            }
2859            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2860        };
2861
2862        // Reset input tensor to a valid shape.
2863        test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?;
2864
2865        // captured_tensor should still have an error, so calling
2866        // the recording should set DependentErrors and not report
2867        // anything to the controller.
2868        let actual_result_ref = test_setup.next_ref();
2869        test_setup
2870            .stream_actor
2871            .call_recording(
2872                &test_setup.client,
2873                2.into(),
2874                0.into(),
2875                vec![actual_result_ref],
2876                vec![actual0_ref, actual1_ref],
2877            )
2878            .await?;
2879
2880        // Both inputs should still be valid.
2881        for ref_ in [actual0_ref, actual1_ref] {
2882            let _ = test_setup
2883                .stream_actor
2884                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2885                .await?
2886                .unwrap()
2887                .unwrap();
2888        }
2889
2890        for ref_ in [captured_ref, actual_result_ref] {
2891            let result_error = test_setup
2892                .stream_actor
2893                .get_tensor_ref_unit_tests_only(&test_setup.client, ref_)
2894                .await?
2895                .unwrap()
2896                .unwrap_err();
2897            // Check that the error contains the expected strings
2898            let error_str = result_error.to_string();
2899            assert!(
2900                error_str.contains("torch operator error"),
2901                "Error should contain input error: {}",
2902                error_str
2903            );
2904        }
2905
2906        // This tests that the DependentError was never reported to the controller.
2907        // If it were reported to the controller, the next message would match
2908        // RemoteFunctionFailed instead of FetchResult.
2909        check_fetch_result_error(
2910            &test_setup.client,
2911            test_setup.stream_actor.clone(),
2912            3.into(),
2913            captured_ref,
2914            &mut test_setup.controller_rx,
2915            "torch operator error",
2916        )
2917        .await;
2918
2919        Ok(())
2920    }
2921
2922    #[async_timed_test(timeout_secs = 60)]
2923    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2924        let mut test_setup = TestSetup::new().await?;
2925        test_setup
2926            .stream_actor
2927            .define_recording(&test_setup.client, 0.into())
2928            .await?;
2929
2930        let borrow_id = 1;
2931        let tensor_ref = test_setup.next_ref();
2932        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2933
2934        test_setup
2935            .stream_actor
2936            .borrow_create(
2937                &test_setup.client,
2938                borrow_id,
2939                tensor_ref,
2940                first_use_sender.clone(),
2941            )
2942            .await?;
2943
2944        test_setup
2945            .stream_actor
2946            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2947            .await?;
2948
2949        assert_actor_failed_with_msg(
2950            &test_setup.proc,
2951            test_setup.stream_actor.actor_id(),
2952            "duplicate borrow create in recording".into(),
2953        )
2954        .await;
2955
2956        Ok(())
2957    }
2958
2959    #[async_timed_test(timeout_secs = 60)]
2960    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2961        let test_setup = TestSetup::new().await?;
2962        test_setup
2963            .stream_actor
2964            .define_recording(&test_setup.client, 0.into())
2965            .await?;
2966
2967        let borrow_id = 1;
2968        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2969
2970        test_setup
2971            .stream_actor
2972            .borrow_drop(
2973                &test_setup.client,
2974                borrow_id,
2975                Arc::new(Mutex::new(last_use_receiver)),
2976            )
2977            .await?;
2978
2979        assert_actor_failed_with_msg(
2980            &test_setup.proc,
2981            test_setup.stream_actor.actor_id(),
2982            "borrow drop for borrow not defined in recording".into(),
2983        )
2984        .await;
2985
2986        Ok(())
2987    }
2988
2989    #[async_timed_test(timeout_secs = 60)]
2990    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2991        let mut test_setup = TestSetup::new().await?;
2992        test_setup
2993            .stream_actor
2994            .define_recording(&test_setup.client, 0.into())
2995            .await?;
2996
2997        let borrow_id = 1;
2998        let tensor_ref = test_setup.next_ref();
2999        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
3000
3001        test_setup
3002            .stream_actor
3003            .borrow_create(
3004                &test_setup.client,
3005                borrow_id,
3006                tensor_ref,
3007                first_use_sender.clone(),
3008            )
3009            .await?;
3010
3011        // Attempt to finalize the recording without dropping the borrow
3012        test_setup
3013            .stream_actor
3014            .finalize_recording(&test_setup.client, 0.into())
3015            .await?;
3016
3017        assert_actor_failed_with_msg(
3018            &test_setup.proc,
3019            test_setup.stream_actor.actor_id(),
3020            "all borrows created within recording must be dropped within recording".into(),
3021        )
3022        .await;
3023
3024        Ok(())
3025    }
3026
3027    #[async_timed_test(timeout_secs = 60)]
3028    async fn test_borrow_in_recording() -> Result<()> {
3029        let mut test_setup = TestSetup::new().await?;
3030
3031        let borrower_stream = test_setup
3032            .proc
3033            .spawn::<StreamActor>(
3034                "stream1",
3035                StreamParams {
3036                    world_size: 1,
3037                    rank: 0,
3038                    creation_mode: StreamCreationMode::CreateNewStream,
3039                    id: 1.into(),
3040                    device: Some(CudaDevice::new(0.into())),
3041                    controller_actor: test_setup.controller_actor.clone(),
3042                    respond_with_python_message: false,
3043                },
3044            )
3045            .await?;
3046
3047        let lender_stream = test_setup.stream_actor.clone();
3048
3049        let borrow_id = 1;
3050        let (first_use_sender, first_use_receiver) = test_setup.client.open_port();
3051        let (last_use_sender, last_use_receiver) = test_setup.client.open_port();
3052
3053        // Stream 1: Define a recording that creates a borrow and drops it.
3054        lender_stream
3055            .define_recording(&test_setup.client, 0.into())
3056            .await?;
3057
3058        let formal_ref = test_setup.next_ref();
3059        lender_stream
3060            .recording_formal(&test_setup.client, formal_ref, 0)
3061            .await?;
3062
3063        lender_stream
3064            .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender)
3065            .await?;
3066
3067        lender_stream
3068            .borrow_drop(
3069                &test_setup.client,
3070                borrow_id,
3071                Arc::new(Mutex::new(last_use_receiver)),
3072            )
3073            .await?;
3074
3075        lender_stream
3076            .finalize_recording(&test_setup.client, 0.into())
3077            .await?;
3078
3079        let borrower_tensor_ref = test_setup.next_ref();
3080        let borrower_tensor = TensorCell::new(factory_float_tensor(
3081            &[1.0, 2.0, 3.0],
3082            "cuda".try_into().unwrap(),
3083        ));
3084
3085        borrower_stream
3086            .set_tensor_ref_unit_tests_only(
3087                &test_setup.client,
3088                borrower_tensor_ref,
3089                Ok(borrower_tensor.clone()),
3090            )
3091            .await?;
3092
3093        // Stream 2: Define a recording that uses the borrow from Stream 1.
3094        borrower_stream
3095            .define_recording(&test_setup.client, 0.into())
3096            .await?;
3097
3098        let borrowed_ref = test_setup.next_ref();
3099
3100        borrower_stream
3101            .borrow_first_use(
3102                &test_setup.client,
3103                borrow_id,
3104                borrowed_ref,
3105                Arc::new(Mutex::new(first_use_receiver)),
3106            )
3107            .await?;
3108
3109        let result_ref = test_setup.next_ref();
3110        borrower_stream
3111            .call_function(
3112                &test_setup.client,
3113                CallFunctionParams {
3114                    seq: 100.into(),
3115                    function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()),
3116                    args: vec![
3117                        WireValue::Ref(borrowed_ref),
3118                        WireValue::Ref(borrower_tensor_ref),
3119                    ],
3120                    kwargs: HashMap::new(),
3121                    results: vec![Some(result_ref)],
3122                    mutates: vec![],
3123                    stream: 1.into(),
3124                    remote_process_groups: Vec::new(),
3125                },
3126                HashMap::new(),
3127                HashMap::new(),
3128            )
3129            .await?;
3130
3131        borrower_stream
3132            .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender)
3133            .await?;
3134
3135        borrower_stream
3136            .recording_result(&test_setup.client, result_ref, 0)
3137            .await?;
3138
3139        borrower_stream
3140            .finalize_recording(&test_setup.client, 0.into())
3141            .await?;
3142
3143        // Set up a tensor in the lender stream and call the recording.
3144        let input_tensor_ref = test_setup.next_ref();
3145        test_setup
3146            .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0])
3147            .await?;
3148
3149        let result_tensor_ref = test_setup.next_ref();
3150
3151        let lender_future = lender_stream.call_recording(
3152            &test_setup.client,
3153            0.into(),
3154            0.into(),
3155            vec![],
3156            vec![input_tensor_ref],
3157        );
3158
3159        let borrower_future = borrower_stream.call_recording(
3160            &test_setup.client,
3161            0.into(),
3162            0.into(),
3163            vec![result_tensor_ref],
3164            vec![],
3165        );
3166
3167        tokio::try_join!(lender_future, borrower_future)?;
3168
3169        let result_tensor = borrower_stream
3170            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3171            .await?
3172            .unwrap()
3173            .unwrap();
3174
3175        let expected_tensor = TensorCell::new(factory_float_tensor(
3176            &[5.0, 7.0, 9.0],
3177            "cpu".try_into().unwrap(),
3178        ));
3179        assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap());
3180
3181        // Set borrower_tensor to a tensor with only 2 elements to cause a failure.
3182        let invalid_borrower_tensor = TensorCell::new(factory_float_tensor(
3183            &[1.0, 2.0],
3184            "cuda".try_into().unwrap(),
3185        ));
3186        borrower_stream
3187            .set_tensor_ref_unit_tests_only(
3188                &test_setup.client,
3189                borrower_tensor_ref,
3190                Ok(invalid_borrower_tensor.clone()),
3191            )
3192            .await?;
3193
3194        // Call the recording again.
3195        let lender_future = lender_stream.call_recording(
3196            &test_setup.client,
3197            1.into(),
3198            0.into(),
3199            vec![],
3200            vec![input_tensor_ref],
3201        );
3202
3203        let borrower_future = borrower_stream.call_recording(
3204            &test_setup.client,
3205            1.into(),
3206            0.into(),
3207            vec![result_tensor_ref],
3208            vec![],
3209        );
3210
3211        tokio::try_join!(lender_future, borrower_future)?;
3212
3213        // Check that the borrower_stream reports the error to the controller.
3214        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
3215        match controller_msg {
3216            ControllerMessage::RemoteFunctionFailed { seq, error } => {
3217                assert_eq!(seq, 1.into());
3218                assert!(
3219                    error.backtrace.contains("recording failed"),
3220                    "Unexpected WorkerError: {:?}",
3221                    error
3222                );
3223                assert_eq!(&error.worker_actor_id, borrower_stream.actor_id());
3224            }
3225            _ => panic!("Unexpected controller message: {:?}", controller_msg),
3226        };
3227
3228        // Check that no error was reported from the lender stream
3229        check_fetch_result_value(
3230            &test_setup.client,
3231            lender_stream.clone(),
3232            2.into(),
3233            input_tensor_ref,
3234            &mut test_setup.controller_rx,
3235        )
3236        .await;
3237
3238        // Set the recording's input tensor to an error.
3239        let input_error = fake_seq_error(anyhow!("input error"));
3240        lender_stream
3241            .set_tensor_ref_unit_tests_only(
3242                &test_setup.client,
3243                input_tensor_ref,
3244                Err(input_error.clone()),
3245            )
3246            .await?;
3247
3248        let lender_future = lender_stream.call_recording(
3249            &test_setup.client,
3250            3.into(),
3251            0.into(),
3252            vec![],
3253            vec![input_tensor_ref],
3254        );
3255
3256        let borrower_future = borrower_stream.call_recording(
3257            &test_setup.client,
3258            3.into(),
3259            0.into(),
3260            vec![result_tensor_ref],
3261            vec![],
3262        );
3263
3264        tokio::try_join!(lender_future, borrower_future)?;
3265
3266        // Verify that borrower_stream sets a CallFunctionError::DependentError on result_tensor_ref.
3267        let result_error = borrower_stream
3268            .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref)
3269            .await?
3270            .unwrap()
3271            .unwrap_err();
3272
3273        // Check that the error contains the expected strings
3274        let error_str = result_error.to_string();
3275        assert!(
3276            error_str.contains("input error"),
3277            "Error should contain input error: {}",
3278            error_str
3279        );
3280
3281        // Since we're checking for pointer equality in the original code, we need to ensure
3282        // the error is propagated correctly. We can check that the original error message is contained.
3283        let input_error_str = input_error.to_string();
3284        assert!(
3285            error_str.contains(&input_error_str),
3286            "Error should contain the original error: {}",
3287            error_str
3288        );
3289
3290        // Verify that neither stream sends a failure message to the controller.
3291        check_fetch_result_error(
3292            &test_setup.client,
3293            lender_stream,
3294            4.into(),
3295            input_tensor_ref,
3296            &mut test_setup.controller_rx,
3297            "input error",
3298        )
3299        .await;
3300
3301        // Verify that neither stream sends a failure message to the controller.
3302        check_fetch_result_error(
3303            &test_setup.client,
3304            borrower_stream,
3305            5.into(),
3306            result_tensor_ref,
3307            &mut test_setup.controller_rx,
3308            "input error",
3309        )
3310        .await;
3311
3312        Ok(())
3313    }
3314
3315    #[async_timed_test(timeout_secs = 60)]
3316    async fn test_reduce_in_recording() -> Result<()> {
3317        let mut test_setup = TestSetup::new().await?;
3318        let recording_ref = test_setup.next_ref();
3319
3320        let comm = Arc::new(
3321            test_setup
3322                .proc
3323                .spawn::<NcclCommActor>(
3324                    "comm",
3325                    CommParams::New {
3326                        device: CudaDevice::new(0.into()),
3327                        unique_id: UniqueId::new()?,
3328                        world_size: 1,
3329                        rank: 0,
3330                    },
3331                )
3332                .await?,
3333        );
3334
3335        let factory = Factory {
3336            size: vec![3],
3337            dtype: torch_sys::ScalarType::Float,
3338            layout: torch_sys::Layout::Strided,
3339            device: "cuda".try_into().unwrap(),
3340        };
3341
3342        let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum);
3343
3344        test_setup
3345            .stream_actor
3346            .define_recording(&test_setup.client, recording_ref)
3347            .await?;
3348
3349        let formal_tensor_ref_0 = test_setup.next_ref();
3350        let formal_tensor_ref_1 = test_setup.next_ref();
3351        let formal_tensor_ref_2 = test_setup.next_ref();
3352
3353        test_setup
3354            .stream_actor
3355            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3356            .await?;
3357        test_setup
3358            .stream_actor
3359            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3360            .await?;
3361        test_setup
3362            .stream_actor
3363            .recording_formal(&test_setup.client, formal_tensor_ref_2, 2)
3364            .await?;
3365
3366        let intermediate_tensor_ref_0 = test_setup.next_ref();
3367
3368        // Handle case with in_place = true.
3369        test_setup
3370            .stream_actor
3371            .reduce(
3372                &test_setup.client,
3373                comm.clone(),
3374                1,
3375                intermediate_tensor_ref_0,
3376                formal_tensor_ref_0,
3377                factory.clone(),
3378                reduction.clone(),
3379                false,
3380                true,
3381                None,
3382            )
3383            .await?;
3384
3385        // Handle case with in_place = false and out = None.
3386        let intermediate_tensor_ref_1 = test_setup.next_ref();
3387        test_setup
3388            .stream_actor
3389            .reduce(
3390                &test_setup.client,
3391                comm.clone(),
3392                1,
3393                intermediate_tensor_ref_1,
3394                formal_tensor_ref_1,
3395                factory.clone(),
3396                reduction.clone(),
3397                false,
3398                false,
3399                None,
3400            )
3401            .await?;
3402
3403        let intermediate_tensor_ref_2 = test_setup.next_ref();
3404
3405        // Third reduce call with out = formal_tensor_ref_2
3406        test_setup
3407            .stream_actor
3408            .reduce(
3409                &test_setup.client,
3410                comm.clone(),
3411                1,
3412                intermediate_tensor_ref_2,
3413                intermediate_tensor_ref_1,
3414                factory.clone(),
3415                reduction.clone(),
3416                false,
3417                false,
3418                Some(formal_tensor_ref_2),
3419            )
3420            .await?;
3421
3422        test_setup
3423            .stream_actor
3424            .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0)
3425            .await?;
3426
3427        test_setup
3428            .stream_actor
3429            .finalize_recording(&test_setup.client, recording_ref)
3430            .await?;
3431
3432        let input_tensor_ref_0 = test_setup.next_ref();
3433        let input_tensor_ref_1 = test_setup.next_ref();
3434        let input_tensor_ref_2 = test_setup.next_ref();
3435
3436        test_setup
3437            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3438            .await?;
3439
3440        test_setup
3441            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3442            .await?;
3443
3444        test_setup
3445            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3446            .await?;
3447
3448        let output_ref = test_setup.next_ref();
3449
3450        test_setup
3451            .stream_actor
3452            .call_recording(
3453                &test_setup.client,
3454                0.into(),
3455                recording_ref,
3456                vec![output_ref],
3457                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3458            )
3459            .await?;
3460
3461        // Validate that input_tensor_ref_0 is unchanged.
3462        assert!(
3463            test_setup
3464                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3465                .await
3466        );
3467        // All the other inputs/outputs should be equal to input 1
3468        for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] {
3469            assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await);
3470        }
3471
3472        // Set an error on input 0
3473        let input_error = fake_seq_error(anyhow!("input error"));
3474        test_setup
3475            .stream_actor
3476            .set_tensor_ref_unit_tests_only(
3477                &test_setup.client,
3478                input_tensor_ref_0,
3479                Err(input_error.clone()),
3480            )
3481            .await?;
3482
3483        test_setup
3484            .stream_actor
3485            .call_recording(
3486                &test_setup.client,
3487                1.into(),
3488                recording_ref,
3489                vec![output_ref],
3490                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3491            )
3492            .await?;
3493
3494        // Verify that input_tensor_ref_0, input_tensor_ref_2, and output_ref have a dependent error.
3495        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3496            test_setup
3497                .validate_dependent_error(ref_, input_error.clone())
3498                .await;
3499        }
3500
3501        // Verify that input_tensor_ref_1 is untouched.
3502        assert!(
3503            test_setup
3504                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3505                .await
3506        );
3507
3508        // Verify that no failure was reported to the controller.
3509        check_fetch_result_value(
3510            &test_setup.client,
3511            test_setup.stream_actor.clone(),
3512            2.into(),
3513            input_tensor_ref_1,
3514            &mut test_setup.controller_rx,
3515        )
3516        .await;
3517
3518        // Reset input tensors 0 and 2 to their original values
3519        test_setup
3520            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3521            .await?;
3522        test_setup
3523            .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0])
3524            .await?;
3525
3526        // Set an error on input tensor 1
3527        test_setup
3528            .stream_actor
3529            .set_tensor_ref_unit_tests_only(
3530                &test_setup.client,
3531                input_tensor_ref_1,
3532                Err(input_error.clone()),
3533            )
3534            .await?;
3535
3536        test_setup
3537            .stream_actor
3538            .call_recording(
3539                &test_setup.client,
3540                3.into(),
3541                recording_ref,
3542                vec![output_ref],
3543                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3544            )
3545            .await?;
3546
3547        // Validate that the mutated inputs and the output have a dependent error containing
3548        // the input error
3549        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3550            test_setup
3551                .validate_dependent_error(ref_, input_error.clone())
3552                .await;
3553        }
3554
3555        // Validate that no error was reported to the controller
3556        check_fetch_result_error(
3557            &test_setup.client,
3558            test_setup.stream_actor.clone(),
3559            4.into(),
3560            input_tensor_ref_1,
3561            &mut test_setup.controller_rx,
3562            "input error",
3563        )
3564        .await;
3565
3566        // Reset input tensors 0 and 1 to their original values
3567        test_setup
3568            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3569            .await?;
3570        test_setup
3571            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3572            .await?;
3573
3574        // Set an error on input tensor 2
3575        test_setup
3576            .stream_actor
3577            .set_tensor_ref_unit_tests_only(
3578                &test_setup.client,
3579                input_tensor_ref_2,
3580                Err(input_error.clone()),
3581            )
3582            .await?;
3583
3584        test_setup
3585            .stream_actor
3586            .call_recording(
3587                &test_setup.client,
3588                5.into(),
3589                recording_ref,
3590                vec![output_ref],
3591                vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2],
3592            )
3593            .await?;
3594
3595        // Validate that input tensor 1 has its original values
3596        assert!(
3597            test_setup
3598                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3599                .await
3600        );
3601
3602        // Validate that the mutated inputs and the output have a dependent error containing
3603        // the input error
3604        for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] {
3605            test_setup
3606                .validate_dependent_error(ref_, input_error.clone())
3607                .await;
3608        }
3609
3610        // Validate that no error was reported to the controller
3611        check_fetch_result_value(
3612            &test_setup.client,
3613            test_setup.stream_actor.clone(),
3614            6.into(),
3615            input_tensor_ref_1,
3616            &mut test_setup.controller_rx,
3617        )
3618        .await;
3619
3620        Ok(())
3621    }
3622
3623    #[async_timed_test(timeout_secs = 60)]
3624    async fn test_send_tensor_in_recording() -> Result<()> {
3625        let mut test_setup = TestSetup::new_with_world_size(2).await?;
3626        let recording_ref = test_setup.next_ref();
3627
3628        let unique_id = UniqueId::new()?;
3629        let comm0 = test_setup.proc.spawn::<NcclCommActor>(
3630            "comm0",
3631            CommParams::New {
3632                device: CudaDevice::new(0.into()),
3633                unique_id: unique_id.clone(),
3634                world_size: 2,
3635                rank: 0,
3636            },
3637        );
3638        let comm1 = test_setup.proc.spawn::<NcclCommActor>(
3639            "comm1",
3640            CommParams::New {
3641                device: CudaDevice::new(1.into()),
3642                unique_id,
3643                world_size: 2,
3644                rank: 1,
3645            },
3646        );
3647        let (comm0, comm1) = tokio::try_join!(comm0, comm1)?;
3648        let comm0 = Arc::new(comm0);
3649        let comm1 = Arc::new(comm1);
3650
3651        let factory = Factory {
3652            size: vec![3],
3653            dtype: torch_sys::ScalarType::Float,
3654            layout: torch_sys::Layout::Strided,
3655            device: "cuda".try_into().unwrap(),
3656        };
3657
3658        let send_stream = test_setup.stream_actor.clone();
3659        let recv_stream = test_setup
3660            .proc
3661            .spawn::<StreamActor>(
3662                "recv_stream",
3663                StreamParams {
3664                    world_size: 2,
3665                    rank: 1,
3666                    creation_mode: StreamCreationMode::CreateNewStream,
3667                    id: 1.into(),
3668                    device: Some(CudaDevice::new(1.into())),
3669                    controller_actor: test_setup.controller_actor.clone(),
3670                    respond_with_python_message: false,
3671                },
3672            )
3673            .await?;
3674
3675        send_stream
3676            .define_recording(&test_setup.client, recording_ref)
3677            .await?;
3678        recv_stream
3679            .define_recording(&test_setup.client, recording_ref)
3680            .await?;
3681
3682        let formal_tensor_ref_0 = test_setup.next_ref();
3683        let formal_tensor_ref_1 = test_setup.next_ref();
3684
3685        send_stream
3686            .recording_formal(&test_setup.client, formal_tensor_ref_0, 0)
3687            .await?;
3688        send_stream
3689            .recording_formal(&test_setup.client, formal_tensor_ref_1, 1)
3690            .await?;
3691
3692        let _ref = test_setup.next_ref();
3693        send_stream
3694            .send_tensor(
3695                &test_setup.client,
3696                _ref,
3697                None,
3698                Some(1),
3699                formal_tensor_ref_0,
3700                factory.clone(),
3701                comm0.clone(),
3702            )
3703            .await?;
3704
3705        let result_ref_0 = test_setup.next_ref();
3706        let _ref = test_setup.next_ref();
3707        recv_stream
3708            .send_tensor(
3709                &test_setup.client,
3710                result_ref_0,
3711                Some(0),
3712                None,
3713                _ref,
3714                factory.clone(),
3715                comm1,
3716            )
3717            .await?;
3718
3719        let result_ref_1 = test_setup.next_ref();
3720        send_stream
3721            .send_tensor(
3722                &test_setup.client,
3723                result_ref_1,
3724                Some(0),
3725                Some(0),
3726                formal_tensor_ref_1,
3727                factory.clone(),
3728                comm0,
3729            )
3730            .await?;
3731
3732        send_stream
3733            .recording_result(&test_setup.client, result_ref_1, 0)
3734            .await?;
3735        recv_stream
3736            .recording_result(&test_setup.client, result_ref_0, 0)
3737            .await?;
3738
3739        send_stream
3740            .finalize_recording(&test_setup.client, recording_ref)
3741            .await?;
3742        recv_stream
3743            .finalize_recording(&test_setup.client, recording_ref)
3744            .await?;
3745
3746        let input_tensor_ref_0 = test_setup.next_ref();
3747        let input_tensor_ref_1 = test_setup.next_ref();
3748        test_setup
3749            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3750            .await?;
3751        test_setup
3752            .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3753            .await?;
3754
3755        let actual_result_ref_0 = test_setup.next_ref();
3756        let actual_result_ref_1 = test_setup.next_ref();
3757        let send_fut = send_stream.call_recording(
3758            &test_setup.client,
3759            0.into(),
3760            recording_ref,
3761            vec![actual_result_ref_1],
3762            vec![input_tensor_ref_0, input_tensor_ref_1],
3763        );
3764        let recv_fut = recv_stream.call_recording(
3765            &test_setup.client,
3766            0.into(),
3767            recording_ref,
3768            vec![actual_result_ref_0],
3769            vec![],
3770        );
3771        tokio::try_join!(send_fut, recv_fut)?;
3772
3773        assert!(
3774            test_setup
3775                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3776                .await
3777        );
3778        assert!(
3779            test_setup
3780                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3781                .await
3782        );
3783        assert!(
3784            test_setup
3785                .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0])
3786                .await
3787        );
3788
3789        let actual_result_0 = recv_stream
3790            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3791            .await
3792            .unwrap()
3793            .unwrap()
3794            .unwrap();
3795        assert!(allclose(
3796            &actual_result_0.borrow(),
3797            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3798        )?);
3799
3800        // Validate that failure wasn't reported to controller.
3801        check_fetch_result_value(
3802            &test_setup.client,
3803            send_stream.clone(),
3804            1.into(),
3805            actual_result_ref_1,
3806            &mut test_setup.controller_rx,
3807        )
3808        .await;
3809        check_fetch_result_value(
3810            &test_setup.client,
3811            recv_stream.clone(),
3812            2.into(),
3813            actual_result_ref_0,
3814            &mut test_setup.controller_rx,
3815        )
3816        .await;
3817
3818        let input_error = fake_seq_error(anyhow!("input error"));
3819        send_stream
3820            .set_tensor_ref_unit_tests_only(
3821                &test_setup.client,
3822                input_tensor_ref_0,
3823                Err(input_error.clone()),
3824            )
3825            .await?;
3826
3827        let send_fut = send_stream.call_recording(
3828            &test_setup.client,
3829            3.into(),
3830            recording_ref,
3831            vec![actual_result_ref_1],
3832            vec![input_tensor_ref_0, input_tensor_ref_1],
3833        );
3834        let recv_fut = recv_stream.call_recording(
3835            &test_setup.client,
3836            3.into(),
3837            recording_ref,
3838            vec![actual_result_ref_0],
3839            vec![],
3840        );
3841        tokio::try_join!(send_fut, recv_fut)?;
3842
3843        // The result on recv_stream should have a value, but it will be garbage.
3844        let _ = recv_stream
3845            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3846            .await
3847            .unwrap()
3848            .unwrap()
3849            .unwrap();
3850
3851        test_setup
3852            .validate_dependent_error(actual_result_ref_1, input_error.clone())
3853            .await;
3854
3855        // Input 1 should be untouched.
3856        assert!(
3857            test_setup
3858                .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0])
3859                .await
3860        );
3861
3862        // Validate that failure wasn't reported to controller.
3863        check_fetch_result_error(
3864            &test_setup.client,
3865            send_stream.clone(),
3866            4.into(),
3867            actual_result_ref_1,
3868            &mut test_setup.controller_rx,
3869            "input error",
3870        )
3871        .await;
3872        check_fetch_result_value(
3873            &test_setup.client,
3874            recv_stream.clone(),
3875            5.into(),
3876            actual_result_ref_0,
3877            &mut test_setup.controller_rx,
3878        )
3879        .await;
3880
3881        test_setup
3882            .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3883            .await?;
3884        send_stream
3885            .set_tensor_ref_unit_tests_only(
3886                &test_setup.client,
3887                input_tensor_ref_1,
3888                Err(input_error.clone()),
3889            )
3890            .await?;
3891
3892        let send_fut = send_stream.call_recording(
3893            &test_setup.client,
3894            6.into(),
3895            recording_ref,
3896            vec![actual_result_ref_1],
3897            vec![input_tensor_ref_0, input_tensor_ref_1],
3898        );
3899        let recv_fut = recv_stream.call_recording(
3900            &test_setup.client,
3901            6.into(),
3902            recording_ref,
3903            vec![actual_result_ref_0],
3904            vec![],
3905        );
3906        tokio::try_join!(send_fut, recv_fut)?;
3907
3908        let actual_result_0 = recv_stream
3909            .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0)
3910            .await
3911            .unwrap()
3912            .unwrap()
3913            .unwrap();
3914        assert!(allclose(
3915            &actual_result_0.borrow(),
3916            &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap())
3917        )?);
3918
3919        assert!(
3920            test_setup
3921                .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0])
3922                .await
3923        );
3924
3925        test_setup
3926            .validate_dependent_error(actual_result_ref_1, input_error)
3927            .await;
3928
3929        // Validate that failure wasn't reported to controller.
3930        check_fetch_result_error(
3931            &test_setup.client,
3932            send_stream.clone(),
3933            7.into(),
3934            actual_result_ref_1,
3935            &mut test_setup.controller_rx,
3936            "input error",
3937        )
3938        .await;
3939        check_fetch_result_value(
3940            &test_setup.client,
3941            recv_stream.clone(),
3942            8.into(),
3943            actual_result_ref_0,
3944            &mut test_setup.controller_rx,
3945        )
3946        .await;
3947
3948        Ok(())
3949    }
3950
3951    #[async_timed_test(timeout_secs = 60)]
3952    async fn test_set_value_in_recording_valid_pipe() -> Result<()> {
3953        let mut test_setup = TestSetup::new().await?;
3954
3955        let (pipe_tx, mut pipe_rx) = test_setup.client.open_port();
3956
3957        let recording_ref = test_setup.next_ref();
3958        test_setup
3959            .stream_actor
3960            .define_recording(&test_setup.client, recording_ref)
3961            .await?;
3962
3963        let result_ref_0 = test_setup.next_ref();
3964
3965        test_setup
3966            .stream_actor
3967            .set_value(
3968                &test_setup.client,
3969                0.into(),
3970                vec![Some(result_ref_0)],
3971                pipe_tx,
3972            )
3973            .await?;
3974
3975        test_setup
3976            .stream_actor
3977            .recording_result(&test_setup.client, result_ref_0, 0)
3978            .await?;
3979
3980        test_setup
3981            .stream_actor
3982            .finalize_recording(&test_setup.client, recording_ref)
3983            .await?;
3984
3985        let real_result_ref = test_setup.next_ref();
3986        let recording_fut = test_setup.stream_actor.call_recording(
3987            &test_setup.client,
3988            0.into(),
3989            recording_ref,
3990            vec![real_result_ref],
3991            vec![],
3992        );
3993
3994        let pipe_fut = async {
3995            let msg = pipe_rx.recv().await.unwrap();
3996            match msg {
3997                PipeMessage::RecvValue(tx) => {
3998                    tx.send(PyTree::from(RValue::Tensor(TensorCell::new(
3999                        factory_float_tensor(&[1.0, 2.0, 3.0], "cuda".try_into().unwrap()),
4000                    ))))
4001                    .unwrap();
4002                }
4003                _ => panic!("Unexpected message"),
4004            }
4005            Ok(())
4006        };
4007
4008        tokio::try_join!(recording_fut, pipe_fut)?;
4009
4010        assert!(test_setup.allclose(real_result_ref, &[1.0, 2.0, 3.0]).await);
4011
4012        // This will cause the next call to set_value to fail.
4013        drop(pipe_rx);
4014
4015        let real_result_ref = test_setup.next_ref();
4016        test_setup
4017            .stream_actor
4018            .call_recording(
4019                &test_setup.client,
4020                1.into(),
4021                recording_ref,
4022                vec![real_result_ref],
4023                vec![],
4024            )
4025            .await?;
4026
4027        let real_result_err = test_setup
4028            .stream_actor
4029            .get_tensor_ref_unit_tests_only(&test_setup.client, real_result_ref)
4030            .await?
4031            .unwrap()
4032            .unwrap_err();
4033        // Check that the error contains the expected string
4034        let error_str = real_result_err.to_string();
4035        assert!(
4036            error_str.contains("send error"),
4037            "Error should contain 'send error': {}",
4038            error_str
4039        );
4040
4041        let controller_msg = test_setup.controller_rx.recv().await.unwrap();
4042        match controller_msg {
4043            ControllerMessage::RemoteFunctionFailed { seq, error } => {
4044                assert_eq!(seq, 1.into());
4045                assert!(
4046                    error.backtrace.contains("send error"),
4047                    "Unexpected WorkerError: {:?}",
4048                    error
4049                );
4050            }
4051            _ => panic!("Unexpected controller message: {:?}", controller_msg),
4052        };
4053
4054        Ok(())
4055    }
4056}