1use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::OncePortRef;
52use hyperactor::PortRef;
53use hyperactor::clock::Clock;
54use hyperactor::clock::RealClock;
55use hyperactor::context;
56use hyperactor::mailbox::OncePortReceiver;
57use hyperactor::mailbox::PortReceiver;
58use hyperactor::mailbox::open_once_port;
59use hyperactor::mailbox::open_port;
60use hyperactor::message::Bind;
61use hyperactor::message::Bindings;
62use hyperactor::message::Unbind;
63use pin_project::pin_project;
64use pin_project::pinned_drop;
65use serde::Deserialize;
66use serde::Serialize;
67use tokio::io::AsyncRead;
68use tokio::io::AsyncWrite;
69use tokio_util::io::StreamReader;
70use typeuri::Named;
71
72const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
74
75#[derive(Debug, Serialize, Deserialize, Named, Clone)]
77enum Io {
78 Data(#[serde(with = "serde_bytes")] Vec<u8>),
80 Eof,
82}
83wirevalue::register_type!(Io);
84
85struct OwnedReadHalfStream {
86 port: PortReceiver<Io>,
87 exhausted: bool,
88}
89
90pub struct OwnedReadHalf {
92 peer: ActorId,
93 inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
94}
95
96#[pin_project(PinnedDrop)]
98pub struct OwnedWriteHalf<C: context::Actor> {
99 peer: ActorId,
100 #[pin]
101 caps: C,
102 #[pin]
103 port: PortRef<Io>,
104 #[pin]
105 shutdown: bool,
106}
107
108#[pin_project]
110pub struct ActorConnection<C: context::Actor> {
111 #[pin]
112 reader: OwnedReadHalf,
113 #[pin]
114 writer: OwnedWriteHalf<C>,
115}
116
117impl<C: context::Actor> ActorConnection<C> {
118 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
119 (self.reader, self.writer)
120 }
121
122 pub fn peer(&self) -> &ActorId {
123 self.reader.peer()
124 }
125}
126
127impl OwnedReadHalf {
128 fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
129 Self {
130 peer,
131 inner: StreamReader::new(OwnedReadHalfStream {
132 port,
133 exhausted: false,
134 }),
135 }
136 }
137
138 pub fn peer(&self) -> &ActorId {
139 &self.peer
140 }
141
142 pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
143 ActorConnection {
144 reader: self,
145 writer: other,
146 }
147 }
148}
149
150impl<C: context::Actor> OwnedWriteHalf<C> {
151 fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
152 Self {
153 peer,
154 caps,
155 port,
156 shutdown: false,
157 }
158 }
159
160 pub fn peer(&self) -> &ActorId {
161 &self.peer
162 }
163
164 pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
165 ActorConnection {
166 reader: other,
167 writer: self,
168 }
169 }
170}
171
172#[pinned_drop]
173impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
174 fn drop(self: Pin<&mut Self>) {
175 let this = self.project();
176 if !*this.shutdown {
177 let _ = this.port.send(&*this.caps, Io::Eof);
178 }
179 }
180}
181
182impl<C: context::Actor> AsyncRead for ActorConnection<C> {
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut tokio::io::ReadBuf<'_>,
187 ) -> Poll<std::io::Result<()>> {
188 let this = self.project();
190 this.reader.poll_read(cx, buf)
191 }
192}
193
194impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
195 fn poll_write(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &[u8],
199 ) -> Poll<Result<usize, std::io::Error>> {
200 let this = self.project();
202 this.writer.poll_write(cx, buf)
203 }
204
205 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
206 let this = self.project();
207 this.writer.poll_flush(cx)
208 }
209
210 fn poll_shutdown(
211 self: Pin<&mut Self>,
212 cx: &mut Context<'_>,
213 ) -> Poll<Result<(), std::io::Error>> {
214 let this = self.project();
215 this.writer.poll_shutdown(cx)
216 }
217}
218
219impl Stream for OwnedReadHalfStream {
220 type Item = std::io::Result<Cursor<Vec<u8>>>;
221
222 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223 if self.exhausted {
225 return Poll::Ready(None);
226 }
227
228 let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
229 match result {
230 Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
231 Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
232 Ok(Io::Eof) => {
234 self.exhausted = true;
235 Poll::Ready(None)
236 }
237 }
238 }
239}
240
241impl FusedStream for OwnedReadHalfStream {
242 fn is_terminated(&self) -> bool {
243 self.exhausted
244 }
245}
246
247impl AsyncRead for OwnedReadHalf {
248 fn poll_read(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &mut tokio::io::ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 Pin::new(&mut self.inner).poll_read(cx, buf)
254 }
255}
256
257impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
258 fn poll_write(
259 self: Pin<&mut Self>,
260 _cx: &mut Context<'_>,
261 buf: &[u8],
262 ) -> Poll<Result<usize, std::io::Error>> {
263 let this = self.project();
264 if *this.shutdown {
265 return Poll::Ready(Err(std::io::Error::new(
266 std::io::ErrorKind::BrokenPipe,
267 "write after shutdown",
268 )));
269 }
270 match this.port.send(&*this.caps, Io::Data(buf.into())) {
271 Ok(()) => Poll::Ready(Ok(buf.len())),
272 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
273 }
274 }
275
276 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
277 Poll::Ready(Ok(()))
278 }
279
280 fn poll_shutdown(
281 self: Pin<&mut Self>,
282 _cx: &mut Context<'_>,
283 ) -> Poll<Result<(), std::io::Error>> {
284 match self.port.send(&self.caps, Io::Eof) {
286 Ok(()) => {
287 let mut this = self.project();
288 *this.shutdown = true;
289 Poll::Ready(Ok(()))
290 }
291 Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
292 }
293 }
294}
295
296pub struct ConnectionCompleter<C> {
298 caps: C,
299 conn: PortReceiver<Io>,
300 port: OncePortReceiver<Accept>,
301}
302
303impl<C: context::Actor> ConnectionCompleter<C> {
304 pub async fn complete(self) -> Result<ActorConnection<C>> {
307 let accept = RealClock
308 .timeout(CONNECT_TIMEOUT, self.port.recv())
309 .await??;
310 Ok(ActorConnection {
311 reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
312 writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
313 })
314 }
315}
316
317#[derive(Debug, Serialize, Deserialize, Named, Clone)]
319pub struct Connect {
320 id: ActorId,
322 conn: PortRef<Io>,
323 return_conn: OncePortRef<Accept>,
325}
326wirevalue::register_type!(Connect);
327
328impl Connect {
329 pub fn allocate<C: context::Actor>(id: ActorId, caps: C) -> (Self, ConnectionCompleter<C>) {
332 let (conn_tx, conn_rx) = open_port::<Io>(&caps);
333 let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
334 (
335 Self {
336 id,
337 conn: conn_tx.bind(),
338 return_conn: return_tx.bind(),
339 },
340 ConnectionCompleter {
341 caps,
342 conn: conn_rx,
343 port: return_rx,
344 },
345 )
346 }
347}
348
349#[derive(Debug, Serialize, Deserialize, Named, Clone)]
352struct Accept {
353 id: ActorId,
355 conn: PortRef<Io>,
357}
358wirevalue::register_type!(Accept);
359
360impl Bind for Connect {
361 fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
362 self.conn.bind(bindings)?;
363 self.return_conn.bind(bindings)
364 }
365}
366
367impl Unbind for Connect {
368 fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
369 self.conn.unbind(bindings)?;
370 self.return_conn.unbind(bindings)
371 }
372}
373
374pub async fn accept<C: context::Actor>(
377 caps: C,
378 self_id: ActorId,
379 message: Connect,
380) -> Result<ActorConnection<C>> {
381 let (tx, rx) = open_port::<Io>(&caps);
382 message.return_conn.send(
383 &caps,
384 Accept {
385 id: self_id,
386 conn: tx.bind(),
387 },
388 )?;
389 Ok(ActorConnection {
390 reader: OwnedReadHalf::new(message.id.clone(), rx),
391 writer: OwnedWriteHalf::new(message.id, caps, message.conn),
392 })
393}
394
395#[cfg(test)]
396mod tests {
397 use anyhow::Result;
398 use async_trait::async_trait;
399 use futures::try_join;
400 use hyperactor::Actor;
401 use hyperactor::Context;
402 use hyperactor::Handler;
403 use hyperactor::proc::Proc;
404 use tokio::io::AsyncReadExt;
405 use tokio::io::AsyncWriteExt;
406
407 use super::*;
408
409 #[derive(Debug, Default)]
410 struct EchoActor {}
411
412 impl Actor for EchoActor {}
413
414 #[async_trait]
415 impl Handler<Connect> for EchoActor {
416 async fn handle(
417 &mut self,
418 cx: &Context<Self>,
419 message: Connect,
420 ) -> Result<(), anyhow::Error> {
421 let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
422 .await?
423 .into_split();
424 tokio::io::copy(&mut rd, &mut wr).await?;
425 wr.shutdown().await?;
426 Ok(())
427 }
428 }
429
430 #[tokio::test]
431 async fn test_simple_connection() -> Result<()> {
432 let proc = Proc::local();
433 let (client, _) = proc.instance("client")?;
434 let (connect, completer) = Connect::allocate(client.self_id().clone(), client);
435 let actor = proc.spawn("actor", EchoActor {})?;
436 actor.send(&completer.caps, connect)?;
437 let (mut rd, mut wr) = completer.complete().await?.into_split();
438 let send = [3u8, 4u8, 5u8, 6u8];
439 try_join!(
440 async move {
441 wr.write_all(&send).await?;
442 wr.shutdown().await?;
443 anyhow::Ok(())
444 },
445 async {
446 let mut recv = vec![];
447 rd.read_to_end(&mut recv).await?;
448 assert_eq!(&send, recv.as_slice());
449 anyhow::Ok(())
450 },
451 )?;
452 Ok(())
453 }
454
455 #[tokio::test]
456 async fn test_connection_close_on_drop() -> Result<()> {
457 let proc = Proc::local();
458 let (client, _client_handle) = proc.instance("client")?;
459
460 let (connect, completer) =
461 Connect::allocate(client.self_id().clone(), client.clone_for_py());
462 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
463 .await?
464 .into_split();
465 let (_, mut wr) = completer.complete().await?.into_split();
466
467 let send = [1u8, 2u8, 3u8];
469 wr.write_all(&send).await?;
470
471 drop(wr);
473
474 let mut recv = vec![];
476 rd.read_to_end(&mut recv).await?;
477 assert_eq!(&send, recv.as_slice());
478
479 Ok(())
480 }
481
482 #[tokio::test]
483 async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
484 let proc = Proc::local();
485 let (client, _client_handle) = proc.instance("client")?;
486
487 let (connect, completer) =
488 Connect::allocate(client.self_id().clone(), client.clone_for_py());
489 let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
490 .await?
491 .into_split();
492 let (_, mut wr) = completer.complete().await?.into_split();
493
494 let send = [1u8, 2u8, 3u8];
496 wr.write_all(&send).await?;
497
498 wr.shutdown().await?;
500
501 let mut recv = vec![];
503 rd.read_to_end(&mut recv).await?;
504 assert_eq!(&send, recv.as_slice());
505
506 drop(wr);
508
509 assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
511
512 Ok(())
513 }
514}