1use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::Named;
52use hyperactor::OncePortRef;
53use hyperactor::PortRef;
54use hyperactor::clock::Clock;
55use hyperactor::clock::RealClock;
56use hyperactor::context;
57use hyperactor::mailbox::OncePortReceiver;
58use hyperactor::mailbox::PortReceiver;
59use hyperactor::mailbox::open_once_port;
60use hyperactor::mailbox::open_port;
61use hyperactor::message::Bind;
62use hyperactor::message::Bindings;
63use hyperactor::message::Unbind;
64use pin_project::pin_project;
65use pin_project::pinned_drop;
66use serde::Deserialize;
67use serde::Serialize;
68use tokio::io::AsyncRead;
69use tokio::io::AsyncWrite;
70use tokio_util::io::StreamReader;
71
72const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
74
75#[derive(Debug, Serialize, Deserialize, Named, Clone)]
77enum Io {
78 Data(#[serde(with = "serde_bytes")] Vec<u8>),
80 Eof,
82}
83
84struct OwnedReadHalfStream {
85 port: PortReceiver<Io>,
86 exhausted: bool,
87}
88
89pub struct OwnedReadHalf {
91 peer: ActorId,
92 inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
93}
94
95#[pin_project(PinnedDrop)]
97pub struct OwnedWriteHalf<C: context::Actor> {
98 peer: ActorId,
99 #[pin]
100 caps: C,
101 #[pin]
102 port: PortRef<Io>,
103 #[pin]
104 shutdown: bool,
105}
106
107#[pin_project]
109pub struct ActorConnection<C: context::Actor> {
110 #[pin]
111 reader: OwnedReadHalf,
112 #[pin]
113 writer: OwnedWriteHalf<C>,
114}
115
116impl<C: context::Actor> ActorConnection<C> {
117 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
118 (self.reader, self.writer)
119 }
120
121 pub fn peer(&self) -> &ActorId {
122 self.reader.peer()
123 }
124}
125
126impl OwnedReadHalf {
127 fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
128 Self {
129 peer,
130 inner: StreamReader::new(OwnedReadHalfStream {
131 port,
132 exhausted: false,
133 }),
134 }
135 }
136
137 pub fn peer(&self) -> &ActorId {
138 &self.peer
139 }
140
141 pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
142 ActorConnection {
143 reader: self,
144 writer: other,
145 }
146 }
147}
148
149impl<C: context::Actor> OwnedWriteHalf<C> {
150 fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
151 Self {
152 peer,
153 caps,
154 port,
155 shutdown: false,
156 }
157 }
158
159 pub fn peer(&self) -> &ActorId {
160 &self.peer
161 }
162
163 pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
164 ActorConnection {
165 reader: other,
166 writer: self,
167 }
168 }
169}
170
171#[pinned_drop]
172impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
173 fn drop(self: Pin<&mut Self>) {
174 let this = self.project();
175 if !*this.shutdown {
176 let _ = this.port.send(&*this.caps, Io::Eof);
177 }
178 }
179}
180
181impl<C: context::Actor> AsyncRead for ActorConnection<C> {
182 fn poll_read(
183 self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &mut tokio::io::ReadBuf<'_>,
186 ) -> Poll<std::io::Result<()>> {
187 let this = self.project();
189 this.reader.poll_read(cx, buf)
190 }
191}
192
193impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
194 fn poll_write(
195 self: Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 buf: &[u8],
198 ) -> Poll<Result<usize, std::io::Error>> {
199 let this = self.project();
201 this.writer.poll_write(cx, buf)
202 }
203
204 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
205 let this = self.project();
206 this.writer.poll_flush(cx)
207 }
208
209 fn poll_shutdown(
210 self: Pin<&mut Self>,
211 cx: &mut Context<'_>,
212 ) -> Poll<Result<(), std::io::Error>> {
213 let this = self.project();
214 this.writer.poll_shutdown(cx)
215 }
216}
217
218impl Stream for OwnedReadHalfStream {
219 type Item = std::io::Result<Cursor<Vec<u8>>>;
220
221 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 if self.exhausted {
224 return Poll::Ready(None);
225 }
226
227 let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
228 match result {
229 Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
230 Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
231 Ok(Io::Eof) => {
233 self.exhausted = true;
234 Poll::Ready(None)
235 }
236 }
237 }
238}
239
240impl FusedStream for OwnedReadHalfStream {
241 fn is_terminated(&self) -> bool {
242 self.exhausted
243 }
244}
245
246impl AsyncRead for OwnedReadHalf {
247 fn poll_read(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> Poll<std::io::Result<()>> {
252 Pin::new(&mut self.inner).poll_read(cx, buf)
253 }
254}
255
256impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
257 fn poll_write(
258 self: Pin<&mut Self>,
259 _cx: &mut Context<'_>,
260 buf: &[u8],
261 ) -> Poll<Result<usize, std::io::Error>> {
262 let this = self.project();
263 if *this.shutdown {
264 return Poll::Ready(Err(std::io::Error::new(
265 std::io::ErrorKind::BrokenPipe,
266 "write after shutdown",
267 )));
268 }
269 match this.port.send(&*this.caps, Io::Data(buf.into())) {
270 Ok(()) => Poll::Ready(Ok(buf.len())),
271 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
272 }
273 }
274
275 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
276 Poll::Ready(Ok(()))
277 }
278
279 fn poll_shutdown(
280 self: Pin<&mut Self>,
281 _cx: &mut Context<'_>,
282 ) -> Poll<Result<(), std::io::Error>> {
283 match self.port.send(&self.caps, Io::Eof) {
285 Ok(()) => {
286 let mut this = self.project();
287 *this.shutdown = true;
288 Poll::Ready(Ok(()))
289 }
290 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
291 }
292 }
293}
294
295pub struct ConnectionCompleter<C> {
297 caps: C,
298 conn: PortReceiver<Io>,
299 port: OncePortReceiver<Accept>,
300}
301
302impl<C: context::Actor> ConnectionCompleter<C> {
303 pub async fn complete(self) -> Result<ActorConnection<C>> {
306 let accept = RealClock
307 .timeout(CONNECT_TIMEOUT, self.port.recv())
308 .await??;
309 Ok(ActorConnection {
310 reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
311 writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
312 })
313 }
314}
315
316#[derive(Debug, Serialize, Deserialize, Named, Clone)]
318pub struct Connect {
319 id: ActorId,
321 conn: PortRef<Io>,
322 return_conn: OncePortRef<Accept>,
324}
325
326impl Connect {
327 pub fn allocate<C: context::Actor>(id: ActorId, caps: C) -> (Self, ConnectionCompleter<C>) {
330 let (conn_tx, conn_rx) = open_port::<Io>(&caps);
331 let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
332 (
333 Self {
334 id,
335 conn: conn_tx.bind(),
336 return_conn: return_tx.bind(),
337 },
338 ConnectionCompleter {
339 caps,
340 conn: conn_rx,
341 port: return_rx,
342 },
343 )
344 }
345}
346
347#[derive(Debug, Serialize, Deserialize, Named, Clone)]
350struct Accept {
351 id: ActorId,
353 conn: PortRef<Io>,
355}
356
357impl Bind for Connect {
358 fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
359 self.conn.bind(bindings)?;
360 self.return_conn.bind(bindings)
361 }
362}
363
364impl Unbind for Connect {
365 fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
366 self.conn.unbind(bindings)?;
367 self.return_conn.unbind(bindings)
368 }
369}
370
371pub async fn accept<C: context::Actor>(
374 caps: C,
375 self_id: ActorId,
376 message: Connect,
377) -> Result<ActorConnection<C>> {
378 let (tx, rx) = open_port::<Io>(&caps);
379 message.return_conn.send(
380 &caps,
381 Accept {
382 id: self_id,
383 conn: tx.bind(),
384 },
385 )?;
386 Ok(ActorConnection {
387 reader: OwnedReadHalf::new(message.id.clone(), rx),
388 writer: OwnedWriteHalf::new(message.id, caps, message.conn),
389 })
390}
391
392#[cfg(test)]
393mod tests {
394 use anyhow::Result;
395 use async_trait::async_trait;
396 use futures::try_join;
397 use hyperactor::Actor;
398 use hyperactor::Context;
399 use hyperactor::Handler;
400 use hyperactor::proc::Proc;
401 use tokio::io::AsyncReadExt;
402 use tokio::io::AsyncWriteExt;
403
404 use super::*;
405
406 #[derive(Debug, Default, Actor)]
407 struct EchoActor {}
408
409 #[async_trait]
410 impl Handler<Connect> for EchoActor {
411 async fn handle(
412 &mut self,
413 cx: &Context<Self>,
414 message: Connect,
415 ) -> Result<(), anyhow::Error> {
416 let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
417 .await?
418 .into_split();
419 tokio::io::copy(&mut rd, &mut wr).await?;
420 wr.shutdown().await?;
421 Ok(())
422 }
423 }
424
425 #[tokio::test]
426 async fn test_simple_connection() -> Result<()> {
427 let proc = Proc::local();
428 let (client, _client_handle) = proc.instance("client")?;
429 let (connect, completer) = Connect::allocate(client.self_id().clone(), client);
430 let actor = proc.spawn::<EchoActor>("actor", ()).await?;
431 actor.send(connect)?;
432 let (mut rd, mut wr) = completer.complete().await?.into_split();
433 let send = [3u8, 4u8, 5u8, 6u8];
434 try_join!(
435 async move {
436 wr.write_all(&send).await?;
437 wr.shutdown().await?;
438 anyhow::Ok(())
439 },
440 async {
441 let mut recv = vec![];
442 rd.read_to_end(&mut recv).await?;
443 assert_eq!(&send, recv.as_slice());
444 anyhow::Ok(())
445 },
446 )?;
447 Ok(())
448 }
449
450 #[tokio::test]
451 async fn test_connection_close_on_drop() -> Result<()> {
452 let proc = Proc::local();
453 let (client, _client_handle) = proc.instance("client")?;
454
455 let (connect, completer) =
456 Connect::allocate(client.self_id().clone(), client.clone_for_py());
457 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
458 .await?
459 .into_split();
460 let (_, mut wr) = completer.complete().await?.into_split();
461
462 let send = [1u8, 2u8, 3u8];
464 wr.write_all(&send).await?;
465
466 drop(wr);
468
469 let mut recv = vec![];
471 rd.read_to_end(&mut recv).await?;
472 assert_eq!(&send, recv.as_slice());
473
474 Ok(())
475 }
476
477 #[tokio::test]
478 async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
479 let proc = Proc::local();
480 let (client, _client_handle) = proc.instance("client")?;
481
482 let (connect, completer) =
483 Connect::allocate(client.self_id().clone(), client.clone_for_py());
484 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
485 .await?
486 .into_split();
487 let (_, mut wr) = completer.complete().await?.into_split();
488
489 let send = [1u8, 2u8, 3u8];
491 wr.write_all(&send).await?;
492
493 wr.shutdown().await?;
495
496 let mut recv = vec![];
498 rd.read_to_end(&mut recv).await?;
499 assert_eq!(&send, recv.as_slice());
500
501 drop(wr);
503
504 assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
506
507 Ok(())
508 }
509}