1use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorAddr;
51use hyperactor::Endpoint as _;
52use hyperactor::OncePortRef;
53use hyperactor::PortRef;
54use hyperactor::context;
55use hyperactor::mailbox::OncePortReceiver;
56use hyperactor::mailbox::PortReceiver;
57use hyperactor::mailbox::open_once_port;
58use hyperactor::mailbox::open_port;
59use hyperactor::message::Bind;
60use hyperactor::message::Bindings;
61use hyperactor::message::Unbind;
62use pin_project::pin_project;
63use pin_project::pinned_drop;
64use serde::Deserialize;
65use serde::Serialize;
66use tokio::io::AsyncRead;
67use tokio::io::AsyncWrite;
68use tokio_util::io::StreamReader;
69use typeuri::Named;
70
71const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
73
74#[derive(Debug, Serialize, Deserialize, Named, Clone)]
76enum Io {
77 Data(#[serde(with = "serde_bytes")] Vec<u8>),
79 Eof,
81}
82wirevalue::register_type!(Io);
83
84struct OwnedReadHalfStream {
85 port: PortReceiver<Io>,
86 exhausted: bool,
87}
88
89pub struct OwnedReadHalf {
91 peer: ActorAddr,
92 inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
93}
94
95#[pin_project(PinnedDrop)]
97pub struct OwnedWriteHalf<C: context::Actor> {
98 peer: ActorAddr,
99 #[pin]
100 caps: C,
101 #[pin]
102 port: PortRef<Io>,
103 #[pin]
104 shutdown: bool,
105}
106
107#[pin_project]
109pub struct ActorConnection<C: context::Actor> {
110 #[pin]
111 reader: OwnedReadHalf,
112 #[pin]
113 writer: OwnedWriteHalf<C>,
114}
115
116impl<C: context::Actor> ActorConnection<C> {
117 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
118 (self.reader, self.writer)
119 }
120
121 pub fn peer(&self) -> &ActorAddr {
122 self.reader.peer()
123 }
124}
125
126impl OwnedReadHalf {
127 fn new(peer: ActorAddr, port: PortReceiver<Io>) -> Self {
128 Self {
129 peer,
130 inner: StreamReader::new(OwnedReadHalfStream {
131 port,
132 exhausted: false,
133 }),
134 }
135 }
136
137 pub fn peer(&self) -> &ActorAddr {
138 &self.peer
139 }
140
141 pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
142 ActorConnection {
143 reader: self,
144 writer: other,
145 }
146 }
147}
148
149impl<C: context::Actor> OwnedWriteHalf<C> {
150 fn new(peer: ActorAddr, caps: C, port: PortRef<Io>) -> Self {
151 Self {
152 peer,
153 caps,
154 port,
155 shutdown: false,
156 }
157 }
158
159 pub fn peer(&self) -> &ActorAddr {
160 &self.peer
161 }
162
163 pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
164 ActorConnection {
165 reader: other,
166 writer: self,
167 }
168 }
169}
170
171#[pinned_drop]
172impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
173 fn drop(self: Pin<&mut Self>) {
174 let this = self.project();
175 if !*this.shutdown {
176 let _ = this.port.post(&*this.caps, Io::Eof);
177 }
178 }
179}
180
181impl<C: context::Actor> AsyncRead for ActorConnection<C> {
182 fn poll_read(
183 self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &mut tokio::io::ReadBuf<'_>,
186 ) -> Poll<std::io::Result<()>> {
187 let this = self.project();
189 this.reader.poll_read(cx, buf)
190 }
191}
192
193impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
194 fn poll_write(
195 self: Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 buf: &[u8],
198 ) -> Poll<Result<usize, std::io::Error>> {
199 let this = self.project();
201 this.writer.poll_write(cx, buf)
202 }
203
204 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
205 let this = self.project();
206 this.writer.poll_flush(cx)
207 }
208
209 fn poll_shutdown(
210 self: Pin<&mut Self>,
211 cx: &mut Context<'_>,
212 ) -> Poll<Result<(), std::io::Error>> {
213 let this = self.project();
214 this.writer.poll_shutdown(cx)
215 }
216}
217
218impl Stream for OwnedReadHalfStream {
219 type Item = std::io::Result<Cursor<Vec<u8>>>;
220
221 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 if self.exhausted {
224 return Poll::Ready(None);
225 }
226
227 let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
228 match result {
229 Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
230 Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
231 Ok(Io::Eof) => {
233 self.exhausted = true;
234 Poll::Ready(None)
235 }
236 }
237 }
238}
239
240impl FusedStream for OwnedReadHalfStream {
241 fn is_terminated(&self) -> bool {
242 self.exhausted
243 }
244}
245
246impl AsyncRead for OwnedReadHalf {
247 fn poll_read(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> Poll<std::io::Result<()>> {
252 Pin::new(&mut self.inner).poll_read(cx, buf)
253 }
254}
255
256impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
257 fn poll_write(
258 self: Pin<&mut Self>,
259 _cx: &mut Context<'_>,
260 buf: &[u8],
261 ) -> Poll<Result<usize, std::io::Error>> {
262 let this = self.project();
263 if *this.shutdown {
264 return Poll::Ready(Err(std::io::Error::new(
265 std::io::ErrorKind::BrokenPipe,
266 "write after shutdown",
267 )));
268 }
269 this.port.post(&*this.caps, Io::Data(buf.into()));
270 Poll::Ready(Ok(buf.len()))
271 }
272
273 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
274 Poll::Ready(Ok(()))
275 }
276
277 fn poll_shutdown(
278 self: Pin<&mut Self>,
279 _cx: &mut Context<'_>,
280 ) -> Poll<Result<(), std::io::Error>> {
281 self.port.post(&self.caps, Io::Eof);
283 let mut this = self.project();
284 *this.shutdown = true;
285 Poll::Ready(Ok(()))
286 }
287}
288
289pub struct ConnectionCompleter<C> {
291 caps: C,
292 conn: PortReceiver<Io>,
293 port: OncePortReceiver<Accept>,
294}
295
296impl<C: context::Actor> ConnectionCompleter<C> {
297 pub async fn complete(self) -> Result<ActorConnection<C>> {
300 let accept = tokio::time::timeout(CONNECT_TIMEOUT, self.port.recv()).await??;
301 Ok(ActorConnection {
302 reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
303 writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
304 })
305 }
306}
307
308#[derive(Debug, Serialize, Deserialize, Named, Clone)]
310pub struct Connect {
311 id: ActorAddr,
313 conn: PortRef<Io>,
314 return_conn: OncePortRef<Accept>,
316}
317wirevalue::register_type!(Connect);
318
319impl Connect {
320 pub fn allocate<C: context::Actor>(id: ActorAddr, caps: C) -> (Self, ConnectionCompleter<C>) {
323 let (conn_tx, conn_rx) = open_port::<Io>(&caps);
324 let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
325 (
326 Self {
327 id,
328 conn: conn_tx.bind(),
329 return_conn: return_tx.bind(),
330 },
331 ConnectionCompleter {
332 caps,
333 conn: conn_rx,
334 port: return_rx,
335 },
336 )
337 }
338}
339
340#[derive(Debug, Serialize, Deserialize, Named, Clone)]
343struct Accept {
344 id: ActorAddr,
346 conn: PortRef<Io>,
348}
349wirevalue::register_type!(Accept);
350
351impl Bind for Connect {
352 fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
353 self.conn.bind(bindings)?;
354 self.return_conn.bind(bindings)
355 }
356}
357
358impl Unbind for Connect {
359 fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
360 self.conn.unbind(bindings)?;
361 self.return_conn.unbind(bindings)
362 }
363}
364
365pub async fn accept<C: context::Actor>(
368 caps: C,
369 self_id: ActorAddr,
370 message: Connect,
371) -> Result<ActorConnection<C>> {
372 let (tx, rx) = open_port::<Io>(&caps);
373 message.return_conn.post(
374 &caps,
375 Accept {
376 id: self_id,
377 conn: tx.bind(),
378 },
379 );
380 Ok(ActorConnection {
381 reader: OwnedReadHalf::new(message.id.clone(), rx),
382 writer: OwnedWriteHalf::new(message.id, caps, message.conn),
383 })
384}
385
386#[cfg(test)]
387mod tests {
388 use anyhow::Result;
389 use async_trait::async_trait;
390 use futures::try_join;
391 use hyperactor::Actor;
392 use hyperactor::Context;
393 use hyperactor::Handler;
394 use hyperactor::proc::Proc;
395 use tokio::io::AsyncReadExt;
396 use tokio::io::AsyncWriteExt;
397
398 use super::*;
399
400 #[derive(Debug, Default)]
401 struct EchoActor {}
402
403 impl Actor for EchoActor {}
404
405 #[async_trait]
406 impl Handler<Connect> for EchoActor {
407 async fn handle(
408 &mut self,
409 cx: &Context<Self>,
410 message: Connect,
411 ) -> Result<(), anyhow::Error> {
412 let (mut rd, mut wr) = accept(cx, cx.self_addr().clone(), message)
413 .await?
414 .into_split();
415 tokio::io::copy(&mut rd, &mut wr).await?;
416 wr.shutdown().await?;
417 Ok(())
418 }
419 }
420
421 #[tokio::test]
422 async fn test_simple_connection() -> Result<()> {
423 let proc = Proc::isolated();
424 let (client, _) = proc.client("client")?;
425 let (connect, completer) = Connect::allocate(client.self_addr().clone(), client);
426 let actor = proc.spawn("actor", EchoActor {})?;
427 actor.post(&completer.caps, connect);
428 let (mut rd, mut wr) = completer.complete().await?.into_split();
429 let send = [3u8, 4u8, 5u8, 6u8];
430 try_join!(
431 async move {
432 wr.write_all(&send).await?;
433 wr.shutdown().await?;
434 anyhow::Ok(())
435 },
436 async {
437 let mut recv = vec![];
438 rd.read_to_end(&mut recv).await?;
439 assert_eq!(&send, recv.as_slice());
440 anyhow::Ok(())
441 },
442 )?;
443 Ok(())
444 }
445
446 #[tokio::test]
447 async fn test_connection_close_on_drop() -> Result<()> {
448 let proc = Proc::isolated();
449 let (client, _client_handle) = proc.client("client")?;
450
451 let (connect, completer) =
452 Connect::allocate(client.self_addr().clone(), client.clone_for_py());
453 let (mut rd, _) = accept(client.clone_for_py(), client.self_addr().clone(), connect)
454 .await?
455 .into_split();
456 let (_, mut wr) = completer.complete().await?.into_split();
457
458 let send = [1u8, 2u8, 3u8];
460 wr.write_all(&send).await?;
461
462 drop(wr);
464
465 let mut recv = vec![];
467 rd.read_to_end(&mut recv).await?;
468 assert_eq!(&send, recv.as_slice());
469
470 Ok(())
471 }
472
473 #[tokio::test]
474 async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
475 let proc = Proc::isolated();
476 let (client, _client_handle) = proc.client("client")?;
477
478 let (connect, completer) =
479 Connect::allocate(client.self_addr().clone(), client.clone_for_py());
480 let (mut rd, _) = accept(client.clone_for_py(), client.self_addr().clone(), connect)
481 .await?
482 .into_split();
483 let (_, mut wr) = completer.complete().await?.into_split();
484
485 let send = [1u8, 2u8, 3u8];
487 wr.write_all(&send).await?;
488
489 wr.shutdown().await?;
491
492 let mut recv = vec![];
494 rd.read_to_end(&mut recv).await?;
495 assert_eq!(&send, recv.as_slice());
496
497 drop(wr);
499
500 assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
502
503 Ok(())
504 }
505}