Skip to main content

monarch_hyperactor/
actor_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::future::Future;
10use std::ops::Deref;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::thread;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use futures::future;
18use futures::future::FutureExt;
19use futures::future::Shared;
20use hyperactor::Instance;
21use hyperactor::supervision::ActorSupervisionEvent;
22use hyperactor_mesh::actor_mesh::ActorMesh;
23use hyperactor_mesh::actor_mesh::ActorMeshRef;
24use hyperactor_mesh::sel;
25use monarch_types::py_global;
26use monarch_types::py_module_add_function;
27use ndslice::Region;
28use ndslice::Slice;
29use ndslice::selection::Selection;
30use ndslice::selection::structurally_equal;
31use ndslice::view::Ranked;
32use ndslice::view::RankedSliceable;
33use pyo3::IntoPyObjectExt;
34use pyo3::exceptions::PyNotImplementedError;
35use pyo3::exceptions::PyRuntimeError;
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38use pyo3::types::PyBytes;
39use pyo3::types::PyTuple;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::sync::mpsc::unbounded_channel;
42
43use crate::actor::PythonActor;
44use crate::actor::PythonMessage;
45use crate::actor::PythonMessageKind;
46use crate::context::PyInstance;
47use crate::pickle::PendingMessage;
48use crate::proc::PyActorAddr;
49use crate::pytokio::PyPythonTask;
50use crate::runtime::get_tokio_runtime;
51use crate::runtime::monarch_with_gil;
52use crate::runtime::monarch_with_gil_blocking;
53use crate::shape::PyRegion;
54use crate::supervision::Supervisable;
55use crate::supervision::SupervisionError;
56
57py_global!(
58    is_pending_pickle_allowed,
59    "monarch._src.actor.pickle",
60    "is_pending_pickle_allowed"
61);
62py_global!(_pickle, "monarch._src.actor.actor_mesh", "_pickle");
63
64py_global!(
65    shared_class,
66    "monarch._rust_bindings.monarch_hyperactor.pytokio",
67    "Shared"
68);
69
70/// Trait defining the common interface for actor mesh, mesh ref and actor mesh implementations.
71/// This corresponds to the Python ActorMeshProtocol ABC.
72pub(crate) trait ActorMeshProtocol: Send + Sync {
73    /// Cast a message to actors selected by the given selection using the specified mailbox.
74    fn cast(
75        &self,
76        message: PythonMessage,
77        selection: Selection,
78        instance: &Instance<PythonActor>,
79    ) -> PyResult<()>;
80
81    /// Cast a message, merging caller-supplied envelope headers into
82    /// the outbound request. Implementations that reach the real
83    /// envelope emission site override this to thread `caller_headers`
84    /// through `hyperactor_mesh::ActorMeshRef::cast_with_headers`;
85    /// the default collapses to the non-headers path for impls that
86    /// have no envelope access.
87    fn cast_with_headers(
88        &self,
89        message: PythonMessage,
90        selection: Selection,
91        instance: &Instance<PythonActor>,
92        _caller_headers: hyperactor_config::Flattrs,
93    ) -> PyResult<()> {
94        self.cast(message, selection, instance)
95    }
96
97    /// Cast a pending message (which may contain unresolved async values) to actors.
98    ///
99    /// The default implementation blocks on resolving the message and then calls cast.
100    /// AsyncActorMesh overrides this with an optimized async implementation.
101    fn cast_unresolved(
102        &self,
103        message: PendingMessage,
104        selection: Selection,
105        instance: &Instance<PythonActor>,
106    ) -> PyResult<()> {
107        let message = get_tokio_runtime().block_on(message.resolve())?;
108        self.cast(message, selection, instance)
109    }
110
111    /// Async counterpart of `cast_with_headers`. The default
112    /// resolves the pending message synchronously and delegates;
113    /// `AsyncActorMesh` overrides this to resolve asynchronously
114    /// and route through `cast_with_headers`.
115    fn cast_unresolved_with_headers(
116        &self,
117        message: PendingMessage,
118        selection: Selection,
119        instance: &Instance<PythonActor>,
120        caller_headers: hyperactor_config::Flattrs,
121    ) -> PyResult<()> {
122        let message = get_tokio_runtime().block_on(message.resolve())?;
123        self.cast_with_headers(message, selection, instance, caller_headers)
124    }
125
126    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>;
127
128    /// Stop the actor mesh asynchronously.
129    /// Default implementation raises NotImplementedError for types that don't support stopping.
130    fn stop(&self, _instance: &PyInstance, _reason: String) -> PyResult<PyPythonTask> {
131        Err(PyNotImplementedError::new_err(format!(
132            "stop() is not supported for {}",
133            std::any::type_name::<Self>()
134        )))
135    }
136
137    /// Initialize the actor mesh asynchronously.
138    /// Default implementation returns None (no initialization needed).
139    fn initialized(&self) -> PyResult<PyPythonTask> {
140        PyPythonTask::new(async { Ok(None::<()>) })
141    }
142
143    /// The name of the mesh.
144    fn name(&self) -> PyResult<PyPythonTask>;
145}
146
147pub(crate) trait SupervisableActorMesh: ActorMeshProtocol + Supervisable {
148    fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>>;
149}
150
151/// This just forwards to the rust trait that can implement these bindings
152#[pyclass(
153    name = "PythonActorMesh",
154    module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
155)]
156#[derive(Clone)]
157pub(crate) struct PythonActorMesh {
158    inner: Arc<dyn SupervisableActorMesh>,
159}
160
161impl PythonActorMesh {
162    pub(crate) fn new<F>(f: F, supervised: bool) -> Self
163    where
164        F: Future<Output = PyResult<Box<dyn SupervisableActorMesh>>> + Send + 'static,
165    {
166        let f = async move { Ok(Arc::from(f.await?)) }.boxed().shared();
167        PythonActorMesh {
168            inner: Arc::new(AsyncActorMesh::new_queue(f, supervised)),
169        }
170    }
171
172    pub(crate) fn from_impl(inner: Arc<dyn SupervisableActorMesh>) -> Self {
173        PythonActorMesh { inner }
174    }
175
176    pub(crate) fn get_inner(&self) -> Arc<dyn SupervisableActorMesh> {
177        self.inner.clone()
178    }
179}
180
181pub(crate) fn to_hy_sel(selection: &str) -> PyResult<Selection> {
182    match selection {
183        "choose" => Ok(sel!(?)),
184        "all" => Ok(sel!(*)),
185        _ => Err(PyErr::new::<PyValueError, _>(format!(
186            "Invalid selection: {}",
187            selection
188        ))),
189    }
190}
191
192#[pymethods]
193impl PythonActorMesh {
194    #[tracing::instrument(level = "debug", skip_all)]
195    #[pyo3(name = "cast")]
196    fn py_cast(
197        &self,
198        message: &PythonMessage,
199        selection: &str,
200        instance: &PyInstance,
201    ) -> PyResult<()> {
202        let sel = to_hy_sel(selection)?;
203        self.inner.cast(message.clone(), sel, instance.deref())
204    }
205
206    #[hyperactor::instrument]
207    pub(crate) fn cast_unresolved(
208        &self,
209        message: &mut PendingMessage,
210        selection: &str,
211        instance: &PyInstance,
212    ) -> PyResult<()> {
213        let sel = to_hy_sel(selection)?;
214        let message = message.take()?;
215        self.inner.cast_unresolved(message, sel, instance)
216    }
217
218    fn new_with_region(&self, region: &PyRegion) -> PyResult<PythonActorMesh> {
219        let inner = self.inner.new_with_region(region)?;
220        Ok(PythonActorMesh {
221            inner: Arc::from(inner),
222        })
223    }
224
225    fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
226        self.inner.stop(instance, reason)
227    }
228
229    fn initialized(&self) -> PyResult<PyPythonTask> {
230        self.inner.initialized()
231    }
232
233    fn name(&self) -> PyResult<PyPythonTask> {
234        self.inner.name()
235    }
236
237    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
238        self.inner.__reduce__(py)
239    }
240}
241
242#[derive(Debug)]
243pub(crate) struct ClonePyErr {
244    inner: PyErr,
245}
246
247impl From<ClonePyErr> for PyErr {
248    fn from(value: ClonePyErr) -> PyErr {
249        value.inner
250    }
251}
252impl From<PyErr> for ClonePyErr {
253    fn from(inner: PyErr) -> ClonePyErr {
254        ClonePyErr { inner }
255    }
256}
257
258impl Clone for ClonePyErr {
259    fn clone(&self) -> Self {
260        monarch_with_gil_blocking(|py| self.inner.clone_ref(py).into())
261    }
262}
263
264type ActorMeshResult = Result<Arc<dyn SupervisableActorMesh>, ClonePyErr>;
265type ActorMeshFut = Shared<Pin<Box<dyn Future<Output = ActorMeshResult> + Send + 'static>>>;
266
267pub(crate) struct AsyncActorMesh {
268    mesh: ActorMeshFut,
269    queue: UnboundedSender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
270    supervised: bool,
271}
272
273impl AsyncActorMesh {
274    pub(crate) fn new_queue(f: ActorMeshFut, supervised: bool) -> AsyncActorMesh {
275        let (queue, mut recv) = unbounded_channel();
276
277        get_tokio_runtime().spawn(async move {
278            loop {
279                let r = recv.recv().await;
280                if let Some(r) = r {
281                    r.await;
282                } else {
283                    return;
284                }
285            }
286        });
287
288        let mesh = AsyncActorMesh::new(queue, supervised, f);
289        // Eagerly trigger the mesh initialization by pushing an init task onto
290        // the queue. This ensures actors are spawned immediately rather than
291        // waiting for the first endpoint call, which is critical for:
292        // 1. Tests/code that wait for supervision events from actor __init__
293        //    failures without making any endpoint calls.
294        // 2. Ensuring all meshes on a proc are spawned before any errors occur,
295        //    preventing spawn rejections due to stale supervision events.
296        let f = mesh.mesh.clone();
297        mesh.push(async move {
298            let _ = f.await;
299        });
300        mesh
301    }
302
303    fn new(
304        queue: UnboundedSender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
305        supervised: bool,
306        f: ActorMeshFut,
307    ) -> AsyncActorMesh {
308        AsyncActorMesh {
309            mesh: f,
310            queue,
311            supervised,
312        }
313    }
314
315    fn push<F>(&self, f: F)
316    where
317        F: Future<Output = ()> + Send + 'static,
318    {
319        self.queue.send(f.boxed()).unwrap();
320    }
321
322    pub(crate) fn from_impl(mesh: Arc<dyn SupervisableActorMesh>) -> Self {
323        let fut = future::ready(Ok::<Arc<dyn SupervisableActorMesh>, ClonePyErr>(mesh))
324            .boxed()
325            .shared();
326        // Poll the future so that its result can be observed without blocking the tokio runtime.
327        let _ = futures::executor::block_on(fut.clone());
328        Self::new_queue(fut, true)
329    }
330}
331
332impl ActorMeshProtocol for AsyncActorMesh {
333    fn cast(
334        &self,
335        _message: PythonMessage,
336        _selection: Selection,
337        _instance: &Instance<PythonActor>,
338    ) -> PyResult<()> {
339        panic!("not implemented")
340    }
341
342    fn cast_unresolved(
343        &self,
344        message: PendingMessage,
345        selection: Selection,
346        instance: &Instance<PythonActor>,
347    ) -> PyResult<()> {
348        self.cast_unresolved_with_headers(
349            message,
350            selection,
351            instance,
352            hyperactor_config::Flattrs::new(),
353        )
354    }
355
356    fn cast_unresolved_with_headers(
357        &self,
358        message: PendingMessage,
359        selection: Selection,
360        instance: &Instance<PythonActor>,
361        caller_headers: hyperactor_config::Flattrs,
362    ) -> PyResult<()> {
363        let mesh = self.mesh.clone();
364        let instance = instance.clone_for_py();
365        let port = match &message.kind {
366            PythonMessageKind::CallMethod { response_port, .. } => response_port.clone(),
367            _ => None,
368        };
369        self.push(async move {
370            let result = async {
371                let resolved = message.resolve().await?;
372                mesh.await?
373                    .cast_with_headers(resolved, selection, &instance, caller_headers)
374            }
375            .await;
376            if let (Some(mut port_ref), Err(pyerr)) = (port, result) {
377                let _ = monarch_with_gil(|py: Python<'_>| {
378                    let exception_str = crate::logging::format_traceback(py, &pyerr);
379                    tracing::error!(
380                        actor_id = instance.self_addr().to_string(),
381                        "error occurred during cast unresolved: {}",
382                        exception_str
383                    );
384
385                    // Endpoint calls create a response port: the
386                    // PortRef is sent to the remote worker (to send
387                    // results back), and collect_valuemesh owns the
388                    // PortReceiver. If mesh.cast() fails here, we try
389                    // to send the exception back to the caller via
390                    // the PortRef ourselves. But a supervision event
391                    // can cause collect_valuemesh to drop the
392                    // PortReceiver (removing the port from the
393                    // mailbox) before we get here. Disable
394                    // return-undeliverable so a delivery failure
395                    // doesn't bounce back and crash the root client.
396                    //
397                    // TODO: Tie the lifetime of this queued work to
398                    // the PortReceiver (e.g. a cancellation token set
399                    // on drop) so we can distinguish
400                    // supervision-caused failures — where the caller
401                    // already knows — from other cast errors where
402                    // the caller actually needs this exception.
403
404                    port_ref.set_return_undeliverable(false);
405
406                    let mut state =
407                        crate::pickle::pickle(py, pyerr.into_value(py).into_any(), false, false)?;
408                    let _ = port_ref.post(
409                        &instance,
410                        PythonMessage::new_from_buf(
411                            PythonMessageKind::Exception { rank: Some(0) },
412                            state.take_inner()?.take_buffer(),
413                        ),
414                    );
415
416                    Ok::<_, PyErr>(())
417                })
418                .await;
419            }
420        });
421        Ok(())
422    }
423
424    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
425        let fut = self.mesh.clone();
426        match fut.peek().cloned() {
427            Some(mesh) => mesh?.__reduce__(py),
428            None => {
429                let shared =
430                    PyPythonTask::new(async move { Ok(PythonActorMesh::from_impl(fut.await?)) })?
431                        .spawn_abortable()?;
432                // Get Shared.block_on as an unbound method
433                let block_on = shared_class(py).getattr("block_on")?;
434                let args = PyTuple::new(py, [shared.into_pyobject(py)?])?;
435                Ok((block_on, args.into_any()))
436            }
437        }
438    }
439
440    fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
441        let mesh = self.mesh.clone();
442        let instance = monarch_with_gil_blocking(|_py| instance.clone());
443        let (tx, rx) = tokio::sync::oneshot::channel();
444        self.push(async move {
445            let result =
446                async move { mesh.await?.stop(&instance, reason)?.take_task()?.await }.await;
447            if tx.send(result).is_err() {
448                panic!("oneshot failed");
449            }
450        });
451        PyPythonTask::new(async move { rx.await.map_err(anyhow::Error::from)? })
452    }
453
454    fn initialized<'py>(&self) -> PyResult<PyPythonTask> {
455        let mesh = self.mesh.clone();
456        PyPythonTask::new(async {
457            mesh.await?;
458            Ok(None::<()>)
459        })
460    }
461
462    fn name(&self) -> PyResult<PyPythonTask> {
463        let mesh = self.mesh.clone();
464        let (tx, rx) = tokio::sync::oneshot::channel();
465        self.push(async move {
466            let result = async move { mesh.await?.name()?.take_task()?.await }.await;
467            if tx.send(result).is_err() {
468                panic!("oneshot failed");
469            }
470        });
471        PyPythonTask::new(async move { rx.await.map_err(anyhow::Error::from)? })
472    }
473}
474
475#[async_trait]
476impl Supervisable for AsyncActorMesh {
477    async fn supervision_event(&self, instance: &Instance<PythonActor>) -> Option<PyErr> {
478        if !self.supervised {
479            return None;
480        }
481        let mesh = self.mesh.clone();
482        match mesh.await {
483            Ok(mesh) => mesh.supervision_event(instance).await,
484            Err(e) => Some(e.into()),
485        }
486    }
487}
488
489impl SupervisableActorMesh for AsyncActorMesh {
490    fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>> {
491        let mesh = self.mesh.clone();
492        let region = region.clone();
493        Ok(Box::new(AsyncActorMesh::new(
494            self.queue.clone(),
495            self.supervised,
496            async move { Ok(Arc::from(mesh.await?.new_with_region(&region)?)) }
497                .boxed()
498                .shared(),
499        )))
500    }
501}
502
503#[derive(Debug, Clone)]
504#[pyclass(
505    name = "PyActorMesh",
506    module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
507)]
508pub(crate) struct PyActorMesh {
509    mesh: ActorMesh<PythonActor>,
510}
511
512#[derive(Debug, Clone)]
513#[pyclass(
514    name = "PyActorMeshRef",
515    module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
516)]
517pub(crate) struct PyActorMeshRef {
518    mesh: ActorMeshRef<PythonActor>,
519}
520
521#[derive(Debug, Clone)]
522#[pyclass(
523    name = "PythonActorMeshImpl",
524    module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
525)]
526#[expect(
527    clippy::large_enum_variant,
528    reason = "PyO3 #[pyclass] enum; Box wrapping interacts with PyO3 codegen and Python interop — separate diff"
529)]
530pub(crate) enum PythonActorMeshImpl {
531    Owned(PyActorMesh),
532    Ref(PyActorMeshRef),
533}
534
535impl PythonActorMeshImpl {
536    /// Get a new owned [`PythonActorMeshImpl`].
537    pub(crate) fn new_owned(inner: ActorMesh<PythonActor>) -> Self {
538        PythonActorMeshImpl::Owned(PyActorMesh { mesh: inner })
539    }
540
541    /// Get a new ref-based [`PythonActorMeshImpl`].
542    pub(crate) fn new_ref(inner: ActorMeshRef<PythonActor>) -> Self {
543        PythonActorMeshImpl::Ref(PyActorMeshRef { mesh: inner })
544    }
545
546    fn mesh_ref(&self) -> &ActorMeshRef<PythonActor> {
547        match self {
548            PythonActorMeshImpl::Owned(inner) => &inner.mesh,
549            PythonActorMeshImpl::Ref(inner) => &inner.mesh,
550        }
551    }
552}
553
554#[async_trait]
555impl Supervisable for PythonActorMeshImpl {
556    async fn supervision_event(&self, instance: &Instance<PythonActor>) -> Option<PyErr> {
557        let mesh = self.mesh_ref();
558        match mesh.next_supervision_event(instance).await {
559            Ok(supervision_failure) => Some(SupervisionError::new_err_from(supervision_failure)),
560            Err(e) => Some(PyValueError::new_err(e.to_string())),
561        }
562    }
563}
564
565impl ActorMeshProtocol for PythonActorMeshImpl {
566    fn cast(
567        &self,
568        message: PythonMessage,
569        selection: Selection,
570        instance: &Instance<PythonActor>,
571    ) -> PyResult<()> {
572        <ActorMeshRef<PythonActor> as ActorMeshProtocol>::cast(
573            self.mesh_ref(),
574            message,
575            selection,
576            instance,
577        )
578    }
579
580    fn cast_with_headers(
581        &self,
582        message: PythonMessage,
583        selection: Selection,
584        instance: &Instance<PythonActor>,
585        caller_headers: hyperactor_config::Flattrs,
586    ) -> PyResult<()> {
587        <ActorMeshRef<PythonActor> as ActorMeshProtocol>::cast_with_headers(
588            self.mesh_ref(),
589            message,
590            selection,
591            instance,
592            caller_headers,
593        )
594    }
595
596    fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
597        let (slf, instance) = monarch_with_gil_blocking(|_py| (self.clone(), instance.clone()));
598        match slf {
599            PythonActorMeshImpl::Owned(mut mesh) => PyPythonTask::new(async move {
600                mesh.mesh
601                    .stop(instance.deref(), reason)
602                    .await
603                    .map_err(|err| PyValueError::new_err(err.to_string()))
604            }),
605            PythonActorMeshImpl::Ref(_) => Err(PyNotImplementedError::new_err(
606                "Cannot call stop on an ActorMeshRef, requires an owned ActorMesh",
607            )),
608        }
609    }
610
611    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
612        self.mesh_ref().__reduce__(py)
613    }
614
615    fn name(&self) -> PyResult<PyPythonTask> {
616        let name = self.mesh_ref().id().to_string();
617        PyPythonTask::new(async move { Ok(name) })
618    }
619}
620
621impl SupervisableActorMesh for PythonActorMeshImpl {
622    fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>> {
623        assert!(region.as_inner().is_subset(self.mesh_ref().region()));
624        Ok(Box::new(PythonActorMeshImpl::new_ref(
625            self.mesh_ref().sliced(region.as_inner().clone()),
626        )))
627    }
628}
629
630// Convert a hyperactor_mesh::Error to a Python exception. hyperactor_mesh::Error::Supervision becomes a SupervisionError,
631// all others become a RuntimeError.
632fn cast_error_to_py_error(err: hyperactor_mesh::Error) -> PyErr {
633    if let hyperactor_mesh::Error::Supervision(failure) = err {
634        SupervisionError::new_err_from(*failure)
635    } else {
636        PyRuntimeError::new_err(err.to_string())
637    }
638}
639
640impl ActorMeshProtocol for ActorMeshRef<PythonActor> {
641    fn cast(
642        &self,
643        message: PythonMessage,
644        selection: Selection,
645        instance: &Instance<PythonActor>,
646    ) -> PyResult<()> {
647        <Self as ActorMeshProtocol>::cast_with_headers(
648            self,
649            message,
650            selection,
651            instance,
652            hyperactor_config::Flattrs::new(),
653        )
654    }
655
656    fn cast_with_headers(
657        &self,
658        message: PythonMessage,
659        selection: Selection,
660        instance: &Instance<PythonActor>,
661        caller_headers: hyperactor_config::Flattrs,
662    ) -> PyResult<()> {
663        if structurally_equal(&selection, &Selection::All(Box::new(Selection::True))) {
664            ActorMeshRef::<PythonActor>::cast_with_headers(
665                self,
666                instance,
667                &caller_headers,
668                message.clone(),
669            )
670            .map_err(cast_error_to_py_error)?;
671        } else if structurally_equal(&selection, &Selection::Any(Box::new(Selection::True))) {
672            let region = Ranked::region(self);
673            let random_rank = fastrand::usize(0..region.num_ranks());
674            let offset = region
675                .slice()
676                .get(random_rank)
677                .map_err(anyhow::Error::from)?;
678            let singleton_region = Region::new(
679                Vec::new(),
680                Slice::new(offset, Vec::new(), Vec::new()).map_err(anyhow::Error::from)?,
681            );
682            ActorMeshRef::<PythonActor>::cast_with_headers(
683                &self.sliced(singleton_region),
684                instance,
685                &caller_headers,
686                message.clone(),
687            )
688            .map_err(cast_error_to_py_error)?;
689        } else {
690            return Err(PyRuntimeError::new_err(format!(
691                "invalid selection: {:?}",
692                selection
693            )));
694        }
695
696        Ok(())
697    }
698
699    /// Stop the actor mesh asynchronously.
700    fn stop(&self, _instance: &PyInstance, _reason: String) -> PyResult<PyPythonTask> {
701        Err(PyNotImplementedError::new_err(
702            "This cannot be used on ActorMeshRef, only on owned ActorMesh",
703        ))
704    }
705
706    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
707        let bytes = bincode::serde::encode_to_vec(self, bincode::config::legacy())
708            .map_err(|e| PyValueError::new_err(e.to_string()))?;
709        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
710        let module = py
711            .import("monarch._rust_bindings.monarch_hyperactor.actor_mesh")
712            .unwrap();
713        let from_bytes = module.getattr("py_actor_mesh_from_bytes").unwrap();
714        Ok((from_bytes, py_bytes))
715    }
716
717    fn name(&self) -> PyResult<PyPythonTask> {
718        let name = self.id().to_string();
719        PyPythonTask::new(async move { Ok(name) })
720    }
721}
722
723#[pymethods]
724impl PythonActorMeshImpl {
725    fn get(&self, rank: usize) -> PyResult<Option<PyActorAddr>> {
726        Ok(self
727            .mesh_ref()
728            .get(rank)
729            .map(|r| hyperactor::ActorRef::into_actor_addr(r.clone()))
730            .map(PyActorAddr::from))
731    }
732
733    fn __repr__(&self) -> String {
734        format!("PythonActorMeshImpl({:?})", self.mesh_ref())
735    }
736}
737
738#[pyfunction]
739fn py_actor_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PythonActorMesh> {
740    let r: PyResult<ActorMeshRef<PythonActor>> =
741        bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
742            .map(|(v, _)| v)
743            .map_err(|e| PyValueError::new_err(e.to_string()));
744    r.map(|r| AsyncActorMesh::from_impl(Arc::new(PythonActorMeshImpl::new_ref(r))))
745        .map(|r| PythonActorMesh::from_impl(Arc::from(r)))
746}
747
748#[pyclass(
749    name = "ActorSupervisionEvent",
750    module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
751)]
752#[derive(Debug)]
753pub struct PyActorSupervisionEvent {
754    inner: ActorSupervisionEvent,
755}
756
757#[pymethods]
758impl PyActorSupervisionEvent {
759    pub(crate) fn __repr__(&self) -> PyResult<String> {
760        Ok(format!("<PyActorSupervisionEvent: {}>", self.inner))
761    }
762
763    #[getter]
764    pub(crate) fn actor_id(&self) -> PyResult<PyActorAddr> {
765        Ok(PyActorAddr::from(self.inner.actor_id.clone()))
766    }
767
768    #[getter]
769    pub(crate) fn actor_status(&self) -> PyResult<String> {
770        Ok(self.inner.actor_status.to_string())
771    }
772}
773
774impl From<ActorSupervisionEvent> for PyActorSupervisionEvent {
775    fn from(event: ActorSupervisionEvent) -> Self {
776        PyActorSupervisionEvent { inner: event }
777    }
778}
779
780#[pyfunction]
781fn py_identity(obj: Py<PyAny>) -> PyResult<Py<PyAny>> {
782    Ok(obj)
783}
784
785/// Holds the GIL for the specified number of seconds without releasing it.
786///
787/// This is a test utility function that spawns a background thread which
788/// acquires the GIL using Rust's Python::attach and holds it for the
789/// specified duration using thread::sleep. Unlike Python code which
790/// periodically releases the GIL, this function holds it continuously.
791///
792/// We intentionally use `std::thread::sleep` here (not `Clock::sleep` or async sleep)
793/// because the purpose is to simulate a blocking operation that holds the GIL without
794/// releasing it. Using an async sleep would release the GIL periodically, defeating
795/// the purpose of this test utility.
796///
797/// Args:
798///     delay_secs: Seconds to wait before acquiring the GIL
799///     hold_secs: Seconds to hold the GIL
800#[pyfunction]
801#[pyo3(name = "hold_gil_for_test", signature = (delay_secs, hold_secs))]
802pub fn hold_gil_for_test(delay_secs: f64, hold_secs: f64) {
803    thread::spawn(move || {
804        // Wait before grabbing the GIL (blocking sleep is fine here, we're in a spawned thread)
805        thread::sleep(Duration::from_secs_f64(delay_secs));
806        // Acquire and hold the GIL - MUST use blocking sleep to keep GIL held
807        Python::attach(|_py| {
808            tracing::info!("start holding the gil...");
809            thread::sleep(Duration::from_secs_f64(hold_secs));
810            tracing::info!("end holding the gil...");
811        });
812    });
813}
814
815pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
816    py_module_add_function!(
817        hyperactor_mod,
818        "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
819        py_identity
820    );
821    py_module_add_function!(
822        hyperactor_mod,
823        "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
824        py_actor_mesh_from_bytes
825    );
826    py_module_add_function!(
827        hyperactor_mod,
828        "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
829        hold_gil_for_test
830    );
831    hyperactor_mod.add_class::<PythonActorMesh>()?;
832    hyperactor_mod.add_class::<PythonActorMeshImpl>()?;
833    hyperactor_mod.add_class::<PyActorSupervisionEvent>()?;
834    Ok(())
835}
836
837#[cfg(test)]
838mod tests {
839    use std::sync::OnceLock;
840    use std::time::Duration;
841
842    use async_trait::async_trait;
843    use hyperactor::Actor;
844    use hyperactor::Context;
845    use hyperactor::Endpoint as _;
846    use hyperactor::Handler;
847    use hyperactor::Instance;
848    use hyperactor::Proc;
849    use hyperactor::actor::Signal;
850    use hyperactor::channel::ChannelTransport;
851    use hyperactor::mailbox;
852    use hyperactor::mailbox::PortReceiver;
853    use hyperactor::proc::WorkCell;
854    use hyperactor::supervision::ActorSupervisionEvent;
855    use hyperactor_mesh::host_mesh::HostMesh;
856    use hyperactor_mesh::mesh_controller::GetSubscriberCount;
857    use hyperactor_mesh::supervision::MeshFailure;
858    use monarch_types::PickledPyObject;
859    use ndslice::extent;
860    use pyo3::Python;
861    use tokio::sync::mpsc;
862
863    use super::*;
864    use crate::actor::PythonActor;
865    use crate::actor::PythonActorParams;
866
867    /// Minimal root-client actor for test infrastructure.
868    /// Handles MeshFailure by panicking (test failure).
869    #[derive(Debug)]
870    struct TestClient {
871        signal_rx: PortReceiver<Signal>,
872        supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
873        work_rx: mpsc::UnboundedReceiver<WorkCell<Self>>,
874    }
875
876    impl Actor for TestClient {}
877
878    #[async_trait]
879    impl Handler<MeshFailure> for TestClient {
880        async fn handle(
881            &mut self,
882            _cx: &Context<Self>,
883            msg: MeshFailure,
884        ) -> Result<(), anyhow::Error> {
885            panic!("unexpected supervision failure in test: {}", msg);
886        }
887    }
888
889    impl TestClient {
890        fn run(mut self, instance: &'static Instance<Self>) {
891            tokio::spawn(async move {
892                loop {
893                    tokio::select! {
894                        work = self.work_rx.recv() => {
895                            match work {
896                                Some(work) => {
897                                    let _ = work.handle(&mut self, instance).await;
898                                }
899                                None => break,
900                            }
901                        }
902                        _ = self.signal_rx.recv() => {}
903                        Some(event) = self.supervision_rx.recv() => {
904                            let _ = instance
905                                .handle_supervision_event(&mut self, event)
906                                .await;
907                        }
908                    }
909                }
910            });
911        }
912    }
913
914    fn init_test_instance() -> &'static Instance<TestClient> {
915        static INSTANCE: OnceLock<Instance<TestClient>> = OnceLock::new();
916        let proc = Proc::direct(ChannelTransport::Unix.any(), "test_proc".to_string()).unwrap();
917        let ai = proc.actor_instance("test_client").unwrap();
918
919        INSTANCE
920            .set(ai.instance)
921            .map_err(|_| "already initialized")
922            .unwrap();
923        let instance = INSTANCE.get().unwrap();
924
925        TestClient {
926            signal_rx: ai.signal,
927            supervision_rx: ai.supervision,
928            work_rx: ai.work,
929        }
930        .run(instance);
931
932        instance
933    }
934
935    fn test_instance() -> &'static Instance<TestClient> {
936        static INSTANCE: OnceLock<&'static Instance<TestClient>> = OnceLock::new();
937        INSTANCE.get_or_init(init_test_instance)
938    }
939
940    /// Verify that calling `supervision_event` repeatedly through a
941    /// [`PythonActorMesh`] does not increase the subscriber count on the
942    /// controller.  This guards against a regression where each call
943    /// would create a new supervision subscriber.
944    #[tokio::test]
945    async fn test_subscriber_count_stable_across_supervision_calls() {
946        crate::pytokio::ensure_python();
947
948        let instance = test_instance();
949
950        let mut host_mesh = HostMesh::local_in_process().await.unwrap();
951        let proc_mesh = host_mesh
952            .spawn(instance, "test", extent!(replicas = 2), None, None)
953            .await
954            .unwrap();
955
956        // Create a minimal Python class and pickle it so we can spawn
957        // PythonActor instances (mirroring PyProcMesh::spawn_async).
958        // The class must live in __main__'s globals for pickle to find it.
959        let pickled_type = Python::attach(|py| {
960            py.run(c"class MinimalActor: pass", None, None).unwrap();
961
962            PickledPyObject::pickle(
963                &py.import("__main__")
964                    .unwrap()
965                    .getattr("MinimalActor")
966                    .unwrap(),
967            )
968            .unwrap()
969        });
970
971        let actor_mesh = proc_mesh
972            .spawn::<PythonActor, _>(
973                instance,
974                "test_actors",
975                &PythonActorParams::new(pickled_type, None, None),
976            )
977            .await
978            .unwrap();
979
980        let controller = actor_mesh.controller().as_ref().unwrap().clone();
981
982        // Wrap using the production code path from PyProcMesh::spawn_async.
983        let mesh_impl =
984            async move { Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh))) };
985        let python_actor_mesh = PythonActorMesh::new(
986            async move {
987                let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
988                Ok(mesh_impl)
989            },
990            true,
991        );
992
993        // Instance<PythonActor> required by the Supervisable trait
994        // signature. Only used for subscription routing inside
995        // next_supervision_event.
996        let py_ai = Proc::direct(ChannelTransport::Unix.any(), "py_proc".to_string())
997            .unwrap()
998            .actor_instance::<PythonActor>("py_client")
999            .unwrap();
1000        let py_instance = py_ai.instance;
1001
1002        // Query the subscriber count from the controller.
1003        let (port, mut rx) = mailbox::open_port::<usize>(instance);
1004        controller.post(instance, GetSubscriberCount(port.bind()));
1005        let initial_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
1006            .await
1007            .expect("timed out waiting for subscriber count")
1008            .expect("channel closed");
1009        assert_eq!(initial_count, 0, "should have 0 subscribers initially");
1010
1011        // Call supervision_event through the PythonActorMesh multiple
1012        // times, racing against a short timeout each time.  The mesh is
1013        // healthy so no event fires; we just want to trigger the lazy
1014        // subscriber initialization repeatedly.
1015        for _ in 0..5 {
1016            tokio::select! {
1017                _ = python_actor_mesh.inner.supervision_event(&py_instance) => {
1018                    panic!("unexpected supervision event on healthy mesh");
1019                }
1020                _ = tokio::time::sleep(Duration::from_millis(200)) => {}
1021            }
1022        }
1023
1024        // After 5 calls from the same context, there should be exactly 1
1025        // subscriber (created lazily on the first call, reused thereafter).
1026        let (port, mut rx) = mailbox::open_port::<usize>(instance);
1027        controller.post(instance, GetSubscriberCount(port.bind()));
1028        let after_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
1029            .await
1030            .expect("timed out waiting for subscriber count")
1031            .expect("channel closed");
1032        assert_eq!(
1033            after_count, 1,
1034            "subscriber count should be exactly 1, not growing with each call"
1035        );
1036
1037        // Do 5 more calls to confirm it stays stable.
1038        for _ in 0..5 {
1039            tokio::select! {
1040                _ = python_actor_mesh.inner.supervision_event(&py_instance) => {
1041                    panic!("unexpected supervision event on healthy mesh");
1042                }
1043                _ = tokio::time::sleep(Duration::from_millis(200)) => {}
1044            }
1045        }
1046
1047        let (port, mut rx) = mailbox::open_port::<usize>(instance);
1048        controller.post(instance, GetSubscriberCount(port.bind()));
1049        let final_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
1050            .await
1051            .expect("timed out waiting for subscriber count")
1052            .expect("channel closed");
1053        assert_eq!(
1054            final_count, 1,
1055            "subscriber count should still be 1 after repeated calls"
1056        );
1057
1058        let _ = host_mesh.shutdown(instance).await;
1059    }
1060}