Skip to main content

monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Context as _;
19use anyhow::Result;
20use anyhow::anyhow;
21use anyhow::bail;
22use anyhow::ensure;
23use async_trait::async_trait;
24use hyperactor as reference;
25use hyperactor::Actor;
26use hyperactor::Context;
27use hyperactor::Endpoint as _;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortHandle;
32use hyperactor::actor::ActorHandle;
33use hyperactor::handle;
34use hyperactor::id::Label;
35use hyperactor::mailbox::OncePortHandle;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::proc::Proc;
38use monarch_hyperactor::actor::PythonMessage;
39use monarch_hyperactor::actor::PythonMessageKind;
40use monarch_hyperactor::local_state_broker::BrokerId;
41use monarch_hyperactor::local_state_broker::LocalState;
42use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
43use monarch_hyperactor::pickle::pickle;
44use monarch_messages::controller::ControllerMessageClient;
45use monarch_messages::controller::Seq;
46use monarch_messages::controller::WorkerError;
47use monarch_messages::worker::ActorCallParams;
48use monarch_messages::worker::ActorMethodParams;
49use monarch_messages::worker::ArgsKwargs;
50use monarch_messages::worker::CallFunctionError;
51use monarch_messages::worker::CallFunctionParams;
52use monarch_messages::worker::SeqError;
53use monarch_messages::worker::StreamRef;
54use monarch_types::PyTree;
55use monarch_types::SerializablePyErr;
56use monarch_types::TryIntoPyObjectUnsafe;
57use pyo3::prelude::*;
58use tokio::runtime::Handle;
59use tokio::sync::Mutex;
60use tokio::task::JoinHandle;
61use torch_sys_cuda::cuda::Event;
62use torch_sys_cuda::cuda::Stream;
63use torch_sys2::CloneUnsafe;
64use torch_sys2::CudaDevice;
65use torch_sys2::TensorCell;
66use torch_sys2::deep_clone;
67use torch_sys2::factory_empty;
68use torch_sys2::factory_zeros;
69use tracing_subscriber::fmt::Subscriber;
70use typeuri::Named;
71
72use crate::ControllerActor;
73use crate::DeviceMesh;
74use crate::Factory;
75use crate::Reduction;
76use crate::Ref;
77use crate::ResolvableFunction;
78use crate::StreamCreationMode;
79use crate::WireValue;
80use crate::comm::CommMessage;
81use crate::comm::CommMessageClient;
82use crate::comm::NcclCommActor;
83
84pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
85
86// These thread locals are accessed by the python runtime for debugging sessions.
87thread_local! {
88    pub static CONTROLLER_ACTOR_REF: OnceCell<reference::ActorRef<ControllerActor>> = const { OnceCell::new() };
89    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
90    pub static ROOT_ACTOR_ID: OnceCell<reference::ActorId> = const { OnceCell::new() };
91}
92
93fn pickle_python_result(
94    py: Python<'_>,
95    result: Bound<'_, PyAny>,
96    worker_rank: usize,
97) -> Result<PythonMessage, anyhow::Error> {
98    let mut state = pickle(py, result.unbind(), false, false)
99        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
100    let inner = state
101        .take_inner()
102        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?;
103    Ok(PythonMessage::new_from_buf(
104        PythonMessageKind::Result {
105            rank: Some(worker_rank),
106        },
107        inner.take_buffer(),
108    ))
109}
110
111#[derive(Debug)]
112struct Recording {
113    messages: Vec<StreamMessage>,
114}
115
116impl Recording {
117    fn new() -> Self {
118        Self {
119            messages: Vec::new(),
120        }
121    }
122}
123
124#[derive(Debug, PartialEq)]
125enum RecordingState {
126    Defining {
127        recording: Ref,
128        // Set of borrow ids used to track proper borrow usage inside
129        // a recording.
130        defined_borrows: HashSet<u64>,
131    },
132    Running,
133}
134
135/// Messages handled by the stream. Generally these are stream-local versions of
136/// [`crate::WorkerMessage`].
137#[derive(Handler, HandleClient, Debug, Named)]
138pub enum StreamMessage {
139    CallFunction(
140        CallFunctionParams,
141        HashMap<Ref, DeviceMesh>,
142        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
143    ),
144
145    BorrowCreate {
146        /// Id for the borrow.
147        borrow: u64,
148        /// Tensor to borrow.
149        tensor: Ref,
150        /// Port for sending the first use CUDA event + borrowed tensor to
151        /// the borrower.
152        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
153    },
154
155    BorrowFirstUse {
156        /// Id for the borrow.
157        borrow: u64,
158        /// Ref for storing the borrowed tensor.
159        result: Ref,
160        /// Port for receiving the first use CUDA event + borrowed tensor from
161        /// the provider stream.
162        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
163    },
164
165    BorrowLastUse {
166        /// Id for the borrow.
167        borrow: u64,
168        /// Ref for the borrowed tensor.
169        result: Ref,
170        /// Port for sending the last use CUDA event and borrowed tensor.
171        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
172    },
173
174    BorrowDrop {
175        borrow: u64,
176        /// Port for receiving the last use CUDA event and borrowed tensor.
177        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
178    },
179
180    DeleteRefs(Vec<Ref>),
181
182    RequestStatus(#[reply] OncePortHandle<()>),
183
184    InitComm(ActorHandle<NcclCommActor>),
185
186    Reduce {
187        comm: Arc<ActorHandle<NcclCommActor>>,
188        dim_size: i64,
189        result: Ref,
190        local_tensor: Ref,
191        factory: Factory,
192        reduction: Reduction,
193        scatter: bool,
194        in_place: bool,
195        out: Option<Ref>,
196    },
197
198    SendTensor {
199        result: Ref,
200        from_rank: Option<usize>,
201        to_rank: Option<usize>,
202        tensor: Ref,
203        factory: Factory,
204        comm: Option<Arc<ActorHandle<NcclCommActor>>>,
205    },
206
207    SendValue {
208        seq: Seq,
209        worker_actor_id: reference::ActorAddr,
210        mutates: Vec<Ref>,
211        function: Option<ResolvableFunction>,
212        args_kwargs: ArgsKwargs,
213        device_meshes: HashMap<Ref, DeviceMesh>,
214    },
215
216    DefineRecording {
217        recording: Ref,
218    },
219
220    FinalizeRecording {
221        recording: Ref,
222    },
223
224    CallRecording {
225        seq: Seq,
226        recording: Ref,
227        results: Vec<Ref>,
228        actuals: Vec<Ref>,
229    },
230
231    RecordingFormal {
232        result: Ref,
233        argument_index: usize,
234    },
235
236    RecordingResult {
237        result: Ref,
238        output_index: usize,
239    },
240
241    SetRefUnitTestsOnly(Ref, WireValue),
242
243    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
244
245    GetRefUnitTestsOnly(
246        Ref, // value
247        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
248    ),
249
250    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
251
252    SendResultOfActorCall(ActorCallParams),
253    CallActorMethod(ActorMethodParams),
254}
255
256impl StreamMessage {
257    fn clone_for_recording(&self) -> Self {
258        match self {
259            StreamMessage::RecordingFormal {
260                result,
261                argument_index,
262            } => StreamMessage::RecordingFormal {
263                result: *result,
264                argument_index: *argument_index,
265            },
266            StreamMessage::RecordingResult {
267                result,
268                output_index,
269            } => StreamMessage::RecordingResult {
270                result: *result,
271                output_index: *output_index,
272            },
273            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
274            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
275                StreamMessage::CallFunction(
276                    params.clone(),
277                    device_meshes.clone(),
278                    remote_process_groups.clone(),
279                )
280            }
281            StreamMessage::BorrowCreate {
282                borrow,
283                tensor,
284                first_use_sender,
285            } => StreamMessage::BorrowCreate {
286                borrow: *borrow,
287                tensor: *tensor,
288                first_use_sender: first_use_sender.clone(),
289            },
290            StreamMessage::BorrowFirstUse {
291                borrow,
292                result,
293                first_use_receiver,
294            } => StreamMessage::BorrowFirstUse {
295                borrow: *borrow,
296                result: *result,
297                first_use_receiver: first_use_receiver.clone(),
298            },
299            StreamMessage::BorrowLastUse {
300                borrow,
301                result,
302                last_use_sender,
303            } => StreamMessage::BorrowLastUse {
304                borrow: *borrow,
305                result: *result,
306                last_use_sender: last_use_sender.clone(),
307            },
308            StreamMessage::BorrowDrop {
309                borrow,
310                last_use_receiver,
311            } => StreamMessage::BorrowDrop {
312                borrow: *borrow,
313                last_use_receiver: last_use_receiver.clone(),
314            },
315            StreamMessage::Reduce {
316                comm,
317                dim_size,
318                result,
319                local_tensor,
320                factory,
321                reduction,
322                scatter,
323                in_place,
324                out,
325            } => StreamMessage::Reduce {
326                comm: comm.clone(),
327                dim_size: *dim_size,
328                result: *result,
329                local_tensor: *local_tensor,
330                factory: factory.clone(),
331                reduction: reduction.clone(),
332                scatter: *scatter,
333                in_place: *in_place,
334                out: out.clone(),
335            },
336            StreamMessage::SendTensor {
337                result,
338                from_rank,
339                to_rank,
340                tensor,
341                factory,
342                comm,
343            } => StreamMessage::SendTensor {
344                result: *result,
345                from_rank: *from_rank,
346                to_rank: *to_rank,
347                tensor: *tensor,
348                factory: factory.clone(),
349                comm: comm.clone(),
350            },
351            other => panic!(
352                "StreamMessage variant not supported in recording: {:?}",
353                other
354            ),
355        }
356    }
357
358    // Get the set of refs that this message defines.
359    fn get_defined_refs(&self) -> HashSet<Ref> {
360        match self {
361            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
362            StreamMessage::CallFunction(params, ..) => {
363                params.results.iter().filter_map(|&ref_| ref_).collect()
364            }
365            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
366            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
367            StreamMessage::SendTensor {
368                result, from_rank, ..
369            } => {
370                if from_rank.is_some() {
371                    HashSet::from([*result])
372                } else {
373                    HashSet::new()
374                }
375            }
376            // TODO(slurye): Add SendValue eventually.
377            _ => HashSet::new(),
378        }
379    }
380
381    // Get the set of refs that this message mutates.
382    fn get_mutated_refs(&self) -> HashSet<Ref> {
383        match self {
384            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
385            StreamMessage::Reduce {
386                out,
387                in_place,
388                local_tensor,
389                ..
390            } => {
391                if *in_place {
392                    HashSet::from([*local_tensor])
393                } else if let Some(out) = out {
394                    HashSet::from([*out])
395                } else {
396                    HashSet::new()
397                }
398            }
399            // TODO(slurye): Add SendValue eventually.
400            _ => HashSet::new(),
401        }
402    }
403}
404
405/// A stream represents a linear sequence of execution. Operations on different
406/// streams can execute concurrently.
407///
408/// For CUDA operators, streams will invoke the corresponding stream management
409/// APIs to perform synchronization.
410///
411/// For CPU operators, streams will just execute synchronously on their own OS
412/// thread.
413#[derive(Debug)]
414pub struct StreamActor {
415    _world_size: usize,
416    rank: usize,
417    /// Mapping of refs in the controller environment to TensorIndex in this
418    /// stream's local environment.
419    // TODO(agallagher): Use `ValueError` as the error type.
420    env: HashMap<Ref, Result<Py<PyAny>, Arc<SeqError>>>,
421    /// How to create the stream.
422    creation_mode: StreamCreationMode,
423    /// CUDA stream that this actor will enqueue operations on. None if "device"
424    /// is not a CUDA device.
425    /// NOTE: We lazily create the stream, so that we do it from the dedicated
426    /// Stream OS thread as, otherwise, we see deadlocks when done from
427    /// unexpected threads.
428    cuda_stream: OnceLock<Option<Stream>>,
429    /// Device this stream should be scheduled on.
430    device: Option<CudaDevice>,
431    /// Communicator for this stream. Optional as we lazily initialize it.
432    comm: Option<ActorHandle<NcclCommActor>>,
433    /// Actor ref of the controller that created this stream.
434    controller_actor: reference::ActorRef<ControllerActor>,
435    remote_process_groups: HashMap<Ref, Py<PyAny>>,
436    recordings: HashMap<Ref, Recording>,
437    active_recording: Option<RecordingState>,
438    respond_with_python_message: bool,
439    last_seq_error: Option<Arc<SeqError>>,
440}
441
442/// Parameters for creating a [`Stream`].
443#[derive(Debug, Clone)]
444pub struct StreamParams {
445    pub world_size: usize,
446    pub rank: usize,
447    /// Controls how the underlying CUDA stream is created.
448    pub creation_mode: StreamCreationMode,
449    /// Id of this stream in the worker actor's stream table.
450    pub id: StreamRef,
451    /// Device this stream should be scheduled on. If none, don't do stream
452    /// synchronization.
453    pub device: Option<CudaDevice>,
454    /// Actor ref of the controller that created this stream.
455    pub controller_actor: reference::ActorRef<ControllerActor>,
456    pub respond_with_python_message: bool,
457}
458
459impl StreamActor {
460    pub fn new(
461        StreamParams {
462            world_size,
463            rank,
464            id: _,
465            device,
466            controller_actor,
467            creation_mode,
468            respond_with_python_message,
469        }: StreamParams,
470    ) -> Self {
471        Self {
472            _world_size: world_size,
473            rank,
474            env: HashMap::new(),
475            creation_mode,
476            cuda_stream: OnceLock::new(),
477            device,
478            comm: None,
479            controller_actor,
480            remote_process_groups: HashMap::new(),
481            recordings: HashMap::new(),
482            active_recording: None,
483            respond_with_python_message,
484            last_seq_error: None,
485        }
486    }
487}
488
489#[async_trait]
490impl Actor for StreamActor {
491    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
492        // These thread locals are exposed via python functions, so we need to set them in the
493        // same thread that python will run in. That means we need to initialize them here in
494        // StreamActor::init instead of in StreamActor::new.
495        CONTROLLER_ACTOR_REF.with(
496            |controller_actor_ref: &OnceCell<reference::ActorRef<ControllerActor>>| {
497                controller_actor_ref.set(self.controller_actor.clone()).ok()
498            },
499        );
500        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
501        ROOT_ACTOR_ID.with(|root_actor_id: &OnceCell<reference::ActorId>| {
502            let root_label = cx
503                .self_addr()
504                .label()
505                .cloned()
506                .unwrap_or_else(|| Label::new("stream").unwrap());
507            root_actor_id
508                .set(reference::ActorId::singleton(
509                    root_label,
510                    cx.self_addr().proc_addr().id().clone(),
511                ))
512                .ok()
513        });
514        // Set the current stream for this actor thread.
515        if let Some(stream) = self.cuda_stream() {
516            Stream::set_current_stream(stream);
517        }
518        Ok(())
519    }
520
521    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
522    /// dedicated OS thread. This is because:
523    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
524    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
525    ///     threads is problematic.
526    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
527    where
528        F: Future + Send + 'static,
529        F::Output: Send + 'static,
530    {
531        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
532        // It is important that we spawn a standalone thread for the work here,
533        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
534        // This is because the worker stream may call uninterruptible FFI code
535        // that can deadlock (CUDA, NCCL).
536        // If we use a tokio-managed blocking thread, then runtime teardown will
537        // try to wait for tasks on that thread to reach an await point, and
538        // hang forever.
539        let builder = std::thread::Builder::new().name("worker-stream".to_string());
540        let _thread_handle = builder.spawn(move || {
541            // Spawn a new thread with a single-threaded tokio runtime to run the
542            // actor loop.  We avoid the current-threaded runtime, so that we can
543            // use `block_in_place` for nested async-to-sync-to-async flows.
544            let rt = tokio::runtime::Builder::new_multi_thread()
545                .worker_threads(1)
546                .enable_all()
547                .build()
548                .unwrap();
549            let result = rt.block_on(async {
550                tokio::task::block_in_place(|| {
551                    // Allow e.g. destructing py objects on this thread, which
552                    // can happen at shutdown when the a stream actors env map
553                    // for rvalues is dropped (e.g. P1673311499).
554                    // https://github.com/PyO3/pyo3/discussions/3499
555                    Python::attach(|py| {
556                        py.detach(|| {
557                            let result = Handle::current().block_on(future);
558                            if join_tx.send(result).is_err() {
559                                panic!("could not send join result")
560                            }
561                        })
562                    })
563                })
564            });
565            rt.shutdown_timeout(Duration::from_weeks(1));
566            result
567        });
568
569        // In order to bridge the synchronous join handle with the async world,
570        // smuggle the result through a channel.
571        tokio::spawn(async move { join_rx.await.unwrap() })
572    }
573}
574
575/// The arguments we accept as inputs to Python function calls.
576#[derive(Debug)]
577enum PyArg {
578    Object(Py<PyAny>),
579}
580
581/// Serialize into a `Py<PyAny>`.
582impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
583    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
584        match self {
585            PyArg::Object(obj) => Ok(obj.clone_ref(py).into_bound(py)),
586        }
587    }
588}
589
590impl StreamActor {
591    fn tensor_to_pyobject(tensor_cell: TensorCell) -> Py<PyAny> {
592        Python::attach(|py| {
593            // SAFETY: Cloning a tensor was unsafe because we were tracking their references like
594            // Rust objects (single mutable reference or many immutable references). We are
595            // removing this functionality in upcoming patches, so we use the unsafe version here
596            // until that happens.
597            let tensor = unsafe {
598                // Get the owned tensor by calling clone_unsafe on the reference
599                tensor_cell.get_unchecked().clone_unsafe()
600            };
601            tensor.into_pyobject(py).unwrap().unbind()
602        })
603    }
604
605    /// Extract a TensorCell from a Py<PyAny>.
606    /// SAFETY: Uses new to create the TensorCell. Caller must ensure the Py<PyAny>
607    /// contains a valid tensor.
608    fn pyobject_to_tensor(py: Python<'_>, pyobj: &Py<PyAny>) -> PyResult<TensorCell> {
609        use torch_sys2::Tensor;
610        let tensor = pyobj.bind(py).extract::<Tensor>()?;
611        // Create a new TensorCell from the extracted tensor
612        Ok(TensorCell::new(tensor))
613    }
614
615    fn cuda_stream(&self) -> Option<&Stream> {
616        self.cuda_stream
617            .get_or_init(|| {
618                self.device.map(|device| match self.creation_mode {
619                    StreamCreationMode::UseDefaultStream => {
620                        Stream::get_current_stream_on_device(device)
621                    }
622                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
623                })
624            })
625            .as_ref()
626    }
627
628    fn ref_to_pyobject(&self, ref_: &Ref) -> Result<Py<PyAny>, CallFunctionError> {
629        let pyobject = self
630            .env
631            .get(ref_)
632            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
633        match pyobject {
634            Ok(val) => Ok(val.clone()),
635            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
636        }
637    }
638
639    async fn report_seq_error(
640        &mut self,
641        cx: &Context<'_, Self>,
642        seq: Seq,
643        error: CallFunctionError,
644    ) -> Result<Arc<SeqError>, anyhow::Error> {
645        match error {
646            CallFunctionError::DependentError(root) => Ok(root),
647            CallFunctionError::Error(e) => {
648                if self.active_recording.is_none() {
649                    let worker_error = WorkerError {
650                        backtrace: format!("{e}"),
651                        worker_actor_id: cx.self_addr().clone(),
652                    };
653                    tracing::info!("Propagating remote function error to client: {worker_error}");
654                    self.controller_actor
655                        .remote_function_failed(cx, seq, worker_error)
656                        .await?
657                }
658                let err = Arc::new(SeqError { seq, error: e });
659                self.last_seq_error = Some(err.clone());
660                Ok(err)
661            }
662        }
663    }
664
665    async fn try_define<F>(
666        &mut self,
667        cx: &Context<'_, Self>,
668        seq: Seq,
669        result_refs: Vec<Option<Ref>>,
670        mutates: &Vec<Ref>,
671        f: F,
672    ) -> Result<()>
673    where
674        F: AsyncFnOnce(&mut Self) -> Result<Vec<Py<PyAny>>, CallFunctionError>,
675    {
676        let actual_results = f(self).await;
677        // Check if the expected number of returns is correct, otherwise convert
678        // into an error.
679        let op_results = actual_results.and_then(|actual_results| {
680            if result_refs.len() == actual_results.len() {
681                Ok(actual_results
682                    .into_iter()
683                    .zip(result_refs.iter())
684                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
685                    .collect::<Vec<(Ref, Py<PyAny>)>>())
686            } else {
687                Err(CallFunctionError::UnexpectedNumberOfReturns(
688                    result_refs.len(),
689                    actual_results.len(),
690                ))
691            }
692        });
693
694        // Propagate the results (either the actual values or an error) to the
695        // right entries in the global env mapping.
696        match op_results {
697            Ok(op_results) => {
698                for (ref_, pyobject) in op_results.into_iter() {
699                    let prev = self.env.insert(ref_, Ok(pyobject));
700                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
701                }
702            }
703            Err(err) => {
704                let err = self.report_seq_error(cx, seq, err).await?;
705                for ref_ in result_refs {
706                    match ref_ {
707                        Some(ref_) => {
708                            let prev = self.env.insert(ref_, Err(err.clone()));
709                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
710                        }
711                        None => {}
712                    }
713                }
714                for ref_ in mutates {
715                    self.env.insert(*ref_, Err(err.clone()));
716                }
717            }
718        }
719        Ok(())
720    }
721
722    fn call_python_fn<'py>(
723        &mut self,
724        py: Python<'py>,
725        _cx: &Context<Self>,
726        function: Option<ResolvableFunction>,
727        args_kwargs: ArgsKwargs,
728        _mutates: &[Ref],
729        device_meshes: HashMap<Ref, DeviceMesh>,
730        remote_process_groups: HashMap<
731            Ref,
732            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
733        >,
734    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
735        let (args_tuple, kwargs_dict) = args_kwargs
736            .to_python(py)
737            .map_err(|e| CallFunctionError::Error(e.into()))?;
738        let function = function
739            .map(|function| {
740                function.resolve(py).map_err(|e| {
741                    CallFunctionError::InvalidRemoteFunction(format!(
742                        "failed to resolve function {}: {}",
743                        function,
744                        SerializablePyErr::from(py, &e)
745                    ))
746                })
747            })
748            .transpose()?;
749
750        let remote_process_groups = remote_process_groups
751            .into_iter()
752            .map(|(gref, (_mesh, _dims, _comm))| {
753                let group = match self.remote_process_groups.entry(gref) {
754                    Entry::Occupied(ent) => ent.get().clone_ref(py),
755                    Entry::Vacant(_ent) => {
756                        panic!("no longer implemented");
757                    }
758                };
759                PyResult::Ok((gref, group))
760            })
761            .collect::<Result<HashMap<_, _>, _>>()
762            .map_err(SerializablePyErr::from_fn(py))?;
763
764        let resolve = |val: Bound<'py, PyAny>| {
765            val.extract::<PyTree<Py<PyAny>>>()
766                .map_err(SerializablePyErr::from_fn(py))?
767                .try_into_map(|obj| {
768                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
769                        if let Some(mesh) = device_meshes.get(&ref_) {
770                            PyArg::Object(
771                                Py::new(py, mesh.clone())
772                                    .map_err(SerializablePyErr::from_fn(py))?
773                                    .into(),
774                            )
775                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
776                            PyArg::Object(pg.clone_ref(py))
777                        } else {
778                            let pyobj = self.ref_to_pyobject(&ref_)?;
779                            PyArg::Object(pyobj)
780                        }
781                    } else {
782                        PyArg::Object(obj)
783                    })
784                })
785        };
786
787        // Resolve args and kwargs
788        let py_args: Vec<PyTree<PyArg>> = args_tuple
789            .iter()
790            .map(&resolve)
791            .collect::<Result<_, CallFunctionError>>()?;
792
793        let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
794            .iter()
795            .map(|(k, v)| {
796                let key = k
797                    .extract::<String>()
798                    .map_err(SerializablePyErr::from_fn(py))?;
799                let value = resolve(v)?;
800                Ok((key, value))
801            })
802            .collect::<Result<_, CallFunctionError>>()?;
803
804        // Call function.
805        // Use custom subscriber to route Worker messages to stdout.
806        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
807        let result: Bound<'_, PyAny> =
808            tracing::subscriber::with_default(scoped_subscriber, || {
809                // TODO(agallagher): The args/kwargs conversion traits generate
810                // the appropriate types here, but they get casted to `PyAny`.
811                // It'd be nice to make `TryToPy<PyAny>Unsafe` take a template
812                // arg for the converted py object to avoid this downcast.
813                // SAFETY: Tensor operations were unsafe because we were tracking their references
814                // like Rust objects (single mutable reference or many immutable references). We are
815                // removing this functionality in upcoming patches, so we use the unsafe version here
816                // until that happens.
817                let args = unsafe { py_args.try_to_object_unsafe(py) }
818                    .map_err(SerializablePyErr::from_fn(py))?;
819                // SAFETY: Same as above - reference tracking functionality is being removed.
820                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
821                    .map_err(SerializablePyErr::from_fn(py))?;
822
823                if let Some(function) = function {
824                    function
825                        .call(args, Some(kwargs))
826                        .map_err(SerializablePyErr::from_fn(py))
827                } else {
828                    Ok(args.get_item(0).unwrap())
829                }
830            })?;
831        Ok(result)
832    }
833
834    fn call_python_fn_pytree(
835        &mut self,
836        cx: &hyperactor::Context<Self>,
837        function: ResolvableFunction,
838        args_kwargs: ArgsKwargs,
839        mutates: &[Ref],
840        device_meshes: HashMap<Ref, DeviceMesh>,
841        remote_process_groups: HashMap<
842            Ref,
843            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
844        >,
845    ) -> Result<PyTree<Py<PyAny>>, CallFunctionError> {
846        Python::attach(|py| {
847            let result = self.call_python_fn(
848                py,
849                cx,
850                Some(function),
851                args_kwargs,
852                mutates,
853                device_meshes,
854                remote_process_groups,
855            )?;
856            Ok(PyTree::<Py<PyAny>>::extract_bound(&result)
857                .map_err(SerializablePyErr::from_fn(py))?)
858        })
859    }
860    /// Retrieve `ref_` or create a fake value with the provided factory if it
861    /// is an error. We use this for collective calls, where even if there was
862    /// an upstream failure, we still have participate in the collective to
863    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
864    /// value here of the correct shape; the controller will have been notified
865    /// of the upstream failure and will know to ignore everything dependent on
866    /// it.
867    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
868        let pyobject = self
869            .env
870            .get(&ref_)
871            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
872
873        match pyobject {
874            Ok(val) => Python::attach(|py| {
875                Self::pyobject_to_tensor(py, val)
876                    .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
877            }),
878            Err(_) => {
879                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
880                Ok(TensorCell::new(t))
881            }
882        }
883    }
884
885    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
886        self.active_recording
887            .as_mut()
888            .and_then(|state| match state {
889                RecordingState::Defining {
890                    recording,
891                    defined_borrows,
892                } => {
893                    match self.recordings.get_mut(recording) {
894                        Some(recording) => Some((recording, defined_borrows)),
895                        // Panic, because this would be a logic error in the program.
896                        None => panic!("recording not found: {:?}", recording),
897                    }
898                }
899                RecordingState::Running => None,
900            })
901    }
902
903    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
904        for ref_ in refs {
905            let rvalue_or_err = self
906                .env
907                .get(ref_)
908                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
909            if let Err(err) = rvalue_or_err {
910                return Ok(Some(err.clone()));
911            }
912        }
913        Ok(None)
914    }
915    async fn send_value_python_message(
916        &mut self,
917        cx: &hyperactor::Context<'_, Self>,
918        seq: Seq,
919        mutates: Vec<Ref>,
920        function: Option<ResolvableFunction>,
921        args_kwargs: ArgsKwargs,
922        device_meshes: HashMap<Ref, DeviceMesh>,
923    ) -> Result<()> {
924        let rank = self.rank;
925        self.try_define(cx, seq, vec![], &vec![], async |self_| {
926            let python_message =
927                Python::attach(|py| -> Result<PythonMessage, CallFunctionError> {
928                    let python_result = tokio::task::block_in_place(|| {
929                        self_.call_python_fn(
930                            py,
931                            cx,
932                            function,
933                            args_kwargs,
934                            &mutates,
935                            device_meshes,
936                            HashMap::new(),
937                        )
938                    })?;
939                    pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
940                })?;
941            let ser = wirevalue::Any::serialize(&python_message).unwrap();
942            self_
943                .controller_actor
944                .fetch_result(cx, seq, Ok(ser))
945                .await?;
946            Ok(vec![])
947        })
948        .await
949    }
950    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
951        let rvalue = self
952            .env
953            .get(&src)
954            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
955        self.env.insert(dest, Python::attach(|_py| rvalue.clone()));
956        Ok(())
957    }
958    async fn call_actor(
959        &mut self,
960        cx: &Context<'_, Self>,
961        params: ActorCallParams,
962    ) -> Result<Py<PyAny>, CallFunctionError> {
963        let local_state: Result<Vec<Py<PyAny>>> = Python::attach(|_py| {
964            params
965                .local_state
966                .into_iter()
967                .map(|elem| {
968                    let pyobj = self.ref_to_pyobject(&elem)?;
969                    Ok(pyobj.into_any())
970                })
971                .collect()
972        });
973
974        let (send, recv) = cx.open_once_port();
975        let state = LocalState {
976            response_port: send,
977            state: local_state?,
978        };
979        let x: u64 = params.seq.into();
980        let message = LocalStateBrokerMessage::Set(x as usize, state);
981
982        let broker = BrokerId::new(params.broker_id).resolve(cx).await;
983        broker.post(cx, message);
984        let result = recv
985            .recv()
986            .await
987            .map_err(|e| CallFunctionError::Error(e.into()))?;
988
989        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
990    }
991}
992
993#[async_trait]
994#[handle(StreamMessage)]
995impl StreamMessageHandler for StreamActor {
996    async fn call_function(
997        &mut self,
998        cx: &Context<Self>,
999        params: CallFunctionParams,
1000        device_meshes: HashMap<Ref, DeviceMesh>,
1001        remote_process_groups: HashMap<
1002            Ref,
1003            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1004        >,
1005    ) -> Result<()> {
1006        if let Some((recording, _)) = self.get_defining_recording() {
1007            recording.messages.push(StreamMessage::CallFunction(
1008                params,
1009                device_meshes,
1010                remote_process_groups,
1011            ));
1012            return Ok(());
1013        }
1014
1015        params.function.panic_if_requested();
1016        self.try_define(
1017            cx,
1018            params.seq,
1019            params.results,
1020            &params.mutates,
1021            async |self| {
1022                tokio::task::block_in_place(|| {
1023                    self.call_python_fn_pytree(
1024                        cx,
1025                        params.function,
1026                        params.args_kwargs,
1027                        &params.mutates,
1028                        device_meshes,
1029                        remote_process_groups,
1030                    )
1031                    .map(|results| results.into_leaves())
1032                })
1033            },
1034        )
1035        .await?;
1036        Ok(())
1037    }
1038
1039    async fn borrow_create(
1040        &mut self,
1041        cx: &Context<Self>,
1042        borrow: u64,
1043        tensor: Ref,
1044        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1045    ) -> Result<()> {
1046        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1047            recording.messages.push(StreamMessage::BorrowCreate {
1048                borrow,
1049                tensor,
1050                first_use_sender,
1051            });
1052            ensure!(
1053                defined_borrows.insert(borrow),
1054                "duplicate borrow create in recording"
1055            );
1056            return Ok(());
1057        }
1058
1059        let pyobj_result = self
1060            .env
1061            .get(&tensor)
1062            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1063
1064        let result = match pyobj_result {
1065            Ok(pyobj) => Python::attach(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1066            Err(e) => Err(e.clone()),
1067        };
1068
1069        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1070        first_use_sender.post(cx, (event, result));
1071        Ok(())
1072    }
1073
1074    async fn borrow_first_use(
1075        &mut self,
1076        _cx: &Context<Self>,
1077        borrow: u64,
1078        result: Ref,
1079        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1080    ) -> Result<()> {
1081        if let Some((recording, _)) = self.get_defining_recording() {
1082            recording.messages.push(StreamMessage::BorrowFirstUse {
1083                borrow,
1084                result,
1085                first_use_receiver: first_use_receiver.clone(),
1086            });
1087            return Ok(());
1088        }
1089
1090        let (first_use_event, cell) =
1091            first_use_receiver
1092                .lock()
1093                .await
1094                .recv()
1095                .await
1096                .map_err(|err| {
1097                    anyhow!(
1098                        "failed receiving first use event for borrow {:?}: {:?}",
1099                        borrow,
1100                        err
1101                    )
1102                })?;
1103
1104        if let Some(stream) = self.cuda_stream() {
1105            stream.wait_event(
1106                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1107            );
1108        }
1109        match cell {
1110            Ok(cell) => {
1111                let pyobj = Self::tensor_to_pyobject(cell);
1112                self.env.insert(result, Ok(pyobj));
1113            }
1114            Err(err) => {
1115                self.env.insert(result, Err(err.clone()));
1116            }
1117        }
1118        Ok(())
1119    }
1120
1121    async fn borrow_last_use(
1122        &mut self,
1123        cx: &Context<Self>,
1124        borrow: u64,
1125        result: Ref,
1126        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1127    ) -> Result<()> {
1128        if let Some((recording, _)) = self.get_defining_recording() {
1129            recording.messages.push(StreamMessage::BorrowLastUse {
1130                borrow,
1131                result,
1132                last_use_sender,
1133            });
1134            return Ok(());
1135        }
1136
1137        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1138        let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1139            "Invalid reference for borrow_last_use: {result:#?}"
1140        ))?;
1141        let tensor = match pyobj_or_err {
1142            Ok(pyobj) => Ok(Python::attach(|py| {
1143                Self::pyobject_to_tensor(py, &pyobj).unwrap()
1144            })),
1145            Err(e) => Err(e),
1146        };
1147
1148        last_use_sender.post(cx, (event, tensor));
1149        Ok(())
1150    }
1151
1152    async fn borrow_drop(
1153        &mut self,
1154        _cx: &Context<Self>,
1155        borrow: u64,
1156        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1157    ) -> Result<()> {
1158        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1159            recording.messages.push(StreamMessage::BorrowDrop {
1160                borrow,
1161                last_use_receiver: last_use_receiver.clone(),
1162            });
1163            ensure!(
1164                defined_borrows.remove(&borrow),
1165                "borrow drop for borrow not defined in recording"
1166            );
1167            return Ok(());
1168        }
1169
1170        // The borrowed cell isn't used directly, but we still want to receive it here
1171        // so that the underlying tensor isn't dropped until after we synchronize the
1172        // CUDA streams.
1173        let (last_use_event, _cell) =
1174            last_use_receiver.lock().await.recv().await.map_err(|err| {
1175                anyhow!(
1176                    "failed receiving last use event for borrow {:?}: {:?}",
1177                    borrow,
1178                    err
1179                )
1180            })?;
1181
1182        if let Some(stream) = self.cuda_stream() {
1183            stream.wait_event(
1184                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1185            );
1186        }
1187        // let the cell drop.
1188        Ok(())
1189    }
1190
1191    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1192        if let Some((recording, _)) = self.get_defining_recording() {
1193            recording.messages.push(StreamMessage::DeleteRefs(refs));
1194            return Ok(());
1195        }
1196
1197        for ref_ in refs.iter() {
1198            self.env.remove(ref_);
1199        }
1200        Ok(())
1201    }
1202
1203    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1204        if self.get_defining_recording().is_some() {
1205            bail!("request_status not allowed in recording");
1206        }
1207
1208        Ok(())
1209    }
1210
1211    async fn init_comm(
1212        &mut self,
1213        _cx: &Context<Self>,
1214        comm: ActorHandle<NcclCommActor>,
1215    ) -> Result<()> {
1216        if self.get_defining_recording().is_some() {
1217            bail!("init_comm not allowed in recording");
1218        }
1219
1220        self.comm = Some(comm);
1221        Ok(())
1222    }
1223
1224    async fn reduce(
1225        &mut self,
1226        cx: &Context<Self>,
1227        comm: Arc<ActorHandle<NcclCommActor>>,
1228        dim_size: i64,
1229        result: Ref,
1230        local_tensor: Ref,
1231        factory: Factory,
1232        reduction: Reduction,
1233        scatter: bool,
1234        in_place: bool,
1235        out: Option<Ref>,
1236    ) -> Result<()> {
1237        if let Some((recording, _)) = self.get_defining_recording() {
1238            recording.messages.push(StreamMessage::Reduce {
1239                comm,
1240                dim_size,
1241                result,
1242                local_tensor,
1243                factory,
1244                reduction,
1245                scatter,
1246                in_place,
1247                out,
1248            });
1249            return Ok(());
1250        }
1251
1252        let stream = self
1253            .cuda_stream()
1254            .expect("reductions not yet supported for non-CUDA workers")
1255            .clone();
1256        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1257        let out_cell = out
1258            .map(|out| self.get_or_fake_on_err(out, &factory))
1259            .transpose()?;
1260        let output_cell = match reduction {
1261            Reduction::Stack => {
1262                if scatter {
1263                    let output_cell = if in_place {
1264                        input_cell.clone()
1265                    } else {
1266                        out_cell.unwrap_or({
1267                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1268                            let cloned = deep_clone(&borrow);
1269                            TensorCell::new(cloned)
1270                        })
1271                    };
1272                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1273                        .await?;
1274                    output_cell
1275                } else {
1276                    ensure!(
1277                        !in_place,
1278                        "in-place, non-scatter not supported for stack reduce"
1279                    );
1280
1281                    let output_cell = out_cell.unwrap_or({
1282                        // In Python, this would be [dim_size, *factory.sizes]
1283                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1284                        let output =
1285                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1286                        TensorCell::new(output)
1287                    });
1288
1289                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1290                        .await?;
1291                    output_cell
1292                }
1293            }
1294            Reduction::ReduceOp(op) => {
1295                if scatter {
1296                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1297
1298                    let output_cell = out_cell.unwrap_or({
1299                        let output = factory_empty(
1300                            &factory.size[1..],
1301                            factory.dtype,
1302                            factory.layout,
1303                            factory.device,
1304                        );
1305                        TensorCell::new(output)
1306                    });
1307                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1308                        .await?;
1309                    output_cell
1310                } else {
1311                    let output_cell = if in_place {
1312                        input_cell.clone()
1313                    } else {
1314                        out_cell.map_or(
1315                            {
1316                                let borrow =
1317                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1318                                let cloned = deep_clone(&borrow);
1319                                Ok(TensorCell::new(cloned))
1320                            },
1321                            |out_cell| -> Result<_, anyhow::Error> {
1322                                let mut out_borrow =
1323                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1324                                let in_borrow =
1325                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1326                                out_borrow.copy_(&in_borrow);
1327                                drop(out_borrow);
1328                                Ok(out_cell)
1329                            },
1330                        )?
1331                    };
1332
1333                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1334                    output_cell
1335                }
1336            }
1337        };
1338
1339        let pyobj = Self::tensor_to_pyobject(output_cell);
1340        self.env.insert(result, Ok(pyobj));
1341        Ok(())
1342    }
1343
1344    async fn send_tensor(
1345        &mut self,
1346        cx: &Context<Self>,
1347        result: Ref,
1348        from_rank: Option<usize>,
1349        to_rank: Option<usize>,
1350        tensor: Ref,
1351        factory: Factory,
1352        comm: Option<Arc<ActorHandle<NcclCommActor>>>,
1353    ) -> Result<()> {
1354        if let Some((recording, _)) = self.get_defining_recording() {
1355            recording.messages.push(StreamMessage::SendTensor {
1356                result,
1357                from_rank,
1358                to_rank,
1359                tensor,
1360                factory,
1361                comm,
1362            });
1363            return Ok(());
1364        }
1365
1366        if to_rank.is_none() && from_rank.is_none() {
1367            bail!("tried to send tensor without a to/from rank");
1368        }
1369
1370        // Value is local, so we do not have to actually send it.
1371        if from_rank == to_rank {
1372            let input_cell: &std::result::Result<Py<PyAny>, Arc<SeqError>> = self
1373                .env
1374                .get(&tensor)
1375                .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1376            let output_cell: Result<Py<PyAny>, Arc<SeqError>> = match input_cell {
1377                Ok(pyobj) => {
1378                    Python::attach(|py| -> Result<Py<PyAny>, Arc<SeqError>> {
1379                        let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1380                        // We create a defensive copy here to prevent mutations on
1381                        // the input tensor from affecting output tensor.
1382                        // Should we copy if input ref == output ref?
1383                        // Should we support copy-on-write to avoid unnecessary copy?
1384                        let borrow = input_tensor.try_borrow().unwrap();
1385                        let cloned = deep_clone(&borrow);
1386                        let cloned_cell = TensorCell::new(cloned);
1387                        Ok(Self::tensor_to_pyobject(cloned_cell))
1388                    })
1389                }
1390                Err(err) => Err(err.clone()),
1391            };
1392            self.env.insert(result, output_cell);
1393            return Ok(());
1394        }
1395
1396        let comm = comm.context("send_tensor requires backend comm")?;
1397
1398        let mut messages = Vec::new();
1399
1400        if let Some(to_rank) = to_rank {
1401            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1402            messages.push(CommMessage::Send(
1403                input_cell,
1404                to_rank.try_into().unwrap(),
1405                self.cuda_stream()
1406                    .expect("tried to send_tensor on non-cuda stream")
1407                    .clone(),
1408                cx.open_once_port().0,
1409            ));
1410        }
1411
1412        if let Some(from_rank) = from_rank {
1413            let output_cell = TensorCell::new(factory_empty(
1414                &factory.size,
1415                factory.dtype,
1416                factory.layout,
1417                factory.device,
1418            ));
1419            messages.push(CommMessage::Recv(
1420                output_cell.clone(),
1421                from_rank.try_into().unwrap(),
1422                self.cuda_stream()
1423                    .expect("tried to send_tensor on non-cuda stream")
1424                    .clone(),
1425                cx.open_once_port().0,
1426            ));
1427            let pyobj = Self::tensor_to_pyobject(output_cell);
1428            self.env.insert(result, Ok(pyobj));
1429        }
1430
1431        comm.group(
1432            cx,
1433            messages,
1434            self.cuda_stream()
1435                .expect("tried to send_tensor on non-cuda stream")
1436                .clone(),
1437        )
1438        .await?;
1439        Ok(())
1440    }
1441
1442    async fn send_value(
1443        &mut self,
1444        cx: &Context<Self>,
1445        seq: Seq,
1446        worker_actor_id: reference::ActorAddr,
1447        mutates: Vec<Ref>,
1448        function: Option<ResolvableFunction>,
1449        args_kwargs: ArgsKwargs,
1450        device_meshes: HashMap<Ref, DeviceMesh>,
1451    ) -> Result<()> {
1452        if self.respond_with_python_message {
1453            return self
1454                .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1455                .await;
1456        }
1457
1458        let result = if let Some(function) = function {
1459            // If a function was provided, use that to resolve the value.
1460            tokio::task::block_in_place(|| {
1461                self.call_python_fn_pytree(
1462                    cx,
1463                    function,
1464                    args_kwargs,
1465                    &mutates,
1466                    device_meshes,
1467                    HashMap::new(),
1468                )
1469            })
1470        } else {
1471            // If there's no function provided, there should be exactly one arg
1472            // and no kwargs.
1473            Python::attach(|py| {
1474                let (args, kwargs) = args_kwargs
1475                    .to_python(py)
1476                    .map_err(|e| CallFunctionError::Error(e.into()))?;
1477                match (args.len(), kwargs.len()) {
1478                    (1, 0) => {
1479                        let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1480                        arg.extract::<PyTree<Py<PyAny>>>()
1481                            .map_err(SerializablePyErr::from_fn(py))?
1482                            .try_into_map(|obj| {
1483                                let bound_obj = obj.bind(py);
1484                                if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1485                                    self.ref_to_pyobject(&ref_)
1486                                } else {
1487                                    Ok(obj)
1488                                }
1489                            })
1490                    }
1491                    _ => Err(CallFunctionError::TooManyArgsForValue(
1492                        format!("args with {} elements", args.len()),
1493                        format!("kwargs with {} elements", kwargs.len()),
1494                    )),
1495                }
1496            })
1497        };
1498
1499        let value = match result {
1500            Ok(pyobject) => Ok(pyobject),
1501            Err(err) => {
1502                let err = self.report_seq_error(cx, seq, err).await?;
1503                for ref_ in mutates {
1504                    self.env.insert(ref_, Err(err.clone()));
1505                }
1506                Err(WorkerError {
1507                    backtrace: format!("{:?}", err),
1508                    worker_actor_id,
1509                })
1510            }
1511        };
1512
1513        // Actually send the value.
1514        // NOTE: respond_with_python_message is always true, so serialization is not needed
1515        // The controller will receive the value through send_value_python_message instead
1516        let result = match value {
1517            Ok(_value) => {
1518                // This code path is never executed since respond_with_python_message is true
1519                unreachable!(
1520                    "send_value should return early when respond_with_python_message is true"
1521                )
1522            }
1523            Err(e) => Err(e),
1524        };
1525        self.controller_actor.fetch_result(cx, seq, result).await?;
1526
1527        Ok(())
1528    }
1529
1530    async fn send_result_of_actor_call(
1531        &mut self,
1532        cx: &Context<Self>,
1533        params: ActorCallParams,
1534    ) -> anyhow::Result<()> {
1535        let seq = params.seq;
1536        let mutates = params.mutates.clone();
1537        self.try_define(cx, seq, vec![], &mutates, async |self| {
1538            let value = self.call_actor(cx, params).await?;
1539            let result =
1540                Python::attach(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1541            let result = wirevalue::Any::serialize(&result).unwrap();
1542            self.controller_actor
1543                .fetch_result(cx, seq, Ok(result))
1544                .await?;
1545            Ok(vec![])
1546        })
1547        .await
1548    }
1549
1550    async fn call_actor_method(
1551        &mut self,
1552        cx: &Context<Self>,
1553        params: ActorMethodParams,
1554    ) -> anyhow::Result<()> {
1555        let seq = params.call.seq;
1556        let mutates = params.call.mutates.clone();
1557        self.try_define(cx, seq, params.results, &mutates, async |self| {
1558            let result = self.call_actor(cx, params.call).await?;
1559            let result = Python::attach(|py| {
1560                PyTree::<Py<PyAny>>::extract_bound(&result.into_bound(py))
1561                    .map_err(SerializablePyErr::from_fn(py))
1562            })?;
1563            Ok(result.into_leaves())
1564        })
1565        .await
1566    }
1567
1568    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1569        if self.active_recording.is_some() {
1570            bail!("different recording already active");
1571        }
1572        match self.recordings.entry(recording) {
1573            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1574            Entry::Vacant(entry) => entry.insert(Recording::new()),
1575        };
1576        self.active_recording = Some(RecordingState::Defining {
1577            recording,
1578            defined_borrows: HashSet::new(),
1579        });
1580        Ok(())
1581    }
1582
1583    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1584        match self.active_recording {
1585            Some(RecordingState::Defining {
1586                recording: active_recording,
1587                ref defined_borrows,
1588            }) if active_recording == recording => {
1589                ensure!(
1590                    defined_borrows.is_empty(),
1591                    "all borrows created within recording must be dropped within recording"
1592                );
1593                self.active_recording = None;
1594            }
1595            _ => bail!("cannot finalize recording that isn't active"),
1596        }
1597        Ok(())
1598    }
1599
1600    async fn recording_formal(
1601        &mut self,
1602        _cx: &Context<Self>,
1603        result: Ref,
1604        argument_index: usize,
1605    ) -> Result<()> {
1606        match self.get_defining_recording() {
1607            Some((recording, _)) => {
1608                recording.messages.push(StreamMessage::RecordingFormal {
1609                    result,
1610                    argument_index,
1611                });
1612            }
1613            None => bail!("recording_formal called outside of recording"),
1614        };
1615        Ok(())
1616    }
1617
1618    async fn recording_result(
1619        &mut self,
1620        _cx: &Context<Self>,
1621        result: Ref,
1622        output_index: usize,
1623    ) -> Result<()> {
1624        match self.get_defining_recording() {
1625            Some((recording, _)) => {
1626                recording.messages.push(StreamMessage::RecordingResult {
1627                    result,
1628                    output_index,
1629                });
1630            }
1631            None => bail!("recording_result called outside of recording"),
1632        };
1633        Ok(())
1634    }
1635
1636    async fn call_recording(
1637        &mut self,
1638        cx: &Context<Self>,
1639        seq: Seq,
1640        recording: Ref,
1641        results: Vec<Ref>,
1642        actuals: Vec<Ref>,
1643    ) -> Result<()> {
1644        if self.active_recording.is_some() {
1645            bail!("cannot call recording while another recording is active");
1646        }
1647
1648        let messages = match self.recordings.get(&recording) {
1649            Some(recording) => recording
1650                .messages
1651                .iter()
1652                .map(|message| message.clone_for_recording())
1653                .collect::<Vec<_>>(),
1654            None => bail!("recording {:?} not found", recording),
1655        };
1656
1657        self.active_recording = Some(RecordingState::Running);
1658
1659        // Global error for all messages in the recording. The first time a message
1660        // fails in the recording, we set the error. We then need to propagate this
1661        // error to all of the refs mutated by the entire recording, as well as the
1662        // result refs.
1663        let mut error: Option<Arc<SeqError>> = None;
1664        // The set of all refs defined by this recording (excluding "results"),
1665        // which we need to ensure are deleted when the recording is done executing.
1666        let mut all_defined_refs = HashSet::new();
1667        // The set of all refs mutated by this recording. If there is an error with
1668        // any message, all of these refs need to have the correct error set.
1669        let mut all_mutated_refs = HashSet::new();
1670        // Map from the result ref of a RecordingFormal message to the associated
1671        // actual ref from "actuals". We need to track this in order to properly
1672        // handle recordings that mutate refs contained in "actuals" -- every
1673        // message in the recording that interacts with the recording inputs will
1674        // interact with the formal ref rather than the actual ref.
1675        let mut formal_to_actual_refs = HashMap::new();
1676        // clear any pre-existing error messages before recording started
1677        self.last_seq_error = None;
1678        for message in messages.into_iter() {
1679            let defined_refs = message.get_defined_refs();
1680            all_defined_refs.extend(defined_refs.clone());
1681
1682            let mutated_refs_with_formals = message.get_mutated_refs();
1683            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1684                match formal_to_actual_refs.get(ref_) {
1685                    Some(actual_ref) => Some(*actual_ref),
1686                    None => {
1687                        if all_defined_refs.contains(ref_) {
1688                            None
1689                        } else {
1690                            Some(*ref_)
1691                        }
1692                    }
1693                }
1694            }));
1695
1696            match message {
1697                StreamMessage::RecordingFormal {
1698                    result: formal_ref,
1699                    argument_index,
1700                } => match actuals.get(argument_index) {
1701                    None => bail!("recording_formal called with too few arguments"),
1702                    Some(actual_ref) => {
1703                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1704                        self.define_ref(formal_ref, *actual_ref)?;
1705                    }
1706                },
1707                StreamMessage::RecordingResult {
1708                    result: result_ref,
1709                    output_index,
1710                } => match results.get(output_index) {
1711                    None => bail!("recording_result called with too few results"),
1712                    Some(actual_result_ref) => {
1713                        self.define_ref(*actual_result_ref, result_ref)?;
1714                    }
1715                },
1716                StreamMessage::DeleteRefs(ref refs) => {
1717                    for ref_ in refs {
1718                        all_defined_refs.remove(ref_);
1719                    }
1720                    StreamMessageHandler::handle(self, cx, message).await?;
1721                }
1722                StreamMessage::CallFunction { .. } if error.is_some() => {
1723                    // CallFunction is expensive. If the recording already failed, then
1724                    // just update the necessary refs with the error. Most of the other
1725                    // message types need to run regardless because there are other actors
1726                    // that expect the call to happen (e.g., all of the borrow messages,
1727                    // pipe send/recv, send_tensor, reduce, etc.).
1728                    let error = error.clone().unwrap();
1729                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1730                        self.env.insert(*ref_, Err(error.clone()));
1731                    }
1732                }
1733                StreamMessage::BorrowLastUse { ref result, .. } => {
1734                    all_defined_refs.remove(result);
1735                    StreamMessageHandler::handle(self, cx, message).await?;
1736                }
1737                StreamMessage::Reduce {
1738                    local_tensor,
1739                    ref out,
1740                    ..
1741                } => {
1742                    // Reduce doesn't propagate errors to the result ref, so we need
1743                    // to check for existing errors on the input tensors and set the
1744                    // recording's error if necessary.
1745                    if error.is_none() {
1746                        let inputs_to_check = [Some(local_tensor), out.clone()]
1747                            .iter()
1748                            .filter_map(|r| *r)
1749                            .collect::<Vec<_>>();
1750                        error = self.get_first_error(inputs_to_check.as_slice())?;
1751                    }
1752                    StreamMessageHandler::handle(self, cx, message).await?;
1753                }
1754                StreamMessage::SendTensor {
1755                    ref tensor,
1756                    ref to_rank,
1757                    ..
1758                } => {
1759                    // If this rank is sending a tensor (e.g., to_rank has a value),
1760                    // we need to check for existing errors on the input tensor, because
1761                    // the error is only propagated to the result ref when this rank
1762                    // is also receiving a tensor.
1763                    if to_rank.is_some() && error.is_none() {
1764                        error = self.get_first_error(&[*tensor])?;
1765                    }
1766                    StreamMessageHandler::handle(self, cx, message).await?;
1767                }
1768                _ => {
1769                    StreamMessageHandler::handle(self, cx, message).await?;
1770                }
1771            };
1772
1773            // It's not entirely trivial to determine whether a message "failed" or not.
1774            // For example, the CallFunction message can return Ok(..) if there is an error
1775            // in the underlying function call. But in that case, we would still want to
1776            // consider the recording call as "failed". Unlike in python, where we can just
1777            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1778            // we clear before handling each recording message. If we see it is set, the
1779            // we know the recording has faild.
1780            match (&error, self.last_seq_error.take()) {
1781                (None, Some(seq_err)) => {
1782                    // Report failure to the controller.
1783                    self.controller_actor
1784                        .remote_function_failed(
1785                            cx,
1786                            seq,
1787                            WorkerError {
1788                                backtrace: format!("recording failed: {}", &seq_err),
1789                                worker_actor_id: cx.self_addr().clone(),
1790                            },
1791                        )
1792                        .await?;
1793                    error = Some(seq_err)
1794                }
1795                _ => {}
1796            }
1797            // Continue processing the remaining stream messages regardless of error.
1798            // We need to do this partially for error propagation, but also because
1799            // certain messages (like borrows and reductions) need to run regardless
1800            // in order to prevent deadlocks.
1801        }
1802
1803        // Delete the formal refs and some subset of the RecordingResult refs. The
1804        // controller should have generated DeleteRefs messages for all other refs
1805        // defined by the recording.
1806        StreamMessageHandler::handle(
1807            self,
1808            cx,
1809            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1810        )
1811        .await?;
1812
1813        // Any refs mutated by the recording and all results should have the same error
1814        // (the original error that caused the recording to fail).
1815        if error.is_some() {
1816            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1817                self.env.insert(*ref_, Err(error.clone().unwrap()));
1818            }
1819        }
1820
1821        self.active_recording = None;
1822        Ok(())
1823    }
1824
1825    async fn set_ref_unit_tests_only(
1826        &mut self,
1827        _cx: &Context<Self>,
1828        reference: Ref,
1829        value: WireValue,
1830    ) -> Result<()> {
1831        let pyobj =
1832            Python::attach(|py| -> PyResult<Py<PyAny>> { Ok(value.into_pyobject(py)?.unbind()) })?;
1833        self.env.insert(reference, Ok(pyobj));
1834        Ok(())
1835    }
1836
1837    async fn set_tensor_ref_unit_tests_only(
1838        &mut self,
1839        _cx: &Context<Self>,
1840        reference: Ref,
1841        tensor_result: TensorCellResult,
1842    ) -> Result<()> {
1843        match tensor_result {
1844            Ok(tensor_cell) => {
1845                let pyobj = Self::tensor_to_pyobject(tensor_cell);
1846                self.env.insert(reference, Ok(pyobj));
1847            }
1848            Err(err) => {
1849                self.env.insert(reference, Err(err));
1850            }
1851        }
1852        Ok(())
1853    }
1854
1855    async fn get_ref_unit_tests_only(
1856        &mut self,
1857        _cx: &Context<Self>,
1858        reference: Ref,
1859    ) -> Result<Option<Result<WireValue, String>>> {
1860        use pyo3::types::PyBool;
1861        use pyo3::types::PyFloat;
1862        use pyo3::types::PyInt;
1863        use pyo3::types::PyList;
1864        use pyo3::types::PyNone;
1865        use pyo3::types::PyString;
1866        /// For testing only, doesn't support Tensor or TensorList.
1867        fn pyobject_to_wire(
1868            value: Result<Py<PyAny>, Arc<SeqError>>,
1869        ) -> Result<WireValue, Arc<SeqError>> {
1870            let pyobj = value?;
1871            Python::attach(|py| {
1872                let bound = pyobj.bind(py);
1873                // Check bool before int since Python's bool is a subclass of int
1874                if bound.is_instance_of::<PyBool>() {
1875                    Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1876                } else if bound.is_instance_of::<PyInt>() {
1877                    Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1878                } else if bound.is_instance_of::<PyList>() {
1879                    if let Ok(val) = bound.extract::<Vec<i64>>() {
1880                        Ok(WireValue::IntList(val))
1881                    } else {
1882                        Ok(WireValue::String(format!(
1883                            "unsupported list type: {:?}",
1884                            bound
1885                        )))
1886                    }
1887                } else if bound.is_instance_of::<PyFloat>() {
1888                    Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1889                } else if bound.is_instance_of::<PyString>() {
1890                    Ok(WireValue::String(bound.extract::<String>().unwrap()))
1891                } else if bound.is_instance_of::<PyNone>() {
1892                    Ok(WireValue::None(()))
1893                } else {
1894                    Ok(WireValue::String(format!(
1895                        "unsupported pyobject type: {:?}",
1896                        bound
1897                    )))
1898                }
1899            })
1900        }
1901        Ok(self.env.get(&reference).map(|pyobj| {
1902            pyobject_to_wire(Python::attach(|_py| pyobj.clone())).map_err(|err| err.to_string())
1903        }))
1904    }
1905
1906    async fn get_tensor_ref_unit_tests_only(
1907        &mut self,
1908        _cx: &Context<Self>,
1909        reference: Ref,
1910    ) -> Result<Option<TensorCellResult>> {
1911        match self.env.get(&reference) {
1912            Some(Ok(pyobj)) => Python::attach(|py| match Self::pyobject_to_tensor(py, pyobj) {
1913                Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1914                Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1915            }),
1916            Some(Err(err)) => Ok(Some(Err(err.clone()))),
1917            None => Ok(None),
1918        }
1919    }
1920}
1921
1922#[cfg(all(test, fbcode_build))]
1923mod tests {
1924    use hyperactor::actor::ActorStatus;
1925    use hyperactor::context;
1926    use hyperactor::supervision::ActorSupervisionEvent;
1927    use monarch_messages::controller::ControllerMessage;
1928    use monarch_messages::worker::StreamCreationMode;
1929    use monarch_types::PickledPyObject;
1930    use monarch_types::UniqueId;
1931    use pyo3::IntoPyObjectExt;
1932    use timed_test::async_timed_test;
1933    use tokio::sync::watch;
1934    use torch_sys_cuda::nccl::UniqueIdExt;
1935    use torch_sys2::factory_float_tensor;
1936    use torch_sys2::testing::allclose;
1937
1938    use super::*;
1939    use crate::comm::CommParams;
1940    use crate::test_util;
1941
1942    #[allow(dead_code)]
1943    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1944        Arc::new(SeqError {
1945            seq: 0.into(),
1946            error: err,
1947        })
1948    }
1949
1950    struct TestSetup {
1951        proc: Proc,
1952        stream_actor: ActorHandle<StreamActor>,
1953        client: Instance<()>,
1954        // Unused, but necessary, because proc needs a supervision
1955        // port -- otherwise an actor failure will cause a crash.
1956        #[allow(dead_code)]
1957        supervision_rx: PortReceiver<ActorSupervisionEvent>,
1958        #[allow(dead_code)]
1959        controller_rx: PortReceiver<ControllerMessage>,
1960        #[allow(dead_code)]
1961        controller_actor: reference::ActorRef<ControllerActor>,
1962        next_ref: Ref,
1963    }
1964
1965    impl TestSetup {
1966        async fn new() -> Result<Self> {
1967            Self::new_with_world_size(1).await
1968        }
1969
1970        async fn new_with_world_size(world_size: usize) -> Result<Self> {
1971            test_util::test_setup()?;
1972
1973            let proc = Proc::isolated();
1974            let (_, controller_actor, controller_rx) =
1975                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1976            let (client, _handle) = proc.client("client")?;
1977            let (supervision_tx, supervision_rx) = client.open_port();
1978            proc.set_supervision_coordinator(supervision_tx)?;
1979            let stream_actor = proc.spawn(
1980                "stream",
1981                StreamActor::new(StreamParams {
1982                    world_size,
1983                    rank: 0,
1984                    creation_mode: StreamCreationMode::UseDefaultStream,
1985                    id: 0.into(),
1986                    device: Some(CudaDevice::new(0.into())),
1987                    controller_actor: controller_actor.clone(),
1988                    respond_with_python_message: false,
1989                }),
1990            )?;
1991
1992            Ok(Self {
1993                proc,
1994                stream_actor,
1995                client,
1996                supervision_rx,
1997                controller_rx,
1998                controller_actor,
1999                next_ref: 0.into(),
2000            })
2001        }
2002
2003        fn next_ref(&mut self) -> Ref {
2004            let ref_ = self.next_ref;
2005            self.next_ref = Ref {
2006                id: self.next_ref.id + 1,
2007            };
2008            ref_
2009        }
2010
2011        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2012            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2013            self.stream_actor
2014                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2015                .await
2016        }
2017
2018        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2019            let actual = self
2020                .stream_actor
2021                .get_tensor_ref_unit_tests_only(&self.client, reference)
2022                .await
2023                .unwrap()
2024                .unwrap()
2025                .unwrap();
2026
2027            // rustfmt-ignore
2028            allclose(
2029                &factory_float_tensor(data, "cpu".parse().unwrap()),
2030                &actual.borrow(),
2031            )
2032            .unwrap()
2033        }
2034
2035        #[allow(dead_code)]
2036        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2037            let result_error = self
2038                .stream_actor
2039                .get_tensor_ref_unit_tests_only(&self.client, reference)
2040                .await
2041                .unwrap()
2042                .unwrap()
2043                .unwrap_err();
2044
2045            assert!(Arc::ptr_eq(&result_error, &error));
2046        }
2047    }
2048
2049    async fn assert_actor_failed_with_msg(
2050        status_rx: &mut watch::Receiver<ActorStatus>,
2051        expected_msg: String,
2052    ) {
2053        status_rx
2054            .wait_for(|s| matches!(s, ActorStatus::Failed(_)))
2055            .await
2056            .unwrap();
2057        let status = status_rx.borrow().clone();
2058        if let ActorStatus::Failed(msg) = status {
2059            assert!(msg.to_string().contains(&expected_msg));
2060        } else {
2061            panic!("expected ActorStatus::Failed, got {:?}", status);
2062        }
2063    }
2064
2065    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2066        for ref_ in refs {
2067            assert!(
2068                test_setup
2069                    .stream_actor
2070                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2071                    .await
2072                    .unwrap()
2073                    .is_none()
2074            );
2075        }
2076    }
2077
2078    #[allow(dead_code)]
2079    async fn fetch_result(
2080        cx: &impl context::Actor,
2081        stream_actor: ActorHandle<StreamActor>,
2082        seq: Seq,
2083        reference: Ref,
2084    ) {
2085        let ref_to_send = Python::attach(|py| {
2086            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2087        });
2088
2089        stream_actor
2090            .send_value(
2091                cx,
2092                seq,
2093                stream_actor.actor_addr().clone(),
2094                Vec::new(),
2095                None,
2096                ArgsKwargs::from_wire_values(
2097                    vec![WireValue::PyObject(ref_to_send)],
2098                    HashMap::new(),
2099                )
2100                .unwrap(),
2101                HashMap::new(),
2102            )
2103            .await
2104            .unwrap()
2105    }
2106
2107    #[allow(dead_code)]
2108    async fn check_fetch_result_error(
2109        cx: &impl context::Actor,
2110        stream_actor: ActorHandle<StreamActor>,
2111        seq: Seq,
2112        reference: Ref,
2113        controller_rx: &mut PortReceiver<ControllerMessage>,
2114        expected_backtrace: &str,
2115    ) {
2116        fetch_result(cx, stream_actor, seq, reference).await;
2117
2118        let controller_msg = controller_rx.recv().await.unwrap();
2119        match controller_msg {
2120            ControllerMessage::FetchResult {
2121                seq: actual_seq,
2122                value: Err(err),
2123            } => {
2124                assert_eq!(actual_seq, seq);
2125                assert!(
2126                    err.backtrace.contains(expected_backtrace),
2127                    "backtrace did not contain {:?}: {:?}",
2128                    expected_backtrace,
2129                    err.backtrace
2130                );
2131            }
2132            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2133        };
2134    }
2135
2136    #[allow(dead_code)]
2137    async fn check_fetch_result_value(
2138        cx: &impl context::Actor,
2139        stream_actor: ActorHandle<StreamActor>,
2140        seq: Seq,
2141        reference: Ref,
2142        controller_rx: &mut PortReceiver<ControllerMessage>,
2143    ) {
2144        fetch_result(cx, stream_actor, seq, reference).await;
2145
2146        let controller_msg = controller_rx.recv().await.unwrap();
2147        match controller_msg {
2148            ControllerMessage::FetchResult {
2149                value: Ok(_),
2150                seq: actual_seq,
2151            } => assert_eq!(seq, actual_seq),
2152            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2153        };
2154    }
2155
2156    #[async_timed_test(timeout_secs = 60)]
2157    async fn test_define_recording_other_recording_active() -> Result<()> {
2158        let test_setup = TestSetup::new().await?;
2159        test_setup
2160            .stream_actor
2161            .define_recording(&test_setup.client, 0.into())
2162            .await?;
2163        test_setup
2164            .stream_actor
2165            .define_recording(&test_setup.client, 1.into())
2166            .await?;
2167        assert_actor_failed_with_msg(
2168            &mut test_setup.stream_actor.status(),
2169            "different recording already active".into(),
2170        )
2171        .await;
2172        Ok(())
2173    }
2174
2175    #[async_timed_test(timeout_secs = 60)]
2176    async fn test_define_recording_already_defined() -> Result<()> {
2177        let test_setup = TestSetup::new().await?;
2178        test_setup
2179            .stream_actor
2180            .define_recording(&test_setup.client, 0.into())
2181            .await?;
2182        test_setup
2183            .stream_actor
2184            .finalize_recording(&test_setup.client, 0.into())
2185            .await?;
2186        test_setup
2187            .stream_actor
2188            .define_recording(&test_setup.client, 0.into())
2189            .await?;
2190        assert_actor_failed_with_msg(
2191            &mut test_setup.stream_actor.status(),
2192            "already defined".into(),
2193        )
2194        .await;
2195        Ok(())
2196    }
2197
2198    #[async_timed_test(timeout_secs = 60)]
2199    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2200        let test_setup = TestSetup::new().await?;
2201        test_setup
2202            .stream_actor
2203            .define_recording(&test_setup.client, 0.into())
2204            .await?;
2205        test_setup
2206            .stream_actor
2207            .finalize_recording(&test_setup.client, 1.into())
2208            .await?;
2209        assert_actor_failed_with_msg(
2210            &mut test_setup.stream_actor.status(),
2211            "cannot finalize recording that isn't active".into(),
2212        )
2213        .await;
2214        Ok(())
2215    }
2216
2217    #[async_timed_test(timeout_secs = 60)]
2218    async fn test_recording_formal_outside_recording() -> Result<()> {
2219        let test_setup = TestSetup::new().await?;
2220        test_setup
2221            .stream_actor
2222            .recording_formal(&test_setup.client, 0.into(), 0)
2223            .await?;
2224        assert_actor_failed_with_msg(
2225            &mut test_setup.stream_actor.status(),
2226            "recording_formal called outside of recording".into(),
2227        )
2228        .await;
2229        Ok(())
2230    }
2231
2232    #[async_timed_test(timeout_secs = 60)]
2233    async fn test_recording_result_outside_recording() -> Result<()> {
2234        let test_setup = TestSetup::new().await?;
2235        test_setup
2236            .stream_actor
2237            .recording_result(&test_setup.client, 0.into(), 0)
2238            .await?;
2239        assert_actor_failed_with_msg(
2240            &mut test_setup.stream_actor.status(),
2241            "recording_result called outside of recording".into(),
2242        )
2243        .await;
2244        Ok(())
2245    }
2246
2247    #[async_timed_test(timeout_secs = 60)]
2248    async fn test_call_recording_other_recording_active() -> Result<()> {
2249        let test_setup = TestSetup::new().await?;
2250        test_setup
2251            .stream_actor
2252            .define_recording(&test_setup.client, 0.into())
2253            .await?;
2254        test_setup
2255            .stream_actor
2256            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2257            .await?;
2258        assert_actor_failed_with_msg(
2259            &mut test_setup.stream_actor.status(),
2260            "cannot call recording while another recording is active".into(),
2261        )
2262        .await;
2263        Ok(())
2264    }
2265
2266    #[async_timed_test(timeout_secs = 60)]
2267    async fn test_call_recording_not_found() -> Result<()> {
2268        let test_setup = TestSetup::new().await?;
2269        test_setup
2270            .stream_actor
2271            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2272            .await?;
2273        assert_actor_failed_with_msg(&mut test_setup.stream_actor.status(), "not found".into())
2274            .await;
2275        Ok(())
2276    }
2277
2278    #[async_timed_test(timeout_secs = 60)]
2279    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2280        let test_setup = TestSetup::new().await?;
2281
2282        test_setup
2283            .stream_actor
2284            .define_recording(&test_setup.client, 0.into())
2285            .await?;
2286
2287        test_setup
2288            .stream_actor
2289            .recording_formal(&test_setup.client, 1.into(), 0)
2290            .await?;
2291
2292        test_setup
2293            .stream_actor
2294            .finalize_recording(&test_setup.client, 0.into())
2295            .await?;
2296
2297        test_setup
2298            .stream_actor
2299            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2300            .await?;
2301
2302        assert_actor_failed_with_msg(
2303            &mut test_setup.stream_actor.status(),
2304            "recording_formal called with too few arguments".into(),
2305        )
2306        .await;
2307        Ok(())
2308    }
2309
2310    #[async_timed_test(timeout_secs = 60)]
2311    async fn test_recording_result_too_few_results() -> Result<()> {
2312        let test_setup = TestSetup::new().await?;
2313
2314        test_setup
2315            .stream_actor
2316            .define_recording(&test_setup.client, 0.into())
2317            .await?;
2318
2319        test_setup
2320            .stream_actor
2321            .recording_result(&test_setup.client, 1.into(), 0)
2322            .await?;
2323
2324        test_setup
2325            .stream_actor
2326            .finalize_recording(&test_setup.client, 0.into())
2327            .await?;
2328
2329        test_setup
2330            .stream_actor
2331            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2332            .await?;
2333
2334        assert_actor_failed_with_msg(
2335            &mut test_setup.stream_actor.status(),
2336            "recording_result called with too few results".into(),
2337        )
2338        .await;
2339        Ok(())
2340    }
2341
2342    #[async_timed_test(timeout_secs = 60)]
2343    async fn test_basic_call_recording() -> Result<()> {
2344        let mut test_setup = TestSetup::new().await?;
2345
2346        // Define a recording equivalent to:
2347        // def f(x, y):
2348        //   return y, x
2349        test_setup
2350            .stream_actor
2351            .define_recording(&test_setup.client, 0.into())
2352            .await?;
2353
2354        let formal0_ref = 1.into();
2355        let formal0_index = 1;
2356        test_setup
2357            .stream_actor
2358            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2359            .await?;
2360
2361        let formal1_ref = 2.into();
2362        let formal1_index = 0;
2363        test_setup
2364            .stream_actor
2365            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2366            .await?;
2367
2368        let result0_ref = formal0_ref;
2369        let result0_index = 0;
2370        test_setup
2371            .stream_actor
2372            .recording_result(&test_setup.client, result0_ref, result0_index)
2373            .await?;
2374
2375        let result1_ref = formal1_ref;
2376        let result1_index = 1;
2377        test_setup
2378            .stream_actor
2379            .recording_result(&test_setup.client, result1_ref, result1_index)
2380            .await?;
2381
2382        test_setup
2383            .stream_actor
2384            .finalize_recording(&test_setup.client, 0.into())
2385            .await?;
2386
2387        let actual0_ref = 3.into();
2388        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2389
2390        let actual1_ref = 4.into();
2391        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2392
2393        // Call the recording with valid tensors for the actual inputs,
2394        // and store the results in refs 5 and 6.
2395        let actual_result0_ref = 5.into();
2396        let actual_result1_ref = 6.into();
2397        test_setup
2398            .stream_actor
2399            .call_recording(
2400                &test_setup.client,
2401                0.into(),
2402                0.into(),
2403                vec![actual_result0_ref, actual_result1_ref],
2404                vec![actual0_ref, actual1_ref],
2405            )
2406            .await?;
2407
2408        // Ensure the results are correct.
2409        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2410        assert!(
2411            test_setup
2412                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2413                .await
2414        );
2415
2416        // Ensure the temporary refs associated with the formals/results have
2417        // been deleted.
2418        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2419        Ok(())
2420    }
2421
2422    #[async_timed_test(timeout_secs = 60)]
2423    async fn test_request_status_in_recording() -> Result<()> {
2424        let test_setup = TestSetup::new().await?;
2425        test_setup
2426            .stream_actor
2427            .define_recording(&test_setup.client, 0.into())
2428            .await?;
2429        test_setup
2430            .stream_actor
2431            .request_status(&test_setup.client)
2432            .await
2433            .expect_err("request_status should have failed");
2434        assert_actor_failed_with_msg(
2435            &mut test_setup.stream_actor.status(),
2436            "request_status not allowed in recording".into(),
2437        )
2438        .await;
2439        Ok(())
2440    }
2441
2442    #[async_timed_test(timeout_secs = 60)]
2443    async fn test_init_comm_in_recording() -> Result<()> {
2444        let test_setup = TestSetup::new().await?;
2445        test_setup
2446            .stream_actor
2447            .define_recording(&test_setup.client, 0.into())
2448            .await?;
2449
2450        let dummy_comm = test_setup.proc.spawn(
2451            "comm",
2452            NcclCommActor::new(CommParams::New {
2453                device: CudaDevice::new(0.into()),
2454                unique_id: UniqueId::new_nccl()?,
2455                world_size: 1,
2456                rank: 0,
2457            })
2458            .await
2459            .unwrap(),
2460        )?;
2461
2462        test_setup
2463            .stream_actor
2464            .init_comm(&test_setup.client, dummy_comm)
2465            .await?;
2466        assert_actor_failed_with_msg(
2467            &mut test_setup.stream_actor.status(),
2468            "init_comm not allowed in recording".into(),
2469        )
2470        .await;
2471        Ok(())
2472    }
2473
2474    #[async_timed_test(timeout_secs = 60)]
2475    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2476        let mut test_setup = TestSetup::new().await?;
2477        test_setup
2478            .stream_actor
2479            .define_recording(&test_setup.client, 0.into())
2480            .await?;
2481
2482        let borrow_id = 1;
2483        let tensor_ref = test_setup.next_ref();
2484        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2485
2486        test_setup
2487            .stream_actor
2488            .borrow_create(
2489                &test_setup.client,
2490                borrow_id,
2491                tensor_ref,
2492                first_use_sender.clone(),
2493            )
2494            .await?;
2495
2496        test_setup
2497            .stream_actor
2498            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2499            .await?;
2500
2501        assert_actor_failed_with_msg(
2502            &mut test_setup.stream_actor.status(),
2503            "duplicate borrow create in recording".into(),
2504        )
2505        .await;
2506
2507        Ok(())
2508    }
2509
2510    #[async_timed_test(timeout_secs = 60)]
2511    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2512        let test_setup = TestSetup::new().await?;
2513        test_setup
2514            .stream_actor
2515            .define_recording(&test_setup.client, 0.into())
2516            .await?;
2517
2518        let borrow_id = 1;
2519        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2520
2521        test_setup
2522            .stream_actor
2523            .borrow_drop(
2524                &test_setup.client,
2525                borrow_id,
2526                Arc::new(Mutex::new(last_use_receiver)),
2527            )
2528            .await?;
2529
2530        assert_actor_failed_with_msg(
2531            &mut test_setup.stream_actor.status(),
2532            "borrow drop for borrow not defined in recording".into(),
2533        )
2534        .await;
2535
2536        Ok(())
2537    }
2538
2539    #[async_timed_test(timeout_secs = 60)]
2540    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2541        let mut test_setup = TestSetup::new().await?;
2542        test_setup
2543            .stream_actor
2544            .define_recording(&test_setup.client, 0.into())
2545            .await?;
2546
2547        let borrow_id = 1;
2548        let tensor_ref = test_setup.next_ref();
2549        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2550
2551        test_setup
2552            .stream_actor
2553            .borrow_create(
2554                &test_setup.client,
2555                borrow_id,
2556                tensor_ref,
2557                first_use_sender.clone(),
2558            )
2559            .await?;
2560
2561        // Attempt to finalize the recording without dropping the borrow
2562        test_setup
2563            .stream_actor
2564            .finalize_recording(&test_setup.client, 0.into())
2565            .await?;
2566
2567        assert_actor_failed_with_msg(
2568            &mut test_setup.stream_actor.status(),
2569            "all borrows created within recording must be dropped within recording".into(),
2570        )
2571        .await;
2572
2573        Ok(())
2574    }
2575}