1use std::future::Future;
10use std::ops::Deref;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::thread;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use futures::future;
18use futures::future::FutureExt;
19use futures::future::Shared;
20use hyperactor::Instance;
21use hyperactor::reference;
22use hyperactor::supervision::ActorSupervisionEvent;
23use hyperactor_mesh::actor_mesh::ActorMesh;
24use hyperactor_mesh::actor_mesh::ActorMeshRef;
25use hyperactor_mesh::sel;
26use monarch_types::py_global;
27use monarch_types::py_module_add_function;
28use ndslice::Region;
29use ndslice::Slice;
30use ndslice::selection::Selection;
31use ndslice::selection::structurally_equal;
32use ndslice::view::Ranked;
33use ndslice::view::RankedSliceable;
34use pyo3::IntoPyObjectExt;
35use pyo3::exceptions::PyNotImplementedError;
36use pyo3::exceptions::PyRuntimeError;
37use pyo3::exceptions::PyValueError;
38use pyo3::prelude::*;
39use pyo3::types::PyBytes;
40use pyo3::types::PyTuple;
41use tokio::sync::mpsc::UnboundedSender;
42use tokio::sync::mpsc::unbounded_channel;
43
44use crate::actor::PythonActor;
45use crate::actor::PythonMessage;
46use crate::actor::PythonMessageKind;
47use crate::context::PyInstance;
48use crate::pickle::PendingMessage;
49use crate::proc::PyActorId;
50use crate::pytokio::PyPythonTask;
51use crate::runtime::get_tokio_runtime;
52use crate::runtime::monarch_with_gil;
53use crate::runtime::monarch_with_gil_blocking;
54use crate::shape::PyRegion;
55use crate::supervision::Supervisable;
56use crate::supervision::SupervisionError;
57
58py_global!(
59 is_pending_pickle_allowed,
60 "monarch._src.actor.pickle",
61 "is_pending_pickle_allowed"
62);
63py_global!(_pickle, "monarch._src.actor.actor_mesh", "_pickle");
64
65py_global!(
66 shared_class,
67 "monarch._rust_bindings.monarch_hyperactor.pytokio",
68 "Shared"
69);
70
71pub(crate) trait ActorMeshProtocol: Send + Sync {
74 fn cast(
76 &self,
77 message: PythonMessage,
78 selection: Selection,
79 instance: &Instance<PythonActor>,
80 ) -> PyResult<()>;
81
82 fn cast_unresolved(
87 &self,
88 message: PendingMessage,
89 selection: Selection,
90 instance: &Instance<PythonActor>,
91 ) -> PyResult<()> {
92 let message = get_tokio_runtime().block_on(message.resolve())?;
93 self.cast(message, selection, instance)
94 }
95
96 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>;
97
98 fn stop(&self, _instance: &PyInstance, _reason: String) -> PyResult<PyPythonTask> {
101 Err(PyNotImplementedError::new_err(format!(
102 "stop() is not supported for {}",
103 std::any::type_name::<Self>()
104 )))
105 }
106
107 fn initialized(&self) -> PyResult<PyPythonTask> {
110 PyPythonTask::new(async { Ok(None::<()>) })
111 }
112
113 fn name(&self) -> PyResult<PyPythonTask>;
115}
116
117pub(crate) trait SupervisableActorMesh: ActorMeshProtocol + Supervisable {
118 fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>>;
119}
120
121#[pyclass(
123 name = "PythonActorMesh",
124 module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
125)]
126#[derive(Clone)]
127pub(crate) struct PythonActorMesh {
128 inner: Arc<dyn SupervisableActorMesh>,
129}
130
131impl PythonActorMesh {
132 pub(crate) fn new<F>(f: F, supervised: bool) -> Self
133 where
134 F: Future<Output = PyResult<Box<dyn SupervisableActorMesh>>> + Send + 'static,
135 {
136 let f = async move { Ok(Arc::from(f.await?)) }.boxed().shared();
137 PythonActorMesh {
138 inner: Arc::new(AsyncActorMesh::new_queue(f, supervised)),
139 }
140 }
141
142 pub(crate) fn from_impl(inner: Arc<dyn SupervisableActorMesh>) -> Self {
143 PythonActorMesh { inner }
144 }
145
146 pub(crate) fn get_inner(&self) -> Arc<dyn SupervisableActorMesh> {
147 self.inner.clone()
148 }
149}
150
151pub(crate) fn to_hy_sel(selection: &str) -> PyResult<Selection> {
152 match selection {
153 "choose" => Ok(sel!(?)),
154 "all" => Ok(sel!(*)),
155 _ => Err(PyErr::new::<PyValueError, _>(format!(
156 "Invalid selection: {}",
157 selection
158 ))),
159 }
160}
161
162#[pymethods]
163impl PythonActorMesh {
164 #[tracing::instrument(level = "debug", skip_all)]
165 #[pyo3(name = "cast")]
166 fn py_cast(
167 &self,
168 message: &PythonMessage,
169 selection: &str,
170 instance: &PyInstance,
171 ) -> PyResult<()> {
172 let sel = to_hy_sel(selection)?;
173 self.inner.cast(message.clone(), sel, instance.deref())
174 }
175
176 #[hyperactor::instrument]
177 pub(crate) fn cast_unresolved(
178 &self,
179 message: &mut PendingMessage,
180 selection: &str,
181 instance: &PyInstance,
182 ) -> PyResult<()> {
183 let sel = to_hy_sel(selection)?;
184 let message = message.take()?;
185 self.inner.cast_unresolved(message, sel, instance)
186 }
187
188 fn new_with_region(&self, region: &PyRegion) -> PyResult<PythonActorMesh> {
189 let inner = self.inner.new_with_region(region)?;
190 Ok(PythonActorMesh {
191 inner: Arc::from(inner),
192 })
193 }
194
195 fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
196 self.inner.stop(instance, reason)
197 }
198
199 fn initialized(&self) -> PyResult<PyPythonTask> {
200 self.inner.initialized()
201 }
202
203 fn name(&self) -> PyResult<PyPythonTask> {
204 self.inner.name()
205 }
206
207 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
208 self.inner.__reduce__(py)
209 }
210}
211
212#[derive(Debug)]
213pub(crate) struct ClonePyErr {
214 inner: PyErr,
215}
216
217impl From<ClonePyErr> for PyErr {
218 fn from(value: ClonePyErr) -> PyErr {
219 value.inner
220 }
221}
222impl From<PyErr> for ClonePyErr {
223 fn from(inner: PyErr) -> ClonePyErr {
224 ClonePyErr { inner }
225 }
226}
227
228impl Clone for ClonePyErr {
229 fn clone(&self) -> Self {
230 monarch_with_gil_blocking(|py| self.inner.clone_ref(py).into())
231 }
232}
233
234type ActorMeshResult = Result<Arc<dyn SupervisableActorMesh>, ClonePyErr>;
235type ActorMeshFut = Shared<Pin<Box<dyn Future<Output = ActorMeshResult> + Send + 'static>>>;
236
237pub(crate) struct AsyncActorMesh {
238 mesh: ActorMeshFut,
239 queue: UnboundedSender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
240 supervised: bool,
241}
242
243impl AsyncActorMesh {
244 pub(crate) fn new_queue(f: ActorMeshFut, supervised: bool) -> AsyncActorMesh {
245 let (queue, mut recv) = unbounded_channel();
246
247 get_tokio_runtime().spawn(async move {
248 loop {
249 let r = recv.recv().await;
250 if let Some(r) = r {
251 r.await;
252 } else {
253 return;
254 }
255 }
256 });
257
258 let mesh = AsyncActorMesh::new(queue, supervised, f);
259 let f = mesh.mesh.clone();
267 mesh.push(async move {
268 let _ = f.await;
269 });
270 mesh
271 }
272
273 fn new(
274 queue: UnboundedSender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
275 supervised: bool,
276 f: ActorMeshFut,
277 ) -> AsyncActorMesh {
278 AsyncActorMesh {
279 mesh: f,
280 queue,
281 supervised,
282 }
283 }
284
285 fn push<F>(&self, f: F)
286 where
287 F: Future<Output = ()> + Send + 'static,
288 {
289 self.queue.send(f.boxed()).unwrap();
290 }
291
292 pub(crate) fn from_impl(mesh: Arc<dyn SupervisableActorMesh>) -> Self {
293 let fut = future::ready(Ok::<Arc<dyn SupervisableActorMesh>, ClonePyErr>(mesh))
294 .boxed()
295 .shared();
296 let _ = futures::executor::block_on(fut.clone());
298 Self::new_queue(fut, true)
299 }
300}
301
302impl ActorMeshProtocol for AsyncActorMesh {
303 fn cast(
304 &self,
305 _message: PythonMessage,
306 _selection: Selection,
307 _instance: &Instance<PythonActor>,
308 ) -> PyResult<()> {
309 panic!("not implemented")
310 }
311
312 fn cast_unresolved(
313 &self,
314 message: PendingMessage,
315 selection: Selection,
316 instance: &Instance<PythonActor>,
317 ) -> PyResult<()> {
318 let mesh = self.mesh.clone();
319 let instance = instance.clone_for_py();
320 let port = match &message.kind {
321 PythonMessageKind::CallMethod { response_port, .. } => response_port.clone(),
322 _ => None,
323 };
324 self.push(async move {
325 let result = async {
326 let resolved = message.resolve().await?;
327 mesh.await?.cast(resolved, selection, &instance)
328 }
329 .await;
330 if let (Some(mut port_ref), Err(pyerr)) = (port, result) {
331 let _ = monarch_with_gil(|py: Python<'_>| {
332 let exception_str = crate::logging::format_traceback(py, &pyerr);
333 tracing::error!(
334 actor_id = instance.self_id().to_string(),
335 "error occurred during cast unresolved: {}",
336 exception_str
337 );
338
339 port_ref.set_return_undeliverable(false);
359
360 let mut state =
361 crate::pickle::pickle(py, pyerr.into_value(py).into_any(), false, false)?;
362 let _ = port_ref.send(
363 &instance,
364 PythonMessage::new_from_buf(
365 PythonMessageKind::Exception { rank: Some(0) },
366 state.take_inner()?.take_buffer(),
367 ),
368 );
369
370 Ok::<_, PyErr>(())
371 })
372 .await;
373 }
374 });
375 Ok(())
376 }
377
378 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
379 let fut = self.mesh.clone();
380 match fut.peek().cloned() {
381 Some(mesh) => mesh?.__reduce__(py),
382 None => {
383 let shared =
384 PyPythonTask::new(async move { Ok(PythonActorMesh::from_impl(fut.await?)) })?
385 .spawn_abortable()?;
386 let block_on = shared_class(py).getattr("block_on")?;
388 let args = PyTuple::new(py, [shared.into_pyobject(py)?])?;
389 Ok((block_on, args.into_any()))
390 }
391 }
392 }
393
394 fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
395 let mesh = self.mesh.clone();
396 let instance = monarch_with_gil_blocking(|_py| instance.clone());
397 let (tx, rx) = tokio::sync::oneshot::channel();
398 self.push(async move {
399 let result =
400 async move { mesh.await?.stop(&instance, reason)?.take_task()?.await }.await;
401 if tx.send(result).is_err() {
402 panic!("oneshot failed");
403 }
404 });
405 PyPythonTask::new(async move { rx.await.map_err(anyhow::Error::from)? })
406 }
407
408 fn initialized<'py>(&self) -> PyResult<PyPythonTask> {
409 let mesh = self.mesh.clone();
410 PyPythonTask::new(async {
411 mesh.await?;
412 Ok(None::<()>)
413 })
414 }
415
416 fn name(&self) -> PyResult<PyPythonTask> {
417 let mesh = self.mesh.clone();
418 let (tx, rx) = tokio::sync::oneshot::channel();
419 self.push(async move {
420 let result = async move { mesh.await?.name()?.take_task()?.await }.await;
421 if tx.send(result).is_err() {
422 panic!("oneshot failed");
423 }
424 });
425 PyPythonTask::new(async move { rx.await.map_err(anyhow::Error::from)? })
426 }
427}
428
429#[async_trait]
430impl Supervisable for AsyncActorMesh {
431 async fn supervision_event(&self, instance: &Instance<PythonActor>) -> Option<PyErr> {
432 if !self.supervised {
433 return None;
434 }
435 let mesh = self.mesh.clone();
436 match mesh.await {
437 Ok(mesh) => mesh.supervision_event(instance).await,
438 Err(e) => Some(e.into()),
439 }
440 }
441}
442
443impl SupervisableActorMesh for AsyncActorMesh {
444 fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>> {
445 let mesh = self.mesh.clone();
446 let region = region.clone();
447 Ok(Box::new(AsyncActorMesh::new(
448 self.queue.clone(),
449 self.supervised,
450 async move { Ok(Arc::from(mesh.await?.new_with_region(®ion)?)) }
451 .boxed()
452 .shared(),
453 )))
454 }
455}
456
457#[derive(Debug, Clone)]
458#[pyclass(
459 name = "PyActorMesh",
460 module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
461)]
462pub(crate) struct PyActorMesh {
463 mesh: ActorMesh<PythonActor>,
464}
465
466#[derive(Debug, Clone)]
467#[pyclass(
468 name = "PyActorMeshRef",
469 module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
470)]
471pub(crate) struct PyActorMeshRef {
472 mesh: ActorMeshRef<PythonActor>,
473}
474
475#[derive(Debug, Clone)]
476#[pyclass(
477 name = "PythonActorMeshImpl",
478 module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
479)]
480pub(crate) enum PythonActorMeshImpl {
481 Owned(PyActorMesh),
482 Ref(PyActorMeshRef),
483}
484
485impl PythonActorMeshImpl {
486 pub(crate) fn new_owned(inner: ActorMesh<PythonActor>) -> Self {
488 PythonActorMeshImpl::Owned(PyActorMesh { mesh: inner })
489 }
490
491 pub(crate) fn new_ref(inner: ActorMeshRef<PythonActor>) -> Self {
493 PythonActorMeshImpl::Ref(PyActorMeshRef { mesh: inner })
494 }
495
496 fn mesh_ref(&self) -> &ActorMeshRef<PythonActor> {
497 match self {
498 PythonActorMeshImpl::Owned(inner) => &inner.mesh,
499 PythonActorMeshImpl::Ref(inner) => &inner.mesh,
500 }
501 }
502}
503
504#[async_trait]
505impl Supervisable for PythonActorMeshImpl {
506 async fn supervision_event(&self, instance: &Instance<PythonActor>) -> Option<PyErr> {
507 let mesh = self.mesh_ref();
508 match mesh.next_supervision_event(instance).await {
509 Ok(supervision_failure) => Some(SupervisionError::new_err_from(supervision_failure)),
510 Err(e) => Some(PyValueError::new_err(e.to_string())),
511 }
512 }
513}
514
515impl ActorMeshProtocol for PythonActorMeshImpl {
516 fn cast(
517 &self,
518 message: PythonMessage,
519 selection: Selection,
520 instance: &Instance<PythonActor>,
521 ) -> PyResult<()> {
522 <ActorMeshRef<PythonActor> as ActorMeshProtocol>::cast(
523 self.mesh_ref(),
524 message,
525 selection,
526 instance,
527 )
528 }
529
530 fn stop(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
531 let (slf, instance) = monarch_with_gil_blocking(|_py| (self.clone(), instance.clone()));
532 match slf {
533 PythonActorMeshImpl::Owned(mut mesh) => PyPythonTask::new(async move {
534 mesh.mesh
535 .stop(instance.deref(), reason)
536 .await
537 .map_err(|err| PyValueError::new_err(err.to_string()))
538 }),
539 PythonActorMeshImpl::Ref(_) => Err(PyNotImplementedError::new_err(
540 "Cannot call stop on an ActorMeshRef, requires an owned ActorMesh",
541 )),
542 }
543 }
544
545 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
546 self.mesh_ref().__reduce__(py)
547 }
548
549 fn name(&self) -> PyResult<PyPythonTask> {
550 let name = self.mesh_ref().name().to_string();
551 PyPythonTask::new(async move { Ok(name) })
552 }
553}
554
555impl SupervisableActorMesh for PythonActorMeshImpl {
556 fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn SupervisableActorMesh>> {
557 assert!(region.as_inner().is_subset(self.mesh_ref().region()));
558 Ok(Box::new(PythonActorMeshImpl::new_ref(
559 self.mesh_ref().sliced(region.as_inner().clone()),
560 )))
561 }
562}
563
564fn cast_error_to_py_error(err: hyperactor_mesh::Error) -> PyErr {
567 if let hyperactor_mesh::Error::Supervision(failure) = err {
568 SupervisionError::new_err_from(*failure)
569 } else {
570 PyRuntimeError::new_err(err.to_string())
571 }
572}
573
574impl ActorMeshProtocol for ActorMeshRef<PythonActor> {
575 fn cast(
576 &self,
577 message: PythonMessage,
578 selection: Selection,
579 instance: &Instance<PythonActor>,
580 ) -> PyResult<()> {
581 if structurally_equal(&selection, &Selection::All(Box::new(Selection::True))) {
582 self.cast(instance, message.clone())
583 .map_err(cast_error_to_py_error)?;
584 } else if structurally_equal(&selection, &Selection::Any(Box::new(Selection::True))) {
585 let region = Ranked::region(self);
586 let random_rank = fastrand::usize(0..region.num_ranks());
587 let offset = region
588 .slice()
589 .get(random_rank)
590 .map_err(anyhow::Error::from)?;
591 let singleton_region = Region::new(
592 Vec::new(),
593 Slice::new(offset, Vec::new(), Vec::new()).map_err(anyhow::Error::from)?,
594 );
595 self.sliced(singleton_region)
596 .cast(instance, message.clone())
597 .map_err(cast_error_to_py_error)?;
598 } else {
599 return Err(PyRuntimeError::new_err(format!(
600 "invalid selection: {:?}",
601 selection
602 )));
603 }
604
605 Ok(())
606 }
607
608 fn stop(&self, _instance: &PyInstance, _reason: String) -> PyResult<PyPythonTask> {
610 Err(PyNotImplementedError::new_err(
611 "This cannot be used on ActorMeshRef, only on owned ActorMesh",
612 ))
613 }
614
615 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
616 let bytes = bincode::serialize(self).map_err(|e| PyValueError::new_err(e.to_string()))?;
617 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
618 let module = py
619 .import("monarch._rust_bindings.monarch_hyperactor.actor_mesh")
620 .unwrap();
621 let from_bytes = module.getattr("py_actor_mesh_from_bytes").unwrap();
622 Ok((from_bytes, py_bytes))
623 }
624
625 fn name(&self) -> PyResult<PyPythonTask> {
626 let name = self.name().to_string();
627 PyPythonTask::new(async move { Ok(name) })
628 }
629}
630
631#[pymethods]
632impl PythonActorMeshImpl {
633 fn get(&self, rank: usize) -> PyResult<Option<PyActorId>> {
634 Ok(self
635 .mesh_ref()
636 .get(rank)
637 .map(|r| reference::ActorRef::into_actor_id(r.clone()))
638 .map(PyActorId::from))
639 }
640
641 fn __repr__(&self) -> String {
642 format!("PythonActorMeshImpl({:?})", self.mesh_ref())
643 }
644}
645
646#[pyfunction]
647fn py_actor_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PythonActorMesh> {
648 let r: PyResult<ActorMeshRef<PythonActor>> =
649 bincode::deserialize(bytes.as_bytes()).map_err(|e| PyValueError::new_err(e.to_string()));
650 r.map(|r| AsyncActorMesh::from_impl(Arc::new(PythonActorMeshImpl::new_ref(r))))
651 .map(|r| PythonActorMesh::from_impl(Arc::from(r)))
652}
653
654#[pyclass(
655 name = "ActorSupervisionEvent",
656 module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
657)]
658#[derive(Debug)]
659pub struct PyActorSupervisionEvent {
660 inner: ActorSupervisionEvent,
661}
662
663#[pymethods]
664impl PyActorSupervisionEvent {
665 pub(crate) fn __repr__(&self) -> PyResult<String> {
666 Ok(format!("<PyActorSupervisionEvent: {}>", self.inner))
667 }
668
669 #[getter]
670 pub(crate) fn actor_id(&self) -> PyResult<PyActorId> {
671 Ok(PyActorId::from(self.inner.actor_id.clone()))
672 }
673
674 #[getter]
675 pub(crate) fn actor_status(&self) -> PyResult<String> {
676 Ok(self.inner.actor_status.to_string())
677 }
678}
679
680impl From<ActorSupervisionEvent> for PyActorSupervisionEvent {
681 fn from(event: ActorSupervisionEvent) -> Self {
682 PyActorSupervisionEvent { inner: event }
683 }
684}
685
686#[pyfunction]
687fn py_identity(obj: Py<PyAny>) -> PyResult<Py<PyAny>> {
688 Ok(obj)
689}
690
691#[pyfunction]
707#[pyo3(name = "hold_gil_for_test", signature = (delay_secs, hold_secs))]
708pub fn hold_gil_for_test(delay_secs: f64, hold_secs: f64) {
709 thread::spawn(move || {
710 thread::sleep(Duration::from_secs_f64(delay_secs));
712 Python::attach(|_py| {
714 tracing::info!("start holding the gil...");
715 thread::sleep(Duration::from_secs_f64(hold_secs));
716 tracing::info!("end holding the gil...");
717 });
718 });
719}
720
721pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
722 py_module_add_function!(
723 hyperactor_mod,
724 "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
725 py_identity
726 );
727 py_module_add_function!(
728 hyperactor_mod,
729 "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
730 py_actor_mesh_from_bytes
731 );
732 py_module_add_function!(
733 hyperactor_mod,
734 "monarch._rust_bindings.monarch_hyperactor.actor_mesh",
735 hold_gil_for_test
736 );
737 hyperactor_mod.add_class::<PythonActorMesh>()?;
738 hyperactor_mod.add_class::<PythonActorMeshImpl>()?;
739 hyperactor_mod.add_class::<PyActorSupervisionEvent>()?;
740 Ok(())
741}
742
743#[cfg(test)]
744mod tests {
745 use std::sync::OnceLock;
746 use std::time::Duration;
747
748 use async_trait::async_trait;
749 use hyperactor::Actor;
750 use hyperactor::Context;
751 use hyperactor::Handler;
752 use hyperactor::Instance;
753 use hyperactor::Proc;
754 use hyperactor::actor::Signal;
755 use hyperactor::channel::ChannelTransport;
756 use hyperactor::mailbox;
757 use hyperactor::mailbox::PortReceiver;
758 use hyperactor::proc::WorkCell;
759 use hyperactor::supervision::ActorSupervisionEvent;
760 use hyperactor_mesh::ProcMesh;
761 use hyperactor_mesh::alloc::AllocSpec;
762 use hyperactor_mesh::alloc::Allocator;
763 use hyperactor_mesh::alloc::LocalAllocator;
764 use hyperactor_mesh::mesh_controller::GetSubscriberCount;
765 use hyperactor_mesh::supervision::MeshFailure;
766 use monarch_types::PickledPyObject;
767 use ndslice::extent;
768 use pyo3::Python;
769 use tokio::sync::mpsc;
770
771 use super::*;
772 use crate::actor::PythonActor;
773 use crate::actor::PythonActorParams;
774
775 #[derive(Debug)]
778 struct TestClient {
779 signal_rx: PortReceiver<Signal>,
780 supervision_rx: PortReceiver<ActorSupervisionEvent>,
781 work_rx: mpsc::UnboundedReceiver<WorkCell<Self>>,
782 }
783
784 impl Actor for TestClient {}
785
786 #[async_trait]
787 impl Handler<MeshFailure> for TestClient {
788 async fn handle(
789 &mut self,
790 _cx: &Context<Self>,
791 msg: MeshFailure,
792 ) -> Result<(), anyhow::Error> {
793 panic!("unexpected supervision failure in test: {}", msg);
794 }
795 }
796
797 impl TestClient {
798 fn run(mut self, instance: &'static Instance<Self>) {
799 tokio::spawn(async move {
800 loop {
801 tokio::select! {
802 work = self.work_rx.recv() => {
803 match work {
804 Some(work) => {
805 let _ = work.handle(&mut self, instance).await;
806 }
807 None => break,
808 }
809 }
810 _ = self.signal_rx.recv() => {}
811 Ok(event) = self.supervision_rx.recv() => {
812 let _ = instance
813 .handle_supervision_event(&mut self, event)
814 .await;
815 }
816 }
817 }
818 });
819 }
820 }
821
822 fn init_test_instance() -> &'static Instance<TestClient> {
823 static INSTANCE: OnceLock<Instance<TestClient>> = OnceLock::new();
824 let proc = Proc::direct(ChannelTransport::Unix.any(), "test_proc".to_string()).unwrap();
825 let ai = proc.actor_instance("test_client").unwrap();
826
827 INSTANCE
828 .set(ai.instance)
829 .map_err(|_| "already initialized")
830 .unwrap();
831 let instance = INSTANCE.get().unwrap();
832
833 TestClient {
834 signal_rx: ai.signal,
835 supervision_rx: ai.supervision,
836 work_rx: ai.work,
837 }
838 .run(instance);
839
840 instance
841 }
842
843 fn test_instance() -> &'static Instance<TestClient> {
844 static INSTANCE: OnceLock<&'static Instance<TestClient>> = OnceLock::new();
845 INSTANCE.get_or_init(init_test_instance)
846 }
847
848 #[tokio::test]
853 async fn test_subscriber_count_stable_across_supervision_calls() {
854 crate::pytokio::ensure_python();
855
856 let instance = test_instance();
857
858 let proc_mesh = ProcMesh::allocate(
859 instance,
860 Box::new(
861 LocalAllocator
862 .allocate(AllocSpec {
863 extent: extent!(replicas = 2),
864 constraints: Default::default(),
865 proc_name: None,
866 transport: ChannelTransport::Local,
867 proc_allocation_mode: Default::default(),
868 })
869 .await
870 .unwrap(),
871 ),
872 "test",
873 )
874 .await
875 .unwrap();
876
877 let pickled_type = Python::attach(|py| {
881 py.run(c"class MinimalActor: pass", None, None).unwrap();
882
883 PickledPyObject::pickle(
884 &py.import("__main__")
885 .unwrap()
886 .getattr("MinimalActor")
887 .unwrap(),
888 )
889 .unwrap()
890 });
891
892 let actor_mesh = proc_mesh
893 .spawn::<PythonActor, _>(
894 instance,
895 "test_actors",
896 &PythonActorParams::new(pickled_type, None),
897 )
898 .await
899 .unwrap();
900
901 let controller = actor_mesh.controller().as_ref().unwrap().clone();
902
903 let mesh_impl =
905 async move { Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh))) };
906 let python_actor_mesh = PythonActorMesh::new(
907 async move {
908 let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
909 Ok(mesh_impl)
910 },
911 true,
912 );
913
914 let py_ai = Proc::direct(ChannelTransport::Unix.any(), "py_proc".to_string())
918 .unwrap()
919 .actor_instance::<PythonActor>("py_client")
920 .unwrap();
921 let py_instance = py_ai.instance;
922
923 let (port, mut rx) = mailbox::open_port::<usize>(instance);
925 controller
926 .send(instance, GetSubscriberCount(port.bind()))
927 .unwrap();
928 let initial_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
929 .await
930 .expect("timed out waiting for subscriber count")
931 .expect("channel closed");
932 assert_eq!(initial_count, 0, "should have 0 subscribers initially");
933
934 for _ in 0..5 {
939 tokio::select! {
940 _ = python_actor_mesh.inner.supervision_event(&py_instance) => {
941 panic!("unexpected supervision event on healthy mesh");
942 }
943 _ = tokio::time::sleep(Duration::from_millis(200)) => {}
944 }
945 }
946
947 let (port, mut rx) = mailbox::open_port::<usize>(instance);
950 controller
951 .send(instance, GetSubscriberCount(port.bind()))
952 .unwrap();
953 let after_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
954 .await
955 .expect("timed out waiting for subscriber count")
956 .expect("channel closed");
957 assert_eq!(
958 after_count, 1,
959 "subscriber count should be exactly 1, not growing with each call"
960 );
961
962 for _ in 0..5 {
964 tokio::select! {
965 _ = python_actor_mesh.inner.supervision_event(&py_instance) => {
966 panic!("unexpected supervision event on healthy mesh");
967 }
968 _ = tokio::time::sleep(Duration::from_millis(200)) => {}
969 }
970 }
971
972 let (port, mut rx) = mailbox::open_port::<usize>(instance);
973 controller
974 .send(instance, GetSubscriberCount(port.bind()))
975 .unwrap();
976 let final_count = tokio::time::timeout(Duration::from_secs(5), rx.recv())
977 .await
978 .expect("timed out waiting for subscriber count")
979 .expect("channel closed");
980 assert_eq!(
981 final_count, 1,
982 "subscriber count should still be 1 after repeated calls"
983 );
984 }
985}