monarch_tensor_worker/
stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::Entry;
13use std::future::Future;
14use std::sync::Arc;
15use std::sync::OnceLock;
16use std::time::Duration;
17
18use anyhow::Result;
19use anyhow::anyhow;
20use anyhow::bail;
21use anyhow::ensure;
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorId;
25use hyperactor::ActorRef;
26use hyperactor::Context;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::PortHandle;
31use hyperactor::actor::ActorHandle;
32use hyperactor::forward;
33use hyperactor::mailbox::OncePortHandle;
34use hyperactor::mailbox::PortReceiver;
35use hyperactor::proc::Proc;
36use monarch_hyperactor::actor::PythonMessage;
37use monarch_hyperactor::actor::PythonMessageKind;
38use monarch_hyperactor::buffers::Buffer;
39use monarch_hyperactor::local_state_broker::BrokerId;
40use monarch_hyperactor::local_state_broker::LocalState;
41use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
42use monarch_messages::controller::ControllerMessageClient;
43use monarch_messages::controller::Seq;
44use monarch_messages::controller::WorkerError;
45use monarch_messages::worker::ActorCallParams;
46use monarch_messages::worker::ActorMethodParams;
47use monarch_messages::worker::ArgsKwargs;
48use monarch_messages::worker::CallFunctionError;
49use monarch_messages::worker::CallFunctionParams;
50use monarch_messages::worker::SeqError;
51use monarch_messages::worker::StreamRef;
52use monarch_types::PyTree;
53use monarch_types::SerializablePyErr;
54use monarch_types::TryIntoPyObjectUnsafe;
55use pyo3::prelude::*;
56use tokio::runtime::Handle;
57use tokio::sync::Mutex;
58use tokio::task::JoinHandle;
59use torch_sys_cuda::cuda::Event;
60use torch_sys_cuda::cuda::Stream;
61use torch_sys2::CloneUnsafe;
62use torch_sys2::CudaDevice;
63use torch_sys2::TensorCell;
64use torch_sys2::deep_clone;
65use torch_sys2::factory_empty;
66use torch_sys2::factory_zeros;
67use tracing_subscriber::fmt::Subscriber;
68use typeuri::Named;
69
70use crate::ControllerActor;
71use crate::DeviceMesh;
72use crate::Factory;
73use crate::Reduction;
74use crate::Ref;
75use crate::ResolvableFunction;
76use crate::StreamCreationMode;
77use crate::WireValue;
78use crate::comm::CommMessage;
79use crate::comm::CommMessageClient;
80use crate::comm::NcclCommActor;
81
82pub type TensorCellResult = Result<TensorCell, Arc<SeqError>>;
83
84// These thread locals are accessed by the python runtime for debugging sessions.
85thread_local! {
86    pub static CONTROLLER_ACTOR_REF: OnceCell<ActorRef<ControllerActor>> = const { OnceCell::new() };
87    pub static PROC: OnceCell<Proc> = const { OnceCell::new() };
88    pub static ROOT_ACTOR_ID: OnceCell<ActorId> = const { OnceCell::new() };
89}
90
91fn pickle_python_result(
92    py: Python<'_>,
93    result: Bound<'_, PyAny>,
94    worker_rank: usize,
95) -> Result<PythonMessage, anyhow::Error> {
96    let pickle = py
97        .import("monarch._src.actor.actor_mesh")
98        .unwrap()
99        .getattr("_pickle")
100        .unwrap();
101    let mut data: Buffer = pickle
102        .call1((result,))
103        .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
104        .extract()
105        .unwrap();
106    Ok(PythonMessage::new_from_buf(
107        PythonMessageKind::Result {
108            rank: Some(worker_rank),
109        },
110        data.take_part(),
111    ))
112}
113
114#[derive(Debug)]
115struct Recording {
116    messages: Vec<StreamMessage>,
117}
118
119impl Recording {
120    fn new() -> Self {
121        Self {
122            messages: Vec::new(),
123        }
124    }
125}
126
127#[derive(Debug, PartialEq)]
128enum RecordingState {
129    Defining {
130        recording: Ref,
131        // Set of borrow ids used to track proper borrow usage inside
132        // a recording.
133        defined_borrows: HashSet<u64>,
134    },
135    Running,
136}
137
138/// Messages handled by the stream. Generally these are stream-local versions of
139/// [`crate::WorkerMessage`].
140#[derive(Handler, HandleClient, Debug, Named)]
141pub enum StreamMessage {
142    CallFunction(
143        CallFunctionParams,
144        HashMap<Ref, DeviceMesh>,
145        HashMap<Ref, (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>)>,
146    ),
147
148    BorrowCreate {
149        /// Id for the borrow.
150        borrow: u64,
151        /// Tensor to borrow.
152        tensor: Ref,
153        /// Port for sending the first use CUDA event + borrowed tensor to
154        /// the borrower.
155        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
156    },
157
158    BorrowFirstUse {
159        /// Id for the borrow.
160        borrow: u64,
161        /// Ref for storing the borrowed tensor.
162        result: Ref,
163        /// Port for receiving the first use CUDA event + borrowed tensor from
164        /// the provider stream.
165        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
166    },
167
168    BorrowLastUse {
169        /// Id for the borrow.
170        borrow: u64,
171        /// Ref for the borrowed tensor.
172        result: Ref,
173        /// Port for sending the last use CUDA event and borrowed tensor.
174        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
175    },
176
177    BorrowDrop {
178        borrow: u64,
179        /// Port for receiving the last use CUDA event and borrowed tensor.
180        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
181    },
182
183    DeleteRefs(Vec<Ref>),
184
185    RequestStatus(#[reply] OncePortHandle<()>),
186
187    InitComm(ActorHandle<NcclCommActor>),
188
189    Reduce {
190        comm: Arc<ActorHandle<NcclCommActor>>,
191        dim_size: i64,
192        result: Ref,
193        local_tensor: Ref,
194        factory: Factory,
195        reduction: Reduction,
196        scatter: bool,
197        in_place: bool,
198        out: Option<Ref>,
199    },
200
201    SendTensor {
202        result: Ref,
203        from_rank: Option<usize>,
204        to_rank: Option<usize>,
205        tensor: Ref,
206        factory: Factory,
207        comm: Arc<ActorHandle<NcclCommActor>>,
208    },
209
210    SendValue {
211        seq: Seq,
212        worker_actor_id: ActorId,
213        mutates: Vec<Ref>,
214        function: Option<ResolvableFunction>,
215        args_kwargs: ArgsKwargs,
216        device_meshes: HashMap<Ref, DeviceMesh>,
217    },
218
219    DefineRecording {
220        recording: Ref,
221    },
222
223    FinalizeRecording {
224        recording: Ref,
225    },
226
227    CallRecording {
228        seq: Seq,
229        recording: Ref,
230        results: Vec<Ref>,
231        actuals: Vec<Ref>,
232    },
233
234    RecordingFormal {
235        result: Ref,
236        argument_index: usize,
237    },
238
239    RecordingResult {
240        result: Ref,
241        output_index: usize,
242    },
243
244    SetRefUnitTestsOnly(Ref, WireValue),
245
246    SetTensorRefUnitTestsOnly(Ref, TensorCellResult),
247
248    GetRefUnitTestsOnly(
249        Ref, // value
250        #[reply] OncePortHandle<Option<Result<WireValue, String>>>,
251    ),
252
253    GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
254
255    SendResultOfActorCall(ActorCallParams),
256    CallActorMethod(ActorMethodParams),
257}
258
259impl StreamMessage {
260    fn clone_for_recording(&self) -> Self {
261        match self {
262            StreamMessage::RecordingFormal {
263                result,
264                argument_index,
265            } => StreamMessage::RecordingFormal {
266                result: *result,
267                argument_index: *argument_index,
268            },
269            StreamMessage::RecordingResult {
270                result,
271                output_index,
272            } => StreamMessage::RecordingResult {
273                result: *result,
274                output_index: *output_index,
275            },
276            StreamMessage::DeleteRefs(refs) => StreamMessage::DeleteRefs(refs.clone()),
277            StreamMessage::CallFunction(params, device_meshes, remote_process_groups) => {
278                StreamMessage::CallFunction(
279                    params.clone(),
280                    device_meshes.clone(),
281                    remote_process_groups.clone(),
282                )
283            }
284            StreamMessage::BorrowCreate {
285                borrow,
286                tensor,
287                first_use_sender,
288            } => StreamMessage::BorrowCreate {
289                borrow: *borrow,
290                tensor: *tensor,
291                first_use_sender: first_use_sender.clone(),
292            },
293            StreamMessage::BorrowFirstUse {
294                borrow,
295                result,
296                first_use_receiver,
297            } => StreamMessage::BorrowFirstUse {
298                borrow: *borrow,
299                result: *result,
300                first_use_receiver: first_use_receiver.clone(),
301            },
302            StreamMessage::BorrowLastUse {
303                borrow,
304                result,
305                last_use_sender,
306            } => StreamMessage::BorrowLastUse {
307                borrow: *borrow,
308                result: *result,
309                last_use_sender: last_use_sender.clone(),
310            },
311            StreamMessage::BorrowDrop {
312                borrow,
313                last_use_receiver,
314            } => StreamMessage::BorrowDrop {
315                borrow: *borrow,
316                last_use_receiver: last_use_receiver.clone(),
317            },
318            StreamMessage::Reduce {
319                comm,
320                dim_size,
321                result,
322                local_tensor,
323                factory,
324                reduction,
325                scatter,
326                in_place,
327                out,
328            } => StreamMessage::Reduce {
329                comm: comm.clone(),
330                dim_size: *dim_size,
331                result: *result,
332                local_tensor: *local_tensor,
333                factory: factory.clone(),
334                reduction: reduction.clone(),
335                scatter: *scatter,
336                in_place: *in_place,
337                out: out.clone(),
338            },
339            StreamMessage::SendTensor {
340                result,
341                from_rank,
342                to_rank,
343                tensor,
344                factory,
345                comm,
346            } => StreamMessage::SendTensor {
347                result: *result,
348                from_rank: *from_rank,
349                to_rank: *to_rank,
350                tensor: *tensor,
351                factory: factory.clone(),
352                comm: comm.clone(),
353            },
354            other => panic!(
355                "StreamMessage variant not supported in recording: {:?}",
356                other
357            ),
358        }
359    }
360
361    // Get the set of refs that this message defines.
362    fn get_defined_refs(&self) -> HashSet<Ref> {
363        match self {
364            StreamMessage::RecordingFormal { result, .. } => HashSet::from([*result]),
365            StreamMessage::CallFunction(params, ..) => {
366                params.results.iter().filter_map(|&ref_| ref_).collect()
367            }
368            StreamMessage::BorrowFirstUse { result, .. } => HashSet::from([*result]),
369            StreamMessage::Reduce { result, .. } => HashSet::from([*result]),
370            StreamMessage::SendTensor {
371                result, from_rank, ..
372            } => {
373                if from_rank.is_some() {
374                    HashSet::from([*result])
375                } else {
376                    HashSet::new()
377                }
378            }
379            // TODO(slurye): Add SendValue eventually.
380            _ => HashSet::new(),
381        }
382    }
383
384    // Get the set of refs that this message mutates.
385    fn get_mutated_refs(&self) -> HashSet<Ref> {
386        match self {
387            StreamMessage::CallFunction(params, ..) => HashSet::from_iter(params.mutates.clone()),
388            StreamMessage::Reduce {
389                out,
390                in_place,
391                local_tensor,
392                ..
393            } => {
394                if *in_place {
395                    HashSet::from([*local_tensor])
396                } else if let Some(out) = out {
397                    HashSet::from([*out])
398                } else {
399                    HashSet::new()
400                }
401            }
402            // TODO(slurye): Add SendValue eventually.
403            _ => HashSet::new(),
404        }
405    }
406}
407
408/// A stream represents a linear sequence of execution. Operations on different
409/// streams can execute concurrently.
410///
411/// For CUDA operators, streams will invoke the corresponding stream management
412/// APIs to perform synchronization.
413///
414/// For CPU operators, streams will just execute synchronously on their own OS
415/// thread.
416#[derive(Debug)]
417pub struct StreamActor {
418    _world_size: usize,
419    rank: usize,
420    /// Mapping of refs in the controller environment to TensorIndex in this
421    /// stream's local environment.
422    // TODO(agallagher): Use `ValueError` as the error type.
423    env: HashMap<Ref, Result<PyObject, Arc<SeqError>>>,
424    /// How to create the stream.
425    creation_mode: StreamCreationMode,
426    /// CUDA stream that this actor will enqueue operations on. None if "device"
427    /// is not a CUDA device.
428    /// NOTE: We lazily create the stream, so that we do it from the dedicated
429    /// Stream OS thread as, otherwise, we see deadlocks when done from
430    /// unexpected threads.
431    cuda_stream: OnceLock<Option<Stream>>,
432    /// Device this stream should be scheduled on.
433    device: Option<CudaDevice>,
434    /// Communicator for this stream. Optional as we lazily initialize it.
435    comm: Option<ActorHandle<NcclCommActor>>,
436    /// Actor ref of the controller that created this stream.
437    controller_actor: ActorRef<ControllerActor>,
438    remote_process_groups: HashMap<Ref, PyObject>,
439    recordings: HashMap<Ref, Recording>,
440    active_recording: Option<RecordingState>,
441    respond_with_python_message: bool,
442    last_seq_error: Option<Arc<SeqError>>,
443}
444
445/// Parameters for creating a [`Stream`].
446#[derive(Debug, Clone)]
447pub struct StreamParams {
448    pub world_size: usize,
449    pub rank: usize,
450    /// Controls how the underlying CUDA stream is created.
451    pub creation_mode: StreamCreationMode,
452    /// Id of this stream in the worker actor's stream table.
453    pub id: StreamRef,
454    /// Device this stream should be scheduled on. If none, don't do stream
455    /// synchronization.
456    pub device: Option<CudaDevice>,
457    /// Actor ref of the controller that created this stream.
458    pub controller_actor: ActorRef<ControllerActor>,
459    pub respond_with_python_message: bool,
460}
461
462impl StreamActor {
463    pub fn new(
464        StreamParams {
465            world_size,
466            rank,
467            id: _,
468            device,
469            controller_actor,
470            creation_mode,
471            respond_with_python_message,
472        }: StreamParams,
473    ) -> Self {
474        Self {
475            _world_size: world_size,
476            rank,
477            env: HashMap::new(),
478            creation_mode,
479            cuda_stream: OnceLock::new(),
480            device,
481            comm: None,
482            controller_actor,
483            remote_process_groups: HashMap::new(),
484            recordings: HashMap::new(),
485            active_recording: None,
486            respond_with_python_message,
487            last_seq_error: None,
488        }
489    }
490}
491
492#[async_trait]
493impl Actor for StreamActor {
494    async fn init(&mut self, cx: &Instance<Self>) -> Result<()> {
495        // These thread locals are exposed via python functions, so we need to set them in the
496        // same thread that python will run in. That means we need to initialize them here in
497        // StreamActor::init instead of in StreamActor::new.
498        CONTROLLER_ACTOR_REF.with(|controller_actor_ref| {
499            controller_actor_ref.set(self.controller_actor.clone()).ok()
500        });
501        PROC.with(|proc| proc.set(cx.proc().clone()).ok());
502        ROOT_ACTOR_ID.with(|root_actor_id| {
503            root_actor_id
504                .set(ActorId::root(
505                    cx.self_id().proc_id().clone(),
506                    cx.self_id().name().to_string(),
507                ))
508                .ok()
509        });
510        // Set the current stream for this actor thread.
511        if let Some(stream) = self.cuda_stream() {
512            Stream::set_current_stream(stream);
513        }
514        Ok(())
515    }
516
517    /// Specialize spawn_server_task for StreamActor, because we want to run the stream on a
518    /// dedicated OS thread. This is because:
519    ///   - Streams do expensive blocking CPU operations (like calling CPU kernels).
520    ///   - Torch/CUDA make use of thread-local state, so moving tasks across
521    ///     threads is problematic.
522    fn spawn_server_task<F>(future: F) -> JoinHandle<F::Output>
523    where
524        F: Future + Send + 'static,
525        F::Output: Send + 'static,
526    {
527        let (join_tx, join_rx) = tokio::sync::oneshot::channel();
528        // It is important that we spawn a standalone thread for the work here,
529        // as opposed to using `spawn_blocking` to spawn a tokio-managed thread.
530        // This is because the worker stream may call uninterruptible FFI code
531        // that can deadlock (CUDA, NCCL).
532        // If we use a tokio-managed blocking thread, then runtime teardown will
533        // try to wait for tasks on that thread to reach an await point, and
534        // hang forever.
535        let builder = std::thread::Builder::new().name("worker-stream".to_string());
536        let _thread_handle = builder.spawn(move || {
537            // Spawn a new thread with a single-threaded tokio runtime to run the
538            // actor loop.  We avoid the current-threaded runtime, so that we can
539            // use `block_in_place` for nested async-to-sync-to-async flows.
540            let rt = tokio::runtime::Builder::new_multi_thread()
541                .worker_threads(1)
542                .enable_all()
543                .build()
544                .unwrap();
545            let result = rt.block_on(async {
546                tokio::task::block_in_place(|| {
547                    // Allow e.g. destructing py objects on this thread, which
548                    // can happen at shutdown when the a stream actors env map
549                    // for rvalues is dropped (e.g. P1673311499).
550                    // https://github.com/PyO3/pyo3/discussions/3499
551                    Python::with_gil(|py| {
552                        py.allow_threads(|| {
553                            let result = Handle::current().block_on(future);
554                            if join_tx.send(result).is_err() {
555                                panic!("could not send join result")
556                            }
557                        })
558                    })
559                })
560            });
561            rt.shutdown_timeout(Duration::from_weeks(1));
562            result
563        });
564
565        // In order to bridge the synchronous join handle with the async world,
566        // smuggle the result through a channel.
567        tokio::spawn(async move { join_rx.await.unwrap() })
568    }
569}
570
571/// The arguments we accept as inputs to Python function calls.
572#[derive(Debug)]
573enum PyArg {
574    PyObject(PyObject),
575}
576
577/// Serialize into a `PyObject`.
578impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &PyArg {
579    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
580        match self {
581            PyArg::PyObject(obj) => Ok(obj.clone_ref(py).into_bound(py)),
582        }
583    }
584}
585
586impl StreamActor {
587    fn tensor_to_pyobject(tensor_cell: TensorCell) -> PyObject {
588        Python::with_gil(|py| {
589            // SAFETY: Cloning a tensor was unsafe because we were tracking their references like
590            // Rust objects (single mutable reference or many immutable references). We are
591            // removing this functionality in upcoming patches, so we use the unsafe version here
592            // until that happens.
593            let tensor = unsafe {
594                // Get the owned tensor by calling clone_unsafe on the reference
595                tensor_cell.get_unchecked().clone_unsafe()
596            };
597            tensor.into_pyobject(py).unwrap().unbind()
598        })
599    }
600
601    /// Extract a TensorCell from a PyObject.
602    /// SAFETY: Uses new to create the TensorCell. Caller must ensure the PyObject
603    /// contains a valid tensor.
604    fn pyobject_to_tensor(py: Python<'_>, pyobj: &PyObject) -> PyResult<TensorCell> {
605        use torch_sys2::Tensor;
606        let tensor = pyobj.bind(py).extract::<Tensor>()?;
607        // Create a new TensorCell from the extracted tensor
608        Ok(TensorCell::new(tensor))
609    }
610
611    fn cuda_stream(&self) -> Option<&Stream> {
612        self.cuda_stream
613            .get_or_init(|| {
614                self.device.map(|device| match self.creation_mode {
615                    StreamCreationMode::UseDefaultStream => {
616                        Stream::get_current_stream_on_device(device)
617                    }
618                    StreamCreationMode::CreateNewStream => Stream::new_with_device(device),
619                })
620            })
621            .as_ref()
622    }
623
624    fn ref_to_pyobject(&self, ref_: &Ref) -> Result<PyObject, CallFunctionError> {
625        let pyobject = self
626            .env
627            .get(ref_)
628            .ok_or_else(|| CallFunctionError::RefNotFound(*ref_))?;
629        match pyobject {
630            Ok(val) => Ok(val.clone()),
631            Err(err) => Err(CallFunctionError::DependentError(err.clone())),
632        }
633    }
634
635    async fn report_seq_error(
636        &mut self,
637        cx: &Context<'_, Self>,
638        seq: Seq,
639        error: CallFunctionError,
640    ) -> Result<Arc<SeqError>, anyhow::Error> {
641        match error {
642            CallFunctionError::DependentError(root) => Ok(root),
643            CallFunctionError::Error(e) => {
644                if self.active_recording.is_none() {
645                    let worker_error = WorkerError {
646                        backtrace: format!("{e}"),
647                        worker_actor_id: cx.self_id().clone(),
648                    };
649                    tracing::info!("Propagating remote function error to client: {worker_error}");
650                    self.controller_actor
651                        .remote_function_failed(cx, seq, worker_error)
652                        .await?
653                }
654                let err = Arc::new(SeqError { seq, error: e });
655                self.last_seq_error = Some(err.clone());
656                Ok(err)
657            }
658        }
659    }
660
661    async fn try_define<F>(
662        &mut self,
663        cx: &Context<'_, Self>,
664        seq: Seq,
665        result_refs: Vec<Option<Ref>>,
666        mutates: &Vec<Ref>,
667        f: F,
668    ) -> Result<()>
669    where
670        F: AsyncFnOnce(&mut Self) -> Result<Vec<PyObject>, CallFunctionError>,
671    {
672        let actual_results = f(self).await;
673        // Check if the expected number of returns is correct, otherwise convert
674        // into an error.
675        let op_results = actual_results.and_then(|actual_results| {
676            if result_refs.len() == actual_results.len() {
677                Ok(actual_results
678                    .into_iter()
679                    .zip(result_refs.iter())
680                    .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result)))
681                    .collect::<Vec<(Ref, PyObject)>>())
682            } else {
683                Err(CallFunctionError::UnexpectedNumberOfReturns(
684                    result_refs.len(),
685                    actual_results.len(),
686                ))
687            }
688        });
689
690        // Propagate the results (either the actual values or an error) to the
691        // right entries in the global env mapping.
692        match op_results {
693            Ok(op_results) => {
694                for (ref_, pyobject) in op_results.into_iter() {
695                    let prev = self.env.insert(ref_, Ok(pyobject));
696                    assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
697                }
698            }
699            Err(err) => {
700                let err = self.report_seq_error(cx, seq, err).await?;
701                for ref_ in result_refs {
702                    match ref_ {
703                        Some(ref_) => {
704                            let prev = self.env.insert(ref_, Err(err.clone()));
705                            assert!(prev.is_none(), "Duplicate write to reference: {:?}", ref_);
706                        }
707                        None => {}
708                    }
709                }
710                for ref_ in mutates {
711                    self.env.insert(*ref_, Err(err.clone()));
712                }
713            }
714        }
715        Ok(())
716    }
717
718    fn call_python_fn<'py>(
719        &mut self,
720        py: Python<'py>,
721        _cx: &Context<Self>,
722        function: Option<ResolvableFunction>,
723        args_kwargs: ArgsKwargs,
724        _mutates: &[Ref],
725        device_meshes: HashMap<Ref, DeviceMesh>,
726        remote_process_groups: HashMap<
727            Ref,
728            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
729        >,
730    ) -> Result<Bound<'py, PyAny>, CallFunctionError> {
731        let (args_tuple, kwargs_dict) = args_kwargs
732            .to_python(py)
733            .map_err(|e| CallFunctionError::Error(e.into()))?;
734        let function = function
735            .map(|function| {
736                function.resolve(py).map_err(|e| {
737                    CallFunctionError::InvalidRemoteFunction(format!(
738                        "failed to resolve function {}: {}",
739                        function,
740                        SerializablePyErr::from(py, &e)
741                    ))
742                })
743            })
744            .transpose()?;
745
746        let remote_process_groups = remote_process_groups
747            .into_iter()
748            .map(|(gref, (_mesh, _dims, _comm))| {
749                let group = match self.remote_process_groups.entry(gref) {
750                    Entry::Occupied(ent) => ent.get().clone_ref(py),
751                    Entry::Vacant(_ent) => {
752                        panic!("no longer implemented");
753                    }
754                };
755                PyResult::Ok((gref, group))
756            })
757            .collect::<Result<HashMap<_, _>, _>>()
758            .map_err(SerializablePyErr::from_fn(py))?;
759
760        let resolve = |val: Bound<'py, PyAny>| {
761            val.extract::<PyTree<PyObject>>()
762                .map_err(SerializablePyErr::from_fn(py))?
763                .try_into_map(|obj| {
764                    Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) {
765                        if let Some(mesh) = device_meshes.get(&ref_) {
766                            PyArg::PyObject(
767                                Py::new(py, mesh.clone())
768                                    .map_err(SerializablePyErr::from_fn(py))?
769                                    .into(),
770                            )
771                        } else if let Some(pg) = remote_process_groups.get(&ref_) {
772                            PyArg::PyObject(pg.clone_ref(py))
773                        } else {
774                            let pyobj = self.ref_to_pyobject(&ref_)?;
775                            PyArg::PyObject(pyobj)
776                        }
777                    } else {
778                        PyArg::PyObject(obj)
779                    })
780                })
781        };
782
783        // Resolve args and kwargs
784        let py_args: Vec<PyTree<PyArg>> = args_tuple
785            .iter()
786            .map(&resolve)
787            .collect::<Result<_, CallFunctionError>>()?;
788
789        let py_kwargs: HashMap<String, PyTree<PyArg>> = kwargs_dict
790            .iter()
791            .map(|(k, v)| {
792                let key = k
793                    .extract::<String>()
794                    .map_err(SerializablePyErr::from_fn(py))?;
795                let value = resolve(v)?;
796                Ok((key, value))
797            })
798            .collect::<Result<_, CallFunctionError>>()?;
799
800        // Call function.
801        // Use custom subscriber to route Worker messages to stdout.
802        let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
803        let result: Bound<'_, PyAny> =
804            tracing::subscriber::with_default(scoped_subscriber, || {
805                // TODO(agallagher): The args/kwargs conversion traits generate
806                // the appropriate types here, but they get casted to `PyAny`.
807                // It'd be nice to make `TryToPyObjectUnsafe` take a template
808                // arg for the converted py object to avoid this downcast.
809                // SAFETY: Tensor operations were unsafe because we were tracking their references
810                // like Rust objects (single mutable reference or many immutable references). We are
811                // removing this functionality in upcoming patches, so we use the unsafe version here
812                // until that happens.
813                let args = unsafe { py_args.try_to_object_unsafe(py) }
814                    .map_err(SerializablePyErr::from_fn(py))?;
815                // SAFETY: Same as above - reference tracking functionality is being removed.
816                let kwargs = &unsafe { py_kwargs.try_to_object_unsafe(py) }
817                    .map_err(SerializablePyErr::from_fn(py))?;
818
819                if let Some(function) = function {
820                    function
821                        .call(args, Some(kwargs))
822                        .map_err(SerializablePyErr::from_fn(py))
823                } else {
824                    Ok(args.get_item(0).unwrap())
825                }
826            })?;
827        Ok(result)
828    }
829
830    fn call_python_fn_pytree(
831        &mut self,
832        cx: &hyperactor::Context<Self>,
833        function: ResolvableFunction,
834        args_kwargs: ArgsKwargs,
835        mutates: &[Ref],
836        device_meshes: HashMap<Ref, DeviceMesh>,
837        remote_process_groups: HashMap<
838            Ref,
839            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
840        >,
841    ) -> Result<PyTree<PyObject>, CallFunctionError> {
842        Python::with_gil(|py| {
843            let result = self.call_python_fn(
844                py,
845                cx,
846                Some(function),
847                args_kwargs,
848                mutates,
849                device_meshes,
850                remote_process_groups,
851            )?;
852            Ok(PyTree::<PyObject>::extract_bound(&result)
853                .map_err(SerializablePyErr::from_fn(py))?)
854        })
855    }
856    /// Retrieve `ref_` or create a fake value with the provided factory if it
857    /// is an error. We use this for collective calls, where even if there was
858    /// an upstream failure, we still have participate in the collective to
859    /// avoid deadlocking the other ranks. It's okay to just put a nonsense
860    /// value here of the correct shape; the controller will have been notified
861    /// of the upstream failure and will know to ignore everything dependent on
862    /// it.
863    fn get_or_fake_on_err(&self, ref_: Ref, factory: &Factory) -> Result<TensorCell> {
864        let pyobject = self
865            .env
866            .get(&ref_)
867            .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
868
869        match pyobject {
870            Ok(val) => Python::with_gil(|py| {
871                Self::pyobject_to_tensor(py, val)
872                    .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))
873            }),
874            Err(_) => {
875                let t = factory_zeros(&factory.size, factory.dtype, factory.layout, factory.device);
876                Ok(TensorCell::new(t))
877            }
878        }
879    }
880
881    fn get_defining_recording(&mut self) -> Option<(&mut Recording, &mut HashSet<u64>)> {
882        self.active_recording
883            .as_mut()
884            .and_then(|state| match state {
885                RecordingState::Defining {
886                    recording,
887                    defined_borrows,
888                } => {
889                    match self.recordings.get_mut(recording) {
890                        Some(recording) => Some((recording, defined_borrows)),
891                        // Panic, because this would be a logic error in the program.
892                        None => panic!("recording not found: {:?}", recording),
893                    }
894                }
895                RecordingState::Running => None,
896            })
897    }
898
899    fn get_first_error(&self, refs: &[Ref]) -> Result<Option<Arc<SeqError>>> {
900        for ref_ in refs {
901            let rvalue_or_err = self
902                .env
903                .get(ref_)
904                .ok_or_else(|| anyhow!("tensor not found in stream: {ref_:#?}"))?;
905            if let Err(err) = rvalue_or_err {
906                return Ok(Some(err.clone()));
907            }
908        }
909        Ok(None)
910    }
911    async fn send_value_python_message(
912        &mut self,
913        cx: &hyperactor::Context<'_, Self>,
914        seq: Seq,
915        mutates: Vec<Ref>,
916        function: Option<ResolvableFunction>,
917        args_kwargs: ArgsKwargs,
918        device_meshes: HashMap<Ref, DeviceMesh>,
919    ) -> Result<()> {
920        let rank = self.rank;
921        self.try_define(cx, seq, vec![], &vec![], async |self_| {
922            let python_message =
923                Python::with_gil(|py| -> Result<PythonMessage, CallFunctionError> {
924                    let python_result = tokio::task::block_in_place(|| {
925                        self_.call_python_fn(
926                            py,
927                            cx,
928                            function,
929                            args_kwargs,
930                            &mutates,
931                            device_meshes,
932                            HashMap::new(),
933                        )
934                    })?;
935                    pickle_python_result(py, python_result, rank).map_err(CallFunctionError::Error)
936                })?;
937            let ser = wirevalue::Any::serialize(&python_message).unwrap();
938            self_
939                .controller_actor
940                .fetch_result(cx, seq, Ok(ser))
941                .await?;
942            Ok(vec![])
943        })
944        .await
945    }
946    fn define_ref(&mut self, dest: Ref, src: Ref) -> Result<(), anyhow::Error> {
947        let rvalue = self
948            .env
949            .get(&src)
950            .ok_or_else(|| CallFunctionError::RefNotFound(src))?;
951        self.env
952            .insert(dest, Python::with_gil(|_py| rvalue.clone()));
953        Ok(())
954    }
955    async fn call_actor(
956        &mut self,
957        cx: &Context<'_, Self>,
958        params: ActorCallParams,
959    ) -> Result<PyObject, CallFunctionError> {
960        let local_state: Result<Vec<PyObject>> = Python::with_gil(|_py| {
961            params
962                .local_state
963                .into_iter()
964                .map(|elem| {
965                    let pyobj = self.ref_to_pyobject(&elem)?;
966                    Ok(pyobj.into_any())
967                })
968                .collect()
969        });
970
971        let (send, recv) = cx.open_once_port();
972        let state = LocalState {
973            response_port: send,
974            state: local_state?,
975        };
976        let x: u64 = params.seq.into();
977        let message = LocalStateBrokerMessage::Set(x as usize, state);
978
979        let broker = BrokerId::new(params.broker_id).resolve(cx).await;
980        broker
981            .send(cx, message)
982            .map_err(|e| CallFunctionError::Error(e.into()))?;
983        let result = recv
984            .recv()
985            .await
986            .map_err(|e| CallFunctionError::Error(e.into()))?;
987
988        result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
989    }
990}
991
992#[async_trait]
993#[forward(StreamMessage)]
994impl StreamMessageHandler for StreamActor {
995    async fn call_function(
996        &mut self,
997        cx: &Context<Self>,
998        params: CallFunctionParams,
999        device_meshes: HashMap<Ref, DeviceMesh>,
1000        remote_process_groups: HashMap<
1001            Ref,
1002            (DeviceMesh, Vec<String>, Arc<ActorHandle<NcclCommActor>>),
1003        >,
1004    ) -> Result<()> {
1005        if let Some((recording, _)) = self.get_defining_recording() {
1006            recording.messages.push(StreamMessage::CallFunction(
1007                params,
1008                device_meshes,
1009                remote_process_groups,
1010            ));
1011            return Ok(());
1012        }
1013
1014        params.function.panic_if_requested();
1015        self.try_define(
1016            cx,
1017            params.seq,
1018            params.results,
1019            &params.mutates,
1020            async |self| {
1021                tokio::task::block_in_place(|| {
1022                    self.call_python_fn_pytree(
1023                        cx,
1024                        params.function,
1025                        params.args_kwargs,
1026                        &params.mutates,
1027                        device_meshes,
1028                        remote_process_groups,
1029                    )
1030                    .map(|results| results.into_leaves())
1031                })
1032            },
1033        )
1034        .await?;
1035        Ok(())
1036    }
1037
1038    async fn borrow_create(
1039        &mut self,
1040        _cx: &Context<Self>,
1041        borrow: u64,
1042        tensor: Ref,
1043        first_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1044    ) -> Result<()> {
1045        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1046            recording.messages.push(StreamMessage::BorrowCreate {
1047                borrow,
1048                tensor,
1049                first_use_sender,
1050            });
1051            ensure!(
1052                defined_borrows.insert(borrow),
1053                "duplicate borrow create in recording"
1054            );
1055            return Ok(());
1056        }
1057
1058        let pyobj_result = self
1059            .env
1060            .get(&tensor)
1061            .ok_or_else(|| anyhow!("invalid reference for borrow_create: {:#?}", tensor))?;
1062
1063        let result = match pyobj_result {
1064            Ok(pyobj) => Python::with_gil(|py| Ok(Self::pyobject_to_tensor(py, pyobj).unwrap())),
1065            Err(e) => Err(e.clone()),
1066        };
1067
1068        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1069        first_use_sender.send((event, result)).map_err(|err| {
1070            anyhow!(
1071                "failed sending first use event for borrow {:?}: {:?}",
1072                borrow,
1073                err
1074            )
1075        })
1076    }
1077
1078    async fn borrow_first_use(
1079        &mut self,
1080        _cx: &Context<Self>,
1081        borrow: u64,
1082        result: Ref,
1083        first_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1084    ) -> Result<()> {
1085        if let Some((recording, _)) = self.get_defining_recording() {
1086            recording.messages.push(StreamMessage::BorrowFirstUse {
1087                borrow,
1088                result,
1089                first_use_receiver: first_use_receiver.clone(),
1090            });
1091            return Ok(());
1092        }
1093
1094        let (first_use_event, cell) =
1095            first_use_receiver
1096                .lock()
1097                .await
1098                .recv()
1099                .await
1100                .map_err(|err| {
1101                    anyhow!(
1102                        "failed receiving first use event for borrow {:?}: {:?}",
1103                        borrow,
1104                        err
1105                    )
1106                })?;
1107
1108        if let Some(stream) = self.cuda_stream() {
1109            stream.wait_event(
1110                &mut first_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1111            );
1112        }
1113        match cell {
1114            Ok(cell) => {
1115                let pyobj = Self::tensor_to_pyobject(cell);
1116                self.env.insert(result, Ok(pyobj));
1117            }
1118            Err(err) => {
1119                self.env.insert(result, Err(err.clone()));
1120            }
1121        }
1122        Ok(())
1123    }
1124
1125    async fn borrow_last_use(
1126        &mut self,
1127        _cx: &Context<Self>,
1128        borrow: u64,
1129        result: Ref,
1130        last_use_sender: PortHandle<(Option<Event>, TensorCellResult)>,
1131    ) -> Result<()> {
1132        if let Some((recording, _)) = self.get_defining_recording() {
1133            recording.messages.push(StreamMessage::BorrowLastUse {
1134                borrow,
1135                result,
1136                last_use_sender,
1137            });
1138            return Ok(());
1139        }
1140
1141        let event = self.cuda_stream().map(|stream| stream.record_event(None));
1142        let pyobj_or_err = self.env.remove(&result).ok_or(anyhow!(
1143            "Invalid reference for borrow_last_use: {result:#?}"
1144        ))?;
1145        let tensor = match pyobj_or_err {
1146            Ok(pyobj) => Ok(Python::with_gil(|py| {
1147                Self::pyobject_to_tensor(py, &pyobj).unwrap()
1148            })),
1149            Err(e) => Err(e),
1150        };
1151
1152        last_use_sender.send((event, tensor)).map_err(|err| {
1153            anyhow!(
1154                "failed sending last use event for borrow {:?}: {:?}",
1155                borrow,
1156                err
1157            )
1158        })
1159    }
1160
1161    async fn borrow_drop(
1162        &mut self,
1163        _cx: &Context<Self>,
1164        borrow: u64,
1165        last_use_receiver: Arc<Mutex<PortReceiver<(Option<Event>, TensorCellResult)>>>,
1166    ) -> Result<()> {
1167        if let Some((recording, defined_borrows)) = self.get_defining_recording() {
1168            recording.messages.push(StreamMessage::BorrowDrop {
1169                borrow,
1170                last_use_receiver: last_use_receiver.clone(),
1171            });
1172            ensure!(
1173                defined_borrows.remove(&borrow),
1174                "borrow drop for borrow not defined in recording"
1175            );
1176            return Ok(());
1177        }
1178
1179        // The borrowed cell isn't used directly, but we still want to receive it here
1180        // so that the underlying tensor isn't dropped until after we synchronize the
1181        // CUDA streams.
1182        let (last_use_event, _cell) =
1183            last_use_receiver.lock().await.recv().await.map_err(|err| {
1184                anyhow!(
1185                    "failed receiving last use event for borrow {:?}: {:?}",
1186                    borrow,
1187                    err
1188                )
1189            })?;
1190
1191        if let Some(stream) = self.cuda_stream() {
1192            stream.wait_event(
1193                &mut last_use_event.expect("sent borrow to CUDA stream, expected a CUDA event"),
1194            );
1195        }
1196        // let the cell drop.
1197        Ok(())
1198    }
1199
1200    async fn delete_refs(&mut self, _cx: &Context<Self>, refs: Vec<Ref>) -> Result<()> {
1201        if let Some((recording, _)) = self.get_defining_recording() {
1202            recording.messages.push(StreamMessage::DeleteRefs(refs));
1203            return Ok(());
1204        }
1205
1206        for ref_ in refs.iter() {
1207            self.env.remove(ref_);
1208        }
1209        Ok(())
1210    }
1211
1212    async fn request_status(&mut self, _cx: &Context<Self>) -> Result<()> {
1213        if self.get_defining_recording().is_some() {
1214            bail!("request_status not allowed in recording");
1215        }
1216
1217        Ok(())
1218    }
1219
1220    async fn init_comm(
1221        &mut self,
1222        _cx: &Context<Self>,
1223        comm: ActorHandle<NcclCommActor>,
1224    ) -> Result<()> {
1225        if self.get_defining_recording().is_some() {
1226            bail!("init_comm not allowed in recording");
1227        }
1228
1229        self.comm = Some(comm);
1230        Ok(())
1231    }
1232
1233    async fn reduce(
1234        &mut self,
1235        cx: &Context<Self>,
1236        comm: Arc<ActorHandle<NcclCommActor>>,
1237        dim_size: i64,
1238        result: Ref,
1239        local_tensor: Ref,
1240        factory: Factory,
1241        reduction: Reduction,
1242        scatter: bool,
1243        in_place: bool,
1244        out: Option<Ref>,
1245    ) -> Result<()> {
1246        if let Some((recording, _)) = self.get_defining_recording() {
1247            recording.messages.push(StreamMessage::Reduce {
1248                comm,
1249                dim_size,
1250                result,
1251                local_tensor,
1252                factory,
1253                reduction,
1254                scatter,
1255                in_place,
1256                out,
1257            });
1258            return Ok(());
1259        }
1260
1261        let stream = self
1262            .cuda_stream()
1263            .expect("reductions not yet supported for non-CUDA workers")
1264            .clone();
1265        let input_cell = self.get_or_fake_on_err(local_tensor, &factory)?;
1266        let out_cell = out
1267            .map(|out| self.get_or_fake_on_err(out, &factory))
1268            .transpose()?;
1269        let output_cell = match reduction {
1270            Reduction::Stack => {
1271                if scatter {
1272                    let output_cell = if in_place {
1273                        input_cell.clone()
1274                    } else {
1275                        out_cell.unwrap_or({
1276                            let borrow = input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1277                            let cloned = deep_clone(&borrow);
1278                            TensorCell::new(cloned)
1279                        })
1280                    };
1281                    comm.all_to_all_single(cx, output_cell.clone(), input_cell, stream)
1282                        .await?;
1283                    output_cell
1284                } else {
1285                    ensure!(
1286                        !in_place,
1287                        "in-place, non-scatter not supported for stack reduce"
1288                    );
1289
1290                    let output_cell = out_cell.unwrap_or({
1291                        // In Python, this would be [dim_size, *factory.sizes]
1292                        let sizes = [&[dim_size][..], &factory.size[..]].concat();
1293                        let output =
1294                            factory_empty(&sizes, factory.dtype, factory.layout, factory.device);
1295                        TensorCell::new(output)
1296                    });
1297
1298                    comm.all_gather_into_tensor(cx, output_cell.clone(), input_cell, stream)
1299                        .await?;
1300                    output_cell
1301                }
1302            }
1303            Reduction::ReduceOp(op) => {
1304                if scatter {
1305                    ensure!(!in_place, "in-place, scatter not supported for reduce");
1306
1307                    let output_cell = out_cell.unwrap_or({
1308                        let output = factory_empty(
1309                            &factory.size[1..],
1310                            factory.dtype,
1311                            factory.layout,
1312                            factory.device,
1313                        );
1314                        TensorCell::new(output)
1315                    });
1316                    comm.reduce_scatter_tensor(cx, output_cell.clone(), input_cell, op, stream)
1317                        .await?;
1318                    output_cell
1319                } else {
1320                    let output_cell = if in_place {
1321                        input_cell.clone()
1322                    } else {
1323                        out_cell.map_or(
1324                            {
1325                                let borrow =
1326                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1327                                let cloned = deep_clone(&borrow);
1328                                Ok(TensorCell::new(cloned))
1329                            },
1330                            |out_cell| -> Result<_, anyhow::Error> {
1331                                let mut out_borrow =
1332                                    out_cell.try_borrow_mut().map_err(|e| anyhow!("{e:?}"))?;
1333                                let in_borrow =
1334                                    input_cell.try_borrow().map_err(|e| anyhow!("{e:?}"))?;
1335                                out_borrow.copy_(&in_borrow);
1336                                drop(out_borrow);
1337                                Ok(out_cell)
1338                            },
1339                        )?
1340                    };
1341
1342                    comm.all_reduce(cx, output_cell.clone(), op, stream).await?;
1343                    output_cell
1344                }
1345            }
1346        };
1347
1348        let pyobj = Self::tensor_to_pyobject(output_cell);
1349        self.env.insert(result, Ok(pyobj));
1350        Ok(())
1351    }
1352
1353    async fn send_tensor(
1354        &mut self,
1355        cx: &Context<Self>,
1356        result: Ref,
1357        from_rank: Option<usize>,
1358        to_rank: Option<usize>,
1359        tensor: Ref,
1360        factory: Factory,
1361        comm: Arc<ActorHandle<NcclCommActor>>,
1362    ) -> Result<()> {
1363        if let Some((recording, _)) = self.get_defining_recording() {
1364            recording.messages.push(StreamMessage::SendTensor {
1365                result,
1366                from_rank,
1367                to_rank,
1368                tensor,
1369                factory,
1370                comm,
1371            });
1372            return Ok(());
1373        }
1374
1375        if to_rank.is_none() && from_rank.is_none() {
1376            bail!("tried to send tensor without a to/from rank");
1377        }
1378
1379        // Value is local, so we do not have to actually send it.
1380        if from_rank == to_rank {
1381            let input_cell: &std::result::Result<PyObject, Arc<SeqError>> =
1382                self.env
1383                    .get(&tensor)
1384                    .ok_or_else(|| anyhow!("tensor not found in stream: {tensor:#?}"))?;
1385            let output_cell: Result<PyObject, Arc<SeqError>> = match input_cell {
1386                Ok(pyobj) => {
1387                    Python::with_gil(|py| -> Result<PyObject, Arc<SeqError>> {
1388                        let input_tensor = Self::pyobject_to_tensor(py, pyobj).unwrap();
1389                        // We create a defensive copy here to prevent mutations on
1390                        // the input tensor from affecting output tensor.
1391                        // Should we copy if input ref == output ref?
1392                        // Should we support copy-on-write to avoid unnecessary copy?
1393                        let borrow = input_tensor.try_borrow().unwrap();
1394                        let cloned = deep_clone(&borrow);
1395                        let cloned_cell = TensorCell::new(cloned);
1396                        Ok(Self::tensor_to_pyobject(cloned_cell))
1397                    })
1398                }
1399                Err(err) => Err(err.clone()),
1400            };
1401            self.env.insert(result, output_cell);
1402            return Ok(());
1403        }
1404
1405        let mut messages = Vec::new();
1406
1407        if let Some(to_rank) = to_rank {
1408            let input_cell = self.get_or_fake_on_err(tensor, &factory)?;
1409            messages.push(CommMessage::Send(
1410                input_cell,
1411                to_rank.try_into().unwrap(),
1412                self.cuda_stream()
1413                    .expect("tried to send_tensor on non-cuda stream")
1414                    .clone(),
1415                cx.open_once_port().0,
1416            ));
1417        }
1418
1419        if let Some(from_rank) = from_rank {
1420            let output_cell = TensorCell::new(factory_empty(
1421                &factory.size,
1422                factory.dtype,
1423                factory.layout,
1424                factory.device,
1425            ));
1426            messages.push(CommMessage::Recv(
1427                output_cell.clone(),
1428                from_rank.try_into().unwrap(),
1429                self.cuda_stream()
1430                    .expect("tried to send_tensor on non-cuda stream")
1431                    .clone(),
1432                cx.open_once_port().0,
1433            ));
1434            let pyobj = Self::tensor_to_pyobject(output_cell);
1435            self.env.insert(result, Ok(pyobj));
1436        }
1437
1438        comm.group(
1439            cx,
1440            messages,
1441            self.cuda_stream()
1442                .expect("tried to send_tensor on non-cuda stream")
1443                .clone(),
1444        )
1445        .await?;
1446        Ok(())
1447    }
1448
1449    async fn send_value(
1450        &mut self,
1451        cx: &Context<Self>,
1452        seq: Seq,
1453        worker_actor_id: ActorId,
1454        mutates: Vec<Ref>,
1455        function: Option<ResolvableFunction>,
1456        args_kwargs: ArgsKwargs,
1457        device_meshes: HashMap<Ref, DeviceMesh>,
1458    ) -> Result<()> {
1459        if self.respond_with_python_message {
1460            return self
1461                .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes)
1462                .await;
1463        }
1464
1465        let result = if let Some(function) = function {
1466            // If a function was provided, use that to resolve the value.
1467            tokio::task::block_in_place(|| {
1468                self.call_python_fn_pytree(
1469                    cx,
1470                    function,
1471                    args_kwargs,
1472                    &mutates,
1473                    device_meshes,
1474                    HashMap::new(),
1475                )
1476            })
1477        } else {
1478            // If there's no function provided, there should be exactly one arg
1479            // and no kwargs.
1480            Python::with_gil(|py| {
1481                let (args, kwargs) = args_kwargs
1482                    .to_python(py)
1483                    .map_err(|e| CallFunctionError::Error(e.into()))?;
1484                match (args.len(), kwargs.len()) {
1485                    (1, 0) => {
1486                        let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?;
1487                        arg.extract::<PyTree<PyObject>>()
1488                            .map_err(SerializablePyErr::from_fn(py))?
1489                            .try_into_map(|obj| {
1490                                let bound_obj = obj.bind(py);
1491                                if let Ok(ref_) = Ref::from_py_object(bound_obj) {
1492                                    self.ref_to_pyobject(&ref_)
1493                                } else {
1494                                    Ok(obj)
1495                                }
1496                            })
1497                    }
1498                    _ => Err(CallFunctionError::TooManyArgsForValue(
1499                        format!("args with {} elements", args.len()),
1500                        format!("kwargs with {} elements", kwargs.len()),
1501                    )),
1502                }
1503            })
1504        };
1505
1506        let value = match result {
1507            Ok(pyobject) => Ok(pyobject),
1508            Err(err) => {
1509                let err = self.report_seq_error(cx, seq, err).await?;
1510                for ref_ in mutates {
1511                    self.env.insert(ref_, Err(err.clone()));
1512                }
1513                Err(WorkerError {
1514                    backtrace: format!("{:?}", err),
1515                    worker_actor_id,
1516                })
1517            }
1518        };
1519
1520        // Actually send the value.
1521        // NOTE: respond_with_python_message is always true, so serialization is not needed
1522        // The controller will receive the value through send_value_python_message instead
1523        let result = match value {
1524            Ok(_value) => {
1525                // This code path is never executed since respond_with_python_message is true
1526                unreachable!(
1527                    "send_value should return early when respond_with_python_message is true"
1528                )
1529            }
1530            Err(e) => Err(e),
1531        };
1532        self.controller_actor.fetch_result(cx, seq, result).await?;
1533
1534        Ok(())
1535    }
1536
1537    async fn send_result_of_actor_call(
1538        &mut self,
1539        cx: &Context<Self>,
1540        params: ActorCallParams,
1541    ) -> anyhow::Result<()> {
1542        let seq = params.seq;
1543        let mutates = params.mutates.clone();
1544        self.try_define(cx, seq, vec![], &mutates, async |self| {
1545            let value = self.call_actor(cx, params).await?;
1546            let result =
1547                Python::with_gil(|py| pickle_python_result(py, value.into_bound(py), self.rank))?;
1548            let result = wirevalue::Any::serialize(&result).unwrap();
1549            self.controller_actor
1550                .fetch_result(cx, seq, Ok(result))
1551                .await?;
1552            Ok(vec![])
1553        })
1554        .await
1555    }
1556
1557    async fn call_actor_method(
1558        &mut self,
1559        cx: &Context<Self>,
1560        params: ActorMethodParams,
1561    ) -> anyhow::Result<()> {
1562        let seq = params.call.seq;
1563        let mutates = params.call.mutates.clone();
1564        self.try_define(cx, seq, params.results, &mutates, async |self| {
1565            let result = self.call_actor(cx, params.call).await?;
1566            let result = Python::with_gil(|py| {
1567                PyTree::<PyObject>::extract_bound(&result.into_bound(py))
1568                    .map_err(SerializablePyErr::from_fn(py))
1569            })?;
1570            Ok(result.into_leaves())
1571        })
1572        .await
1573    }
1574
1575    async fn define_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1576        if self.active_recording.is_some() {
1577            bail!("different recording already active");
1578        }
1579        match self.recordings.entry(recording) {
1580            Entry::Occupied(_) => bail!("recording {:?} already defined", recording),
1581            Entry::Vacant(entry) => entry.insert(Recording::new()),
1582        };
1583        self.active_recording = Some(RecordingState::Defining {
1584            recording,
1585            defined_borrows: HashSet::new(),
1586        });
1587        Ok(())
1588    }
1589
1590    async fn finalize_recording(&mut self, _cx: &Context<Self>, recording: Ref) -> Result<()> {
1591        match self.active_recording {
1592            Some(RecordingState::Defining {
1593                recording: active_recording,
1594                ref defined_borrows,
1595            }) if active_recording == recording => {
1596                ensure!(
1597                    defined_borrows.is_empty(),
1598                    "all borrows created within recording must be dropped within recording"
1599                );
1600                self.active_recording = None;
1601            }
1602            _ => bail!("cannot finalize recording that isn't active"),
1603        }
1604        Ok(())
1605    }
1606
1607    async fn recording_formal(
1608        &mut self,
1609        _cx: &Context<Self>,
1610        result: Ref,
1611        argument_index: usize,
1612    ) -> Result<()> {
1613        match self.get_defining_recording() {
1614            Some((recording, _)) => {
1615                recording.messages.push(StreamMessage::RecordingFormal {
1616                    result,
1617                    argument_index,
1618                });
1619            }
1620            None => bail!("recording_formal called outside of recording"),
1621        };
1622        Ok(())
1623    }
1624
1625    async fn recording_result(
1626        &mut self,
1627        _cx: &Context<Self>,
1628        result: Ref,
1629        output_index: usize,
1630    ) -> Result<()> {
1631        match self.get_defining_recording() {
1632            Some((recording, _)) => {
1633                recording.messages.push(StreamMessage::RecordingResult {
1634                    result,
1635                    output_index,
1636                });
1637            }
1638            None => bail!("recording_result called outside of recording"),
1639        };
1640        Ok(())
1641    }
1642
1643    async fn call_recording(
1644        &mut self,
1645        cx: &Context<Self>,
1646        seq: Seq,
1647        recording: Ref,
1648        results: Vec<Ref>,
1649        actuals: Vec<Ref>,
1650    ) -> Result<()> {
1651        if self.active_recording.is_some() {
1652            bail!("cannot call recording while another recording is active");
1653        }
1654
1655        let messages = match self.recordings.get(&recording) {
1656            Some(recording) => recording
1657                .messages
1658                .iter()
1659                .map(|message| message.clone_for_recording())
1660                .collect::<Vec<_>>(),
1661            None => bail!("recording {:?} not found", recording),
1662        };
1663
1664        self.active_recording = Some(RecordingState::Running);
1665
1666        // Global error for all messages in the recording. The first time a message
1667        // fails in the recording, we set the error. We then need to propagate this
1668        // error to all of the refs mutated by the entire recording, as well as the
1669        // result refs.
1670        let mut error: Option<Arc<SeqError>> = None;
1671        // The set of all refs defined by this recording (excluding "results"),
1672        // which we need to ensure are deleted when the recording is done executing.
1673        let mut all_defined_refs = HashSet::new();
1674        // The set of all refs mutated by this recording. If there is an error with
1675        // any message, all of these refs need to have the correct error set.
1676        let mut all_mutated_refs = HashSet::new();
1677        // Map from the result ref of a RecordingFormal message to the associated
1678        // actual ref from "actuals". We need to track this in order to properly
1679        // handle recordings that mutate refs contained in "actuals" -- every
1680        // message in the recording that interacts with the recording inputs will
1681        // interact with the formal ref rather than the actual ref.
1682        let mut formal_to_actual_refs = HashMap::new();
1683        // clear any pre-existing error messages before recording started
1684        self.last_seq_error = None;
1685        for message in messages.into_iter() {
1686            let defined_refs = message.get_defined_refs();
1687            all_defined_refs.extend(defined_refs.clone());
1688
1689            let mutated_refs_with_formals = message.get_mutated_refs();
1690            all_mutated_refs.extend(mutated_refs_with_formals.iter().filter_map(|ref_| {
1691                match formal_to_actual_refs.get(ref_) {
1692                    Some(actual_ref) => Some(*actual_ref),
1693                    None => {
1694                        if all_defined_refs.contains(ref_) {
1695                            None
1696                        } else {
1697                            Some(*ref_)
1698                        }
1699                    }
1700                }
1701            }));
1702
1703            match message {
1704                StreamMessage::RecordingFormal {
1705                    result: formal_ref,
1706                    argument_index,
1707                } => match actuals.get(argument_index) {
1708                    None => bail!("recording_formal called with too few arguments"),
1709                    Some(actual_ref) => {
1710                        formal_to_actual_refs.insert(formal_ref, *actual_ref);
1711                        self.define_ref(formal_ref, *actual_ref)?;
1712                    }
1713                },
1714                StreamMessage::RecordingResult {
1715                    result: result_ref,
1716                    output_index,
1717                } => match results.get(output_index) {
1718                    None => bail!("recording_result called with too few results"),
1719                    Some(actual_result_ref) => {
1720                        self.define_ref(*actual_result_ref, result_ref)?;
1721                    }
1722                },
1723                StreamMessage::DeleteRefs(ref refs) => {
1724                    for ref_ in refs {
1725                        all_defined_refs.remove(ref_);
1726                    }
1727                    StreamMessageHandler::handle(self, cx, message).await?;
1728                }
1729                StreamMessage::CallFunction { .. } if error.is_some() => {
1730                    // CallFunction is expensive. If the recording already failed, then
1731                    // just update the necessary refs with the error. Most of the other
1732                    // message types need to run regardless because there are other actors
1733                    // that expect the call to happen (e.g., all of the borrow messages,
1734                    // pipe send/recv, send_tensor, reduce, etc.).
1735                    let error = error.clone().unwrap();
1736                    for ref_ in defined_refs.iter().chain(mutated_refs_with_formals.iter()) {
1737                        self.env.insert(*ref_, Err(error.clone()));
1738                    }
1739                }
1740                StreamMessage::BorrowLastUse { ref result, .. } => {
1741                    all_defined_refs.remove(result);
1742                    StreamMessageHandler::handle(self, cx, message).await?;
1743                }
1744                StreamMessage::Reduce {
1745                    local_tensor,
1746                    ref out,
1747                    ..
1748                } => {
1749                    // Reduce doesn't propagate errors to the result ref, so we need
1750                    // to check for existing errors on the input tensors and set the
1751                    // recording's error if necessary.
1752                    if error.is_none() {
1753                        let inputs_to_check = [Some(local_tensor), out.clone()]
1754                            .iter()
1755                            .filter_map(|r| *r)
1756                            .collect::<Vec<_>>();
1757                        error = self.get_first_error(inputs_to_check.as_slice())?;
1758                    }
1759                    StreamMessageHandler::handle(self, cx, message).await?;
1760                }
1761                StreamMessage::SendTensor {
1762                    ref tensor,
1763                    ref to_rank,
1764                    ..
1765                } => {
1766                    // If this rank is sending a tensor (e.g., to_rank has a value),
1767                    // we need to check for existing errors on the input tensor, because
1768                    // the error is only propagated to the result ref when this rank
1769                    // is also receiving a tensor.
1770                    if to_rank.is_some() && error.is_none() {
1771                        error = self.get_first_error(&[*tensor])?;
1772                    }
1773                    StreamMessageHandler::handle(self, cx, message).await?;
1774                }
1775                _ => {
1776                    StreamMessageHandler::handle(self, cx, message).await?;
1777                }
1778            };
1779
1780            // It's not entirely trivial to determine whether a message "failed" or not.
1781            // For example, the CallFunction message can return Ok(..) if there is an error
1782            // in the underlying function call. But in that case, we would still want to
1783            // consider the recording call as "failed". Unlike in python, where we can just
1784            // wrap everything in try-except, in rust, we keep track of the last report SeqError, which
1785            // we clear before handling each recording message. If we see it is set, the
1786            // we know the recording has faild.
1787            match (&error, self.last_seq_error.take()) {
1788                (None, Some(seq_err)) => {
1789                    // Report failure to the controller.
1790                    self.controller_actor
1791                        .remote_function_failed(
1792                            cx,
1793                            seq,
1794                            WorkerError {
1795                                backtrace: format!("recording failed: {}", &seq_err),
1796                                worker_actor_id: cx.self_id().clone(),
1797                            },
1798                        )
1799                        .await?;
1800                    error = Some(seq_err)
1801                }
1802                _ => {}
1803            }
1804            // Continue processing the remaining stream messages regardless of error.
1805            // We need to do this partially for error propagation, but also because
1806            // certain messages (like borrows and reductions) need to run regardless
1807            // in order to prevent deadlocks.
1808        }
1809
1810        // Delete the formal refs and some subset of the RecordingResult refs. The
1811        // controller should have generated DeleteRefs messages for all other refs
1812        // defined by the recording.
1813        StreamMessageHandler::handle(
1814            self,
1815            cx,
1816            StreamMessage::DeleteRefs(all_defined_refs.into_iter().collect()),
1817        )
1818        .await?;
1819
1820        // Any refs mutated by the recording and all results should have the same error
1821        // (the original error that caused the recording to fail).
1822        if error.is_some() {
1823            for ref_ in results.iter().chain(all_mutated_refs.iter()) {
1824                self.env.insert(*ref_, Err(error.clone().unwrap()));
1825            }
1826        }
1827
1828        self.active_recording = None;
1829        Ok(())
1830    }
1831
1832    async fn set_ref_unit_tests_only(
1833        &mut self,
1834        _cx: &Context<Self>,
1835        reference: Ref,
1836        value: WireValue,
1837    ) -> Result<()> {
1838        let pyobj =
1839            Python::with_gil(|py| -> PyResult<PyObject> { Ok(value.into_pyobject(py)?.unbind()) })?;
1840        self.env.insert(reference, Ok(pyobj));
1841        Ok(())
1842    }
1843
1844    async fn set_tensor_ref_unit_tests_only(
1845        &mut self,
1846        _cx: &Context<Self>,
1847        reference: Ref,
1848        tensor_result: TensorCellResult,
1849    ) -> Result<()> {
1850        match tensor_result {
1851            Ok(tensor_cell) => {
1852                let pyobj = Self::tensor_to_pyobject(tensor_cell);
1853                self.env.insert(reference, Ok(pyobj));
1854            }
1855            Err(err) => {
1856                self.env.insert(reference, Err(err));
1857            }
1858        }
1859        Ok(())
1860    }
1861
1862    async fn get_ref_unit_tests_only(
1863        &mut self,
1864        _cx: &Context<Self>,
1865        reference: Ref,
1866    ) -> Result<Option<Result<WireValue, String>>> {
1867        use pyo3::types::PyBool;
1868        use pyo3::types::PyFloat;
1869        use pyo3::types::PyInt;
1870        use pyo3::types::PyList;
1871        use pyo3::types::PyNone;
1872        use pyo3::types::PyString;
1873        /// For testing only, doesn't support Tensor or TensorList.
1874        fn pyobject_to_wire(
1875            value: Result<PyObject, Arc<SeqError>>,
1876        ) -> Result<WireValue, Arc<SeqError>> {
1877            let pyobj = value?;
1878            Python::with_gil(|py| {
1879                let bound = pyobj.bind(py);
1880                // Check bool before int since Python's bool is a subclass of int
1881                if bound.is_instance_of::<PyBool>() {
1882                    Ok(WireValue::Bool(bound.extract::<bool>().unwrap()))
1883                } else if bound.is_instance_of::<PyInt>() {
1884                    Ok(WireValue::Int(bound.extract::<i64>().unwrap()))
1885                } else if bound.is_instance_of::<PyList>() {
1886                    if let Ok(val) = bound.extract::<Vec<i64>>() {
1887                        Ok(WireValue::IntList(val))
1888                    } else {
1889                        Ok(WireValue::String(format!(
1890                            "unsupported list type: {:?}",
1891                            bound
1892                        )))
1893                    }
1894                } else if bound.is_instance_of::<PyFloat>() {
1895                    Ok(WireValue::Double(bound.extract::<f64>().unwrap()))
1896                } else if bound.is_instance_of::<PyString>() {
1897                    Ok(WireValue::String(bound.extract::<String>().unwrap()))
1898                } else if bound.is_instance_of::<PyNone>() {
1899                    Ok(WireValue::None(()))
1900                } else {
1901                    Ok(WireValue::String(format!(
1902                        "unsupported pyobject type: {:?}",
1903                        bound
1904                    )))
1905                }
1906            })
1907        }
1908        Ok(self.env.get(&reference).map(|pyobj| {
1909            pyobject_to_wire(Python::with_gil(|_py| pyobj.clone())).map_err(|err| err.to_string())
1910        }))
1911    }
1912
1913    async fn get_tensor_ref_unit_tests_only(
1914        &mut self,
1915        _cx: &Context<Self>,
1916        reference: Ref,
1917    ) -> Result<Option<TensorCellResult>> {
1918        match self.env.get(&reference) {
1919            Some(Ok(pyobj)) => Python::with_gil(|py| match Self::pyobject_to_tensor(py, pyobj) {
1920                Ok(tensor) => Ok(Some(Ok(tensor.try_cpu().unwrap()))),
1921                Err(e) => bail!("expected tensor, got extraction error: {:?}", e),
1922            }),
1923            Some(Err(err)) => Ok(Some(Err(err.clone()))),
1924            None => Ok(None),
1925        }
1926    }
1927}
1928
1929#[cfg(test)]
1930mod tests {
1931    use hyperactor::actor::ActorStatus;
1932    use hyperactor::context;
1933    use hyperactor::supervision::ActorSupervisionEvent;
1934    use monarch_messages::controller::ControllerMessage;
1935    use monarch_messages::worker::StreamCreationMode;
1936    use monarch_types::PickledPyObject;
1937    use pyo3::IntoPyObjectExt;
1938    use timed_test::async_timed_test;
1939    use torch_sys_cuda::nccl::UniqueId;
1940    use torch_sys2::factory_float_tensor;
1941    use torch_sys2::testing::allclose;
1942
1943    use super::*;
1944    use crate::comm::CommParams;
1945    use crate::test_util;
1946
1947    #[allow(dead_code)]
1948    fn fake_seq_error(err: anyhow::Error) -> Arc<SeqError> {
1949        Arc::new(SeqError {
1950            seq: 0.into(),
1951            error: err,
1952        })
1953    }
1954
1955    struct TestSetup {
1956        proc: Proc,
1957        stream_actor: ActorHandle<StreamActor>,
1958        client: Instance<()>,
1959        // Unused, but necessary, because proc needs a supervision
1960        // port -- otherwise an actor failure will cause a crash.
1961        #[allow(dead_code)]
1962        supervision_rx: PortReceiver<ActorSupervisionEvent>,
1963        #[allow(dead_code)]
1964        controller_rx: PortReceiver<ControllerMessage>,
1965        #[allow(dead_code)]
1966        controller_actor: ActorRef<ControllerActor>,
1967        next_ref: Ref,
1968    }
1969
1970    impl TestSetup {
1971        async fn new() -> Result<Self> {
1972            Self::new_with_world_size(1).await
1973        }
1974
1975        async fn new_with_world_size(world_size: usize) -> Result<Self> {
1976            test_util::test_setup()?;
1977
1978            let proc = Proc::local();
1979            let (_, controller_actor, controller_rx) =
1980                proc.attach_actor::<ControllerActor, ControllerMessage>("controller")?;
1981            let (client, _handle) = proc.instance("client")?;
1982            let (supervision_tx, supervision_rx) = client.open_port();
1983            proc.set_supervision_coordinator(supervision_tx)?;
1984            let stream_actor = proc.spawn(
1985                "stream",
1986                StreamActor::new(StreamParams {
1987                    world_size,
1988                    rank: 0,
1989                    creation_mode: StreamCreationMode::UseDefaultStream,
1990                    id: 0.into(),
1991                    device: Some(CudaDevice::new(0.into())),
1992                    controller_actor: controller_actor.clone(),
1993                    respond_with_python_message: false,
1994                }),
1995            )?;
1996
1997            Ok(Self {
1998                proc,
1999                stream_actor,
2000                client,
2001                supervision_rx,
2002                controller_rx,
2003                controller_actor,
2004                next_ref: 0.into(),
2005            })
2006        }
2007
2008        fn next_ref(&mut self) -> Ref {
2009            let ref_ = self.next_ref;
2010            self.next_ref = Ref {
2011                id: self.next_ref.id + 1,
2012            };
2013            ref_
2014        }
2015
2016        async fn set_tensor(&mut self, reference: Ref, data: &[f32]) -> Result<()> {
2017            let tensor = TensorCell::new(factory_float_tensor(data, "cuda".parse().unwrap()));
2018            self.stream_actor
2019                .set_tensor_ref_unit_tests_only(&self.client, reference, Ok(tensor))
2020                .await
2021        }
2022
2023        async fn allclose(&mut self, reference: Ref, data: &[f32]) -> bool {
2024            let actual = self
2025                .stream_actor
2026                .get_tensor_ref_unit_tests_only(&self.client, reference)
2027                .await
2028                .unwrap()
2029                .unwrap()
2030                .unwrap();
2031
2032            // rustfmt-ignore
2033            allclose(
2034                &factory_float_tensor(data, "cpu".parse().unwrap()),
2035                &actual.borrow(),
2036            )
2037            .unwrap()
2038        }
2039
2040        #[allow(dead_code)]
2041        async fn validate_dependent_error(&mut self, reference: Ref, error: Arc<SeqError>) {
2042            let result_error = self
2043                .stream_actor
2044                .get_tensor_ref_unit_tests_only(&self.client, reference)
2045                .await
2046                .unwrap()
2047                .unwrap()
2048                .unwrap_err();
2049
2050            assert!(Arc::ptr_eq(&result_error, &error));
2051        }
2052    }
2053
2054    async fn assert_actor_failed_with_msg(proc: &Proc, actor_id: &ActorId, expected_msg: String) {
2055        loop {
2056            let status = proc
2057                .ledger_snapshot()
2058                .roots
2059                .get(actor_id)
2060                .unwrap()
2061                .status
2062                .clone();
2063            if let ActorStatus::Failed(msg) = status {
2064                assert!(msg.to_string().contains(&expected_msg));
2065                break;
2066            } else {
2067                tokio::task::yield_now().await;
2068            }
2069        }
2070    }
2071
2072    async fn assert_refs_do_not_exist(test_setup: &TestSetup, refs: &[Ref]) {
2073        for ref_ in refs {
2074            assert!(
2075                test_setup
2076                    .stream_actor
2077                    .get_tensor_ref_unit_tests_only(&test_setup.client, *ref_)
2078                    .await
2079                    .unwrap()
2080                    .is_none()
2081            );
2082        }
2083    }
2084
2085    #[allow(dead_code)]
2086    async fn fetch_result(
2087        cx: &impl context::Actor,
2088        stream_actor: ActorHandle<StreamActor>,
2089        seq: Seq,
2090        reference: Ref,
2091    ) {
2092        let ref_to_send = Python::with_gil(|py| {
2093            PickledPyObject::pickle(&reference.into_bound_py_any(py).unwrap()).unwrap()
2094        });
2095
2096        stream_actor
2097            .send_value(
2098                cx,
2099                seq,
2100                stream_actor.actor_id().clone(),
2101                Vec::new(),
2102                None,
2103                ArgsKwargs::from_wire_values(
2104                    vec![WireValue::PyObject(ref_to_send)],
2105                    HashMap::new(),
2106                )
2107                .unwrap(),
2108                HashMap::new(),
2109            )
2110            .await
2111            .unwrap()
2112    }
2113
2114    #[allow(dead_code)]
2115    async fn check_fetch_result_error(
2116        cx: &impl context::Actor,
2117        stream_actor: ActorHandle<StreamActor>,
2118        seq: Seq,
2119        reference: Ref,
2120        controller_rx: &mut PortReceiver<ControllerMessage>,
2121        expected_backtrace: &str,
2122    ) {
2123        fetch_result(cx, stream_actor, seq, reference).await;
2124
2125        let controller_msg = controller_rx.recv().await.unwrap();
2126        match controller_msg {
2127            ControllerMessage::FetchResult {
2128                seq: actual_seq,
2129                value: Err(err),
2130            } => {
2131                assert_eq!(actual_seq, seq);
2132                assert!(
2133                    err.backtrace.contains(expected_backtrace),
2134                    "backtrace did not contain {:?}: {:?}",
2135                    expected_backtrace,
2136                    err.backtrace
2137                );
2138            }
2139            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2140        };
2141    }
2142
2143    #[allow(dead_code)]
2144    async fn check_fetch_result_value(
2145        cx: &impl context::Actor,
2146        stream_actor: ActorHandle<StreamActor>,
2147        seq: Seq,
2148        reference: Ref,
2149        controller_rx: &mut PortReceiver<ControllerMessage>,
2150    ) {
2151        fetch_result(cx, stream_actor, seq, reference).await;
2152
2153        let controller_msg = controller_rx.recv().await.unwrap();
2154        match controller_msg {
2155            ControllerMessage::FetchResult {
2156                value: Ok(_),
2157                seq: actual_seq,
2158            } => assert_eq!(seq, actual_seq),
2159            _ => panic!("Unexpected controller message: {:?}", controller_msg),
2160        };
2161    }
2162
2163    #[async_timed_test(timeout_secs = 60)]
2164    async fn test_define_recording_other_recording_active() -> Result<()> {
2165        let test_setup = TestSetup::new().await?;
2166        test_setup
2167            .stream_actor
2168            .define_recording(&test_setup.client, 0.into())
2169            .await?;
2170        test_setup
2171            .stream_actor
2172            .define_recording(&test_setup.client, 1.into())
2173            .await?;
2174        assert_actor_failed_with_msg(
2175            &test_setup.proc,
2176            test_setup.stream_actor.actor_id(),
2177            "different recording already active".into(),
2178        )
2179        .await;
2180        Ok(())
2181    }
2182
2183    #[async_timed_test(timeout_secs = 60)]
2184    async fn test_define_recording_already_defined() -> Result<()> {
2185        let test_setup = TestSetup::new().await?;
2186        test_setup
2187            .stream_actor
2188            .define_recording(&test_setup.client, 0.into())
2189            .await?;
2190        test_setup
2191            .stream_actor
2192            .finalize_recording(&test_setup.client, 0.into())
2193            .await?;
2194        test_setup
2195            .stream_actor
2196            .define_recording(&test_setup.client, 0.into())
2197            .await?;
2198        assert_actor_failed_with_msg(
2199            &test_setup.proc,
2200            test_setup.stream_actor.actor_id(),
2201            "already defined".into(),
2202        )
2203        .await;
2204        Ok(())
2205    }
2206
2207    #[async_timed_test(timeout_secs = 60)]
2208    async fn test_finalize_recording_other_recording_active() -> Result<()> {
2209        let test_setup = TestSetup::new().await?;
2210        test_setup
2211            .stream_actor
2212            .define_recording(&test_setup.client, 0.into())
2213            .await?;
2214        test_setup
2215            .stream_actor
2216            .finalize_recording(&test_setup.client, 1.into())
2217            .await?;
2218        assert_actor_failed_with_msg(
2219            &test_setup.proc,
2220            test_setup.stream_actor.actor_id(),
2221            "cannot finalize recording that isn't active".into(),
2222        )
2223        .await;
2224        Ok(())
2225    }
2226
2227    #[async_timed_test(timeout_secs = 60)]
2228    async fn test_recording_formal_outside_recording() -> Result<()> {
2229        let test_setup = TestSetup::new().await?;
2230        test_setup
2231            .stream_actor
2232            .recording_formal(&test_setup.client, 0.into(), 0)
2233            .await?;
2234        assert_actor_failed_with_msg(
2235            &test_setup.proc,
2236            test_setup.stream_actor.actor_id(),
2237            "recording_formal called outside of recording".into(),
2238        )
2239        .await;
2240        Ok(())
2241    }
2242
2243    #[async_timed_test(timeout_secs = 60)]
2244    async fn test_recording_result_outside_recording() -> Result<()> {
2245        let test_setup = TestSetup::new().await?;
2246        test_setup
2247            .stream_actor
2248            .recording_result(&test_setup.client, 0.into(), 0)
2249            .await?;
2250        assert_actor_failed_with_msg(
2251            &test_setup.proc,
2252            test_setup.stream_actor.actor_id(),
2253            "recording_result called outside of recording".into(),
2254        )
2255        .await;
2256        Ok(())
2257    }
2258
2259    #[async_timed_test(timeout_secs = 60)]
2260    async fn test_call_recording_other_recording_active() -> Result<()> {
2261        let test_setup = TestSetup::new().await?;
2262        test_setup
2263            .stream_actor
2264            .define_recording(&test_setup.client, 0.into())
2265            .await?;
2266        test_setup
2267            .stream_actor
2268            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2269            .await?;
2270        assert_actor_failed_with_msg(
2271            &test_setup.proc,
2272            test_setup.stream_actor.actor_id(),
2273            "cannot call recording while another recording is active".into(),
2274        )
2275        .await;
2276        Ok(())
2277    }
2278
2279    #[async_timed_test(timeout_secs = 60)]
2280    async fn test_call_recording_not_found() -> Result<()> {
2281        let test_setup = TestSetup::new().await?;
2282        test_setup
2283            .stream_actor
2284            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2285            .await?;
2286        assert_actor_failed_with_msg(
2287            &test_setup.proc,
2288            test_setup.stream_actor.actor_id(),
2289            "not found".into(),
2290        )
2291        .await;
2292        Ok(())
2293    }
2294
2295    #[async_timed_test(timeout_secs = 60)]
2296    async fn test_recording_formal_too_few_arguments() -> Result<()> {
2297        let test_setup = TestSetup::new().await?;
2298
2299        test_setup
2300            .stream_actor
2301            .define_recording(&test_setup.client, 0.into())
2302            .await?;
2303
2304        test_setup
2305            .stream_actor
2306            .recording_formal(&test_setup.client, 1.into(), 0)
2307            .await?;
2308
2309        test_setup
2310            .stream_actor
2311            .finalize_recording(&test_setup.client, 0.into())
2312            .await?;
2313
2314        test_setup
2315            .stream_actor
2316            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2317            .await?;
2318
2319        assert_actor_failed_with_msg(
2320            &test_setup.proc,
2321            test_setup.stream_actor.actor_id(),
2322            "recording_formal called with too few arguments".into(),
2323        )
2324        .await;
2325        Ok(())
2326    }
2327
2328    #[async_timed_test(timeout_secs = 60)]
2329    async fn test_recording_result_too_few_results() -> Result<()> {
2330        let test_setup = TestSetup::new().await?;
2331
2332        test_setup
2333            .stream_actor
2334            .define_recording(&test_setup.client, 0.into())
2335            .await?;
2336
2337        test_setup
2338            .stream_actor
2339            .recording_result(&test_setup.client, 1.into(), 0)
2340            .await?;
2341
2342        test_setup
2343            .stream_actor
2344            .finalize_recording(&test_setup.client, 0.into())
2345            .await?;
2346
2347        test_setup
2348            .stream_actor
2349            .call_recording(&test_setup.client, 0.into(), 0.into(), vec![], vec![])
2350            .await?;
2351
2352        assert_actor_failed_with_msg(
2353            &test_setup.proc,
2354            test_setup.stream_actor.actor_id(),
2355            "recording_result called with too few results".into(),
2356        )
2357        .await;
2358        Ok(())
2359    }
2360
2361    #[async_timed_test(timeout_secs = 60)]
2362    async fn test_basic_call_recording() -> Result<()> {
2363        let mut test_setup = TestSetup::new().await?;
2364
2365        // Define a recording equivalent to:
2366        // def f(x, y):
2367        //   return y, x
2368        test_setup
2369            .stream_actor
2370            .define_recording(&test_setup.client, 0.into())
2371            .await?;
2372
2373        let formal0_ref = 1.into();
2374        let formal0_index = 1;
2375        test_setup
2376            .stream_actor
2377            .recording_formal(&test_setup.client, formal0_ref, formal0_index)
2378            .await?;
2379
2380        let formal1_ref = 2.into();
2381        let formal1_index = 0;
2382        test_setup
2383            .stream_actor
2384            .recording_formal(&test_setup.client, formal1_ref, formal1_index)
2385            .await?;
2386
2387        let result0_ref = formal0_ref;
2388        let result0_index = 0;
2389        test_setup
2390            .stream_actor
2391            .recording_result(&test_setup.client, result0_ref, result0_index)
2392            .await?;
2393
2394        let result1_ref = formal1_ref;
2395        let result1_index = 1;
2396        test_setup
2397            .stream_actor
2398            .recording_result(&test_setup.client, result1_ref, result1_index)
2399            .await?;
2400
2401        test_setup
2402            .stream_actor
2403            .finalize_recording(&test_setup.client, 0.into())
2404            .await?;
2405
2406        let actual0_ref = 3.into();
2407        test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?;
2408
2409        let actual1_ref = 4.into();
2410        test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?;
2411
2412        // Call the recording with valid tensors for the actual inputs,
2413        // and store the results in refs 5 and 6.
2414        let actual_result0_ref = 5.into();
2415        let actual_result1_ref = 6.into();
2416        test_setup
2417            .stream_actor
2418            .call_recording(
2419                &test_setup.client,
2420                0.into(),
2421                0.into(),
2422                vec![actual_result0_ref, actual_result1_ref],
2423                vec![actual0_ref, actual1_ref],
2424            )
2425            .await?;
2426
2427        // Ensure the results are correct.
2428        assert!(test_setup.allclose(actual_result0_ref, &[4.0, 5.0]).await);
2429        assert!(
2430            test_setup
2431                .allclose(actual_result1_ref, &[1.0, 2.0, 3.0])
2432                .await
2433        );
2434
2435        // Ensure the temporary refs associated with the formals/results have
2436        // been deleted.
2437        assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await;
2438        Ok(())
2439    }
2440
2441    #[async_timed_test(timeout_secs = 60)]
2442    async fn test_request_status_in_recording() -> Result<()> {
2443        let test_setup = TestSetup::new().await?;
2444        test_setup
2445            .stream_actor
2446            .define_recording(&test_setup.client, 0.into())
2447            .await?;
2448        test_setup
2449            .stream_actor
2450            .request_status(&test_setup.client)
2451            .await
2452            .expect_err("request_status should have failed");
2453        assert_actor_failed_with_msg(
2454            &test_setup.proc,
2455            test_setup.stream_actor.actor_id(),
2456            "request_status not allowed in recording".into(),
2457        )
2458        .await;
2459        Ok(())
2460    }
2461
2462    #[async_timed_test(timeout_secs = 60)]
2463    async fn test_init_comm_in_recording() -> Result<()> {
2464        let test_setup = TestSetup::new().await?;
2465        test_setup
2466            .stream_actor
2467            .define_recording(&test_setup.client, 0.into())
2468            .await?;
2469
2470        let dummy_comm = test_setup.proc.spawn(
2471            "comm",
2472            NcclCommActor::new(CommParams::New {
2473                device: CudaDevice::new(0.into()),
2474                unique_id: UniqueId::new()?,
2475                world_size: 1,
2476                rank: 0,
2477            })
2478            .await
2479            .unwrap(),
2480        )?;
2481
2482        test_setup
2483            .stream_actor
2484            .init_comm(&test_setup.client, dummy_comm)
2485            .await?;
2486        assert_actor_failed_with_msg(
2487            &test_setup.proc,
2488            test_setup.stream_actor.actor_id(),
2489            "init_comm not allowed in recording".into(),
2490        )
2491        .await;
2492        Ok(())
2493    }
2494
2495    #[async_timed_test(timeout_secs = 60)]
2496    async fn test_borrow_create_duplicate_borrow() -> Result<()> {
2497        let mut test_setup = TestSetup::new().await?;
2498        test_setup
2499            .stream_actor
2500            .define_recording(&test_setup.client, 0.into())
2501            .await?;
2502
2503        let borrow_id = 1;
2504        let tensor_ref = test_setup.next_ref();
2505        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2506
2507        test_setup
2508            .stream_actor
2509            .borrow_create(
2510                &test_setup.client,
2511                borrow_id,
2512                tensor_ref,
2513                first_use_sender.clone(),
2514            )
2515            .await?;
2516
2517        test_setup
2518            .stream_actor
2519            .borrow_create(&test_setup.client, borrow_id, tensor_ref, first_use_sender)
2520            .await?;
2521
2522        assert_actor_failed_with_msg(
2523            &test_setup.proc,
2524            test_setup.stream_actor.actor_id(),
2525            "duplicate borrow create in recording".into(),
2526        )
2527        .await;
2528
2529        Ok(())
2530    }
2531
2532    #[async_timed_test(timeout_secs = 60)]
2533    async fn test_borrow_drop_borrow_not_defined() -> Result<()> {
2534        let test_setup = TestSetup::new().await?;
2535        test_setup
2536            .stream_actor
2537            .define_recording(&test_setup.client, 0.into())
2538            .await?;
2539
2540        let borrow_id = 1;
2541        let (_last_use_sender, last_use_receiver) = test_setup.client.open_port();
2542
2543        test_setup
2544            .stream_actor
2545            .borrow_drop(
2546                &test_setup.client,
2547                borrow_id,
2548                Arc::new(Mutex::new(last_use_receiver)),
2549            )
2550            .await?;
2551
2552        assert_actor_failed_with_msg(
2553            &test_setup.proc,
2554            test_setup.stream_actor.actor_id(),
2555            "borrow drop for borrow not defined in recording".into(),
2556        )
2557        .await;
2558
2559        Ok(())
2560    }
2561
2562    #[async_timed_test(timeout_secs = 60)]
2563    async fn test_borrow_not_dropped_before_finalize() -> Result<()> {
2564        let mut test_setup = TestSetup::new().await?;
2565        test_setup
2566            .stream_actor
2567            .define_recording(&test_setup.client, 0.into())
2568            .await?;
2569
2570        let borrow_id = 1;
2571        let tensor_ref = test_setup.next_ref();
2572        let (first_use_sender, _first_use_receiver) = test_setup.client.open_port();
2573
2574        test_setup
2575            .stream_actor
2576            .borrow_create(
2577                &test_setup.client,
2578                borrow_id,
2579                tensor_ref,
2580                first_use_sender.clone(),
2581            )
2582            .await?;
2583
2584        // Attempt to finalize the recording without dropping the borrow
2585        test_setup
2586            .stream_actor
2587            .finalize_recording(&test_setup.client, 0.into())
2588            .await?;
2589
2590        assert_actor_failed_with_msg(
2591            &test_setup.proc,
2592            test_setup.stream_actor.actor_id(),
2593            "all borrows created within recording must be dropped within recording".into(),
2594        )
2595        .await;
2596
2597        Ok(())
2598    }
2599}