1use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::context;
51use hyperactor::mailbox::OncePortReceiver;
52use hyperactor::mailbox::PortReceiver;
53use hyperactor::mailbox::open_once_port;
54use hyperactor::mailbox::open_port;
55use hyperactor::message::Bind;
56use hyperactor::message::Bindings;
57use hyperactor::message::Unbind;
58use hyperactor::reference as hyperactor_reference;
59use pin_project::pin_project;
60use pin_project::pinned_drop;
61use serde::Deserialize;
62use serde::Serialize;
63use tokio::io::AsyncRead;
64use tokio::io::AsyncWrite;
65use tokio_util::io::StreamReader;
66use typeuri::Named;
67
68const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
70
71#[derive(Debug, Serialize, Deserialize, Named, Clone)]
73enum Io {
74 Data(#[serde(with = "serde_bytes")] Vec<u8>),
76 Eof,
78}
79wirevalue::register_type!(Io);
80
81struct OwnedReadHalfStream {
82 port: PortReceiver<Io>,
83 exhausted: bool,
84}
85
86pub struct OwnedReadHalf {
88 peer: hyperactor_reference::ActorId,
89 inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
90}
91
92#[pin_project(PinnedDrop)]
94pub struct OwnedWriteHalf<C: context::Actor> {
95 peer: hyperactor_reference::ActorId,
96 #[pin]
97 caps: C,
98 #[pin]
99 port: hyperactor_reference::PortRef<Io>,
100 #[pin]
101 shutdown: bool,
102}
103
104#[pin_project]
106pub struct ActorConnection<C: context::Actor> {
107 #[pin]
108 reader: OwnedReadHalf,
109 #[pin]
110 writer: OwnedWriteHalf<C>,
111}
112
113impl<C: context::Actor> ActorConnection<C> {
114 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
115 (self.reader, self.writer)
116 }
117
118 pub fn peer(&self) -> &hyperactor_reference::ActorId {
119 self.reader.peer()
120 }
121}
122
123impl OwnedReadHalf {
124 fn new(peer: hyperactor_reference::ActorId, port: PortReceiver<Io>) -> Self {
125 Self {
126 peer,
127 inner: StreamReader::new(OwnedReadHalfStream {
128 port,
129 exhausted: false,
130 }),
131 }
132 }
133
134 pub fn peer(&self) -> &hyperactor_reference::ActorId {
135 &self.peer
136 }
137
138 pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
139 ActorConnection {
140 reader: self,
141 writer: other,
142 }
143 }
144}
145
146impl<C: context::Actor> OwnedWriteHalf<C> {
147 fn new(
148 peer: hyperactor_reference::ActorId,
149 caps: C,
150 port: hyperactor_reference::PortRef<Io>,
151 ) -> Self {
152 Self {
153 peer,
154 caps,
155 port,
156 shutdown: false,
157 }
158 }
159
160 pub fn peer(&self) -> &hyperactor_reference::ActorId {
161 &self.peer
162 }
163
164 pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
165 ActorConnection {
166 reader: other,
167 writer: self,
168 }
169 }
170}
171
172#[pinned_drop]
173impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
174 fn drop(self: Pin<&mut Self>) {
175 let this = self.project();
176 if !*this.shutdown {
177 let _ = this.port.send(&*this.caps, Io::Eof);
178 }
179 }
180}
181
182impl<C: context::Actor> AsyncRead for ActorConnection<C> {
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut tokio::io::ReadBuf<'_>,
187 ) -> Poll<std::io::Result<()>> {
188 let this = self.project();
190 this.reader.poll_read(cx, buf)
191 }
192}
193
194impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
195 fn poll_write(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &[u8],
199 ) -> Poll<Result<usize, std::io::Error>> {
200 let this = self.project();
202 this.writer.poll_write(cx, buf)
203 }
204
205 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
206 let this = self.project();
207 this.writer.poll_flush(cx)
208 }
209
210 fn poll_shutdown(
211 self: Pin<&mut Self>,
212 cx: &mut Context<'_>,
213 ) -> Poll<Result<(), std::io::Error>> {
214 let this = self.project();
215 this.writer.poll_shutdown(cx)
216 }
217}
218
219impl Stream for OwnedReadHalfStream {
220 type Item = std::io::Result<Cursor<Vec<u8>>>;
221
222 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223 if self.exhausted {
225 return Poll::Ready(None);
226 }
227
228 let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
229 match result {
230 Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
231 Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
232 Ok(Io::Eof) => {
234 self.exhausted = true;
235 Poll::Ready(None)
236 }
237 }
238 }
239}
240
241impl FusedStream for OwnedReadHalfStream {
242 fn is_terminated(&self) -> bool {
243 self.exhausted
244 }
245}
246
247impl AsyncRead for OwnedReadHalf {
248 fn poll_read(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &mut tokio::io::ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 Pin::new(&mut self.inner).poll_read(cx, buf)
254 }
255}
256
257impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
258 fn poll_write(
259 self: Pin<&mut Self>,
260 _cx: &mut Context<'_>,
261 buf: &[u8],
262 ) -> Poll<Result<usize, std::io::Error>> {
263 let this = self.project();
264 if *this.shutdown {
265 return Poll::Ready(Err(std::io::Error::new(
266 std::io::ErrorKind::BrokenPipe,
267 "write after shutdown",
268 )));
269 }
270 match this.port.send(&*this.caps, Io::Data(buf.into())) {
271 Ok(()) => Poll::Ready(Ok(buf.len())),
272 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
273 }
274 }
275
276 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
277 Poll::Ready(Ok(()))
278 }
279
280 fn poll_shutdown(
281 self: Pin<&mut Self>,
282 _cx: &mut Context<'_>,
283 ) -> Poll<Result<(), std::io::Error>> {
284 match self.port.send(&self.caps, Io::Eof) {
286 Ok(()) => {
287 let mut this = self.project();
288 *this.shutdown = true;
289 Poll::Ready(Ok(()))
290 }
291 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
292 }
293 }
294}
295
296pub struct ConnectionCompleter<C> {
298 caps: C,
299 conn: PortReceiver<Io>,
300 port: OncePortReceiver<Accept>,
301}
302
303impl<C: context::Actor> ConnectionCompleter<C> {
304 pub async fn complete(self) -> Result<ActorConnection<C>> {
307 let accept = tokio::time::timeout(CONNECT_TIMEOUT, self.port.recv()).await??;
308 Ok(ActorConnection {
309 reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
310 writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
311 })
312 }
313}
314
315#[derive(Debug, Serialize, Deserialize, Named, Clone)]
317pub struct Connect {
318 id: hyperactor_reference::ActorId,
320 conn: hyperactor_reference::PortRef<Io>,
321 return_conn: hyperactor_reference::OncePortRef<Accept>,
323}
324wirevalue::register_type!(Connect);
325
326impl Connect {
327 pub fn allocate<C: context::Actor>(
330 id: hyperactor_reference::ActorId,
331 caps: C,
332 ) -> (Self, ConnectionCompleter<C>) {
333 let (conn_tx, conn_rx) = open_port::<Io>(&caps);
334 let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
335 (
336 Self {
337 id,
338 conn: conn_tx.bind(),
339 return_conn: return_tx.bind(),
340 },
341 ConnectionCompleter {
342 caps,
343 conn: conn_rx,
344 port: return_rx,
345 },
346 )
347 }
348}
349
350#[derive(Debug, Serialize, Deserialize, Named, Clone)]
353struct Accept {
354 id: hyperactor_reference::ActorId,
356 conn: hyperactor_reference::PortRef<Io>,
358}
359wirevalue::register_type!(Accept);
360
361impl Bind for Connect {
362 fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
363 self.conn.bind(bindings)?;
364 self.return_conn.bind(bindings)
365 }
366}
367
368impl Unbind for Connect {
369 fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
370 self.conn.unbind(bindings)?;
371 self.return_conn.unbind(bindings)
372 }
373}
374
375pub async fn accept<C: context::Actor>(
378 caps: C,
379 self_id: hyperactor_reference::ActorId,
380 message: Connect,
381) -> Result<ActorConnection<C>> {
382 let (tx, rx) = open_port::<Io>(&caps);
383 message.return_conn.send(
384 &caps,
385 Accept {
386 id: self_id,
387 conn: tx.bind(),
388 },
389 )?;
390 Ok(ActorConnection {
391 reader: OwnedReadHalf::new(message.id.clone(), rx),
392 writer: OwnedWriteHalf::new(message.id, caps, message.conn),
393 })
394}
395
396#[cfg(test)]
397mod tests {
398 use anyhow::Result;
399 use async_trait::async_trait;
400 use futures::try_join;
401 use hyperactor::Actor;
402 use hyperactor::Context;
403 use hyperactor::Handler;
404 use hyperactor::proc::Proc;
405 use tokio::io::AsyncReadExt;
406 use tokio::io::AsyncWriteExt;
407
408 use super::*;
409
410 #[derive(Debug, Default)]
411 struct EchoActor {}
412
413 impl Actor for EchoActor {}
414
415 #[async_trait]
416 impl Handler<Connect> for EchoActor {
417 async fn handle(
418 &mut self,
419 cx: &Context<Self>,
420 message: Connect,
421 ) -> Result<(), anyhow::Error> {
422 let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
423 .await?
424 .into_split();
425 tokio::io::copy(&mut rd, &mut wr).await?;
426 wr.shutdown().await?;
427 Ok(())
428 }
429 }
430
431 #[tokio::test]
432 async fn test_simple_connection() -> Result<()> {
433 let proc = Proc::local();
434 let (client, _) = proc.instance("client")?;
435 let (connect, completer) = Connect::allocate(client.self_id().clone(), client);
436 let actor = proc.spawn("actor", EchoActor {})?;
437 actor.send(&completer.caps, connect)?;
438 let (mut rd, mut wr) = completer.complete().await?.into_split();
439 let send = [3u8, 4u8, 5u8, 6u8];
440 try_join!(
441 async move {
442 wr.write_all(&send).await?;
443 wr.shutdown().await?;
444 anyhow::Ok(())
445 },
446 async {
447 let mut recv = vec![];
448 rd.read_to_end(&mut recv).await?;
449 assert_eq!(&send, recv.as_slice());
450 anyhow::Ok(())
451 },
452 )?;
453 Ok(())
454 }
455
456 #[tokio::test]
457 async fn test_connection_close_on_drop() -> Result<()> {
458 let proc = Proc::local();
459 let (client, _client_handle) = proc.instance("client")?;
460
461 let (connect, completer) =
462 Connect::allocate(client.self_id().clone(), client.clone_for_py());
463 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
464 .await?
465 .into_split();
466 let (_, mut wr) = completer.complete().await?.into_split();
467
468 let send = [1u8, 2u8, 3u8];
470 wr.write_all(&send).await?;
471
472 drop(wr);
474
475 let mut recv = vec![];
477 rd.read_to_end(&mut recv).await?;
478 assert_eq!(&send, recv.as_slice());
479
480 Ok(())
481 }
482
483 #[tokio::test]
484 async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
485 let proc = Proc::local();
486 let (client, _client_handle) = proc.instance("client")?;
487
488 let (connect, completer) =
489 Connect::allocate(client.self_id().clone(), client.clone_for_py());
490 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
491 .await?
492 .into_split();
493 let (_, mut wr) = completer.complete().await?.into_split();
494
495 let send = [1u8, 2u8, 3u8];
497 wr.write_all(&send).await?;
498
499 wr.shutdown().await?;
501
502 let mut recv = vec![];
504 rd.read_to_end(&mut recv).await?;
505 assert_eq!(&send, recv.as_slice());
506
507 drop(wr);
509
510 assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
512
513 Ok(())
514 }
515}