1use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::Mailbox;
52use hyperactor::Named;
53use hyperactor::OncePortRef;
54use hyperactor::PortRef;
55use hyperactor::cap::CanOpenPort;
56use hyperactor::cap::CanSend;
57use hyperactor::clock::Clock;
58use hyperactor::clock::RealClock;
59use hyperactor::mailbox::OncePortReceiver;
60use hyperactor::mailbox::PortReceiver;
61use hyperactor::mailbox::open_once_port;
62use hyperactor::mailbox::open_port;
63use hyperactor::message::Bind;
64use hyperactor::message::Bindings;
65use hyperactor::message::Unbind;
66use pin_project::pin_project;
67use pin_project::pinned_drop;
68use serde::Deserialize;
69use serde::Serialize;
70use tokio::io::AsyncRead;
71use tokio::io::AsyncWrite;
72use tokio_util::io::StreamReader;
73
74const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
76
77#[derive(Debug, Serialize, Deserialize, Named, Clone)]
79enum Io {
80 Data(#[serde(with = "serde_bytes")] Vec<u8>),
82 Eof,
84}
85
86struct OwnedReadHalfStream {
87 port: PortReceiver<Io>,
88 exhausted: bool,
89}
90
91pub struct OwnedReadHalf {
93 peer: ActorId,
94 inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
95}
96
97#[pin_project(PinnedDrop)]
99pub struct OwnedWriteHalf<C: CanSend> {
100 peer: ActorId,
101 #[pin]
102 caps: C,
103 #[pin]
104 port: PortRef<Io>,
105 #[pin]
106 shutdown: bool,
107}
108
109#[pin_project]
111pub struct ActorConnection<C: CanSend> {
112 #[pin]
113 reader: OwnedReadHalf,
114 #[pin]
115 writer: OwnedWriteHalf<C>,
116}
117
118impl<C: CanSend> ActorConnection<C> {
119 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
120 (self.reader, self.writer)
121 }
122
123 pub fn peer(&self) -> &ActorId {
124 self.reader.peer()
125 }
126}
127
128impl OwnedReadHalf {
129 fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
130 Self {
131 peer,
132 inner: StreamReader::new(OwnedReadHalfStream {
133 port,
134 exhausted: false,
135 }),
136 }
137 }
138
139 pub fn peer(&self) -> &ActorId {
140 &self.peer
141 }
142
143 pub fn reunited<C: CanSend>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
144 ActorConnection {
145 reader: self,
146 writer: other,
147 }
148 }
149}
150
151impl<C: CanSend> OwnedWriteHalf<C> {
152 fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
153 Self {
154 peer,
155 caps,
156 port,
157 shutdown: false,
158 }
159 }
160
161 pub fn peer(&self) -> &ActorId {
162 &self.peer
163 }
164
165 pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
166 ActorConnection {
167 reader: other,
168 writer: self,
169 }
170 }
171}
172
173#[pinned_drop]
174impl<C: CanSend> PinnedDrop for OwnedWriteHalf<C> {
175 fn drop(self: Pin<&mut Self>) {
176 let this = self.project();
177 if !*this.shutdown {
178 let _ = this.port.send(&*this.caps, Io::Eof);
179 }
180 }
181}
182
183impl<C: CanSend> AsyncRead for ActorConnection<C> {
184 fn poll_read(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &mut tokio::io::ReadBuf<'_>,
188 ) -> Poll<std::io::Result<()>> {
189 let this = self.project();
191 this.reader.poll_read(cx, buf)
192 }
193}
194
195impl<C: CanSend> AsyncWrite for ActorConnection<C> {
196 fn poll_write(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &[u8],
200 ) -> Poll<Result<usize, std::io::Error>> {
201 let this = self.project();
203 this.writer.poll_write(cx, buf)
204 }
205
206 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
207 let this = self.project();
208 this.writer.poll_flush(cx)
209 }
210
211 fn poll_shutdown(
212 self: Pin<&mut Self>,
213 cx: &mut Context<'_>,
214 ) -> Poll<Result<(), std::io::Error>> {
215 let this = self.project();
216 this.writer.poll_shutdown(cx)
217 }
218}
219
220impl Stream for OwnedReadHalfStream {
221 type Item = std::io::Result<Cursor<Vec<u8>>>;
222
223 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 if self.exhausted {
226 return Poll::Ready(None);
227 }
228
229 let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
230 match result {
231 Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
232 Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
233 Ok(Io::Eof) => {
235 self.exhausted = true;
236 Poll::Ready(None)
237 }
238 }
239 }
240}
241
242impl FusedStream for OwnedReadHalfStream {
243 fn is_terminated(&self) -> bool {
244 self.exhausted
245 }
246}
247
248impl AsyncRead for OwnedReadHalf {
249 fn poll_read(
250 mut self: Pin<&mut Self>,
251 cx: &mut Context<'_>,
252 buf: &mut tokio::io::ReadBuf<'_>,
253 ) -> Poll<std::io::Result<()>> {
254 Pin::new(&mut self.inner).poll_read(cx, buf)
255 }
256}
257
258impl<C: CanSend> AsyncWrite for OwnedWriteHalf<C> {
259 fn poll_write(
260 self: Pin<&mut Self>,
261 _cx: &mut Context<'_>,
262 buf: &[u8],
263 ) -> Poll<Result<usize, std::io::Error>> {
264 let this = self.project();
265 if *this.shutdown {
266 return Poll::Ready(Err(std::io::Error::new(
267 std::io::ErrorKind::BrokenPipe,
268 "write after shutdown",
269 )));
270 }
271 match this.port.send(&*this.caps, Io::Data(buf.into())) {
272 Ok(()) => Poll::Ready(Ok(buf.len())),
273 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
274 }
275 }
276
277 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
278 Poll::Ready(Ok(()))
279 }
280
281 fn poll_shutdown(
282 self: Pin<&mut Self>,
283 _cx: &mut Context<'_>,
284 ) -> Poll<Result<(), std::io::Error>> {
285 match self.port.send(&self.caps, Io::Eof) {
287 Ok(()) => {
288 let mut this = self.project();
289 *this.shutdown = true;
290 Poll::Ready(Ok(()))
291 }
292 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
293 }
294 }
295}
296
297pub struct ConnectionCompleter<C> {
299 caps: C,
300 conn: PortReceiver<Io>,
301 port: OncePortReceiver<Accept>,
302}
303
304impl<C: CanOpenPort + CanSend> ConnectionCompleter<C> {
305 pub async fn complete(self) -> Result<ActorConnection<C>> {
308 let accept = RealClock
309 .timeout(CONNECT_TIMEOUT, self.port.recv())
310 .await??;
311 Ok(ActorConnection {
312 reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
313 writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
314 })
315 }
316}
317
318#[derive(Debug, Serialize, Deserialize, Named, Clone)]
320pub struct Connect {
321 id: ActorId,
323 conn: PortRef<Io>,
324 return_conn: OncePortRef<Accept>,
326}
327
328impl Connect {
329 pub fn allocate<C: CanOpenPort + CanSend>(
332 id: ActorId,
333 caps: C,
334 ) -> (Self, ConnectionCompleter<C>) {
335 let (conn_tx, conn_rx) = open_port::<Io>(&caps);
336 let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
337 (
338 Self {
339 id,
340 conn: conn_tx.bind(),
341 return_conn: return_tx.bind(),
342 },
343 ConnectionCompleter {
344 caps,
345 conn: conn_rx,
346 port: return_rx,
347 },
348 )
349 }
350}
351
352#[derive(Debug, Serialize, Deserialize, Named, Clone)]
355struct Accept {
356 id: ActorId,
358 conn: PortRef<Io>,
360}
361
362impl Bind for Connect {
363 fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
364 self.conn.bind(bindings)?;
365 self.return_conn.bind(bindings)
366 }
367}
368
369impl Unbind for Connect {
370 fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
371 self.conn.unbind(bindings)?;
372 self.return_conn.unbind(bindings)
373 }
374}
375
376pub async fn accept<C: CanOpenPort + CanSend>(
379 caps: C,
380 self_id: ActorId,
381 message: Connect,
382) -> Result<ActorConnection<C>> {
383 let (tx, rx) = open_port::<Io>(&caps);
384 message.return_conn.send(
385 &caps,
386 Accept {
387 id: self_id,
388 conn: tx.bind(),
389 },
390 )?;
391 Ok(ActorConnection {
392 reader: OwnedReadHalf::new(message.id.clone(), rx),
393 writer: OwnedWriteHalf::new(message.id, caps, message.conn),
394 })
395}
396
397pub async fn connect(
401 mailbox: &Mailbox,
402 port: PortRef<Connect>,
403) -> Result<ActorConnection<Mailbox>> {
404 let (connect, completer) = Connect::allocate(mailbox.actor_id().clone(), mailbox.clone());
405 port.send(mailbox, connect)?;
406 completer.complete().await
407}
408
409#[cfg(test)]
410mod tests {
411 use anyhow::Result;
412 use async_trait::async_trait;
413 use futures::try_join;
414 use hyperactor::Actor;
415 use hyperactor::Context;
416 use hyperactor::Handler;
417 use hyperactor::proc::Proc;
418 use tokio::io::AsyncReadExt;
419 use tokio::io::AsyncWriteExt;
420
421 use super::*;
422
423 #[derive(Debug, Default, Actor)]
424 struct EchoActor {}
425
426 #[async_trait]
427 impl Handler<Connect> for EchoActor {
428 async fn handle(
429 &mut self,
430 cx: &Context<Self>,
431 message: Connect,
432 ) -> Result<(), anyhow::Error> {
433 let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
434 .await?
435 .into_split();
436 tokio::io::copy(&mut rd, &mut wr).await?;
437 wr.shutdown().await?;
438 Ok(())
439 }
440 }
441
442 #[tokio::test]
443 async fn test_simple_connection() -> Result<()> {
444 let proc = Proc::local();
445 let client = proc.attach("client")?;
446 let actor = proc.spawn::<EchoActor>("actor", ()).await?;
447 let (mut rd, mut wr) = connect(&client, actor.port().bind()).await?.into_split();
448 let send = [3u8, 4u8, 5u8, 6u8];
449 try_join!(
450 async move {
451 wr.write_all(&send).await?;
452 wr.shutdown().await?;
453 anyhow::Ok(())
454 },
455 async {
456 let mut recv = vec![];
457 rd.read_to_end(&mut recv).await?;
458 assert_eq!(&send, recv.as_slice());
459 anyhow::Ok(())
460 },
461 )?;
462 Ok(())
463 }
464
465 #[tokio::test]
466 async fn test_connection_close_on_drop() -> Result<()> {
467 let proc = Proc::local();
468 let client = proc.attach("client")?;
469
470 let (connect, completer) = Connect::allocate(client.actor_id().clone(), client.clone());
471 let (mut rd, _) = accept(&client, client.actor_id().clone(), connect)
472 .await?
473 .into_split();
474 let (_, mut wr) = completer.complete().await?.into_split();
475
476 let send = [1u8, 2u8, 3u8];
478 wr.write_all(&send).await?;
479
480 drop(wr);
482
483 let mut recv = vec![];
485 rd.read_to_end(&mut recv).await?;
486 assert_eq!(&send, recv.as_slice());
487
488 Ok(())
489 }
490
491 #[tokio::test]
492 async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
493 let proc = Proc::local();
494 let client = proc.attach("client")?;
495
496 let (connect, completer) = Connect::allocate(client.actor_id().clone(), client.clone());
497 let (mut rd, _) = accept(&client, client.actor_id().clone(), connect)
498 .await?
499 .into_split();
500 let (_, mut wr) = completer.complete().await?.into_split();
501
502 let send = [1u8, 2u8, 3u8];
504 wr.write_all(&send).await?;
505
506 wr.shutdown().await?;
508
509 let mut recv = vec![];
511 rd.read_to_end(&mut recv).await?;
512 assert_eq!(&send, recv.as_slice());
513
514 drop(wr);
516
517 assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
519
520 Ok(())
521 }
522}