hyperactor_mesh/
router.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module supports (in-process) global routing in hyperactor meshes.
10
11use std::collections::HashMap;
12use std::ops::Deref;
13use std::sync::OnceLock;
14
15use hyperactor::channel;
16use hyperactor::channel::ChannelAddr;
17use hyperactor::channel::ChannelError;
18use hyperactor::channel::ChannelTransport;
19use hyperactor::mailbox::DialMailboxRouter;
20use hyperactor::mailbox::MailboxRouter;
21use hyperactor::mailbox::MailboxServer;
22use tokio::sync::Mutex;
23
24/// The shared, global router for this process.
25pub fn global() -> &'static Router {
26    static GLOBAL_ROUTER: OnceLock<Router> = OnceLock::new();
27    GLOBAL_ROUTER.get_or_init(Router::new)
28}
29
30/// Router augments [`MailboxRouter`] with additional APIs and
31/// bookeeping relevant to meshes.
32pub struct Router {
33    router: MailboxRouter,
34    #[allow(dead_code)] // `servers` isn't read
35    servers: Mutex<HashMap<ChannelTransport, ChannelAddr>>,
36}
37
38/// Deref so that we can use the [`MailboxRouter`] APIs directly.
39impl Deref for Router {
40    type Target = MailboxRouter;
41
42    fn deref(&self) -> &Self::Target {
43        &self.router
44    }
45}
46
47impl Router {
48    /// Create a new router.
49    fn new() -> Self {
50        Self {
51            router: MailboxRouter::new(),
52            servers: Mutex::new(HashMap::new()),
53        }
54    }
55
56    /// Serve this router on the provided transport, returning the address.
57    /// Servers are memoized, and we maintain only one per transport; thus
58    /// subsequent calls using the same transport will return the same address.
59    #[allow(dead_code)]
60    #[hyperactor::instrument]
61    pub async fn serve(&self, transport: &ChannelTransport) -> Result<ChannelAddr, ChannelError> {
62        let mut servers = self.servers.lock().await;
63        if let Some(addr) = servers.get(transport) {
64            return Ok(addr.clone());
65        }
66
67        let (addr, rx) = channel::serve(ChannelAddr::any(transport.clone()))?;
68        self.router.clone().serve(rx);
69        servers.insert(transport.clone(), addr.clone());
70        Ok(addr)
71    }
72
73    /// Binds a [`DialMailboxRouter`] directly into this router. Specifically, each
74    /// prefix served by `router` is bound directly into this [`MailboxRouter`].
75    pub fn bind_dial_router(&self, router: &DialMailboxRouter) {
76        for prefix in router.prefixes() {
77            self.router.bind(prefix, router.clone());
78        }
79    }
80}