monarch_hyperactor/
endpoint.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::Cell;
10use std::sync::Arc;
11use std::sync::atomic::AtomicUsize;
12use std::sync::atomic::Ordering;
13
14use hyperactor::Instance;
15use hyperactor::accum::Accumulator;
16use hyperactor::accum::CommReducer;
17use hyperactor::accum::ReducerFactory;
18use hyperactor::accum::ReducerSpec;
19use hyperactor::mailbox::OncePortReceiver;
20use hyperactor::mailbox::PortReceiver;
21use hyperactor_mesh::sel;
22use hyperactor_mesh::value_mesh::ValueOverlay;
23use hyperactor_mesh::value_mesh::rle;
24use monarch_types::py_global;
25use ndslice::Extent;
26use ndslice::Selection;
27use ndslice::Shape;
28use pyo3::prelude::*;
29use pyo3::types::PyDict;
30use pyo3::types::PyTuple;
31use serde_multipart::Part;
32use typeuri::Named;
33
34use crate::actor::MethodSpecifier;
35use crate::actor::PythonActor;
36use crate::actor::PythonMessage;
37use crate::actor::PythonMessageKind;
38use crate::actor::PythonResponseMessage;
39use crate::actor_mesh::PythonActorMesh;
40use crate::actor_mesh::SupervisableActorMesh;
41use crate::actor_mesh::to_hy_sel;
42use crate::buffers::FrozenBuffer;
43use crate::context::PyInstance;
44use crate::mailbox::EitherPortRef;
45use crate::mailbox::PythonOncePortRef;
46use crate::mailbox::PythonPortRef;
47use crate::metrics::ENDPOINT_BROADCAST_ERROR;
48use crate::metrics::ENDPOINT_BROADCAST_THROUGHPUT;
49use crate::metrics::ENDPOINT_CALL_ERROR;
50use crate::metrics::ENDPOINT_CALL_LATENCY_US_HISTOGRAM;
51use crate::metrics::ENDPOINT_CALL_ONE_ERROR;
52use crate::metrics::ENDPOINT_CALL_ONE_LATENCY_US_HISTOGRAM;
53use crate::metrics::ENDPOINT_CALL_ONE_THROUGHPUT;
54use crate::metrics::ENDPOINT_CALL_THROUGHPUT;
55use crate::metrics::ENDPOINT_CHOOSE_ERROR;
56use crate::metrics::ENDPOINT_CHOOSE_LATENCY_US_HISTOGRAM;
57use crate::metrics::ENDPOINT_CHOOSE_THROUGHPUT;
58use crate::metrics::ENDPOINT_STREAM_ERROR;
59use crate::metrics::ENDPOINT_STREAM_LATENCY_US_HISTOGRAM;
60use crate::metrics::ENDPOINT_STREAM_THROUGHPUT;
61use crate::pickle::PendingMessage;
62use crate::pickle::unpickle;
63use crate::pytokio::PyPythonTask;
64use crate::pytokio::PythonTask;
65use crate::shape::PyExtent;
66use crate::shape::PyShape;
67use crate::supervision::Supervisable;
68use crate::supervision::SupervisionError;
69use crate::value_mesh::PyValueMesh;
70
71py_global!(get_context, "monarch._src.actor.actor_mesh", "context");
72py_global!(
73    create_endpoint_message,
74    "monarch._src.actor.actor_mesh",
75    "_create_endpoint_message"
76);
77py_global!(
78    dispatch_actor_rref,
79    "monarch._src.actor.actor_mesh",
80    "_dispatch_actor_rref"
81);
82py_global!(make_future, "monarch._src.actor.future", "Future");
83
84fn unpickle_from_part<'py>(py: Python<'py>, part: Part) -> PyResult<Bound<'py, PyAny>> {
85    unpickle(
86        py,
87        FrozenBuffer {
88            inner: part.into_bytes(),
89        },
90    )
91}
92
93/// The type of endpoint operation being performed.
94///
95/// Used to select the appropriate telemetry metrics for each operation type.
96#[derive(Clone, Copy, Debug)]
97enum EndpointAdverb {
98    Call,
99    CallOne,
100    Choose,
101    Stream,
102}
103
104/// RAII guard for recording endpoint call telemetry.
105///
106/// Records latency on drop, similar to Python's `@_with_telemetry` decorator.
107/// Call `mark_error()` before dropping to also record an error.
108pub struct RecordEndpointGuard {
109    start: tokio::time::Instant,
110    method_name: String,
111    actor_count: usize,
112    adverb: EndpointAdverb,
113    error_occurred: Cell<bool>,
114}
115
116impl RecordEndpointGuard {
117    fn new(
118        start: tokio::time::Instant,
119        method_name: String,
120        actor_count: usize,
121        adverb: EndpointAdverb,
122    ) -> Self {
123        let attributes = hyperactor_telemetry::kv_pairs!(
124            "method" => method_name.clone()
125        );
126        match adverb {
127            EndpointAdverb::Call => {
128                ENDPOINT_CALL_THROUGHPUT.add(1, attributes);
129            }
130            EndpointAdverb::CallOne => {
131                ENDPOINT_CALL_ONE_THROUGHPUT.add(1, attributes);
132            }
133            EndpointAdverb::Choose => {
134                ENDPOINT_CHOOSE_THROUGHPUT.add(1, attributes);
135            }
136            EndpointAdverb::Stream => {
137                // Throughput already recorded once at stream creation in py_stream_collector
138            }
139        }
140
141        Self {
142            start,
143            method_name,
144            actor_count,
145            adverb,
146            error_occurred: Cell::new(false),
147        }
148    }
149
150    fn mark_error(&self) {
151        self.error_occurred.set(true);
152    }
153}
154
155impl Drop for RecordEndpointGuard {
156    fn drop(&mut self) {
157        let actor_count_str = self.actor_count.to_string();
158        let attributes = hyperactor_telemetry::kv_pairs!(
159            "method" => self.method_name.clone(),
160            "actor_count" => actor_count_str
161        );
162        tracing::info!(message = "response received", method = self.method_name);
163
164        let duration_us = self.start.elapsed().as_micros();
165
166        match self.adverb {
167            EndpointAdverb::Call => {
168                ENDPOINT_CALL_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
169            }
170            EndpointAdverb::CallOne => {
171                ENDPOINT_CALL_ONE_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
172            }
173            EndpointAdverb::Choose => {
174                ENDPOINT_CHOOSE_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
175            }
176            EndpointAdverb::Stream => {
177                ENDPOINT_STREAM_LATENCY_US_HISTOGRAM.record(duration_us as f64, attributes);
178            }
179        }
180
181        if self.error_occurred.get() {
182            match self.adverb {
183                EndpointAdverb::Call => {
184                    ENDPOINT_CALL_ERROR.add(1, attributes);
185                }
186                EndpointAdverb::CallOne => {
187                    ENDPOINT_CALL_ONE_ERROR.add(1, attributes);
188                }
189                EndpointAdverb::Choose => {
190                    ENDPOINT_CHOOSE_ERROR.add(1, attributes);
191                }
192                EndpointAdverb::Stream => {
193                    ENDPOINT_STREAM_ERROR.add(1, attributes);
194                }
195            }
196        }
197    }
198}
199
200fn supervision_error_to_pyerr(err: PyErr, qualified_endpoint_name: &Option<String>) -> PyErr {
201    match qualified_endpoint_name {
202        Some(endpoint) => {
203            Python::attach(|py| SupervisionError::set_endpoint_on_err(py, err, endpoint.clone()))
204        }
205        None => err,
206    }
207}
208
209async fn collect_value(
210    rx: &mut PortReceiver<PythonMessage>,
211    supervision_monitor: &Option<Arc<dyn Supervisable>>,
212    instance: &Instance<PythonActor>,
213    qualified_endpoint_name: &Option<String>,
214) -> PyResult<(Part, Option<usize>)> {
215    enum RaceResult {
216        Message(PythonMessage),
217        SupervisionError(PyErr),
218        RecvError(String),
219    }
220
221    let race_result = match supervision_monitor {
222        Some(sup) => {
223            tokio::select! {
224                biased;
225                result = sup.supervision_event(instance) => {
226                    match result {
227                        Some(err) => RaceResult::SupervisionError(err),
228                        None => {
229                            match rx.recv().await {
230                                Ok(msg) => RaceResult::Message(msg),
231                                Err(e) => RaceResult::RecvError(e.to_string()),
232                            }
233                        }
234                    }
235                }
236                msg = rx.recv() => {
237                    match msg {
238                        Ok(m) => RaceResult::Message(m),
239                        Err(e) => RaceResult::RecvError(e.to_string()),
240                    }
241                }
242            }
243        }
244        _ => match rx.recv().await {
245            Ok(msg) => RaceResult::Message(msg),
246            Err(e) => RaceResult::RecvError(e.to_string()),
247        },
248    };
249
250    match race_result {
251        RaceResult::Message(PythonMessage {
252            kind: PythonMessageKind::Result { rank, .. },
253            message,
254            ..
255        }) => Ok((message, rank)),
256        RaceResult::Message(PythonMessage {
257            kind: PythonMessageKind::Exception { .. },
258            message,
259            ..
260        }) => Python::attach(|py| Err(PyErr::from_value(unpickle_from_part(py, message)?))),
261        RaceResult::Message(msg) => Err(pyo3::exceptions::PyValueError::new_err(format!(
262            "unexpected message kind {:?}",
263            msg.kind
264        ))),
265        RaceResult::RecvError(e) => Err(pyo3::exceptions::PyEOFError::new_err(format!(
266            "Port closed: {}",
267            e
268        ))),
269        RaceResult::SupervisionError(err) => {
270            Err(supervision_error_to_pyerr(err, qualified_endpoint_name))
271        }
272    }
273}
274
275async fn collect_valuemesh(
276    extent: Extent,
277    rx: OncePortReceiver<PythonMessage>,
278    method_name: String,
279    supervision_monitor: Option<Arc<dyn Supervisable>>,
280    instance: &Instance<PythonActor>,
281    qualified_endpoint_name: Option<String>,
282) -> PyResult<Py<PyAny>> {
283    let start = tokio::time::Instant::now();
284
285    let expected_count = extent.num_ranks();
286
287    let record_guard = RecordEndpointGuard::new(
288        start,
289        method_name.clone(),
290        expected_count,
291        EndpointAdverb::Call,
292    );
293
294    enum RaceResult {
295        Collected(PythonMessage),
296        SupervisionError(PyErr),
297        RecvError(String),
298    }
299
300    let race_result = match &supervision_monitor {
301        Some(sup) => {
302            tokio::select! {
303                biased;
304                result = sup.supervision_event(instance) => {
305                    match result {
306                        Some(err) => RaceResult::SupervisionError(err),
307                        None => RaceResult::RecvError(
308                            "supervision monitor closed unexpectedly".to_string()
309                        ),
310                    }
311                }
312                batch = rx.recv() => {
313                    match batch {
314                        Ok(b) => RaceResult::Collected(b),
315                        Err(e) => RaceResult::RecvError(e.to_string()),
316                    }
317                }
318            }
319        }
320        None => match rx.recv().await {
321            Ok(batch) => RaceResult::Collected(batch),
322            Err(e) => RaceResult::RecvError(e.to_string()),
323        },
324    };
325
326    match race_result {
327        RaceResult::Collected(msg) => {
328            let overlay = msg.into_overlay().map_err(|e| {
329                pyo3::exceptions::PyRuntimeError::new_err(format!(
330                    "failed to extract overlay from collected responses: {e}"
331                ))
332            })?;
333            Python::attach(|py| {
334                Ok(PyValueMesh::build_from_parts(
335                    &extent,
336                    overlay.runs().try_fold(
337                        Vec::with_capacity(expected_count),
338                        |mut parts, (range, payload)| match payload {
339                            PythonResponseMessage::Result(part) => {
340                                parts.extend(range.clone().map(|_| part.clone()));
341                                Ok(parts)
342                            }
343                            PythonResponseMessage::Exception(part) => {
344                                record_guard.mark_error();
345                                Python::attach(|py| {
346                                    Err(PyErr::from_value(unpickle_from_part(py, part.clone())?))
347                                })
348                            }
349                        },
350                    )?,
351                )?
352                .into_pyobject(py)?
353                .into_any()
354                .unbind())
355            })
356        }
357        RaceResult::RecvError(e) => {
358            record_guard.mark_error();
359            Err(pyo3::exceptions::PyEOFError::new_err(format!(
360                "Port closed: {}",
361                e
362            )))
363        }
364        RaceResult::SupervisionError(err) => {
365            record_guard.mark_error();
366            Err(supervision_error_to_pyerr(err, &qualified_endpoint_name))
367        }
368    }
369}
370
371fn value_collector(
372    mut receiver: PortReceiver<PythonMessage>,
373    method_name: String,
374    supervision_monitor: Option<Arc<dyn Supervisable>>,
375    instance: Instance<PythonActor>,
376    qualified_endpoint_name: Option<String>,
377    adverb: EndpointAdverb,
378) -> PyResult<PyPythonTask> {
379    Ok(PythonTask::new(async move {
380        let start = tokio::time::Instant::now();
381
382        let record_guard = RecordEndpointGuard::new(start, method_name, 1, adverb);
383
384        match collect_value(
385            &mut receiver,
386            &supervision_monitor,
387            &instance,
388            &qualified_endpoint_name,
389        )
390        .await
391        {
392            Ok((message, _)) => {
393                Python::attach(|py| unpickle_from_part(py, message).map(|obj| obj.unbind()))
394            }
395            Err(e) => {
396                record_guard.mark_error();
397                Err(e)
398            }
399        }
400    })?
401    .into())
402}
403
404/// A streaming iterator that yields futures for each response from actors.
405///
406/// Implements Python's iterator protocol (`__iter__`/`__next__`) to yield
407/// `Future` objects that resolve to individual actor responses.
408#[pyclass(
409    name = "ValueStream",
410    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
411)]
412pub struct PyValueStream {
413    receiver: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
414    /// Supervisor for monitoring actor health during streaming.
415    supervision_monitor: Option<Arc<dyn Supervisable>>,
416    instance: Instance<PythonActor>,
417    remaining: AtomicUsize,
418    method_name: String,
419    qualified_endpoint_name: Option<String>,
420    start: tokio::time::Instant,
421    actor_count: usize,
422    future_class: Py<PyAny>,
423}
424
425#[pymethods]
426impl PyValueStream {
427    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
428        slf
429    }
430
431    fn __next__(&self, py: Python<'_>) -> PyResult<Option<Py<PyAny>>> {
432        let remaining = self.remaining.load(Ordering::Relaxed);
433        if remaining == 0 {
434            return Ok(None);
435        }
436        self.remaining.store(remaining - 1, Ordering::Relaxed);
437
438        let receiver = self.receiver.clone();
439        let supervision_monitor = self.supervision_monitor.clone();
440        let instance = self.instance.clone_for_py();
441        let qualified_endpoint_name = self.qualified_endpoint_name.clone();
442        let start = self.start;
443        let method_name = self.method_name.clone();
444        let actor_count = self.actor_count;
445
446        let task: PyPythonTask = PythonTask::new(async move {
447            let record_guard =
448                RecordEndpointGuard::new(start, method_name, actor_count, EndpointAdverb::Stream);
449
450            let mut rx_guard = receiver.lock().await;
451
452            match collect_value(
453                &mut rx_guard,
454                &supervision_monitor,
455                &instance,
456                &qualified_endpoint_name,
457            )
458            .await
459            {
460                Ok((message, _)) => {
461                    Python::attach(|py| unpickle_from_part(py, message).map(|obj| obj.unbind()))
462                }
463                Err(e) => {
464                    record_guard.mark_error();
465                    Err(e)
466                }
467            }
468        })?
469        .into();
470
471        let kwargs = PyDict::new(py);
472        kwargs.set_item("coro", task)?;
473        let future = self.future_class.call(py, (), Some(&kwargs))?;
474        Ok(Some(future))
475    }
476}
477
478fn wrap_in_future(py: Python<'_>, task: PyPythonTask) -> PyResult<Py<PyAny>> {
479    let kwargs = PyDict::new(py);
480    kwargs.set_item("coro", task)?;
481    let future = make_future(py).call((), Some(&kwargs))?;
482    Ok(future.unbind())
483}
484
485/// Trait that defines the core operations an endpoint must provide.
486/// Both ActorEndpoint and RemoteEndpoint implement this trait.
487pub(crate) trait Endpoint {
488    /// Get the extent of the endpoint's targets.
489    fn get_extent(&self, py: Python<'_>) -> PyResult<Extent>;
490
491    /// Get the method name for this endpoint.
492    fn get_method_name(&self) -> &str;
493
494    /// Create and send a message with the given args/kwargs.
495    fn send_message<'py>(
496        &self,
497        py: Python<'py>,
498        args: &Bound<'py, PyTuple>,
499        kwargs: Option<&Bound<'py, PyDict>>,
500        port_ref: Option<EitherPortRef>,
501        selection: Selection,
502        instance: &Instance<PythonActor>,
503    ) -> PyResult<()>;
504
505    /// Get the supervision_monitor for this endpoint (if any).
506    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>>;
507
508    /// Get the qualified endpoint name for error messages (if any).
509    fn get_qualified_name(&self) -> Option<String>;
510
511    fn get_current_instance(&self, py: Python<'_>) -> PyResult<Instance<PythonActor>> {
512        let context = get_context(py).call0()?;
513        let py_instance: PyRef<PyInstance> = context.getattr("actor_instance")?.extract()?;
514        Ok(py_instance.clone().into_instance())
515    }
516
517    fn open_response_port(
518        &self,
519        instance: &Instance<PythonActor>,
520    ) -> (PythonPortRef, PortReceiver<PythonMessage>) {
521        let (p, receiver) = instance.mailbox_for_py().open_port::<PythonMessage>();
522        (PythonPortRef { inner: p.bind() }, receiver)
523    }
524
525    fn open_reduce_response_port(
526        &self,
527        instance: &Instance<PythonActor>,
528    ) -> (PythonOncePortRef, OncePortReceiver<PythonMessage>) {
529        let (p, receiver) = instance
530            .mailbox_for_py()
531            .open_reduce_port(PythonResponseMessageAccumulator);
532        (PythonOncePortRef::from(p.bind()), receiver)
533    }
534
535    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
536    fn call<'py>(
537        &self,
538        py: Python<'py>,
539        args: &Bound<'py, PyTuple>,
540        kwargs: Option<&Bound<'py, PyDict>>,
541    ) -> PyResult<Py<PyAny>> {
542        let extent = self.get_extent(py)?;
543        let method_name = self.get_method_name().to_string();
544
545        let instance = self.get_current_instance(py)?;
546        let (port_ref, receiver) = self.open_reduce_response_port(&instance);
547
548        let supervision_monitor = self.get_supervision_monitor();
549        let qualified_endpoint_name = self.get_qualified_name();
550
551        self.send_message(
552            py,
553            args,
554            kwargs,
555            Some(EitherPortRef::Once(port_ref)),
556            sel!(*),
557            &instance,
558        )?;
559
560        let instance_for_task = instance.clone_for_py();
561        let task: PyPythonTask = PythonTask::new(async move {
562            collect_valuemesh(
563                extent,
564                receiver,
565                method_name,
566                supervision_monitor,
567                &instance_for_task,
568                qualified_endpoint_name,
569            )
570            .await
571        })?
572        .into();
573
574        wrap_in_future(py, task)
575    }
576
577    /// Load balanced sends a message to one chosen actor and awaits a result.
578    fn choose<'py>(
579        &self,
580        py: Python<'py>,
581        args: &Bound<'py, PyTuple>,
582        kwargs: Option<&Bound<'py, PyDict>>,
583    ) -> PyResult<Py<PyAny>> {
584        let method_name = self.get_method_name();
585
586        let instance = self.get_current_instance(py)?;
587        let (port_ref, receiver) = self.open_response_port(&instance);
588
589        self.send_message(
590            py,
591            args,
592            kwargs,
593            Some(EitherPortRef::Unbounded(port_ref)),
594            sel!(?),
595            &instance,
596        )?;
597
598        let task = value_collector(
599            receiver,
600            method_name.to_string(),
601            self.get_supervision_monitor(),
602            instance.clone_for_py(),
603            self.get_qualified_name(),
604            EndpointAdverb::Choose,
605        )?;
606
607        wrap_in_future(py, task)
608    }
609
610    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
611    fn call_one<'py>(
612        &self,
613        py: Python<'py>,
614        args: &Bound<'py, PyTuple>,
615        kwargs: Option<&Bound<'py, PyDict>>,
616    ) -> PyResult<Py<PyAny>> {
617        let extent = self.get_extent(py)?;
618        let method_name = self.get_method_name();
619
620        if extent.num_ranks() != 1 {
621            return Err(pyo3::exceptions::PyValueError::new_err(format!(
622                "call_one requires exactly 1 actor, but mesh has {}",
623                extent.num_ranks()
624            )));
625        }
626
627        let instance = self.get_current_instance(py)?;
628        let (port_ref, receiver) = self.open_response_port(&instance);
629
630        self.send_message(
631            py,
632            args,
633            kwargs,
634            Some(EitherPortRef::Unbounded(port_ref)),
635            sel!(*),
636            &instance,
637        )?;
638
639        let task = value_collector(
640            receiver,
641            method_name.to_string(),
642            self.get_supervision_monitor(),
643            instance.clone_for_py(),
644            self.get_qualified_name(),
645            EndpointAdverb::CallOne,
646        )?;
647
648        wrap_in_future(py, task)
649    }
650
651    /// Call the endpoint on all actors and return an iterator of Futures.
652    fn stream<'py>(
653        &self,
654        py: Python<'py>,
655        args: &Bound<'py, PyTuple>,
656        kwargs: Option<&Bound<'py, PyDict>>,
657    ) -> PyResult<Py<PyAny>> {
658        let extent = self.get_extent(py)?;
659        let method_name = self.get_method_name().to_string();
660
661        let instance = self.get_current_instance(py)?;
662        let (port_ref, receiver) = self.open_response_port(&instance);
663
664        self.send_message(
665            py,
666            args,
667            kwargs,
668            Some(EitherPortRef::Unbounded(port_ref)),
669            sel!(*),
670            &instance,
671        )?;
672
673        let actor_count = extent.num_ranks();
674        let start = tokio::time::Instant::now();
675        let supervision_monitor = self.get_supervision_monitor();
676        let qualified_endpoint_name = self.get_qualified_name();
677        let future_class = make_future(py).unbind();
678
679        let attributes = hyperactor_telemetry::kv_pairs!(
680            "method" => method_name.clone()
681        );
682        ENDPOINT_STREAM_THROUGHPUT.add(1, attributes);
683
684        let stream = PyValueStream {
685            receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
686            supervision_monitor,
687            instance: instance.clone_for_py(),
688            remaining: AtomicUsize::new(actor_count),
689            method_name,
690            qualified_endpoint_name,
691            start,
692            actor_count,
693            future_class,
694        };
695
696        Ok(stream.into_pyobject(py)?.unbind().into())
697    }
698
699    /// Send a message to all actors without waiting for responses (fire-and-forget).
700    fn broadcast<'py>(
701        &self,
702        py: Python<'py>,
703        args: &Bound<'py, PyTuple>,
704        kwargs: Option<&Bound<'py, PyDict>>,
705    ) -> PyResult<()> {
706        let instance = self.get_current_instance(py)?;
707        let method_name = self.get_method_name();
708        let attributes = hyperactor_telemetry::kv_pairs!(
709            "method" => method_name.to_string()
710        );
711
712        match self.send_message(py, args, kwargs, None, sel!(*), &instance) {
713            Ok(()) => {
714                ENDPOINT_BROADCAST_THROUGHPUT.add(1, attributes);
715                Ok(())
716            }
717            Err(e) => {
718                ENDPOINT_BROADCAST_ERROR.add(1, attributes);
719                Err(e)
720            }
721        }
722    }
723}
724
725#[pyclass(
726    name = "ActorEndpoint",
727    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
728)]
729pub struct ActorEndpoint {
730    inner: Arc<dyn SupervisableActorMesh>,
731    shape: Shape,
732    method: MethodSpecifier,
733    mesh_name: String,
734    signature: Option<Py<PyAny>>,
735    proc_mesh: Option<Py<PyAny>>,
736    propagator: Option<Py<PyAny>>,
737}
738
739impl ActorEndpoint {
740    fn create_message<'py>(
741        &self,
742        py: Python<'py>,
743        args: &Bound<'py, PyTuple>,
744        kwargs: Option<&Bound<'py, PyDict>>,
745        port_ref: Option<EitherPortRef>,
746    ) -> PyResult<PendingMessage> {
747        let port_ref_py: Py<PyAny> = match port_ref {
748            Some(pr) => pr.clone().into_pyobject(py)?.unbind(),
749            None => py.None(),
750        };
751
752        let result = create_endpoint_message(py).call1((
753            self.method.clone(),
754            self.signature
755                .as_ref()
756                .map_or_else(|| py.None(), |s| s.clone_ref(py)),
757            args,
758            kwargs
759                .map_or_else(|| PyDict::new(py), |d| d.clone())
760                .into_any(),
761            port_ref_py,
762            self.proc_mesh
763                .as_ref()
764                .map_or_else(|| py.None(), |p| p.clone_ref(py)),
765        ))?;
766        let mut pending: PyRefMut<'_, PendingMessage> = result.extract()?;
767        pending.take()
768    }
769}
770
771impl Endpoint for ActorEndpoint {
772    fn get_extent(&self, _py: Python<'_>) -> PyResult<Extent> {
773        Ok(self.shape.extent())
774    }
775
776    fn get_method_name(&self) -> &str {
777        self.method.name()
778    }
779
780    fn send_message<'py>(
781        &self,
782        py: Python<'py>,
783        args: &Bound<'py, PyTuple>,
784        kwargs: Option<&Bound<'py, PyDict>>,
785        port_ref: Option<EitherPortRef>,
786        selection: Selection,
787        instance: &Instance<PythonActor>,
788    ) -> PyResult<()> {
789        let message = self.create_message(py, args, kwargs, port_ref)?;
790        self.inner.cast_unresolved(message, selection, instance)
791    }
792
793    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>> {
794        Some(self.inner.clone())
795    }
796
797    fn get_qualified_name(&self) -> Option<String> {
798        Some(format!("{}.{}()", self.mesh_name, self.method.name()))
799    }
800}
801
802#[pymethods]
803impl ActorEndpoint {
804    /// Create a new ActorEndpoint.
805    #[new]
806    #[pyo3(signature = (actor_mesh, method, shape, mesh_name, signature=None, proc_mesh=None, propagator=None))]
807    fn new(
808        actor_mesh: PythonActorMesh,
809        method: MethodSpecifier,
810        shape: PyShape,
811        mesh_name: String,
812        signature: Option<Py<PyAny>>,
813        proc_mesh: Option<Py<PyAny>>,
814        propagator: Option<Py<PyAny>>,
815    ) -> Self {
816        Self {
817            inner: actor_mesh.get_inner(),
818            shape: shape.get_inner().clone(),
819            method,
820            mesh_name,
821            signature,
822            proc_mesh,
823            propagator,
824        }
825    }
826
827    /// Get the method specifier (used by actor_rref for tensor dispatch).
828    #[getter]
829    fn _name(&self) -> MethodSpecifier {
830        self.method.clone()
831    }
832
833    /// Get the signature (used for argument checking in _dispatch_actor_rref).
834    #[getter]
835    fn _signature(&self, py: Python<'_>) -> Py<PyAny> {
836        self.signature
837            .clone()
838            .unwrap_or_else(|| py.None().into_any())
839    }
840
841    /// Get the actor mesh (used by actor_rref for sending messages).
842    #[getter]
843    fn _actor_mesh(&self) -> PythonActorMesh {
844        PythonActorMesh::from_impl(self.inner.clone())
845    }
846
847    /// Propagation method for tensor shape inference.
848    /// Delegates to Python _do_propagate helper.
849    fn _propagate<'py>(
850        &self,
851        py: Python<'py>,
852        args: &Bound<'py, PyAny>,
853        kwargs: &Bound<'py, PyAny>,
854        fake_args: &Bound<'py, PyAny>,
855        fake_kwargs: &Bound<'py, PyAny>,
856    ) -> PyResult<Py<PyAny>> {
857        let do_propagate = py
858            .import("monarch._src.actor.endpoint")?
859            .getattr("_do_propagate")?;
860        let propagator = self
861            .propagator
862            .as_ref()
863            .map(|p| p.clone_ref(py).into_bound(py))
864            .unwrap_or_else(|| py.None().into_bound(py));
865        let cache = PyDict::new(py);
866        do_propagate
867            .call1((&propagator, args, kwargs, fake_args, fake_kwargs, cache))?
868            .extract()
869    }
870
871    /// Propagation for fetch operations.
872    /// Returns None if no propagator is provided, otherwise calls _propagate.
873    fn _fetch_propagate<'py>(
874        &self,
875        py: Python<'py>,
876        args: &Bound<'py, PyAny>,
877        kwargs: &Bound<'py, PyAny>,
878        fake_args: &Bound<'py, PyAny>,
879        fake_kwargs: &Bound<'py, PyAny>,
880    ) -> PyResult<Py<PyAny>> {
881        if self.propagator.is_none() {
882            return Ok(py.None());
883        }
884        self._propagate(py, args, kwargs, fake_args, fake_kwargs)
885    }
886
887    /// Propagation for pipe operations.
888    /// Requires an explicit callable propagator.
889    fn _pipe_propagate<'py>(
890        &self,
891        py: Python<'py>,
892        args: &Bound<'py, PyAny>,
893        kwargs: &Bound<'py, PyAny>,
894        fake_args: &Bound<'py, PyAny>,
895        fake_kwargs: &Bound<'py, PyAny>,
896    ) -> PyResult<Py<PyAny>> {
897        // Check if propagator is callable
898        let is_callable = self
899            .propagator
900            .as_ref()
901            .map(|p| p.bind(py).is_callable())
902            .unwrap_or(false);
903        if !is_callable {
904            return Err(pyo3::exceptions::PyValueError::new_err(
905                "Must specify explicit callable for pipe",
906            ));
907        }
908        self._propagate(py, args, kwargs, fake_args, fake_kwargs)
909    }
910
911    /// Get the rref result by calling the Python dispatch helper.
912    #[pyo3(signature = (*args, **kwargs))]
913    fn rref<'py>(
914        slf: PyRef<'py, Self>,
915        py: Python<'py>,
916        args: &Bound<'py, PyTuple>,
917        kwargs: Option<&Bound<'py, PyDict>>,
918    ) -> PyResult<Py<PyAny>> {
919        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
920
921        // Call _dispatch_actor_rref(endpoint, args, kwargs)
922        let result = dispatch_actor_rref(py).call1((slf.into_pyobject(py)?, args, kwargs_dict))?;
923
924        Ok(result.unbind())
925    }
926
927    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
928    #[pyo3(signature = (*args, **kwargs), name = "call")]
929    fn py_call<'py>(
930        &self,
931        py: Python<'py>,
932        args: &Bound<'py, PyTuple>,
933        kwargs: Option<&Bound<'py, PyDict>>,
934    ) -> PyResult<Py<PyAny>> {
935        self.call(py, args, kwargs)
936    }
937
938    /// Load balanced sends a message to one chosen actor and awaits a result.
939    #[pyo3(signature = (*args, **kwargs), name = "choose")]
940    fn py_choose<'py>(
941        &self,
942        py: Python<'py>,
943        args: &Bound<'py, PyTuple>,
944        kwargs: Option<&Bound<'py, PyDict>>,
945    ) -> PyResult<Py<PyAny>> {
946        self.choose(py, args, kwargs)
947    }
948
949    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
950    #[pyo3(signature = (*args, **kwargs), name = "call_one")]
951    fn py_call_one<'py>(
952        &self,
953        py: Python<'py>,
954        args: &Bound<'py, PyTuple>,
955        kwargs: Option<&Bound<'py, PyDict>>,
956    ) -> PyResult<Py<PyAny>> {
957        self.call_one(py, args, kwargs)
958    }
959
960    /// Call the endpoint on all actors and return an iterator of Futures.
961    #[pyo3(signature = (*args, **kwargs), name = "stream")]
962    fn py_stream<'py>(
963        &self,
964        py: Python<'py>,
965        args: &Bound<'py, PyTuple>,
966        kwargs: Option<&Bound<'py, PyDict>>,
967    ) -> PyResult<Py<PyAny>> {
968        self.stream(py, args, kwargs)
969    }
970
971    /// Send a message to all actors without waiting for responses (fire-and-forget).
972    #[pyo3(signature = (*args, **kwargs), name = "broadcast")]
973    fn py_broadcast<'py>(
974        &self,
975        py: Python<'py>,
976        args: &Bound<'py, PyTuple>,
977        kwargs: Option<&Bound<'py, PyDict>>,
978    ) -> PyResult<()> {
979        self.broadcast(py, args, kwargs)
980    }
981
982    /// Send a message with optional port for response (used by actor_mesh.send).
983    fn _send<'py>(
984        &self,
985        py: Python<'py>,
986        args: &Bound<'py, PyTuple>,
987        kwargs: &Bound<'py, PyDict>,
988        port: Option<EitherPortRef>,
989        selection: &str,
990    ) -> PyResult<()> {
991        let instance = self.get_current_instance(py)?;
992        let sel = to_hy_sel(selection)?;
993        self.send_message(py, args, Some(kwargs), port, sel, &instance)
994    }
995}
996
997/// A Rust wrapper for Python's RemoteImpl endpoint.
998///
999/// This allows us to implement the adverb methods (call, choose, call_one, stream, broadcast)
1000/// in Rust while delegating the actual send logic to the Python RemoteImpl._send() method.
1001#[pyclass(
1002    name = "Remote",
1003    module = "monarch._rust_bindings.monarch_hyperactor.endpoint"
1004)]
1005pub struct Remote {
1006    /// The wrapped Python RemoteImpl object
1007    inner: Py<PyAny>,
1008}
1009
1010impl Endpoint for Remote {
1011    fn get_extent(&self, py: Python<'_>) -> PyResult<Extent> {
1012        let extent: PyExtent = self.inner.call_method0(py, "_get_extent")?.extract(py)?;
1013        Ok(extent.into())
1014    }
1015
1016    fn get_method_name(&self) -> &str {
1017        "unknown"
1018    }
1019
1020    fn send_message<'py>(
1021        &self,
1022        py: Python<'py>,
1023        args: &Bound<'py, PyTuple>,
1024        kwargs: Option<&Bound<'py, PyDict>>,
1025        port_ref: Option<EitherPortRef>,
1026        selection: Selection,
1027        _instance: &Instance<PythonActor>,
1028    ) -> PyResult<()> {
1029        let send_kwargs = PyDict::new(py);
1030        match port_ref {
1031            Some(pr) => send_kwargs.set_item("port", pr.clone())?,
1032            None => send_kwargs.set_item("port", py.None())?,
1033        }
1034
1035        let selection_str = match selection {
1036            Selection::All(inner) if matches!(*inner, Selection::True) => "all",
1037            Selection::Any(inner) if matches!(*inner, Selection::True) => "choose",
1038            _ => {
1039                panic!("only sel!(*) and sel!(?) should be provided as selection for send_message")
1040            }
1041        };
1042
1043        send_kwargs.set_item("selection", selection_str)?;
1044
1045        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
1046        self.inner
1047            .call_method(py, "_send", (args.clone(), kwargs_dict), Some(&send_kwargs))?;
1048
1049        Ok(())
1050    }
1051
1052    fn get_supervision_monitor(&self) -> Option<Arc<dyn Supervisable>> {
1053        None // Remote endpoints don't have supervision_monitors
1054    }
1055
1056    fn get_qualified_name(&self) -> Option<String> {
1057        None // Remote endpoints don't have qualified names
1058    }
1059}
1060
1061#[pymethods]
1062impl Remote {
1063    /// Create a new Remote wrapping a Python RemoteImpl object.
1064    #[new]
1065    fn new(remote: Py<PyAny>) -> Self {
1066        Self { inner: remote }
1067    }
1068
1069    /// Call the endpoint on all actors and collect all responses into a ValueMesh.
1070    #[pyo3(signature = (*args, **kwargs), name = "call")]
1071    fn py_call<'py>(
1072        &self,
1073        py: Python<'py>,
1074        args: &Bound<'py, PyTuple>,
1075        kwargs: Option<&Bound<'py, PyDict>>,
1076    ) -> PyResult<Py<PyAny>> {
1077        self.call(py, args, kwargs)
1078    }
1079
1080    /// Load balanced sends a message to one chosen actor and awaits a result.
1081    #[pyo3(signature = (*args, **kwargs), name = "choose")]
1082    fn py_choose<'py>(
1083        &self,
1084        py: Python<'py>,
1085        args: &Bound<'py, PyTuple>,
1086        kwargs: Option<&Bound<'py, PyDict>>,
1087    ) -> PyResult<Py<PyAny>> {
1088        self.choose(py, args, kwargs)
1089    }
1090
1091    /// Call the endpoint on exactly one actor (the mesh must have exactly one actor).
1092    #[pyo3(signature = (*args, **kwargs), name = "call_one")]
1093    fn py_call_one<'py>(
1094        &self,
1095        py: Python<'py>,
1096        args: &Bound<'py, PyTuple>,
1097        kwargs: Option<&Bound<'py, PyDict>>,
1098    ) -> PyResult<Py<PyAny>> {
1099        self.call_one(py, args, kwargs)
1100    }
1101
1102    /// Call the endpoint on all actors and return an iterator of Futures.
1103    #[pyo3(signature = (*args, **kwargs), name = "stream")]
1104    fn py_stream<'py>(
1105        &self,
1106        py: Python<'py>,
1107        args: &Bound<'py, PyTuple>,
1108        kwargs: Option<&Bound<'py, PyDict>>,
1109    ) -> PyResult<Py<PyAny>> {
1110        self.stream(py, args, kwargs)
1111    }
1112
1113    /// Send a message to all actors without waiting for responses (fire-and-forget).
1114    #[pyo3(signature = (*args, **kwargs), name = "broadcast")]
1115    fn py_broadcast<'py>(
1116        &self,
1117        py: Python<'py>,
1118        args: &Bound<'py, PyTuple>,
1119        kwargs: Option<&Bound<'py, PyDict>>,
1120    ) -> PyResult<()> {
1121        self.broadcast(py, args, kwargs)
1122    }
1123
1124    /// Get the rref result by calling the wrapped Remote's rref method.
1125    #[pyo3(signature = (*args, **kwargs))]
1126    fn rref<'py>(
1127        &self,
1128        py: Python<'py>,
1129        args: &Bound<'py, PyTuple>,
1130        kwargs: Option<&Bound<'py, PyDict>>,
1131    ) -> PyResult<Py<PyAny>> {
1132        let kwargs_dict = kwargs.map_or_else(|| PyDict::new(py), |d| d.clone());
1133        self.inner.call_method(py, "rref", args, Some(&kwargs_dict))
1134    }
1135
1136    /// Get the call name by delegating to the wrapped Remote's _call_name.
1137    fn _call_name(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1138        self.inner.call_method0(py, "_call_name")
1139    }
1140
1141    /// Get the maybe_resolvable property from the wrapped RemoteImpl.
1142    #[getter]
1143    fn _maybe_resolvable(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1144        self.inner.getattr(py, "_maybe_resolvable")
1145    }
1146
1147    /// Get the resolvable property from the wrapped RemoteImpl.
1148    #[getter]
1149    fn _resolvable(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1150        self.inner.getattr(py, "_resolvable")
1151    }
1152
1153    /// Get the remote_impl from the wrapped RemoteImpl.
1154    /// This is needed for function_to_import_path() in function.py to work correctly.
1155    #[getter]
1156    fn _remote_impl(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1157        self.inner.getattr(py, "_remote_impl")
1158    }
1159
1160    /// Propagation method for tensor shape inference.
1161    /// Delegates to the wrapped Remote's _propagate.
1162    fn _propagate<'py>(
1163        &self,
1164        py: Python<'py>,
1165        args: &Bound<'py, PyAny>,
1166        kwargs: &Bound<'py, PyAny>,
1167        fake_args: &Bound<'py, PyAny>,
1168        fake_kwargs: &Bound<'py, PyAny>,
1169    ) -> PyResult<Py<PyAny>> {
1170        self.inner
1171            .call_method1(py, "_propagate", (args, kwargs, fake_args, fake_kwargs))
1172    }
1173
1174    /// Propagation for fetch operations.
1175    /// Delegates to the wrapped Remote's _fetch_propagate.
1176    fn _fetch_propagate<'py>(
1177        &self,
1178        py: Python<'py>,
1179        args: &Bound<'py, PyAny>,
1180        kwargs: &Bound<'py, PyAny>,
1181        fake_args: &Bound<'py, PyAny>,
1182        fake_kwargs: &Bound<'py, PyAny>,
1183    ) -> PyResult<Py<PyAny>> {
1184        self.inner.call_method1(
1185            py,
1186            "_fetch_propagate",
1187            (args, kwargs, fake_args, fake_kwargs),
1188        )
1189    }
1190
1191    /// Propagation for pipe operations.
1192    /// Delegates to the wrapped Remote's _pipe_propagate.
1193    fn _pipe_propagate<'py>(
1194        &self,
1195        py: Python<'py>,
1196        args: &Bound<'py, PyAny>,
1197        kwargs: &Bound<'py, PyAny>,
1198        fake_args: &Bound<'py, PyAny>,
1199        fake_kwargs: &Bound<'py, PyAny>,
1200    ) -> PyResult<Py<PyAny>> {
1201        self.inner.call_method1(
1202            py,
1203            "_pipe_propagate",
1204            (args, kwargs, fake_args, fake_kwargs),
1205        )
1206    }
1207
1208    /// Send a message with optional port for response.
1209    /// Delegates to the wrapped RemoteImpl's _send.
1210    fn _send<'py>(
1211        &self,
1212        py: Python<'py>,
1213        args: &Bound<'py, PyTuple>,
1214        kwargs: &Bound<'py, PyDict>,
1215        port: Option<Py<PyAny>>,
1216        selection: &str,
1217    ) -> PyResult<()> {
1218        self.inner.call_method(
1219            py,
1220            "_send",
1221            (args, kwargs),
1222            Some(&{
1223                let d = PyDict::new(py);
1224                d.set_item("port", port.unwrap_or_else(|| py.None()))?;
1225                d.set_item("selection", selection)?;
1226                d
1227            }),
1228        )?;
1229        Ok(())
1230    }
1231
1232    /// Make RemoteEndpoint callable - delegates to rref() like Remote.__call__.
1233    #[pyo3(signature = (*args, **kwargs))]
1234    fn __call__<'py>(
1235        &self,
1236        py: Python<'py>,
1237        args: &Bound<'py, PyTuple>,
1238        kwargs: Option<&Bound<'py, PyDict>>,
1239    ) -> PyResult<Py<PyAny>> {
1240        self.rref(py, args, kwargs)
1241    }
1242}
1243
1244pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
1245    module.add_class::<PyValueStream>()?;
1246    module.add_class::<ActorEndpoint>()?;
1247    module.add_class::<Remote>()?;
1248
1249    Ok(())
1250}
1251
1252#[derive(Named)]
1253struct PythonResponseMessageReducer;
1254
1255impl CommReducer for PythonResponseMessageReducer {
1256    type Update = PythonMessage;
1257
1258    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
1259        Ok(ValueOverlay::try_from_runs(rle::merge_value_runs(
1260            left.into_overlay()?.into_runs(),
1261            right.into_overlay()?.into_runs(),
1262        ))?
1263        .into())
1264    }
1265}
1266
1267inventory::submit! {
1268    ReducerFactory {
1269        typehash_f: <PythonResponseMessageReducer as Named>::typehash,
1270        builder_f: |_| Ok(Box::new(PythonResponseMessageReducer)),
1271    }
1272}
1273
1274struct PythonResponseMessageAccumulator;
1275
1276impl Accumulator for PythonResponseMessageAccumulator {
1277    type State = PythonMessage;
1278    type Update = PythonMessage;
1279
1280    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
1281        *state = ValueOverlay::try_from_runs(rle::merge_value_runs(
1282            std::mem::take(state).into_overlay()?.into_runs(),
1283            update.into_overlay()?.into_runs(),
1284        ))?
1285        .into();
1286
1287        Ok(())
1288    }
1289
1290    fn reducer_spec(&self) -> Option<ReducerSpec> {
1291        Some(ReducerSpec {
1292            typehash: <PythonResponseMessageReducer as Named>::typehash(),
1293            builder_params: None,
1294        })
1295    }
1296}