hyperactor_mesh/
connect.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Actor-based duplex bytestream connections.
10//!
11//! This module provides the equivalent of a `TcpStream` duplex bytestream connection between two actors,
12//! implemented via actor message passing. It allows actors to communicate using familiar `AsyncRead` and
13//! `AsyncWrite` interfaces while leveraging the hyperactor framework's message passing capabilities.
14//!
15//! # Overview
16//!
17//! The connection system consists of:
18//! - [`ActorConnection`]: A duplex connection that implements both `AsyncRead` and `AsyncWrite`
19//! - [`OwnedReadHalf`] and [`OwnedWriteHalf`]: Split halves for independent reading and writing
20//! - [`Connect`] message for establishing connections
21//! - Helper functions [`connect`] and [`accept`] for client and server usage
22//!
23//! # Usage Patterns
24//!
25//! ## Client Side (Initiating Connection)
26//!
27//! Clients use `Connect::allocate()` to create a connection request. This method returns:
28//! 1. A `Connect` message to send to the server to initiate the connection
29//! 2. A `ConnectionCompleter` object that can be awaited for the server to finish connecting,
30//!    returning the `ActorConnection` used by the client.
31//!
32//! The typical pattern is: allocate components, send Connect message to server, await completion.
33//!
34//! ## Server Side (Accepting Connections)
35//!
36//! Servers forward `Connect` messages to the `accept()` helper function to finish setting up the
37//! connection, which returns the `ActorConnection` they can use.
38
39use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::Named;
52use hyperactor::OncePortRef;
53use hyperactor::PortRef;
54use hyperactor::clock::Clock;
55use hyperactor::clock::RealClock;
56use hyperactor::context;
57use hyperactor::mailbox::OncePortReceiver;
58use hyperactor::mailbox::PortReceiver;
59use hyperactor::mailbox::open_once_port;
60use hyperactor::mailbox::open_port;
61use hyperactor::message::Bind;
62use hyperactor::message::Bindings;
63use hyperactor::message::Unbind;
64use pin_project::pin_project;
65use pin_project::pinned_drop;
66use serde::Deserialize;
67use serde::Serialize;
68use tokio::io::AsyncRead;
69use tokio::io::AsyncWrite;
70use tokio_util::io::StreamReader;
71
72// Timeout for establishing a connection, used by both client and server.
73const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
74
75/// Messages sent over the "connection" to facilitate communication.
76#[derive(Debug, Serialize, Deserialize, Named, Clone)]
77enum Io {
78    // A data packet.
79    Data(#[serde(with = "serde_bytes")] Vec<u8>),
80    // Signal the end of one side of the connection.
81    Eof,
82}
83
84struct OwnedReadHalfStream {
85    port: PortReceiver<Io>,
86    exhausted: bool,
87}
88
89/// Wrap a `PortReceiver<IoMsg>` as a `AsyncRead`.
90pub struct OwnedReadHalf {
91    peer: ActorId,
92    inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
93}
94
95/// Wrap a `PortRef<IoMsg>` as a `AsyncWrite`.
96#[pin_project(PinnedDrop)]
97pub struct OwnedWriteHalf<C: context::Actor> {
98    peer: ActorId,
99    #[pin]
100    caps: C,
101    #[pin]
102    port: PortRef<Io>,
103    #[pin]
104    shutdown: bool,
105}
106
107/// A duplex bytestream connection between two actors.  Can generally be used like a `TcpStream`.
108#[pin_project]
109pub struct ActorConnection<C: context::Actor> {
110    #[pin]
111    reader: OwnedReadHalf,
112    #[pin]
113    writer: OwnedWriteHalf<C>,
114}
115
116impl<C: context::Actor> ActorConnection<C> {
117    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
118        (self.reader, self.writer)
119    }
120
121    pub fn peer(&self) -> &ActorId {
122        self.reader.peer()
123    }
124}
125
126impl OwnedReadHalf {
127    fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
128        Self {
129            peer,
130            inner: StreamReader::new(OwnedReadHalfStream {
131                port,
132                exhausted: false,
133            }),
134        }
135    }
136
137    pub fn peer(&self) -> &ActorId {
138        &self.peer
139    }
140
141    pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
142        ActorConnection {
143            reader: self,
144            writer: other,
145        }
146    }
147}
148
149impl<C: context::Actor> OwnedWriteHalf<C> {
150    fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
151        Self {
152            peer,
153            caps,
154            port,
155            shutdown: false,
156        }
157    }
158
159    pub fn peer(&self) -> &ActorId {
160        &self.peer
161    }
162
163    pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
164        ActorConnection {
165            reader: other,
166            writer: self,
167        }
168    }
169}
170
171#[pinned_drop]
172impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
173    fn drop(self: Pin<&mut Self>) {
174        let this = self.project();
175        if !*this.shutdown {
176            let _ = this.port.send(&*this.caps, Io::Eof);
177        }
178    }
179}
180
181impl<C: context::Actor> AsyncRead for ActorConnection<C> {
182    fn poll_read(
183        self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        buf: &mut tokio::io::ReadBuf<'_>,
186    ) -> Poll<std::io::Result<()>> {
187        // Use project() to get pinned references to fields
188        let this = self.project();
189        this.reader.poll_read(cx, buf)
190    }
191}
192
193impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
194    fn poll_write(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &[u8],
198    ) -> Poll<Result<usize, std::io::Error>> {
199        // Use project() to get pinned references to fields
200        let this = self.project();
201        this.writer.poll_write(cx, buf)
202    }
203
204    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
205        let this = self.project();
206        this.writer.poll_flush(cx)
207    }
208
209    fn poll_shutdown(
210        self: Pin<&mut Self>,
211        cx: &mut Context<'_>,
212    ) -> Poll<Result<(), std::io::Error>> {
213        let this = self.project();
214        this.writer.poll_shutdown(cx)
215    }
216}
217
218impl Stream for OwnedReadHalfStream {
219    type Item = std::io::Result<Cursor<Vec<u8>>>;
220
221    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        // Once exhausted, always return None
223        if self.exhausted {
224            return Poll::Ready(None);
225        }
226
227        let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
228        match result {
229            Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
230            Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
231            // Break out of stream when we see EOF.
232            Ok(Io::Eof) => {
233                self.exhausted = true;
234                Poll::Ready(None)
235            }
236        }
237    }
238}
239
240impl FusedStream for OwnedReadHalfStream {
241    fn is_terminated(&self) -> bool {
242        self.exhausted
243    }
244}
245
246impl AsyncRead for OwnedReadHalf {
247    fn poll_read(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<std::io::Result<()>> {
252        Pin::new(&mut self.inner).poll_read(cx, buf)
253    }
254}
255
256impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
257    fn poll_write(
258        self: Pin<&mut Self>,
259        _cx: &mut Context<'_>,
260        buf: &[u8],
261    ) -> Poll<Result<usize, std::io::Error>> {
262        let this = self.project();
263        if *this.shutdown {
264            return Poll::Ready(Err(std::io::Error::new(
265                std::io::ErrorKind::BrokenPipe,
266                "write after shutdown",
267            )));
268        }
269        match this.port.send(&*this.caps, Io::Data(buf.into())) {
270            Ok(()) => Poll::Ready(Ok(buf.len())),
271            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
272        }
273    }
274
275    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
276        Poll::Ready(Ok(()))
277    }
278
279    fn poll_shutdown(
280        self: Pin<&mut Self>,
281        _cx: &mut Context<'_>,
282    ) -> Poll<Result<(), std::io::Error>> {
283        // Send EOF on shutdown.
284        match self.port.send(&self.caps, Io::Eof) {
285            Ok(()) => {
286                let mut this = self.project();
287                *this.shutdown = true;
288                Poll::Ready(Ok(()))
289            }
290            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
291        }
292    }
293}
294
295/// A helper struct that contains the state needed to complete a connection.
296pub struct ConnectionCompleter<C> {
297    caps: C,
298    conn: PortReceiver<Io>,
299    port: OncePortReceiver<Accept>,
300}
301
302impl<C: context::Actor> ConnectionCompleter<C> {
303    /// Wait for the server to accept the connection and return the streams that can be used to communicate
304    /// with the server.
305    pub async fn complete(self) -> Result<ActorConnection<C>> {
306        let accept = RealClock
307            .timeout(CONNECT_TIMEOUT, self.port.recv())
308            .await??;
309        Ok(ActorConnection {
310            reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
311            writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
312        })
313    }
314}
315
316/// A message sent from a client to initiate a connection.
317#[derive(Debug, Serialize, Deserialize, Named, Clone)]
318pub struct Connect {
319    /// The ID of the client initiating the connection.
320    id: ActorId,
321    conn: PortRef<Io>,
322    /// The port the server can use to complete the connection.
323    return_conn: OncePortRef<Accept>,
324}
325
326impl Connect {
327    /// Allocate a new `Connect` message and return the associated `ConnectionCompleter` that can be used
328    /// to finish setting up the connection.
329    pub fn allocate<C: context::Actor>(id: ActorId, caps: C) -> (Self, ConnectionCompleter<C>) {
330        let (conn_tx, conn_rx) = open_port::<Io>(&caps);
331        let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
332        (
333            Self {
334                id,
335                conn: conn_tx.bind(),
336                return_conn: return_tx.bind(),
337            },
338            ConnectionCompleter {
339                caps,
340                conn: conn_rx,
341                port: return_rx,
342            },
343        )
344    }
345}
346
347/// A response message sent from the server back to the client to complete setting
348/// up the connection.
349#[derive(Debug, Serialize, Deserialize, Named, Clone)]
350struct Accept {
351    /// The ID of the server that accepted the connection.
352    id: ActorId,
353    /// The port the client will use to send data over the connection to the server.
354    conn: PortRef<Io>,
355}
356
357impl Bind for Connect {
358    fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
359        self.conn.bind(bindings)?;
360        self.return_conn.bind(bindings)
361    }
362}
363
364impl Unbind for Connect {
365    fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
366        self.conn.unbind(bindings)?;
367        self.return_conn.unbind(bindings)
368    }
369}
370
371/// Helper used by `Handler<Connect>`s to accept a connection initiated by a `Connect` message and
372/// return `AsyncRead` and `AsyncWrite` streams that can be used to communicate with the other side.
373pub async fn accept<C: context::Actor>(
374    caps: C,
375    self_id: ActorId,
376    message: Connect,
377) -> Result<ActorConnection<C>> {
378    let (tx, rx) = open_port::<Io>(&caps);
379    message.return_conn.send(
380        &caps,
381        Accept {
382            id: self_id,
383            conn: tx.bind(),
384        },
385    )?;
386    Ok(ActorConnection {
387        reader: OwnedReadHalf::new(message.id.clone(), rx),
388        writer: OwnedWriteHalf::new(message.id, caps, message.conn),
389    })
390}
391
392#[cfg(test)]
393mod tests {
394    use anyhow::Result;
395    use async_trait::async_trait;
396    use futures::try_join;
397    use hyperactor::Actor;
398    use hyperactor::Context;
399    use hyperactor::Handler;
400    use hyperactor::proc::Proc;
401    use tokio::io::AsyncReadExt;
402    use tokio::io::AsyncWriteExt;
403
404    use super::*;
405
406    #[derive(Debug, Default, Actor)]
407    struct EchoActor {}
408
409    #[async_trait]
410    impl Handler<Connect> for EchoActor {
411        async fn handle(
412            &mut self,
413            cx: &Context<Self>,
414            message: Connect,
415        ) -> Result<(), anyhow::Error> {
416            let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
417                .await?
418                .into_split();
419            tokio::io::copy(&mut rd, &mut wr).await?;
420            wr.shutdown().await?;
421            Ok(())
422        }
423    }
424
425    #[tokio::test]
426    async fn test_simple_connection() -> Result<()> {
427        let proc = Proc::local();
428        let (client, _client_handle) = proc.instance("client")?;
429        let (connect, completer) = Connect::allocate(client.self_id().clone(), client);
430        let actor = proc.spawn::<EchoActor>("actor", ()).await?;
431        actor.send(connect)?;
432        let (mut rd, mut wr) = completer.complete().await?.into_split();
433        let send = [3u8, 4u8, 5u8, 6u8];
434        try_join!(
435            async move {
436                wr.write_all(&send).await?;
437                wr.shutdown().await?;
438                anyhow::Ok(())
439            },
440            async {
441                let mut recv = vec![];
442                rd.read_to_end(&mut recv).await?;
443                assert_eq!(&send, recv.as_slice());
444                anyhow::Ok(())
445            },
446        )?;
447        Ok(())
448    }
449
450    #[tokio::test]
451    async fn test_connection_close_on_drop() -> Result<()> {
452        let proc = Proc::local();
453        let (client, _client_handle) = proc.instance("client")?;
454
455        let (connect, completer) =
456            Connect::allocate(client.self_id().clone(), client.clone_for_py());
457        let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
458            .await?
459            .into_split();
460        let (_, mut wr) = completer.complete().await?.into_split();
461
462        // Write some data
463        let send = [1u8, 2u8, 3u8];
464        wr.write_all(&send).await?;
465
466        // Drop the writer without explicit shutdown - this should send EOF
467        drop(wr);
468
469        // Reader should receive the data and then EOF (causing read_to_end to complete)
470        let mut recv = vec![];
471        rd.read_to_end(&mut recv).await?;
472        assert_eq!(&send, recv.as_slice());
473
474        Ok(())
475    }
476
477    #[tokio::test]
478    async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
479        let proc = Proc::local();
480        let (client, _client_handle) = proc.instance("client")?;
481
482        let (connect, completer) =
483            Connect::allocate(client.self_id().clone(), client.clone_for_py());
484        let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
485            .await?
486            .into_split();
487        let (_, mut wr) = completer.complete().await?.into_split();
488
489        // Write some data
490        let send = [1u8, 2u8, 3u8];
491        wr.write_all(&send).await?;
492
493        // Explicitly shutdown the writer - this sends EOF and sets shutdown=true
494        wr.shutdown().await?;
495
496        // Reader should receive the data and then EOF (from explicit shutdown, not from drop)
497        let mut recv = vec![];
498        rd.read_to_end(&mut recv).await?;
499        assert_eq!(&send, recv.as_slice());
500
501        // Drop the writer after explicit shutdown - this should NOT send another EOF
502        drop(wr);
503
504        // Verify we didn't see another EOF message.
505        assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
506
507        Ok(())
508    }
509}