hyperactor_mesh/
connect.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Actor-based duplex bytestream connections.
10//!
11//! This module provides the equivalent of a `TcpStream` duplex bytestream connection between two actors,
12//! implemented via actor message passing. It allows actors to communicate using familiar `AsyncRead` and
13//! `AsyncWrite` interfaces while leveraging the hyperactor framework's message passing capabilities.
14//!
15//! # Overview
16//!
17//! The connection system consists of:
18//! - [`ActorConnection`]: A duplex connection that implements both `AsyncRead` and `AsyncWrite`
19//! - [`OwnedReadHalf`] and [`OwnedWriteHalf`]: Split halves for independent reading and writing
20//! - [`Connect`] message for establishing connections
21//! - Helper functions [`connect`] and [`accept`] for client and server usage
22//!
23//! # Usage Patterns
24//!
25//! ## Client Side (Initiating Connection)
26//!
27//! Clients use `Connect::allocate()` to create a connection request. This method returns:
28//! 1. A `Connect` message to send to the server to initiate the connection
29//! 2. A `ConnectionCompleter` object that can be awaited for the server to finish connecting,
30//!    returning the `ActorConnection` used by the client.
31//!
32//! The typical pattern is: allocate components, send Connect message to server, await completion.
33//!
34//! ## Server Side (Accepting Connections)
35//!
36//! Servers forward `Connect` messages to the `accept()` helper function to finish setting up the
37//! connection, which returns the `ActorConnection` they can use.
38
39use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::Mailbox;
52use hyperactor::Named;
53use hyperactor::OncePortRef;
54use hyperactor::PortRef;
55use hyperactor::cap::CanOpenPort;
56use hyperactor::cap::CanSend;
57use hyperactor::clock::Clock;
58use hyperactor::clock::RealClock;
59use hyperactor::mailbox::OncePortReceiver;
60use hyperactor::mailbox::PortReceiver;
61use hyperactor::mailbox::open_once_port;
62use hyperactor::mailbox::open_port;
63use hyperactor::message::Bind;
64use hyperactor::message::Bindings;
65use hyperactor::message::Unbind;
66use pin_project::pin_project;
67use pin_project::pinned_drop;
68use serde::Deserialize;
69use serde::Serialize;
70use tokio::io::AsyncRead;
71use tokio::io::AsyncWrite;
72use tokio_util::io::StreamReader;
73
74// Timeout for establishing a connection, used by both client and server.
75const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
76
77/// Messages sent over the "connection" to facilitate communication.
78#[derive(Debug, Serialize, Deserialize, Named, Clone)]
79enum Io {
80    // A data packet.
81    Data(#[serde(with = "serde_bytes")] Vec<u8>),
82    // Signal the end of one side of the connection.
83    Eof,
84}
85
86struct OwnedReadHalfStream {
87    port: PortReceiver<Io>,
88    exhausted: bool,
89}
90
91/// Wrap a `PortReceiver<IoMsg>` as a `AsyncRead`.
92pub struct OwnedReadHalf {
93    peer: ActorId,
94    inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
95}
96
97/// Wrap a `PortRef<IoMsg>` as a `AsyncWrite`.
98#[pin_project(PinnedDrop)]
99pub struct OwnedWriteHalf<C: CanSend> {
100    peer: ActorId,
101    #[pin]
102    caps: C,
103    #[pin]
104    port: PortRef<Io>,
105    #[pin]
106    shutdown: bool,
107}
108
109/// A duplex bytestream connection between two actors.  Can generally be used like a `TcpStream`.
110#[pin_project]
111pub struct ActorConnection<C: CanSend> {
112    #[pin]
113    reader: OwnedReadHalf,
114    #[pin]
115    writer: OwnedWriteHalf<C>,
116}
117
118impl<C: CanSend> ActorConnection<C> {
119    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
120        (self.reader, self.writer)
121    }
122
123    pub fn peer(&self) -> &ActorId {
124        self.reader.peer()
125    }
126}
127
128impl OwnedReadHalf {
129    fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
130        Self {
131            peer,
132            inner: StreamReader::new(OwnedReadHalfStream {
133                port,
134                exhausted: false,
135            }),
136        }
137    }
138
139    pub fn peer(&self) -> &ActorId {
140        &self.peer
141    }
142
143    pub fn reunited<C: CanSend>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
144        ActorConnection {
145            reader: self,
146            writer: other,
147        }
148    }
149}
150
151impl<C: CanSend> OwnedWriteHalf<C> {
152    fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
153        Self {
154            peer,
155            caps,
156            port,
157            shutdown: false,
158        }
159    }
160
161    pub fn peer(&self) -> &ActorId {
162        &self.peer
163    }
164
165    pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
166        ActorConnection {
167            reader: other,
168            writer: self,
169        }
170    }
171}
172
173#[pinned_drop]
174impl<C: CanSend> PinnedDrop for OwnedWriteHalf<C> {
175    fn drop(self: Pin<&mut Self>) {
176        let this = self.project();
177        if !*this.shutdown {
178            let _ = this.port.send(&*this.caps, Io::Eof);
179        }
180    }
181}
182
183impl<C: CanSend> AsyncRead for ActorConnection<C> {
184    fn poll_read(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &mut tokio::io::ReadBuf<'_>,
188    ) -> Poll<std::io::Result<()>> {
189        // Use project() to get pinned references to fields
190        let this = self.project();
191        this.reader.poll_read(cx, buf)
192    }
193}
194
195impl<C: CanSend> AsyncWrite for ActorConnection<C> {
196    fn poll_write(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &[u8],
200    ) -> Poll<Result<usize, std::io::Error>> {
201        // Use project() to get pinned references to fields
202        let this = self.project();
203        this.writer.poll_write(cx, buf)
204    }
205
206    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
207        let this = self.project();
208        this.writer.poll_flush(cx)
209    }
210
211    fn poll_shutdown(
212        self: Pin<&mut Self>,
213        cx: &mut Context<'_>,
214    ) -> Poll<Result<(), std::io::Error>> {
215        let this = self.project();
216        this.writer.poll_shutdown(cx)
217    }
218}
219
220impl Stream for OwnedReadHalfStream {
221    type Item = std::io::Result<Cursor<Vec<u8>>>;
222
223    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224        // Once exhausted, always return None
225        if self.exhausted {
226            return Poll::Ready(None);
227        }
228
229        let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
230        match result {
231            Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
232            Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
233            // Break out of stream when we see EOF.
234            Ok(Io::Eof) => {
235                self.exhausted = true;
236                Poll::Ready(None)
237            }
238        }
239    }
240}
241
242impl FusedStream for OwnedReadHalfStream {
243    fn is_terminated(&self) -> bool {
244        self.exhausted
245    }
246}
247
248impl AsyncRead for OwnedReadHalf {
249    fn poll_read(
250        mut self: Pin<&mut Self>,
251        cx: &mut Context<'_>,
252        buf: &mut tokio::io::ReadBuf<'_>,
253    ) -> Poll<std::io::Result<()>> {
254        Pin::new(&mut self.inner).poll_read(cx, buf)
255    }
256}
257
258impl<C: CanSend> AsyncWrite for OwnedWriteHalf<C> {
259    fn poll_write(
260        self: Pin<&mut Self>,
261        _cx: &mut Context<'_>,
262        buf: &[u8],
263    ) -> Poll<Result<usize, std::io::Error>> {
264        let this = self.project();
265        if *this.shutdown {
266            return Poll::Ready(Err(std::io::Error::new(
267                std::io::ErrorKind::BrokenPipe,
268                "write after shutdown",
269            )));
270        }
271        match this.port.send(&*this.caps, Io::Data(buf.into())) {
272            Ok(()) => Poll::Ready(Ok(buf.len())),
273            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
274        }
275    }
276
277    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
278        Poll::Ready(Ok(()))
279    }
280
281    fn poll_shutdown(
282        self: Pin<&mut Self>,
283        _cx: &mut Context<'_>,
284    ) -> Poll<Result<(), std::io::Error>> {
285        // Send EOF on shutdown.
286        match self.port.send(&self.caps, Io::Eof) {
287            Ok(()) => {
288                let mut this = self.project();
289                *this.shutdown = true;
290                Poll::Ready(Ok(()))
291            }
292            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
293        }
294    }
295}
296
297/// A helper struct that contains the state needed to complete a connection.
298pub struct ConnectionCompleter<C> {
299    caps: C,
300    conn: PortReceiver<Io>,
301    port: OncePortReceiver<Accept>,
302}
303
304impl<C: CanOpenPort + CanSend> ConnectionCompleter<C> {
305    /// Wait for the server to accept the connection and return the streams that can be used to communicate
306    /// with the server.
307    pub async fn complete(self) -> Result<ActorConnection<C>> {
308        let accept = RealClock
309            .timeout(CONNECT_TIMEOUT, self.port.recv())
310            .await??;
311        Ok(ActorConnection {
312            reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
313            writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
314        })
315    }
316}
317
318/// A message sent from a client to initiate a connection.
319#[derive(Debug, Serialize, Deserialize, Named, Clone)]
320pub struct Connect {
321    /// The ID of the client initiating the connection.
322    id: ActorId,
323    conn: PortRef<Io>,
324    /// The port the server can use to complete the connection.
325    return_conn: OncePortRef<Accept>,
326}
327
328impl Connect {
329    /// Allocate a new `Connect` message and return the associated `ConnectionCompleter` that can be used
330    /// to finish setting up the connection.
331    pub fn allocate<C: CanOpenPort + CanSend>(
332        id: ActorId,
333        caps: C,
334    ) -> (Self, ConnectionCompleter<C>) {
335        let (conn_tx, conn_rx) = open_port::<Io>(&caps);
336        let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
337        (
338            Self {
339                id,
340                conn: conn_tx.bind(),
341                return_conn: return_tx.bind(),
342            },
343            ConnectionCompleter {
344                caps,
345                conn: conn_rx,
346                port: return_rx,
347            },
348        )
349    }
350}
351
352/// A response message sent from the server back to the client to complete setting
353/// up the connection.
354#[derive(Debug, Serialize, Deserialize, Named, Clone)]
355struct Accept {
356    /// The ID of the server that accepted the connection.
357    id: ActorId,
358    /// The port the client will use to send data over the connection to the server.
359    conn: PortRef<Io>,
360}
361
362impl Bind for Connect {
363    fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
364        self.conn.bind(bindings)?;
365        self.return_conn.bind(bindings)
366    }
367}
368
369impl Unbind for Connect {
370    fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
371        self.conn.unbind(bindings)?;
372        self.return_conn.unbind(bindings)
373    }
374}
375
376/// Helper used by `Handler<Connect>`s to accept a connection initiated by a `Connect` message and
377/// return `AsyncRead` and `AsyncWrite` streams that can be used to communicate with the other side.
378pub async fn accept<C: CanOpenPort + CanSend>(
379    caps: C,
380    self_id: ActorId,
381    message: Connect,
382) -> Result<ActorConnection<C>> {
383    let (tx, rx) = open_port::<Io>(&caps);
384    message.return_conn.send(
385        &caps,
386        Accept {
387            id: self_id,
388            conn: tx.bind(),
389        },
390    )?;
391    Ok(ActorConnection {
392        reader: OwnedReadHalf::new(message.id.clone(), rx),
393        writer: OwnedWriteHalf::new(message.id, caps, message.conn),
394    })
395}
396
397/// Helper used by clients to initiate a connection by sending a `Connect` message to the given port
398/// and awaiting an `Accept` response. Returns `AsyncRead` and `AsyncWrite` streams that can be used
399/// to communicate with the remote actor.
400pub async fn connect(
401    mailbox: &Mailbox,
402    port: PortRef<Connect>,
403) -> Result<ActorConnection<Mailbox>> {
404    let (connect, completer) = Connect::allocate(mailbox.actor_id().clone(), mailbox.clone());
405    port.send(mailbox, connect)?;
406    completer.complete().await
407}
408
409#[cfg(test)]
410mod tests {
411    use anyhow::Result;
412    use async_trait::async_trait;
413    use futures::try_join;
414    use hyperactor::Actor;
415    use hyperactor::Context;
416    use hyperactor::Handler;
417    use hyperactor::proc::Proc;
418    use tokio::io::AsyncReadExt;
419    use tokio::io::AsyncWriteExt;
420
421    use super::*;
422
423    #[derive(Debug, Default, Actor)]
424    struct EchoActor {}
425
426    #[async_trait]
427    impl Handler<Connect> for EchoActor {
428        async fn handle(
429            &mut self,
430            cx: &Context<Self>,
431            message: Connect,
432        ) -> Result<(), anyhow::Error> {
433            let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
434                .await?
435                .into_split();
436            tokio::io::copy(&mut rd, &mut wr).await?;
437            wr.shutdown().await?;
438            Ok(())
439        }
440    }
441
442    #[tokio::test]
443    async fn test_simple_connection() -> Result<()> {
444        let proc = Proc::local();
445        let client = proc.attach("client")?;
446        let actor = proc.spawn::<EchoActor>("actor", ()).await?;
447        let (mut rd, mut wr) = connect(&client, actor.port().bind()).await?.into_split();
448        let send = [3u8, 4u8, 5u8, 6u8];
449        try_join!(
450            async move {
451                wr.write_all(&send).await?;
452                wr.shutdown().await?;
453                anyhow::Ok(())
454            },
455            async {
456                let mut recv = vec![];
457                rd.read_to_end(&mut recv).await?;
458                assert_eq!(&send, recv.as_slice());
459                anyhow::Ok(())
460            },
461        )?;
462        Ok(())
463    }
464
465    #[tokio::test]
466    async fn test_connection_close_on_drop() -> Result<()> {
467        let proc = Proc::local();
468        let client = proc.attach("client")?;
469
470        let (connect, completer) = Connect::allocate(client.actor_id().clone(), client.clone());
471        let (mut rd, _) = accept(&client, client.actor_id().clone(), connect)
472            .await?
473            .into_split();
474        let (_, mut wr) = completer.complete().await?.into_split();
475
476        // Write some data
477        let send = [1u8, 2u8, 3u8];
478        wr.write_all(&send).await?;
479
480        // Drop the writer without explicit shutdown - this should send EOF
481        drop(wr);
482
483        // Reader should receive the data and then EOF (causing read_to_end to complete)
484        let mut recv = vec![];
485        rd.read_to_end(&mut recv).await?;
486        assert_eq!(&send, recv.as_slice());
487
488        Ok(())
489    }
490
491    #[tokio::test]
492    async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
493        let proc = Proc::local();
494        let client = proc.attach("client")?;
495
496        let (connect, completer) = Connect::allocate(client.actor_id().clone(), client.clone());
497        let (mut rd, _) = accept(&client, client.actor_id().clone(), connect)
498            .await?
499            .into_split();
500        let (_, mut wr) = completer.complete().await?.into_split();
501
502        // Write some data
503        let send = [1u8, 2u8, 3u8];
504        wr.write_all(&send).await?;
505
506        // Explicitly shutdown the writer - this sends EOF and sets shutdown=true
507        wr.shutdown().await?;
508
509        // Reader should receive the data and then EOF (from explicit shutdown, not from drop)
510        let mut recv = vec![];
511        rd.read_to_end(&mut recv).await?;
512        assert_eq!(&send, recv.as_slice());
513
514        // Drop the writer after explicit shutdown - this should NOT send another EOF
515        drop(wr);
516
517        // Verify we didn't see another EOF message.
518        assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
519
520        Ok(())
521    }
522}