Skip to main content

hyperactor_mesh/
connect.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Actor-based duplex bytestream connections.
10//!
11//! This module provides the equivalent of a `TcpStream` duplex bytestream connection between two actors,
12//! implemented via actor message passing. It allows actors to communicate using familiar `AsyncRead` and
13//! `AsyncWrite` interfaces while leveraging the hyperactor framework's message passing capabilities.
14//!
15//! # Overview
16//!
17//! The connection system consists of:
18//! - [`ActorConnection`]: A duplex connection that implements both `AsyncRead` and `AsyncWrite`
19//! - [`OwnedReadHalf`] and [`OwnedWriteHalf`]: Split halves for independent reading and writing
20//! - [`Connect`] message for establishing connections
21//! - Helper functions [`connect`] and [`accept`] for client and server usage
22//!
23//! # Usage Patterns
24//!
25//! ## Client Side (Initiating Connection)
26//!
27//! Clients use `Connect::allocate()` to create a connection request. This method returns:
28//! 1. A `Connect` message to send to the server to initiate the connection
29//! 2. A `ConnectionCompleter` object that can be awaited for the server to finish connecting,
30//!    returning the `ActorConnection` used by the client.
31//!
32//! The typical pattern is: allocate components, send Connect message to server, await completion.
33//!
34//! ## Server Side (Accepting Connections)
35//!
36//! Servers forward `Connect` messages to the `accept()` helper function to finish setting up the
37//! connection, which returns the `ActorConnection` they can use.
38
39use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorAddr;
51use hyperactor::Endpoint as _;
52use hyperactor::OncePortRef;
53use hyperactor::PortRef;
54use hyperactor::context;
55use hyperactor::mailbox::OncePortReceiver;
56use hyperactor::mailbox::PortReceiver;
57use hyperactor::mailbox::open_once_port;
58use hyperactor::mailbox::open_port;
59use hyperactor::message::Bind;
60use hyperactor::message::Bindings;
61use hyperactor::message::Unbind;
62use pin_project::pin_project;
63use pin_project::pinned_drop;
64use serde::Deserialize;
65use serde::Serialize;
66use tokio::io::AsyncRead;
67use tokio::io::AsyncWrite;
68use tokio_util::io::StreamReader;
69use typeuri::Named;
70
71// Timeout for establishing a connection, used by both client and server.
72const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
73
74/// Messages sent over the "connection" to facilitate communication.
75#[derive(Debug, Serialize, Deserialize, Named, Clone)]
76enum Io {
77    // A data packet.
78    Data(#[serde(with = "serde_bytes")] Vec<u8>),
79    // Signal the end of one side of the connection.
80    Eof,
81}
82wirevalue::register_type!(Io);
83
84struct OwnedReadHalfStream {
85    port: PortReceiver<Io>,
86    exhausted: bool,
87}
88
89/// Wrap a `PortReceiver<IoMsg>` as a `AsyncRead`.
90pub struct OwnedReadHalf {
91    peer: ActorAddr,
92    inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
93}
94
95/// Wrap a `PortRef<IoMsg>` as a `AsyncWrite`.
96#[pin_project(PinnedDrop)]
97pub struct OwnedWriteHalf<C: context::Actor> {
98    peer: ActorAddr,
99    #[pin]
100    caps: C,
101    #[pin]
102    port: PortRef<Io>,
103    #[pin]
104    shutdown: bool,
105}
106
107/// A duplex bytestream connection between two actors.  Can generally be used like a `TcpStream`.
108#[pin_project]
109pub struct ActorConnection<C: context::Actor> {
110    #[pin]
111    reader: OwnedReadHalf,
112    #[pin]
113    writer: OwnedWriteHalf<C>,
114}
115
116impl<C: context::Actor> ActorConnection<C> {
117    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
118        (self.reader, self.writer)
119    }
120
121    pub fn peer(&self) -> &ActorAddr {
122        self.reader.peer()
123    }
124}
125
126impl OwnedReadHalf {
127    fn new(peer: ActorAddr, port: PortReceiver<Io>) -> Self {
128        Self {
129            peer,
130            inner: StreamReader::new(OwnedReadHalfStream {
131                port,
132                exhausted: false,
133            }),
134        }
135    }
136
137    pub fn peer(&self) -> &ActorAddr {
138        &self.peer
139    }
140
141    pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
142        ActorConnection {
143            reader: self,
144            writer: other,
145        }
146    }
147}
148
149impl<C: context::Actor> OwnedWriteHalf<C> {
150    fn new(peer: ActorAddr, caps: C, port: PortRef<Io>) -> Self {
151        Self {
152            peer,
153            caps,
154            port,
155            shutdown: false,
156        }
157    }
158
159    pub fn peer(&self) -> &ActorAddr {
160        &self.peer
161    }
162
163    pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
164        ActorConnection {
165            reader: other,
166            writer: self,
167        }
168    }
169}
170
171#[pinned_drop]
172impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
173    fn drop(self: Pin<&mut Self>) {
174        let this = self.project();
175        if !*this.shutdown {
176            let _ = this.port.post(&*this.caps, Io::Eof);
177        }
178    }
179}
180
181impl<C: context::Actor> AsyncRead for ActorConnection<C> {
182    fn poll_read(
183        self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        buf: &mut tokio::io::ReadBuf<'_>,
186    ) -> Poll<std::io::Result<()>> {
187        // Use project() to get pinned references to fields
188        let this = self.project();
189        this.reader.poll_read(cx, buf)
190    }
191}
192
193impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
194    fn poll_write(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &[u8],
198    ) -> Poll<Result<usize, std::io::Error>> {
199        // Use project() to get pinned references to fields
200        let this = self.project();
201        this.writer.poll_write(cx, buf)
202    }
203
204    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
205        let this = self.project();
206        this.writer.poll_flush(cx)
207    }
208
209    fn poll_shutdown(
210        self: Pin<&mut Self>,
211        cx: &mut Context<'_>,
212    ) -> Poll<Result<(), std::io::Error>> {
213        let this = self.project();
214        this.writer.poll_shutdown(cx)
215    }
216}
217
218impl Stream for OwnedReadHalfStream {
219    type Item = std::io::Result<Cursor<Vec<u8>>>;
220
221    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        // Once exhausted, always return None
223        if self.exhausted {
224            return Poll::Ready(None);
225        }
226
227        let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
228        match result {
229            Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
230            Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
231            // Break out of stream when we see EOF.
232            Ok(Io::Eof) => {
233                self.exhausted = true;
234                Poll::Ready(None)
235            }
236        }
237    }
238}
239
240impl FusedStream for OwnedReadHalfStream {
241    fn is_terminated(&self) -> bool {
242        self.exhausted
243    }
244}
245
246impl AsyncRead for OwnedReadHalf {
247    fn poll_read(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<std::io::Result<()>> {
252        Pin::new(&mut self.inner).poll_read(cx, buf)
253    }
254}
255
256impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
257    fn poll_write(
258        self: Pin<&mut Self>,
259        _cx: &mut Context<'_>,
260        buf: &[u8],
261    ) -> Poll<Result<usize, std::io::Error>> {
262        let this = self.project();
263        if *this.shutdown {
264            return Poll::Ready(Err(std::io::Error::new(
265                std::io::ErrorKind::BrokenPipe,
266                "write after shutdown",
267            )));
268        }
269        this.port.post(&*this.caps, Io::Data(buf.into()));
270        Poll::Ready(Ok(buf.len()))
271    }
272
273    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
274        Poll::Ready(Ok(()))
275    }
276
277    fn poll_shutdown(
278        self: Pin<&mut Self>,
279        _cx: &mut Context<'_>,
280    ) -> Poll<Result<(), std::io::Error>> {
281        // Send EOF on shutdown.
282        self.port.post(&self.caps, Io::Eof);
283        let mut this = self.project();
284        *this.shutdown = true;
285        Poll::Ready(Ok(()))
286    }
287}
288
289/// A helper struct that contains the state needed to complete a connection.
290pub struct ConnectionCompleter<C> {
291    caps: C,
292    conn: PortReceiver<Io>,
293    port: OncePortReceiver<Accept>,
294}
295
296impl<C: context::Actor> ConnectionCompleter<C> {
297    /// Wait for the server to accept the connection and return the streams that can be used to communicate
298    /// with the server.
299    pub async fn complete(self) -> Result<ActorConnection<C>> {
300        let accept = tokio::time::timeout(CONNECT_TIMEOUT, self.port.recv()).await??;
301        Ok(ActorConnection {
302            reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
303            writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
304        })
305    }
306}
307
308/// A message sent from a client to initiate a connection.
309#[derive(Debug, Serialize, Deserialize, Named, Clone)]
310pub struct Connect {
311    /// The ID of the client initiating the connection.
312    id: ActorAddr,
313    conn: PortRef<Io>,
314    /// The port the server can use to complete the connection.
315    return_conn: OncePortRef<Accept>,
316}
317wirevalue::register_type!(Connect);
318
319impl Connect {
320    /// Allocate a new `Connect` message and return the associated `ConnectionCompleter` that can be used
321    /// to finish setting up the connection.
322    pub fn allocate<C: context::Actor>(id: ActorAddr, caps: C) -> (Self, ConnectionCompleter<C>) {
323        let (conn_tx, conn_rx) = open_port::<Io>(&caps);
324        let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
325        (
326            Self {
327                id,
328                conn: conn_tx.bind(),
329                return_conn: return_tx.bind(),
330            },
331            ConnectionCompleter {
332                caps,
333                conn: conn_rx,
334                port: return_rx,
335            },
336        )
337    }
338}
339
340/// A response message sent from the server back to the client to complete setting
341/// up the connection.
342#[derive(Debug, Serialize, Deserialize, Named, Clone)]
343struct Accept {
344    /// The ID of the server that accepted the connection.
345    id: ActorAddr,
346    /// The port the client will use to send data over the connection to the server.
347    conn: PortRef<Io>,
348}
349wirevalue::register_type!(Accept);
350
351impl Bind for Connect {
352    fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
353        self.conn.bind(bindings)?;
354        self.return_conn.bind(bindings)
355    }
356}
357
358impl Unbind for Connect {
359    fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
360        self.conn.unbind(bindings)?;
361        self.return_conn.unbind(bindings)
362    }
363}
364
365/// Helper used by `Handler<Connect>`s to accept a connection initiated by a `Connect` message and
366/// return `AsyncRead` and `AsyncWrite` streams that can be used to communicate with the other side.
367pub async fn accept<C: context::Actor>(
368    caps: C,
369    self_id: ActorAddr,
370    message: Connect,
371) -> Result<ActorConnection<C>> {
372    let (tx, rx) = open_port::<Io>(&caps);
373    message.return_conn.post(
374        &caps,
375        Accept {
376            id: self_id,
377            conn: tx.bind(),
378        },
379    );
380    Ok(ActorConnection {
381        reader: OwnedReadHalf::new(message.id.clone(), rx),
382        writer: OwnedWriteHalf::new(message.id, caps, message.conn),
383    })
384}
385
386#[cfg(test)]
387mod tests {
388    use anyhow::Result;
389    use async_trait::async_trait;
390    use futures::try_join;
391    use hyperactor::Actor;
392    use hyperactor::Context;
393    use hyperactor::Handler;
394    use hyperactor::proc::Proc;
395    use tokio::io::AsyncReadExt;
396    use tokio::io::AsyncWriteExt;
397
398    use super::*;
399
400    #[derive(Debug, Default)]
401    struct EchoActor {}
402
403    impl Actor for EchoActor {}
404
405    #[async_trait]
406    impl Handler<Connect> for EchoActor {
407        async fn handle(
408            &mut self,
409            cx: &Context<Self>,
410            message: Connect,
411        ) -> Result<(), anyhow::Error> {
412            let (mut rd, mut wr) = accept(cx, cx.self_addr().clone(), message)
413                .await?
414                .into_split();
415            tokio::io::copy(&mut rd, &mut wr).await?;
416            wr.shutdown().await?;
417            Ok(())
418        }
419    }
420
421    #[tokio::test]
422    async fn test_simple_connection() -> Result<()> {
423        let proc = Proc::isolated();
424        let (client, _) = proc.client("client")?;
425        let (connect, completer) = Connect::allocate(client.self_addr().clone(), client);
426        let actor = proc.spawn("actor", EchoActor {})?;
427        actor.post(&completer.caps, connect);
428        let (mut rd, mut wr) = completer.complete().await?.into_split();
429        let send = [3u8, 4u8, 5u8, 6u8];
430        try_join!(
431            async move {
432                wr.write_all(&send).await?;
433                wr.shutdown().await?;
434                anyhow::Ok(())
435            },
436            async {
437                let mut recv = vec![];
438                rd.read_to_end(&mut recv).await?;
439                assert_eq!(&send, recv.as_slice());
440                anyhow::Ok(())
441            },
442        )?;
443        Ok(())
444    }
445
446    #[tokio::test]
447    async fn test_connection_close_on_drop() -> Result<()> {
448        let proc = Proc::isolated();
449        let (client, _client_handle) = proc.client("client")?;
450
451        let (connect, completer) =
452            Connect::allocate(client.self_addr().clone(), client.clone_for_py());
453        let (mut rd, _) = accept(client.clone_for_py(), client.self_addr().clone(), connect)
454            .await?
455            .into_split();
456        let (_, mut wr) = completer.complete().await?.into_split();
457
458        // Write some data
459        let send = [1u8, 2u8, 3u8];
460        wr.write_all(&send).await?;
461
462        // Drop the writer without explicit shutdown - this should send EOF
463        drop(wr);
464
465        // Reader should receive the data and then EOF (causing read_to_end to complete)
466        let mut recv = vec![];
467        rd.read_to_end(&mut recv).await?;
468        assert_eq!(&send, recv.as_slice());
469
470        Ok(())
471    }
472
473    #[tokio::test]
474    async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
475        let proc = Proc::isolated();
476        let (client, _client_handle) = proc.client("client")?;
477
478        let (connect, completer) =
479            Connect::allocate(client.self_addr().clone(), client.clone_for_py());
480        let (mut rd, _) = accept(client.clone_for_py(), client.self_addr().clone(), connect)
481            .await?
482            .into_split();
483        let (_, mut wr) = completer.complete().await?.into_split();
484
485        // Write some data
486        let send = [1u8, 2u8, 3u8];
487        wr.write_all(&send).await?;
488
489        // Explicitly shutdown the writer - this sends EOF and sets shutdown=true
490        wr.shutdown().await?;
491
492        // Reader should receive the data and then EOF (from explicit shutdown, not from drop)
493        let mut recv = vec![];
494        rd.read_to_end(&mut recv).await?;
495        assert_eq!(&send, recv.as_slice());
496
497        // Drop the writer after explicit shutdown - this should NOT send another EOF
498        drop(wr);
499
500        // Verify we didn't see another EOF message.
501        assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
502
503        Ok(())
504    }
505}