hyperactor_mesh/
connect.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Actor-based duplex bytestream connections.
10//!
11//! This module provides the equivalent of a `TcpStream` duplex bytestream connection between two actors,
12//! implemented via actor message passing. It allows actors to communicate using familiar `AsyncRead` and
13//! `AsyncWrite` interfaces while leveraging the hyperactor framework's message passing capabilities.
14//!
15//! # Overview
16//!
17//! The connection system consists of:
18//! - [`ActorConnection`]: A duplex connection that implements both `AsyncRead` and `AsyncWrite`
19//! - [`OwnedReadHalf`] and [`OwnedWriteHalf`]: Split halves for independent reading and writing
20//! - [`Connect`] message for establishing connections
21//! - Helper functions [`connect`] and [`accept`] for client and server usage
22//!
23//! # Usage Patterns
24//!
25//! ## Client Side (Initiating Connection)
26//!
27//! Clients use `Connect::allocate()` to create a connection request. This method returns:
28//! 1. A `Connect` message to send to the server to initiate the connection
29//! 2. A `ConnectionCompleter` object that can be awaited for the server to finish connecting,
30//!    returning the `ActorConnection` used by the client.
31//!
32//! The typical pattern is: allocate components, send Connect message to server, await completion.
33//!
34//! ## Server Side (Accepting Connections)
35//!
36//! Servers forward `Connect` messages to the `accept()` helper function to finish setting up the
37//! connection, which returns the `ActorConnection` they can use.
38
39use std::io::Cursor;
40use std::pin::Pin;
41use std::time::Duration;
42
43use anyhow::Result;
44use future::Future;
45use futures::Stream;
46use futures::future;
47use futures::stream::FusedStream;
48use futures::task::Context;
49use futures::task::Poll;
50use hyperactor::ActorId;
51use hyperactor::OncePortRef;
52use hyperactor::PortRef;
53use hyperactor::clock::Clock;
54use hyperactor::clock::RealClock;
55use hyperactor::context;
56use hyperactor::mailbox::OncePortReceiver;
57use hyperactor::mailbox::PortReceiver;
58use hyperactor::mailbox::open_once_port;
59use hyperactor::mailbox::open_port;
60use hyperactor::message::Bind;
61use hyperactor::message::Bindings;
62use hyperactor::message::Unbind;
63use pin_project::pin_project;
64use pin_project::pinned_drop;
65use serde::Deserialize;
66use serde::Serialize;
67use tokio::io::AsyncRead;
68use tokio::io::AsyncWrite;
69use tokio_util::io::StreamReader;
70use typeuri::Named;
71
72// Timeout for establishing a connection, used by both client and server.
73const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
74
75/// Messages sent over the "connection" to facilitate communication.
76#[derive(Debug, Serialize, Deserialize, Named, Clone)]
77enum Io {
78    // A data packet.
79    Data(#[serde(with = "serde_bytes")] Vec<u8>),
80    // Signal the end of one side of the connection.
81    Eof,
82}
83wirevalue::register_type!(Io);
84
85struct OwnedReadHalfStream {
86    port: PortReceiver<Io>,
87    exhausted: bool,
88}
89
90/// Wrap a `PortReceiver<IoMsg>` as a `AsyncRead`.
91pub struct OwnedReadHalf {
92    peer: ActorId,
93    inner: StreamReader<OwnedReadHalfStream, Cursor<Vec<u8>>>,
94}
95
96/// Wrap a `PortRef<IoMsg>` as a `AsyncWrite`.
97#[pin_project(PinnedDrop)]
98pub struct OwnedWriteHalf<C: context::Actor> {
99    peer: ActorId,
100    #[pin]
101    caps: C,
102    #[pin]
103    port: PortRef<Io>,
104    #[pin]
105    shutdown: bool,
106}
107
108/// A duplex bytestream connection between two actors.  Can generally be used like a `TcpStream`.
109#[pin_project]
110pub struct ActorConnection<C: context::Actor> {
111    #[pin]
112    reader: OwnedReadHalf,
113    #[pin]
114    writer: OwnedWriteHalf<C>,
115}
116
117impl<C: context::Actor> ActorConnection<C> {
118    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf<C>) {
119        (self.reader, self.writer)
120    }
121
122    pub fn peer(&self) -> &ActorId {
123        self.reader.peer()
124    }
125}
126
127impl OwnedReadHalf {
128    fn new(peer: ActorId, port: PortReceiver<Io>) -> Self {
129        Self {
130            peer,
131            inner: StreamReader::new(OwnedReadHalfStream {
132                port,
133                exhausted: false,
134            }),
135        }
136    }
137
138    pub fn peer(&self) -> &ActorId {
139        &self.peer
140    }
141
142    pub fn reunited<C: context::Actor>(self, other: OwnedWriteHalf<C>) -> ActorConnection<C> {
143        ActorConnection {
144            reader: self,
145            writer: other,
146        }
147    }
148}
149
150impl<C: context::Actor> OwnedWriteHalf<C> {
151    fn new(peer: ActorId, caps: C, port: PortRef<Io>) -> Self {
152        Self {
153            peer,
154            caps,
155            port,
156            shutdown: false,
157        }
158    }
159
160    pub fn peer(&self) -> &ActorId {
161        &self.peer
162    }
163
164    pub fn reunited(self, other: OwnedReadHalf) -> ActorConnection<C> {
165        ActorConnection {
166            reader: other,
167            writer: self,
168        }
169    }
170}
171
172#[pinned_drop]
173impl<C: context::Actor> PinnedDrop for OwnedWriteHalf<C> {
174    fn drop(self: Pin<&mut Self>) {
175        let this = self.project();
176        if !*this.shutdown {
177            let _ = this.port.send(&*this.caps, Io::Eof);
178        }
179    }
180}
181
182impl<C: context::Actor> AsyncRead for ActorConnection<C> {
183    fn poll_read(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &mut tokio::io::ReadBuf<'_>,
187    ) -> Poll<std::io::Result<()>> {
188        // Use project() to get pinned references to fields
189        let this = self.project();
190        this.reader.poll_read(cx, buf)
191    }
192}
193
194impl<C: context::Actor> AsyncWrite for ActorConnection<C> {
195    fn poll_write(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198        buf: &[u8],
199    ) -> Poll<Result<usize, std::io::Error>> {
200        // Use project() to get pinned references to fields
201        let this = self.project();
202        this.writer.poll_write(cx, buf)
203    }
204
205    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
206        let this = self.project();
207        this.writer.poll_flush(cx)
208    }
209
210    fn poll_shutdown(
211        self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213    ) -> Poll<Result<(), std::io::Error>> {
214        let this = self.project();
215        this.writer.poll_shutdown(cx)
216    }
217}
218
219impl Stream for OwnedReadHalfStream {
220    type Item = std::io::Result<Cursor<Vec<u8>>>;
221
222    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223        // Once exhausted, always return None
224        if self.exhausted {
225            return Poll::Ready(None);
226        }
227
228        let result = futures::ready!(Box::pin(self.port.recv()).as_mut().poll(cx));
229        match result {
230            Err(err) => Poll::Ready(Some(Err(std::io::Error::other(err)))),
231            Ok(Io::Data(buf)) => Poll::Ready(Some(Ok(Cursor::new(buf)))),
232            // Break out of stream when we see EOF.
233            Ok(Io::Eof) => {
234                self.exhausted = true;
235                Poll::Ready(None)
236            }
237        }
238    }
239}
240
241impl FusedStream for OwnedReadHalfStream {
242    fn is_terminated(&self) -> bool {
243        self.exhausted
244    }
245}
246
247impl AsyncRead for OwnedReadHalf {
248    fn poll_read(
249        mut self: Pin<&mut Self>,
250        cx: &mut Context<'_>,
251        buf: &mut tokio::io::ReadBuf<'_>,
252    ) -> Poll<std::io::Result<()>> {
253        Pin::new(&mut self.inner).poll_read(cx, buf)
254    }
255}
256
257impl<C: context::Actor> AsyncWrite for OwnedWriteHalf<C> {
258    fn poll_write(
259        self: Pin<&mut Self>,
260        _cx: &mut Context<'_>,
261        buf: &[u8],
262    ) -> Poll<Result<usize, std::io::Error>> {
263        let this = self.project();
264        if *this.shutdown {
265            return Poll::Ready(Err(std::io::Error::new(
266                std::io::ErrorKind::BrokenPipe,
267                "write after shutdown",
268            )));
269        }
270        match this.port.send(&*this.caps, Io::Data(buf.into())) {
271            Ok(()) => Poll::Ready(Ok(buf.len())),
272            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
273        }
274    }
275
276    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
277        Poll::Ready(Ok(()))
278    }
279
280    fn poll_shutdown(
281        self: Pin<&mut Self>,
282        _cx: &mut Context<'_>,
283    ) -> Poll<Result<(), std::io::Error>> {
284        // Send EOF on shutdown.
285        match self.port.send(&self.caps, Io::Eof) {
286            Ok(()) => {
287                let mut this = self.project();
288                *this.shutdown = true;
289                Poll::Ready(Ok(()))
290            }
291            Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
292        }
293    }
294}
295
296/// A helper struct that contains the state needed to complete a connection.
297pub struct ConnectionCompleter<C> {
298    caps: C,
299    conn: PortReceiver<Io>,
300    port: OncePortReceiver<Accept>,
301}
302
303impl<C: context::Actor> ConnectionCompleter<C> {
304    /// Wait for the server to accept the connection and return the streams that can be used to communicate
305    /// with the server.
306    pub async fn complete(self) -> Result<ActorConnection<C>> {
307        let accept = RealClock
308            .timeout(CONNECT_TIMEOUT, self.port.recv())
309            .await??;
310        Ok(ActorConnection {
311            reader: OwnedReadHalf::new(accept.id.clone(), self.conn),
312            writer: OwnedWriteHalf::new(accept.id, self.caps, accept.conn),
313        })
314    }
315}
316
317/// A message sent from a client to initiate a connection.
318#[derive(Debug, Serialize, Deserialize, Named, Clone)]
319pub struct Connect {
320    /// The ID of the client initiating the connection.
321    id: ActorId,
322    conn: PortRef<Io>,
323    /// The port the server can use to complete the connection.
324    return_conn: OncePortRef<Accept>,
325}
326wirevalue::register_type!(Connect);
327
328impl Connect {
329    /// Allocate a new `Connect` message and return the associated `ConnectionCompleter` that can be used
330    /// to finish setting up the connection.
331    pub fn allocate<C: context::Actor>(id: ActorId, caps: C) -> (Self, ConnectionCompleter<C>) {
332        let (conn_tx, conn_rx) = open_port::<Io>(&caps);
333        let (return_tx, return_rx) = open_once_port::<Accept>(&caps);
334        (
335            Self {
336                id,
337                conn: conn_tx.bind(),
338                return_conn: return_tx.bind(),
339            },
340            ConnectionCompleter {
341                caps,
342                conn: conn_rx,
343                port: return_rx,
344            },
345        )
346    }
347}
348
349/// A response message sent from the server back to the client to complete setting
350/// up the connection.
351#[derive(Debug, Serialize, Deserialize, Named, Clone)]
352struct Accept {
353    /// The ID of the server that accepted the connection.
354    id: ActorId,
355    /// The port the client will use to send data over the connection to the server.
356    conn: PortRef<Io>,
357}
358wirevalue::register_type!(Accept);
359
360impl Bind for Connect {
361    fn bind(&mut self, bindings: &mut Bindings) -> Result<()> {
362        self.conn.bind(bindings)?;
363        self.return_conn.bind(bindings)
364    }
365}
366
367impl Unbind for Connect {
368    fn unbind(&self, bindings: &mut Bindings) -> Result<()> {
369        self.conn.unbind(bindings)?;
370        self.return_conn.unbind(bindings)
371    }
372}
373
374/// Helper used by `Handler<Connect>`s to accept a connection initiated by a `Connect` message and
375/// return `AsyncRead` and `AsyncWrite` streams that can be used to communicate with the other side.
376pub async fn accept<C: context::Actor>(
377    caps: C,
378    self_id: ActorId,
379    message: Connect,
380) -> Result<ActorConnection<C>> {
381    let (tx, rx) = open_port::<Io>(&caps);
382    message.return_conn.send(
383        &caps,
384        Accept {
385            id: self_id,
386            conn: tx.bind(),
387        },
388    )?;
389    Ok(ActorConnection {
390        reader: OwnedReadHalf::new(message.id.clone(), rx),
391        writer: OwnedWriteHalf::new(message.id, caps, message.conn),
392    })
393}
394
395#[cfg(test)]
396mod tests {
397    use anyhow::Result;
398    use async_trait::async_trait;
399    use futures::try_join;
400    use hyperactor::Actor;
401    use hyperactor::Context;
402    use hyperactor::Handler;
403    use hyperactor::proc::Proc;
404    use tokio::io::AsyncReadExt;
405    use tokio::io::AsyncWriteExt;
406
407    use super::*;
408
409    #[derive(Debug, Default)]
410    struct EchoActor {}
411
412    impl Actor for EchoActor {}
413
414    #[async_trait]
415    impl Handler<Connect> for EchoActor {
416        async fn handle(
417            &mut self,
418            cx: &Context<Self>,
419            message: Connect,
420        ) -> Result<(), anyhow::Error> {
421            let (mut rd, mut wr) = accept(cx, cx.self_id().clone(), message)
422                .await?
423                .into_split();
424            tokio::io::copy(&mut rd, &mut wr).await?;
425            wr.shutdown().await?;
426            Ok(())
427        }
428    }
429
430    #[tokio::test]
431    async fn test_simple_connection() -> Result<()> {
432        let proc = Proc::local();
433        let (client, _) = proc.instance("client")?;
434        let (connect, completer) = Connect::allocate(client.self_id().clone(), client);
435        let actor = proc.spawn("actor", EchoActor {})?;
436        actor.send(&completer.caps, connect)?;
437        let (mut rd, mut wr) = completer.complete().await?.into_split();
438        let send = [3u8, 4u8, 5u8, 6u8];
439        try_join!(
440            async move {
441                wr.write_all(&send).await?;
442                wr.shutdown().await?;
443                anyhow::Ok(())
444            },
445            async {
446                let mut recv = vec![];
447                rd.read_to_end(&mut recv).await?;
448                assert_eq!(&send, recv.as_slice());
449                anyhow::Ok(())
450            },
451        )?;
452        Ok(())
453    }
454
455    #[tokio::test]
456    async fn test_connection_close_on_drop() -> Result<()> {
457        let proc = Proc::local();
458        let (client, _client_handle) = proc.instance("client")?;
459
460        let (connect, completer) =
461            Connect::allocate(client.self_id().clone(), client.clone_for_py());
462        let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
463            .await?
464            .into_split();
465        let (_, mut wr) = completer.complete().await?.into_split();
466
467        // Write some data
468        let send = [1u8, 2u8, 3u8];
469        wr.write_all(&send).await?;
470
471        // Drop the writer without explicit shutdown - this should send EOF
472        drop(wr);
473
474        // Reader should receive the data and then EOF (causing read_to_end to complete)
475        let mut recv = vec![];
476        rd.read_to_end(&mut recv).await?;
477        assert_eq!(&send, recv.as_slice());
478
479        Ok(())
480    }
481
482    #[tokio::test]
483    async fn test_no_eof_on_drop_after_shutdown() -> Result<()> {
484        let proc = Proc::local();
485        let (client, _client_handle) = proc.instance("client")?;
486
487        let (connect, completer) =
488            Connect::allocate(client.self_id().clone(), client.clone_for_py());
489        let (mut rd, _) = accept(client.clone_for_py(), client.self_id().clone(), connect)
490            .await?
491            .into_split();
492        let (_, mut wr) = completer.complete().await?.into_split();
493
494        // Write some data
495        let send = [1u8, 2u8, 3u8];
496        wr.write_all(&send).await?;
497
498        // Explicitly shutdown the writer - this sends EOF and sets shutdown=true
499        wr.shutdown().await?;
500
501        // Reader should receive the data and then EOF (from explicit shutdown, not from drop)
502        let mut recv = vec![];
503        rd.read_to_end(&mut recv).await?;
504        assert_eq!(&send, recv.as_slice());
505
506        // Drop the writer after explicit shutdown - this should NOT send another EOF
507        drop(wr);
508
509        // Verify we didn't see another EOF message.
510        assert!(rd.inner.into_inner().port.try_recv().unwrap().is_none());
511
512        Ok(())
513    }
514}