1use std::fmt;
12
13use hyperactor_config::Flattrs;
14use serde::Deserialize;
15use serde::Serialize;
16
17use crate::ActorAddr;
18use crate::PortAddr;
19use crate::context;
20use crate::mailbox::PortLocation;
21
22#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, typeuri::Named)]
24pub enum EndpointLocation {
25 Actor(ActorAddr),
27 Port(PortAddr),
29 Local {
31 actor: ActorAddr,
33 message_type: String,
35 },
36}
37
38impl EndpointLocation {
39 pub fn actor_addr(&self) -> ActorAddr {
41 match self {
42 Self::Actor(actor) => actor.clone(),
43 Self::Port(port) => port.actor_addr(),
44 Self::Local { actor, .. } => actor.clone(),
45 }
46 }
47}
48
49impl From<PortLocation> for EndpointLocation {
50 fn from(location: PortLocation) -> Self {
51 match location {
52 PortLocation::Bound(port) => Self::Port(port),
53 PortLocation::Unbound(actor, message_type) => Self::Local {
54 actor,
55 message_type: message_type.to_string(),
56 },
57 }
58 }
59}
60
61impl fmt::Display for EndpointLocation {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 Self::Actor(actor) => write!(f, "{}", actor),
65 Self::Port(port) => write!(f, "{}", port),
66 Self::Local {
67 actor,
68 message_type,
69 } => write!(f, "{}<{}>", actor, message_type),
70 }
71 }
72}
73
74pub trait Endpoint<M>: crate::private::Sealed {
80 fn endpoint_location(&self) -> EndpointLocation;
82
83 fn post<C>(self, cx: &C, message: M)
85 where
86 C: context::Actor;
87}
88
89pub trait RemoteEndpoint<M>: Endpoint<M> {
94 fn post_with_headers<C>(self, cx: &C, headers: Flattrs, message: M)
96 where
97 C: context::Actor;
98}
99
100#[cfg(test)]
101mod tests {
102 use async_trait::async_trait;
103 use hyperactor_config::Flattrs;
104 use hyperactor_config::declare_attrs;
105 use tokio::sync::mpsc;
106 use typeuri::Named;
107
108 use super::*;
109 use crate::Actor;
110 use crate::Handler;
111 use crate::PortRef;
112 use crate::actor::Referable;
113 use crate::actor::RemoteHandles;
114 use crate::proc::Context;
115 use crate::proc::Proc;
116
117 declare_attrs! {
118 attr ENDPOINT_TEST_HEADER: u64;
119 }
120
121 #[derive(Debug)]
122 struct EchoActor {
123 tx: PortRef<u64>,
124 }
125
126 #[async_trait]
127 impl Actor for EchoActor {}
128
129 #[async_trait]
130 impl Handler<u64> for EchoActor {
131 async fn handle(&mut self, cx: &Context<Self>, message: u64) -> anyhow::Result<()> {
132 Endpoint::post(&self.tx, cx, message);
133 Ok(())
134 }
135 }
136
137 struct TestBehavior;
138
139 impl Named for TestBehavior {
140 fn typename() -> &'static str {
141 "hyperactor::endpoint::tests::TestBehavior"
142 }
143 }
144
145 impl Referable for TestBehavior {}
146 impl RemoteHandles<u64> for TestBehavior {}
147
148 #[tokio::test]
149 async fn test_endpoint_actor_handle() {
150 let proc = Proc::isolated();
151 let (client, _) = proc.client("client").unwrap();
152 let (tx, mut rx) = client.open_port();
153 let handle = proc
154 .spawn("echo", EchoActor { tx: tx.bind() })
155 .expect("spawn should succeed");
156
157 Endpoint::post(&handle, &client, 123u64);
158
159 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
160 }
161
162 #[tokio::test]
163 async fn test_endpoint_port_handle() {
164 let proc = Proc::isolated();
165 let (client, _) = proc.client("client").unwrap();
166 let (tx, mut rx) = client.open_port();
167
168 Endpoint::post(&tx, &client, 123u64);
169
170 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
171 }
172
173 #[tokio::test]
174 async fn test_endpoint_once_port_handle() {
175 let proc = Proc::isolated();
176 let (client, _) = proc.client("client").unwrap();
177 let (tx, rx) = client.open_once_port();
178
179 Endpoint::post(tx, &client, 123u64);
180
181 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
182 }
183
184 #[tokio::test]
185 async fn test_endpoint_actor_ref() {
186 let proc = Proc::isolated();
187 let (client, actor_ref, mut rx) = proc
188 .attach_actor::<TestBehavior, u64>("remote_actor")
189 .expect("attach actor should succeed");
190
191 Endpoint::post(&actor_ref, &client, 123u64);
192
193 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
194 }
195
196 #[tokio::test]
197 async fn test_endpoint_port_ref() {
198 let proc = Proc::isolated();
199 let (client, _) = proc.client("client").unwrap();
200 let (tx, mut rx) = client.open_port();
201 let port_ref = tx.bind();
202
203 Endpoint::post(&port_ref, &client, 123u64);
204
205 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
206 }
207
208 #[tokio::test]
209 async fn test_endpoint_once_port_ref() {
210 let proc = Proc::isolated();
211 let (client, _) = proc.client("client").unwrap();
212 let (tx, rx) = client.open_once_port();
213 let port_ref = tx.bind();
214
215 Endpoint::post(port_ref, &client, 123u64);
216
217 assert_eq!(rx.recv().await.expect("message should arrive"), 123);
218 }
219
220 #[tokio::test]
221 async fn test_remote_endpoint_headers() {
222 let proc = Proc::isolated();
223 let (client, _) = proc.client("client").unwrap();
224 let (observed_tx, mut observed_rx) = mpsc::unbounded_channel();
225 let port = client.mailbox_for_py().open_handler_enqueue_port(
226 move |headers: Flattrs, message: u64| {
227 observed_tx
228 .send((
229 headers
230 .get(ENDPOINT_TEST_HEADER)
231 .expect("header should be present"),
232 message,
233 ))
234 .expect("test receiver should be alive");
235 Ok(())
236 },
237 );
238 let port_ref = port.bind();
239 let mut headers = Flattrs::new();
240 headers.set(ENDPOINT_TEST_HEADER, 456u64);
241
242 RemoteEndpoint::post_with_headers(&port_ref, &client, headers, 123u64);
243
244 assert_eq!(
245 observed_rx.recv().await.expect("message should arrive"),
246 (456, 123)
247 );
248 }
249}