Skip to main content

hyperactor/
endpoint.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Generic send endpoints.
10
11use std::fmt;
12
13use hyperactor_config::Flattrs;
14use serde::Deserialize;
15use serde::Serialize;
16
17use crate::ActorAddr;
18use crate::PortAddr;
19use crate::context;
20use crate::mailbox::PortLocation;
21
22/// The logical location of an endpoint.
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, typeuri::Named)]
24pub enum EndpointLocation {
25    /// An actor endpoint.
26    Actor(ActorAddr),
27    /// A port endpoint.
28    Port(PortAddr),
29    /// A local port handle that has not been bound to a routable port.
30    Local {
31        /// The actor that owns the local endpoint.
32        actor: ActorAddr,
33        /// The local endpoint's message type.
34        message_type: String,
35    },
36}
37
38impl EndpointLocation {
39    /// The actor address associated with this endpoint location.
40    pub fn actor_addr(&self) -> ActorAddr {
41        match self {
42            Self::Actor(actor) => actor.clone(),
43            Self::Port(port) => port.actor_addr(),
44            Self::Local { actor, .. } => actor.clone(),
45        }
46    }
47}
48
49impl From<PortLocation> for EndpointLocation {
50    fn from(location: PortLocation) -> Self {
51        match location {
52            PortLocation::Bound(port) => Self::Port(port),
53            PortLocation::Unbound(actor, message_type) => Self::Local {
54                actor,
55                message_type: message_type.to_string(),
56            },
57        }
58    }
59}
60
61impl fmt::Display for EndpointLocation {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            Self::Actor(actor) => write!(f, "{}", actor),
65            Self::Port(port) => write!(f, "{}", port),
66            Self::Local {
67                actor,
68                message_type,
69            } => write!(f, "{}<{}>", actor, message_type),
70        }
71    }
72}
73
74/// A typed endpoint that can receive `M`.
75///
76/// This trait abstracts over local actor handles, local port handles, remote
77/// actor refs, remote port refs, and one-shot ports. It is sealed so that
78/// Hyperactor owns the post semantics for each endpoint kind.
79pub trait Endpoint<M>: crate::private::Sealed {
80    /// The logical location of this endpoint.
81    fn endpoint_location(&self) -> EndpointLocation;
82
83    /// Post `message` to this endpoint from `cx`.
84    fn post<C>(self, cx: &C, message: M)
85    where
86        C: context::Actor;
87}
88
89/// A typed endpoint that can receive `M` with message headers.
90///
91/// `RemoteEndpoint` is implemented only for endpoints whose post path preserves
92/// headers.
93pub trait RemoteEndpoint<M>: Endpoint<M> {
94    /// Post `message` and `headers` to this endpoint from `cx`.
95    fn post_with_headers<C>(self, cx: &C, headers: Flattrs, message: M)
96    where
97        C: context::Actor;
98}
99
100#[cfg(test)]
101mod tests {
102    use async_trait::async_trait;
103    use hyperactor_config::Flattrs;
104    use hyperactor_config::declare_attrs;
105    use tokio::sync::mpsc;
106    use typeuri::Named;
107
108    use super::*;
109    use crate::Actor;
110    use crate::Handler;
111    use crate::PortRef;
112    use crate::actor::Referable;
113    use crate::actor::RemoteHandles;
114    use crate::proc::Context;
115    use crate::proc::Proc;
116
117    declare_attrs! {
118        attr ENDPOINT_TEST_HEADER: u64;
119    }
120
121    #[derive(Debug)]
122    struct EchoActor {
123        tx: PortRef<u64>,
124    }
125
126    #[async_trait]
127    impl Actor for EchoActor {}
128
129    #[async_trait]
130    impl Handler<u64> for EchoActor {
131        async fn handle(&mut self, cx: &Context<Self>, message: u64) -> anyhow::Result<()> {
132            Endpoint::post(&self.tx, cx, message);
133            Ok(())
134        }
135    }
136
137    struct TestBehavior;
138
139    impl Named for TestBehavior {
140        fn typename() -> &'static str {
141            "hyperactor::endpoint::tests::TestBehavior"
142        }
143    }
144
145    impl Referable for TestBehavior {}
146    impl RemoteHandles<u64> for TestBehavior {}
147
148    #[tokio::test]
149    async fn test_endpoint_actor_handle() {
150        let proc = Proc::isolated();
151        let (client, _) = proc.client("client").unwrap();
152        let (tx, mut rx) = client.open_port();
153        let handle = proc
154            .spawn("echo", EchoActor { tx: tx.bind() })
155            .expect("spawn should succeed");
156
157        Endpoint::post(&handle, &client, 123u64);
158
159        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
160    }
161
162    #[tokio::test]
163    async fn test_endpoint_port_handle() {
164        let proc = Proc::isolated();
165        let (client, _) = proc.client("client").unwrap();
166        let (tx, mut rx) = client.open_port();
167
168        Endpoint::post(&tx, &client, 123u64);
169
170        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
171    }
172
173    #[tokio::test]
174    async fn test_endpoint_once_port_handle() {
175        let proc = Proc::isolated();
176        let (client, _) = proc.client("client").unwrap();
177        let (tx, rx) = client.open_once_port();
178
179        Endpoint::post(tx, &client, 123u64);
180
181        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
182    }
183
184    #[tokio::test]
185    async fn test_endpoint_actor_ref() {
186        let proc = Proc::isolated();
187        let (client, actor_ref, mut rx) = proc
188            .attach_actor::<TestBehavior, u64>("remote_actor")
189            .expect("attach actor should succeed");
190
191        Endpoint::post(&actor_ref, &client, 123u64);
192
193        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
194    }
195
196    #[tokio::test]
197    async fn test_endpoint_port_ref() {
198        let proc = Proc::isolated();
199        let (client, _) = proc.client("client").unwrap();
200        let (tx, mut rx) = client.open_port();
201        let port_ref = tx.bind();
202
203        Endpoint::post(&port_ref, &client, 123u64);
204
205        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
206    }
207
208    #[tokio::test]
209    async fn test_endpoint_once_port_ref() {
210        let proc = Proc::isolated();
211        let (client, _) = proc.client("client").unwrap();
212        let (tx, rx) = client.open_once_port();
213        let port_ref = tx.bind();
214
215        Endpoint::post(port_ref, &client, 123u64);
216
217        assert_eq!(rx.recv().await.expect("message should arrive"), 123);
218    }
219
220    #[tokio::test]
221    async fn test_remote_endpoint_headers() {
222        let proc = Proc::isolated();
223        let (client, _) = proc.client("client").unwrap();
224        let (observed_tx, mut observed_rx) = mpsc::unbounded_channel();
225        let port = client.mailbox_for_py().open_handler_enqueue_port(
226            move |headers: Flattrs, message: u64| {
227                observed_tx
228                    .send((
229                        headers
230                            .get(ENDPOINT_TEST_HEADER)
231                            .expect("header should be present"),
232                        message,
233                    ))
234                    .expect("test receiver should be alive");
235                Ok(())
236            },
237        );
238        let port_ref = port.bind();
239        let mut headers = Flattrs::new();
240        headers.set(ENDPOINT_TEST_HEADER, 456u64);
241
242        RemoteEndpoint::post_with_headers(&port_ref, &client, headers, 123u64);
243
244        assert_eq!(
245            observed_rx.recv().await.expect("message should arrive"),
246            (456, 123)
247        );
248    }
249}