monarch_hyperactor/
endpoint.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::Cell;
10use std::sync::Arc;
11use std::sync::atomic::AtomicUsize;
12use std::sync::atomic::Ordering;
13
14use hyperactor::Instance;
15use hyperactor::mailbox::PortReceiver;
16use hyperactor_mesh::sel;
17use monarch_types::py_global;
18use ndslice::Extent;
19use ndslice::Selection;
20use ndslice::Shape;
21use pyo3::prelude::*;
22use pyo3::types::PyDict;
23use pyo3::types::PyTuple;
24use serde_multipart::Part;
25
26use crate::actor::MethodSpecifier;
27use crate::actor::PythonActor;
28use crate::actor::PythonMessage;
29use crate::actor::PythonMessageKind;
30use crate::actor_mesh::PythonActorMesh;
31use crate::actor_mesh::SupervisableActorMesh;
32use crate::actor_mesh::to_hy_sel;
33use crate::buffers::FrozenBuffer;
34use crate::context::PyInstance;
35use crate::mailbox::PythonPortRef;
36use crate::metrics::ENDPOINT_BROADCAST_ERROR;
37use crate::metrics::ENDPOINT_BROADCAST_THROUGHPUT;
38use crate::metrics::ENDPOINT_CALL_ERROR;
39use crate::metrics::ENDPOINT_CALL_LATENCY_US_HISTOGRAM;
40use crate::metrics::ENDPOINT_CALL_ONE_ERROR;
41use crate::metrics::ENDPOINT_CALL_ONE_LATENCY_US_HISTOGRAM;
42use crate::metrics::ENDPOINT_CALL_ONE_THROUGHPUT;
43use crate::metrics::ENDPOINT_CALL_THROUGHPUT;
44use crate::metrics::ENDPOINT_CHOOSE_ERROR;
45use crate::metrics::ENDPOINT_CHOOSE_LATENCY_US_HISTOGRAM;
46use crate::metrics::ENDPOINT_CHOOSE_THROUGHPUT;
47use crate::metrics::ENDPOINT_STREAM_ERROR;
48use crate::metrics::ENDPOINT_STREAM_LATENCY_US_HISTOGRAM;
49use crate::metrics::ENDPOINT_STREAM_THROUGHPUT;
50use crate::pickle::PendingMessage;
51use crate::pickle::unpickle;
52use crate::pytokio::PyPythonTask;
53use crate::pytokio::PythonTask;
54use crate::shape::PyExtent;
55use crate::shape::PyShape;
56use crate::supervision::Supervisable;
57use crate::supervision::SupervisionError;
58use crate::value_mesh::PyValueMesh;
59
60py_global!(get_context, "monarch._src.actor.actor_mesh", "context");
61py_global!(
62    create_endpoint_message,
63    "monarch._src.actor.actor_mesh",
64    "_create_endpoint_message"
65);
66py_global!(
67    dispatch_actor_rref,
68    "monarch._src.actor.actor_mesh",
69    "_dispatch_actor_rref"
70);
71py_global!(make_future, "monarch._src.actor.future", "Future");
72
73fn unpickle_from_part<'py>(py: Python<'py>, part: Part) -> PyResult<Bound<'py, PyAny>> {
74    unpickle(
75        py,
76        FrozenBuffer {
77            inner: part.into_bytes(),
78        },
79    )
80}
81
82/// The type of endpoint operation being performed.
83///
84/// Used to select the appropriate telemetry metrics for each operation type.
85#[derive(Clone, Copy, Debug)]
86enum EndpointAdverb {
87    Call,
88    CallOne,
89    Choose,
90    Stream,
91}
92
93/// RAII guard for recording endpoint call telemetry.
94///
95/// Records latency on drop, similar to Python's `@_with_telemetry` decorator.
96/// Call `mark_error()` before dropping to also record an error.
97pub struct RecordEndpointGuard {
98    start: tokio::time::Instant,
99    method_name: String,
100    actor_count: usize,
101    adverb: EndpointAdverb,
102    error_occurred: Cell<bool>,
103}
104
105impl RecordEndpointGuard {
106    fn new(
107        start: tokio::time::Instant,
108        method_name: String,
109        actor_count: usize,
110        adverb: EndpointAdverb,
111    ) -> Self {
112        let attributes = hyperactor_telemetry::kv_pairs!(
113            "method" => method_name.clone()
114        );
115        match adverb {
116            EndpointAdverb::Call => {
117                ENDPOINT_CALL_THROUGHPUT.add(1, attributes);
118            }
119            EndpointAdverb::CallOne => {
120                ENDPOINT_CALL_ONE_THROUGHPUT.add(1, attributes);
121            }
122            EndpointAdverb::Choose => {
123                ENDPOINT_CHOOSE_THROUGHPUT.add(1, attributes);
124            }
125            EndpointAdverb::Stream => {
126                // Throughput already recorded once at stream creation in py_stream_collector
127            }
128        }
129
130        Self {
131            start,
132            method_name,
133            actor_count,
134            adverb,
135            error_occurred: Cell::new(false),
136        }
137    }
138
139    fn mark_error(&self) {
140        self.error_occurred.set(true);
141    }
142}
143
144impl Drop for RecordEndpointGuard {
145    fn drop(&mut self) {
146        let actor_count_str = self.actor_count.to_string();
147        let attributes = hyperactor_telemetry::kv_pairs!(
148            "method" => self.method_name.clone(),
149            "actor_count" => actor_count_str
150        );
151        tracing::info!(message = "response received", method = self.method_name);
152
153        let duration_us = self.start.elapsed().as_micros();
154
155        match self.adverb {
156            EndpointAdverb::Call => {
157                ENDPOINT_CALL_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
158            }
159            EndpointAdverb::CallOne => {
160                ENDPOINT_CALL_ONE_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
161            }
162            EndpointAdverb::Choose => {
163                ENDPOINT_CHOOSE_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
164            }
165            EndpointAdverb::Stream => {
166                ENDPOINT_STREAM_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
167            }
168        }
169
170        if self.error_occurred.get() {
171            match self.adverb {
172                EndpointAdverb::Call => {
173                    ENDPOINT_CALL_ERROR.add(1, attributes);
174                }
175                EndpointAdverb::CallOne => {
176                    ENDPOINT_CALL_ONE_ERROR.add(1, attributes);
177                }
178                EndpointAdverb::Choose => {
179                    ENDPOINT_CHOOSE_ERROR.add(1, attributes);
180                }
181                EndpointAdverb::Stream => {
182                    ENDPOINT_STREAM_ERROR.add(1, attributes);
183                }
184            }
185        }
186    }
187}
188
189fn supervision_error_to_pyerr(err: PyErr, qualified_endpoint_name: &Option<String>) -> PyErr {
190    match qualified_endpoint_name {
191        Some(endpoint) => {
192            Python::attach(|py| SupervisionError::set_endpoint_on_err(py, err, endpoint.clone()))
193        }
194        None => err,
195    }
196}
197
198async fn collect_value(
199    rx: &mut PortReceiver<PythonMessage>,
200    supervision_monitor: &Option<Arc<dyn Supervisable>>,
201    instance: &Instance<PythonActor>,
202    qualified_endpoint_name: &Option<String>,
203) -> PyResult<(Part, Option<usize>)> {
204    enum RaceResult {
205        Message(PythonMessage),
206        SupervisionError(PyErr),
207        RecvError(String),
208    }
209
210    let race_result = match supervision_monitor {
211        Some(sup) => {
212            tokio::select! {
213                biased;
214                result = sup.supervision_event(instance) => {
215                    match result {
216                        Some(err) => RaceResult::SupervisionError(err),
217                        None => {
218                            match rx.recv().await {
219                                Ok(msg) => RaceResult::Message(msg),
220                                Err(e) => RaceResult::RecvError(e.to_string()),
221                            }
222                        }
223                    }
224                }
225                msg = rx.recv() => {
226                    match msg {
227                        Ok(m) => RaceResult::Message(m),
228                        Err(e) => RaceResult::RecvError(e.to_string()),
229                    }
230                }
231            }
232        }
233        _ => match rx.recv().await {
234            Ok(msg) => RaceResult::Message(msg),
235            Err(e) => RaceResult::RecvError(e.to_string()),
236        },
237    };
238
239    match race_result {
240        RaceResult::Message(PythonMessage {
241            kind: PythonMessageKind::Result { rank, .. },
242            message,
243            ..
244        }) => Ok((message, rank)),
245        RaceResult::Message(PythonMessage {
246            kind: PythonMessageKind::Exception { .. },
247            message,
248            ..
249        }) => Python::attach(|py| Err(PyErr::from_value(unpickle_from_part(py, message)?))),
250        RaceResult::Message(msg) => Err(pyo3::exceptions::PyValueError::new_err(format!(
251            "unexpected message kind {:?}",
252            msg.kind
253        ))),
254        RaceResult::RecvError(e) => Err(pyo3::exceptions::PyEOFError::new_err(format!(
255            "Port closed: {}",
256            e
257        ))),
258        RaceResult::SupervisionError(err) => {
259            Err(supervision_error_to_pyerr(err, qualified_endpoint_name))
260        }
261    }
262}
263
264async fn collect_valuemesh(
265    extent: Extent,
266    mut rx: PortReceiver<PythonMessage>,
267    method_name: String,
268    supervision_monitor: Option<Arc<dyn Supervisable>>,
269    instance: &Instance<PythonActor>,
270    qualified_endpoint_name: Option<String>,
271) -> PyResult<Py<PyAny>> {
272    let start = tokio::time::Instant::now();
273
274    let expected_count = extent.num_ranks();
275
276    let record_guard = RecordEndpointGuard::new(
277        start,
278        method_name.clone(),
279        expected_count,
280        EndpointAdverb::Call,
281    );
282
283    let mut results: Vec<Option<Part>> = vec![None; expected_count];
284
285    for _ in 0..expected_count {
286        match collect_value(
287            &mut rx,
288            &supervision_monitor,
289            instance,
290            &qualified_endpoint_name,
291        )
292        .await
293        {
294            Ok((message, rank)) => {
295                results[rank.expect("RankedPort receiver got a message without a rank")] =
296                    Some(message);
297            }
298            Err(e) => {
299                record_guard.mark_error();
300                return Err(e);
301            }
302        }
303    }
304
305    Python::attach(|py| {
306        Ok(PyValueMesh::build_dense_from_extent(
307            &extent,
308            results
309                .into_iter()
310                .map(|msg| {
311                    let m = msg.expect("all responses should be filled");
312                    unpickle_from_part(py, m).map(|obj| obj.unbind())
313                })
314                .collect::<PyResult<_>>()?,
315        )?
316        .into_pyobject(py)?
317        .into_any()
318        .unbind())
319    })
320}
321
322fn value_collector(
323    mut receiver: PortReceiver<PythonMessage>,
324    method_name: String,
325    supervision_monitor: Option<Arc<dyn Supervisable>>,
326    instance: Instance<PythonActor>,
327    qualified_endpoint_name: Option<String>,
328    adverb: EndpointAdverb,
329) -> PyResult<PyPythonTask> {
330    Ok(PythonTask::new(async move {
331        let start = tokio::time::Instant::now();
332
333        let record_guard = RecordEndpointGuard::new(start, method_name, 1, adverb);
334
335        match collect_value(
336            &mut receiver,
337            &supervision_monitor,
338            &instance,
339            &qualified_endpoint_name,
340        )
341        .await
342        {
343            Ok((message, _)) => {
344                Python::attach(|py| unpickle_from_part(py, message).map(|obj| obj.unbind()))
345            }
346            Err(e) => {
347                record_guard.mark_error();
348                Err(e)
349            }
350        }
351    })?
352    .into())
353}
354
355/// A streaming iterator that yields futures for each response from actors.
356///
357/// Implements Python's iterator protocol (`__iter__`/`__next__`) to yield
358/// `Future` objects that resolve to individual actor responses.
359#[pyclass(
360    name = "ValueStream",
361    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
362)]
363pub struct PyValueStream {
364    receiver: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
365    /// Supervisor for monitoring actor health during streaming.
366    supervision_monitor: Option<Arc<dyn Supervisable>>,
367    instance: Instance<PythonActor>,
368    remaining: AtomicUsize,
369    method_name: String,
370    qualified_endpoint_name: Option<String>,
371    start: tokio::time::Instant,
372    actor_count: usize,
373    future_class: Py<PyAny>,
374}
375
376#[pymethods]
377impl PyValueStream {
378    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
379        slf
380    }
381
382    fn __next__(&self, py: Python<'_>) -> PyResult<Option<Py<PyAny>>> {
383        let remaining = self.remaining.load(Ordering::Relaxed);
384        if remaining == 0 {
385            return Ok(None);
386        }
387        self.remaining.store(remaining - 1, Ordering::Relaxed);
388
389        let receiver = self.receiver.clone();
390        let supervision_monitor = self.supervision_monitor.clone();
391        let instance = self.instance.clone_for_py();
392        let qualified_endpoint_name = self.qualified_endpoint_name.clone();
393        let start = self.start;
394        let method_name = self.method_name.clone();
395        let actor_count = self.actor_count;
396
397        let task: PyPythonTask = PythonTask::new(async move {
398            let record_guard =
399                RecordEndpointGuard::new(start, method_name, actor_count, EndpointAdverb::Stream);
400
401            let mut rx_guard = receiver.lock().await;
402
403            match collect_value(
404                &mut rx_guard,
405                &supervision_monitor,
406                &instance,
407                &qualified_endpoint_name,
408            )
409            .await
410            {
411                Ok((message, _)) => {
412                    Python::attach(|py| unpickle_from_part(py, message).map(|obj| obj.unbind()))
413                }
414                Err(e) => {
415                    record_guard.mark_error();
416                    Err(e)
417                }
418            }
419        })?
420        .into();
421
422        let kwargs = PyDict::new(py);
423        kwargs.set_item("coro", task)?;
424        let future = self.future_class.call(py, (), Some(&kwargs))?;
425        Ok(Some(future))
426    }
427}
428
429fn wrap_in_future(py: Python<'_>, task: PyPythonTask) -> PyResult<Py<PyAny>> {
430    let kwargs = PyDict::new(py);
431    kwargs.set_item("coro", task)?;
432    let future = make_future(py).call((), Some(&kwargs))?;
433    Ok(future.unbind())
434}
435
436/// Trait that defines the core operations an endpoint must provide.
437/// Both ActorEndpoint and RemoteEndpoint implement this trait.
438pub(crate) trait Endpoint {
439    /// Get the extent of the endpoint's targets.
440    fn get_extent(&self, py: Python<'_>) -> PyResult<Extent>;
441
442    /// Get the method name for this endpoint.
443    fn get_method_name(&self) -> &str;
444
445    /// Create and send a message with the given args/kwargs.
446    fn send_message<'py>(
447        &self,
448        py: Python<'py>,
449        args: &Bound<'py, PyTuple>,
450        kwargs: Option<&Bound<'py, PyDict>>,
451        port_ref: Option<&PythonPortRef>,
452        selection: Selection,
453        instance: &Instance<PythonActor>,
454    ) -> PyResult<()>;
455
456    /// Get the supervision_monitor for this endpoint (if any).
457    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>>;
458
459    /// Get the qualified endpoint name for error messages (if any).
460    fn get_qualified_name(&self) -> Option<String>;
461
462    fn get_current_instance(&self, py: Python<'_>) -> PyResult<Instance<PythonActor>> {
463        let context = get_context(py).call0()?;
464        let py_instance: PyRef<PyInstance> = context.getattr("actor_instance")?.extract()?;
465        Ok(py_instance.clone().into_instance())
466    }
467
468    fn open_response_port(
469        &self,
470        instance: &Instance<PythonActor>,
471    ) -> (PythonPortRef, PortReceiver<PythonMessage>) {
472        let (p, receiver) = instance.mailbox_for_py().open_port::<PythonMessage>();
473        (PythonPortRef { inner: p.bind() }, receiver)
474    }
475
476    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
477    fn call<'py>(
478        &self,
479        py: Python<'py>,
480        args: &Bound<'py, PyTuple>,
481        kwargs: Option<&Bound<'py, PyDict>>,
482    ) -> PyResult<Py<PyAny>> {
483        let extent = self.get_extent(py)?;
484        let method_name = self.get_method_name().to_string();
485
486        let instance = self.get_current_instance(py)?;
487        let (port_ref, receiver) = self.open_response_port(&instance);
488
489        let supervision_monitor = self.get_supervision_monitor();
490        let qualified_endpoint_name = self.get_qualified_name();
491
492        self.send_message(py, args, kwargs, Some(&port_ref), sel!(*), &instance)?;
493
494        let instance_for_task = instance.clone_for_py();
495        let task: PyPythonTask = PythonTask::new(async move {
496            collect_valuemesh(
497                extent,
498                receiver,
499                method_name,
500                supervision_monitor,
501                &instance_for_task,
502                qualified_endpoint_name,
503            )
504            .await
505        })?
506        .into();
507
508        wrap_in_future(py, task)
509    }
510
511    /// Load balanced sends a message to one chosen actor and awaits a result.
512    fn choose<'py>(
513        &self,
514        py: Python<'py>,
515        args: &Bound<'py, PyTuple>,
516        kwargs: Option<&Bound<'py, PyDict>>,
517    ) -> PyResult<Py<PyAny>> {
518        let method_name = self.get_method_name();
519
520        let instance = self.get_current_instance(py)?;
521        let (port_ref, receiver) = self.open_response_port(&instance);
522
523        self.send_message(py, args, kwargs, Some(&port_ref), sel!(?), &instance)?;
524
525        let task = value_collector(
526            receiver,
527            method_name.to_string(),
528            self.get_supervision_monitor(),
529            instance.clone_for_py(),
530            self.get_qualified_name(),
531            EndpointAdverb::Choose,
532        )?;
533
534        wrap_in_future(py, task)
535    }
536
537    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
538    fn call_one<'py>(
539        &self,
540        py: Python<'py>,
541        args: &Bound<'py, PyTuple>,
542        kwargs: Option<&Bound<'py, PyDict>>,
543    ) -> PyResult<Py<PyAny>> {
544        let extent = self.get_extent(py)?;
545        let method_name = self.get_method_name();
546
547        if extent.num_ranks() != 1 {
548            return Err(pyo3::exceptions::PyValueError::new_err(format!(
549                "call_one requires exactly 1 actor, but mesh has {}",
550                extent.num_ranks()
551            )));
552        }
553
554        let instance = self.get_current_instance(py)?;
555        let (port_ref, receiver) = self.open_response_port(&instance);
556
557        self.send_message(py, args, kwargs, Some(&port_ref), sel!(*), &instance)?;
558
559        let task = value_collector(
560            receiver,
561            method_name.to_string(),
562            self.get_supervision_monitor(),
563            instance.clone_for_py(),
564            self.get_qualified_name(),
565            EndpointAdverb::CallOne,
566        )?;
567
568        wrap_in_future(py, task)
569    }
570
571    /// Call the endpoint on all actors and return an iterator of Futures.
572    fn stream<'py>(
573        &self,
574        py: Python<'py>,
575        args: &Bound<'py, PyTuple>,
576        kwargs: Option<&Bound<'py, PyDict>>,
577    ) -> PyResult<Py<PyAny>> {
578        let extent = self.get_extent(py)?;
579        let method_name = self.get_method_name().to_string();
580
581        let instance = self.get_current_instance(py)?;
582        let (port_ref, receiver) = self.open_response_port(&instance);
583
584        self.send_message(py, args, kwargs, Some(&port_ref), sel!(*), &instance)?;
585
586        let actor_count = extent.num_ranks();
587        let start = tokio::time::Instant::now();
588        let supervision_monitor = self.get_supervision_monitor();
589        let qualified_endpoint_name = self.get_qualified_name();
590        let future_class = make_future(py).unbind();
591
592        let attributes = hyperactor_telemetry::kv_pairs!(
593            "method" => method_name.clone()
594        );
595        ENDPOINT_STREAM_THROUGHPUT.add(1, attributes);
596
597        let stream = PyValueStream {
598            receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
599            supervision_monitor,
600            instance: instance.clone_for_py(),
601            remaining: AtomicUsize::new(actor_count),
602            method_name,
603            qualified_endpoint_name,
604            start,
605            actor_count,
606            future_class,
607        };
608
609        Ok(stream.into_pyobject(py)?.unbind().into())
610    }
611
612    /// Send a message to all actors without waiting for responses (fire-and-forget).
613    fn broadcast<'py>(
614        &self,
615        py: Python<'py>,
616        args: &Bound<'py, PyTuple>,
617        kwargs: Option<&Bound<'py, PyDict>>,
618    ) -> PyResult<()> {
619        let instance = self.get_current_instance(py)?;
620        let method_name = self.get_method_name();
621        let attributes = hyperactor_telemetry::kv_pairs!(
622            "method" => method_name.to_string()
623        );
624
625        match self.send_message(py, args, kwargs, None, sel!(*), &instance) {
626            Ok(()) => {
627                ENDPOINT_BROADCAST_THROUGHPUT.add(1, attributes);
628                Ok(())
629            }
630            Err(e) => {
631                ENDPOINT_BROADCAST_ERROR.add(1, attributes);
632                Err(e)
633            }
634        }
635    }
636}
637
638#[pyclass(
639    name = "ActorEndpoint",
640    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
641)]
642pub struct ActorEndpoint {
643    inner: Arc<dyn SupervisableActorMesh>,
644    shape: Shape,
645    method: MethodSpecifier,
646    mesh_name: String,
647    signature: Option<Py<PyAny>>,
648    proc_mesh: Option<Py<PyAny>>,
649    propagator: Option<Py<PyAny>>,
650}
651
652impl ActorEndpoint {
653    fn create_message<'py>(
654        &self,
655        py: Python<'py>,
656        args: &Bound<'py, PyTuple>,
657        kwargs: Option<&Bound<'py, PyDict>>,
658        port_ref: Option<&PythonPortRef>,
659    ) -> PyResult<PendingMessage> {
660        let port_ref_py: Py<PyAny> = match port_ref {
661            Some(pr) => pr.clone().into_pyobject(py)?.unbind().into(),
662            None => py.None(),
663        };
664
665        let result = create_endpoint_message(py).call1((
666            self.method.clone(),
667            self.signature
668                .as_ref()
669                .map_or_else(|| py.None(), |s| s.clone_ref(py)),
670            args,
671            kwargs
672                .map_or_else(|| PyDict::new(py), |d| d.clone())
673                .into_any(),
674            port_ref_py,
675            self.proc_mesh
676                .as_ref()
677                .map_or_else(|| py.None(), |p| p.clone_ref(py)),
678        ))?;
679        let mut pending: PyRefMut<'_, PendingMessage> = result.extract()?;
680        pending.take()
681    }
682}
683
684impl Endpoint for ActorEndpoint {
685    fn get_extent(&self, _py: Python<'_>) -> PyResult<Extent> {
686        Ok(self.shape.extent())
687    }
688
689    fn get_method_name(&self) -> &str {
690        self.method.name()
691    }
692
693    fn send_message<'py>(
694        &self,
695        py: Python<'py>,
696        args: &Bound<'py, PyTuple>,
697        kwargs: Option<&Bound<'py, PyDict>>,
698        port_ref: Option<&PythonPortRef>,
699        selection: Selection,
700        instance: &Instance<PythonActor>,
701    ) -> PyResult<()> {
702        let message = self.create_message(py, args, kwargs, port_ref)?;
703        self.inner.cast_unresolved(message, selection, instance)
704    }
705
706    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>> {
707        Some(self.inner.clone())
708    }
709
710    fn get_qualified_name(&self) -> Option<String> {
711        Some(format!("{}.{}()", self.mesh_name, self.method.name()))
712    }
713}
714
715#[pymethods]
716impl ActorEndpoint {
717    /// Create a new ActorEndpoint.
718    #[new]
719    #[pyo3(signature = (actor_mesh, method, shape, mesh_name, signature=None, proc_mesh=None, propagator=None))]
720    fn new(
721        actor_mesh: PythonActorMesh,
722        method: MethodSpecifier,
723        shape: PyShape,
724        mesh_name: String,
725        signature: Option<Py<PyAny>>,
726        proc_mesh: Option<Py<PyAny>>,
727        propagator: Option<Py<PyAny>>,
728    ) -> Self {
729        Self {
730            inner: actor_mesh.get_inner(),
731            shape: shape.get_inner().clone(),
732            method,
733            mesh_name,
734            signature,
735            proc_mesh,
736            propagator,
737        }
738    }
739
740    /// Get the method specifier (used by actor_rref for tensor dispatch).
741    #[getter]
742    fn _name(&self) -> MethodSpecifier {
743        self.method.clone()
744    }
745
746    /// Get the signature (used for argument checking in _dispatch_actor_rref).
747    #[getter]
748    fn _signature(&self, py: Python<'_>) -> Py<PyAny> {
749        self.signature
750            .clone()
751            .unwrap_or_else(|| py.None().into_any())
752    }
753
754    /// Get the actor mesh (used by actor_rref for sending messages).
755    #[getter]
756    fn _actor_mesh(&self) -> PythonActorMesh {
757        PythonActorMesh::from_impl(self.inner.clone())
758    }
759
760    /// Propagation method for tensor shape inference.
761    /// Delegates to Python _do_propagate helper.
762    fn _propagate<'py>(
763        &self,
764        py: Python<'py>,
765        args: &Bound<'py, PyAny>,
766        kwargs: &Bound<'py, PyAny>,
767        fake_args: &Bound<'py, PyAny>,
768        fake_kwargs: &Bound<'py, PyAny>,
769    ) -> PyResult<Py<PyAny>> {
770        let do_propagate = py
771            .import("monarch._src.actor.endpoint")?
772            .getattr("_do_propagate")?;
773        let propagator = self
774            .propagator
775            .as_ref()
776            .map(|p| p.clone_ref(py).into_bound(py))
777            .unwrap_or_else(|| py.None().into_bound(py));
778        let cache = PyDict::new(py);
779        do_propagate
780            .call1((&propagator, args, kwargs, fake_args, fake_kwargs, cache))?
781            .extract()
782    }
783
784    /// Propagation for fetch operations.
785    /// Returns None if no propagator is provided, otherwise calls _propagate.
786    fn _fetch_propagate<'py>(
787        &self,
788        py: Python<'py>,
789        args: &Bound<'py, PyAny>,
790        kwargs: &Bound<'py, PyAny>,
791        fake_args: &Bound<'py, PyAny>,
792        fake_kwargs: &Bound<'py, PyAny>,
793    ) -> PyResult<Py<PyAny>> {
794        if self.propagator.is_none() {
795            return Ok(py.None());
796        }
797        self._propagate(py, args, kwargs, fake_args, fake_kwargs)
798    }
799
800    /// Propagation for pipe operations.
801    /// Requires an explicit callable propagator.
802    fn _pipe_propagate<'py>(
803        &self,
804        py: Python<'py>,
805        args: &Bound<'py, PyAny>,
806        kwargs: &Bound<'py, PyAny>,
807        fake_args: &Bound<'py, PyAny>,
808        fake_kwargs: &Bound<'py, PyAny>,
809    ) -> PyResult<Py<PyAny>> {
810        // Check if propagator is callable
811        let is_callable = self
812            .propagator
813            .as_ref()
814            .map(|p| p.bind(py).is_callable())
815            .unwrap_or(false);
816        if !is_callable {
817            return Err(pyo3::exceptions::PyValueError::new_err(
818                "Must specify explicit callable for pipe",
819            ));
820        }
821        self._propagate(py, args, kwargs, fake_args, fake_kwargs)
822    }
823
824    /// Get the rref result by calling the Python dispatch helper.
825    #[pyo3(signature = (*args, **kwargs))]
826    fn rref<'py>(
827        slf: PyRef<'py, Self>,
828        py: Python<'py>,
829        args: &Bound<'py, PyTuple>,
830        kwargs: Option<&Bound<'py, PyDict>>,
831    ) -> PyResult<Py<PyAny>> {
832        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
833
834        // Call _dispatch_actor_rref(endpoint, args, kwargs)
835        let result = dispatch_actor_rref(py).call1((slf.into_pyobject(py)?, args, kwargs_dict))?;
836
837        Ok(result.unbind())
838    }
839
840    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
841    #[pyo3(signature = (*args, **kwargs), name = "call")]
842    fn py_call<'py>(
843        &self,
844        py: Python<'py>,
845        args: &Bound<'py, PyTuple>,
846        kwargs: Option<&Bound<'py, PyDict>>,
847    ) -> PyResult<Py<PyAny>> {
848        self.call(py, args, kwargs)
849    }
850
851    /// Load balanced sends a message to one chosen actor and awaits a result.
852    #[pyo3(signature = (*args, **kwargs), name = "choose")]
853    fn py_choose<'py>(
854        &self,
855        py: Python<'py>,
856        args: &Bound<'py, PyTuple>,
857        kwargs: Option<&Bound<'py, PyDict>>,
858    ) -> PyResult<Py<PyAny>> {
859        self.choose(py, args, kwargs)
860    }
861
862    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
863    #[pyo3(signature = (*args, **kwargs), name = "call_one")]
864    fn py_call_one<'py>(
865        &self,
866        py: Python<'py>,
867        args: &Bound<'py, PyTuple>,
868        kwargs: Option<&Bound<'py, PyDict>>,
869    ) -> PyResult<Py<PyAny>> {
870        self.call_one(py, args, kwargs)
871    }
872
873    /// Call the endpoint on all actors and return an iterator of Futures.
874    #[pyo3(signature = (*args, **kwargs), name = "stream")]
875    fn py_stream<'py>(
876        &self,
877        py: Python<'py>,
878        args: &Bound<'py, PyTuple>,
879        kwargs: Option<&Bound<'py, PyDict>>,
880    ) -> PyResult<Py<PyAny>> {
881        self.stream(py, args, kwargs)
882    }
883
884    /// Send a message to all actors without waiting for responses (fire-and-forget).
885    #[pyo3(signature = (*args, **kwargs), name = "broadcast")]
886    fn py_broadcast<'py>(
887        &self,
888        py: Python<'py>,
889        args: &Bound<'py, PyTuple>,
890        kwargs: Option<&Bound<'py, PyDict>>,
891    ) -> PyResult<()> {
892        self.broadcast(py, args, kwargs)
893    }
894
895    /// Send a message with optional port for response (used by actor_mesh.send).
896    fn _send<'py>(
897        &self,
898        py: Python<'py>,
899        args: &Bound<'py, PyTuple>,
900        kwargs: &Bound<'py, PyDict>,
901        port: Option<&PythonPortRef>,
902        selection: &str,
903    ) -> PyResult<()> {
904        let instance = self.get_current_instance(py)?;
905        let sel = to_hy_sel(selection)?;
906        self.send_message(py, args, Some(kwargs), port, sel, &instance)
907    }
908}
909
910/// A Rust wrapper for Python's RemoteImpl endpoint.
911///
912/// This allows us to implement the adverb methods (call, choose, call_one, stream, broadcast)
913/// in Rust while delegating the actual send logic to the Python RemoteImpl._send() method.
914#[pyclass(
915    name = "Remote",
916    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
917)]
918pub struct Remote {
919    /// The wrapped Python RemoteImpl object
920    inner: Py<PyAny>,
921}
922
923impl Endpoint for Remote {
924    fn get_extent(&self, py: Python<'_>) -> PyResult<Extent> {
925        let extent: PyExtent = self.inner.call_method0(py, "_get_extent")?.extract(py)?;
926        Ok(extent.into())
927    }
928
929    fn get_method_name(&self) -> &str {
930        "unknown"
931    }
932
933    fn send_message<'py>(
934        &self,
935        py: Python<'py>,
936        args: &Bound<'py, PyTuple>,
937        kwargs: Option<&Bound<'py, PyDict>>,
938        port_ref: Option<&PythonPortRef>,
939        selection: Selection,
940        _instance: &Instance<PythonActor>,
941    ) -> PyResult<()> {
942        let send_kwargs = PyDict::new(py);
943        match port_ref {
944            Some(pr) => send_kwargs.set_item("port", pr.clone())?,
945            None => send_kwargs.set_item("port", py.None())?,
946        }
947
948        let selection_str = match selection {
949            Selection::All(inner) if matches!(*inner, Selection::True) => "all",
950            Selection::Any(inner) if matches!(*inner, Selection::True) => "choose",
951            _ => {
952                panic!("only sel!(*) and sel!(?) should be provided as selection for send_message")
953            }
954        };
955
956        send_kwargs.set_item("selection", selection_str)?;
957
958        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
959        self.inner
960            .call_method(py, "_send", (args.clone(), kwargs_dict), Some(&send_kwargs))?;
961
962        Ok(())
963    }
964
965    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>> {
966        None // Remote endpoints don't have supervision_monitors
967    }
968
969    fn get_qualified_name(&self) -> Option<String> {
970        None // Remote endpoints don't have qualified names
971    }
972}
973
974#[pymethods]
975impl Remote {
976    /// Create a new Remote wrapping a Python RemoteImpl object.
977    #[new]
978    fn new(remote: Py<PyAny>) -> Self {
979        Self { inner: remote }
980    }
981
982    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
983    #[pyo3(signature = (*args, **kwargs), name = "call")]
984    fn py_call<'py>(
985        &self,
986        py: Python<'py>,
987        args: &Bound<'py, PyTuple>,
988        kwargs: Option<&Bound<'py, PyDict>>,
989    ) -> PyResult<Py<PyAny>> {
990        self.call(py, args, kwargs)
991    }
992
993    /// Load balanced sends a message to one chosen actor and awaits a result.
994    #[pyo3(signature = (*args, **kwargs), name = "choose")]
995    fn py_choose<'py>(
996        &self,
997        py: Python<'py>,
998        args: &Bound<'py, PyTuple>,
999        kwargs: Option<&Bound<'py, PyDict>>,
1000    ) -> PyResult<Py<PyAny>> {
1001        self.choose(py, args, kwargs)
1002    }
1003
1004    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
1005    #[pyo3(signature = (*args, **kwargs), name = "call_one")]
1006    fn py_call_one<'py>(
1007        &self,
1008        py: Python<'py>,
1009        args: &Bound<'py, PyTuple>,
1010        kwargs: Option<&Bound<'py, PyDict>>,
1011    ) -> PyResult<Py<PyAny>> {
1012        self.call_one(py, args, kwargs)
1013    }
1014
1015    /// Call the endpoint on all actors and return an iterator of Futures.
1016    #[pyo3(signature = (*args, **kwargs), name = "stream")]
1017    fn py_stream<'py>(
1018        &self,
1019        py: Python<'py>,
1020        args: &Bound<'py, PyTuple>,
1021        kwargs: Option<&Bound<'py, PyDict>>,
1022    ) -> PyResult<Py<PyAny>> {
1023        self.stream(py, args, kwargs)
1024    }
1025
1026    /// Send a message to all actors without waiting for responses (fire-and-forget).
1027    #[pyo3(signature = (*args, **kwargs), name = "broadcast")]
1028    fn py_broadcast<'py>(
1029        &self,
1030        py: Python<'py>,
1031        args: &Bound<'py, PyTuple>,
1032        kwargs: Option<&Bound<'py, PyDict>>,
1033    ) -> PyResult<()> {
1034        self.broadcast(py, args, kwargs)
1035    }
1036
1037    /// Get the rref result by calling the wrapped Remote's rref method.
1038    #[pyo3(signature = (*args, **kwargs))]
1039    fn rref<'py>(
1040        &self,
1041        py: Python<'py>,
1042        args: &Bound<'py, PyTuple>,
1043        kwargs: Option<&Bound<'py, PyDict>>,
1044    ) -> PyResult<Py<PyAny>> {
1045        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
1046        self.inner.call_method(py, "rref", args, Some(&kwargs_dict))
1047    }
1048
1049    /// Get the call name by delegating to the wrapped Remote's _call_name.
1050    fn _call_name(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1051        self.inner.call_method0(py, "_call_name")
1052    }
1053
1054    /// Get the maybe_resolvable property from the wrapped RemoteImpl.
1055    #[getter]
1056    fn _maybe_resolvable(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1057        self.inner.getattr(py, "_maybe_resolvable")
1058    }
1059
1060    /// Get the resolvable property from the wrapped RemoteImpl.
1061    #[getter]
1062    fn _resolvable(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1063        self.inner.getattr(py, "_resolvable")
1064    }
1065
1066    /// Get the remote_impl from the wrapped RemoteImpl.
1067    /// This is needed for function_to_import_path() in function.py to work correctly.
1068    #[getter]
1069    fn _remote_impl(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1070        self.inner.getattr(py, "_remote_impl")
1071    }
1072
1073    /// Propagation method for tensor shape inference.
1074    /// Delegates to the wrapped Remote's _propagate.
1075    fn _propagate<'py>(
1076        &self,
1077        py: Python<'py>,
1078        args: &Bound<'py, PyAny>,
1079        kwargs: &Bound<'py, PyAny>,
1080        fake_args: &Bound<'py, PyAny>,
1081        fake_kwargs: &Bound<'py, PyAny>,
1082    ) -> PyResult<Py<PyAny>> {
1083        self.inner
1084            .call_method1(py, "_propagate", (args, kwargs, fake_args, fake_kwargs))
1085    }
1086
1087    /// Propagation for fetch operations.
1088    /// Delegates to the wrapped Remote's _fetch_propagate.
1089    fn _fetch_propagate<'py>(
1090        &self,
1091        py: Python<'py>,
1092        args: &Bound<'py, PyAny>,
1093        kwargs: &Bound<'py, PyAny>,
1094        fake_args: &Bound<'py, PyAny>,
1095        fake_kwargs: &Bound<'py, PyAny>,
1096    ) -> PyResult<Py<PyAny>> {
1097        self.inner.call_method1(
1098            py,
1099            "_fetch_propagate",
1100            (args, kwargs, fake_args, fake_kwargs),
1101        )
1102    }
1103
1104    /// Propagation for pipe operations.
1105    /// Delegates to the wrapped Remote's _pipe_propagate.
1106    fn _pipe_propagate<'py>(
1107        &self,
1108        py: Python<'py>,
1109        args: &Bound<'py, PyAny>,
1110        kwargs: &Bound<'py, PyAny>,
1111        fake_args: &Bound<'py, PyAny>,
1112        fake_kwargs: &Bound<'py, PyAny>,
1113    ) -> PyResult<Py<PyAny>> {
1114        self.inner.call_method1(
1115            py,
1116            "_pipe_propagate",
1117            (args, kwargs, fake_args, fake_kwargs),
1118        )
1119    }
1120
1121    /// Send a message with optional port for response.
1122    /// Delegates to the wrapped RemoteImpl's _send.
1123    fn _send<'py>(
1124        &self,
1125        py: Python<'py>,
1126        args: &Bound<'py, PyTuple>,
1127        kwargs: &Bound<'py, PyDict>,
1128        port: Option<Py<PyAny>>,
1129        selection: &str,
1130    ) -> PyResult<()> {
1131        self.inner.call_method(
1132            py,
1133            "_send",
1134            (args, kwargs),
1135            Some(&{
1136                let d = PyDict::new(py);
1137                d.set_item("port", port.unwrap_or_else(|| py.None()))?;
1138                d.set_item("selection", selection)?;
1139                d
1140            }),
1141        )?;
1142        Ok(())
1143    }
1144
1145    /// Make RemoteEndpoint callable - delegates to rref() like Remote.__call__.
1146    #[pyo3(signature = (*args, **kwargs))]
1147    fn __call__<'py>(
1148        &self,
1149        py: Python<'py>,
1150        args: &Bound<'py, PyTuple>,
1151        kwargs: Option<&Bound<'py, PyDict>>,
1152    ) -> PyResult<Py<PyAny>> {
1153        self.rref(py, args, kwargs)
1154    }
1155}
1156
1157pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
1158    module.add_class::<PyValueStream>()?;
1159    module.add_class::<ActorEndpoint>()?;
1160    module.add_class::<Remote>()?;
1161
1162    Ok(())
1163}