monarch_hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::str::FromStr;
10
11use hyperactor::channel::BindSpec;
12use hyperactor::channel::ChannelAddr;
13use hyperactor::channel::ChannelTransport;
14use hyperactor::channel::TcpMode;
15use hyperactor::channel::TlsAddr;
16use hyperactor::channel::TlsMode;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyRuntimeError;
19use pyo3::exceptions::PyTypeError;
20use pyo3::exceptions::PyValueError;
21use pyo3::prelude::*;
22
23/// Python binding for [`hyperactor::channel::ChannelTransport`]
24///
25/// This enum represents the basic transport types that can be represented
26/// as simple enum variants. For explicit addresses, use `PyBindSpec`.
27#[pyclass(
28    name = "ChannelTransport",
29    module = "monarch._rust_bindings.monarch_hyperactor.channel",
30    eq
31)]
32#[derive(PartialEq, Clone, Copy, Debug)]
33pub enum PyChannelTransport {
34    TcpWithLocalhost,
35    TcpWithHostname,
36    MetaTlsWithHostname,
37    MetaTlsWithIpV6,
38    Tls,
39    Local,
40    Unix,
41    // Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support
42}
43
44#[pymethods]
45impl PyChannelTransport {
46    fn get(&self) -> Self {
47        self.clone()
48    }
49
50    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
51        let getattr_fn = py.import("builtins")?.getattr("getattr")?;
52        let variant_name = match self {
53            PyChannelTransport::TcpWithLocalhost => "TcpWithLocalhost",
54            PyChannelTransport::TcpWithHostname => "TcpWithHostname",
55            PyChannelTransport::MetaTlsWithHostname => "MetaTlsWithHostname",
56            PyChannelTransport::MetaTlsWithIpV6 => "MetaTlsWithIpV6",
57            PyChannelTransport::Tls => "Tls",
58            PyChannelTransport::Local => "Local",
59            PyChannelTransport::Unix => "Unix",
60        };
61        let cls = py
62            .import("monarch._rust_bindings.monarch_hyperactor.channel")?
63            .getattr("ChannelTransport")?;
64        let args = (cls, variant_name).into_bound_py_any(py)?;
65        Ok((getattr_fn, args))
66    }
67}
68
69impl TryFrom<ChannelTransport> for PyChannelTransport {
70    type Error = PyErr;
71
72    fn try_from(transport: ChannelTransport) -> PyResult<Self> {
73        match transport {
74            ChannelTransport::Tcp(TcpMode::Localhost) => Ok(PyChannelTransport::TcpWithLocalhost),
75            ChannelTransport::Tcp(TcpMode::Hostname) => Ok(PyChannelTransport::TcpWithHostname),
76            ChannelTransport::MetaTls(TlsMode::Hostname) => {
77                Ok(PyChannelTransport::MetaTlsWithHostname)
78            }
79            ChannelTransport::MetaTls(TlsMode::IpV6) => Ok(PyChannelTransport::MetaTlsWithIpV6),
80            ChannelTransport::Tls => Ok(PyChannelTransport::Tls),
81            ChannelTransport::Local => Ok(PyChannelTransport::Local),
82            ChannelTransport::Unix => Ok(PyChannelTransport::Unix),
83        }
84    }
85}
86
87/// Python binding for [`hyperactor::channel::BindSpec`]
88#[pyclass(
89    name = "BindSpec",
90    module = "monarch._rust_bindings.monarch_hyperactor.channel"
91)]
92#[derive(Clone, Debug, PartialEq)]
93pub struct PyBindSpec {
94    inner: BindSpec,
95}
96
97#[pymethods]
98impl PyBindSpec {
99    /// Create a new PyBindSpec from a ChannelTransport enum, a string representation,
100    /// or another PyBindSpec object.
101    ///
102    /// Examples:
103    ///     PyBindSpec(ChannelTransport.Unix)
104    ///     PyBindSpec("tcp://127.0.0.1:8080")
105    ///     PyBindSpec(PyBindSpec(ChannelTransport.Unix))
106    #[new]
107    pub fn new(spec: &Bound<'_, PyAny>) -> PyResult<Self> {
108        // First try to extract as PyBindSpec (for when passing an existing spec)
109        if let Ok(bind_spec) = spec.extract::<PyBindSpec>() {
110            return Ok(bind_spec);
111        }
112
113        // Then try to extract as PyChannelTransport enum
114        if let Ok(py_transport) = spec.extract::<PyChannelTransport>() {
115            let transport: ChannelTransport = py_transport.into();
116            return Ok(PyBindSpec {
117                inner: BindSpec::Any(transport),
118            });
119        }
120
121        // Then try to extract as a string and parse it as a BindSpec
122        if let Ok(spec_str) = spec.extract::<String>() {
123            let bind_spec = BindSpec::from_str(&spec_str).map_err(|e| {
124                PyValueError::new_err(format!("invalid str for BindSpec '{}': {}", spec_str, e))
125            })?;
126            return Ok(PyBindSpec { inner: bind_spec });
127        }
128
129        Err(PyTypeError::new_err(
130            "expected ChannelTransport enum, BindSpec, or str",
131        ))
132    }
133
134    fn __str__(&self) -> String {
135        self.inner.to_string()
136    }
137
138    fn __repr__(&self) -> String {
139        format!("PyBindSpec({:?})", self.inner)
140    }
141
142    fn __eq__(&self, other: &Self) -> bool {
143        self.inner == other.inner
144    }
145}
146
147impl From<PyBindSpec> for BindSpec {
148    fn from(spec: PyBindSpec) -> Self {
149        spec.inner
150    }
151}
152
153impl From<BindSpec> for PyBindSpec {
154    fn from(spec: BindSpec) -> Self {
155        PyBindSpec { inner: spec }
156    }
157}
158
159#[pyclass(
160    name = "ChannelAddr",
161    module = "monarch._rust_bindings.monarch_hyperactor.channel"
162)]
163pub struct PyChannelAddr {
164    inner: ChannelAddr,
165}
166
167impl FromStr for PyChannelAddr {
168    type Err = anyhow::Error;
169    fn from_str(addr: &str) -> anyhow::Result<Self> {
170        let inner = ChannelAddr::from_str(addr)?;
171        Ok(Self { inner })
172    }
173}
174
175#[pymethods]
176impl PyChannelAddr {
177    /// Returns an "any" address for the given transport type.
178    /// Primarily used to bind servers. Returned string form of the address
179    /// is of the format `{transport}!{address}`. For example:
180    /// `tcp![::]:0`, `unix!@a0b1c2d3`, `metatls!devgpu001.pci.facebook.com:0`
181    #[staticmethod]
182    pub fn any(transport: PyChannelTransport) -> PyResult<String> {
183        Ok(ChannelAddr::any(transport.into()).to_string())
184    }
185
186    #[staticmethod]
187    pub fn parse(addr: &str) -> PyResult<Self> {
188        Ok(PyChannelAddr::from_str(addr)?)
189    }
190
191    /// Returns the port number (if any) of this channel address,
192    /// `0` for transports for which unix ports do not apply (e.g. `unix`, `local`)
193    pub fn get_port(&self) -> PyResult<u16> {
194        match &self.inner {
195            ChannelAddr::Tcp(socket_addr) => Ok(socket_addr.port()),
196            ChannelAddr::MetaTls(TlsAddr { port, .. }) | ChannelAddr::Tls(TlsAddr { port, .. }) => {
197                Ok(*port)
198            }
199            ChannelAddr::Unix(_) | ChannelAddr::Local(_) => Ok(0),
200            _ => Err(PyRuntimeError::new_err(format!(
201                "unsupported transport: `{:?}` for channel address: `{}`",
202                self.inner.transport(),
203                self.inner
204            ))),
205        }
206    }
207
208    /// Returns the channel transport of this channel address.
209    pub fn get_transport(&self) -> PyResult<PyChannelTransport> {
210        let transport = self.inner.transport();
211        match transport {
212            ChannelTransport::Tcp(mode) => match mode {
213                TcpMode::Localhost => Ok(PyChannelTransport::TcpWithLocalhost),
214                TcpMode::Hostname => Ok(PyChannelTransport::TcpWithHostname),
215            },
216            ChannelTransport::MetaTls(mode) => match mode {
217                TlsMode::Hostname => Ok(PyChannelTransport::MetaTlsWithHostname),
218                TlsMode::IpV6 => Ok(PyChannelTransport::MetaTlsWithIpV6),
219            },
220            ChannelTransport::Tls => Ok(PyChannelTransport::Tls),
221            ChannelTransport::Local => Ok(PyChannelTransport::Local),
222            ChannelTransport::Unix => Ok(PyChannelTransport::Unix),
223        }
224    }
225}
226
227impl From<PyChannelTransport> for ChannelTransport {
228    fn from(val: PyChannelTransport) -> Self {
229        match val {
230            PyChannelTransport::TcpWithLocalhost => ChannelTransport::Tcp(TcpMode::Localhost),
231            PyChannelTransport::TcpWithHostname => ChannelTransport::Tcp(TcpMode::Hostname),
232            PyChannelTransport::MetaTlsWithHostname => ChannelTransport::MetaTls(TlsMode::Hostname),
233            PyChannelTransport::MetaTlsWithIpV6 => ChannelTransport::MetaTls(TlsMode::IpV6),
234            PyChannelTransport::Tls => ChannelTransport::Tls,
235            PyChannelTransport::Local => ChannelTransport::Local,
236            PyChannelTransport::Unix => ChannelTransport::Unix,
237        }
238    }
239}
240
241pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
242    hyperactor_mod.add_class::<PyChannelTransport>()?;
243    hyperactor_mod.add_class::<PyBindSpec>()?;
244    hyperactor_mod.add_class::<PyChannelAddr>()?;
245    Ok(())
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    // TODO: OSS: failed to retrieve ipv6 address
254    #[cfg_attr(not(fbcode_build), ignore)]
255    fn test_channel_any_and_parse() -> PyResult<()> {
256        // just make sure any() and parse() calls work for all transports
257        for transport in [
258            PyChannelTransport::TcpWithLocalhost,
259            PyChannelTransport::TcpWithHostname,
260            PyChannelTransport::Unix,
261            PyChannelTransport::MetaTlsWithHostname,
262            PyChannelTransport::MetaTlsWithIpV6,
263            PyChannelTransport::Tls,
264            PyChannelTransport::Local,
265        ] {
266            let address = PyChannelAddr::any(transport)?;
267            let _ = PyChannelAddr::parse(&address)?;
268        }
269        Ok(())
270    }
271
272    #[test]
273    fn test_channel_addr_get_port() -> PyResult<()> {
274        assert_eq!(PyChannelAddr::parse("tcp![::]:26600")?.get_port()?, 26600);
275        assert_eq!(
276            PyChannelAddr::parse("metatls!devgpu1.pci.facebook.com:26600")?.get_port()?,
277            26600
278        );
279        assert_eq!(PyChannelAddr::parse("local!12345")?.get_port()?, 0);
280        assert_eq!(PyChannelAddr::parse("unix!@1a2b3c")?.get_port()?, 0);
281        Ok(())
282    }
283
284    #[test]
285    fn test_channel_addr_get_transport() -> PyResult<()> {
286        assert_eq!(
287            PyChannelAddr::parse("tcp![::1]:26600")?.get_transport()?,
288            PyChannelTransport::TcpWithLocalhost,
289        );
290        assert_eq!(
291            PyChannelAddr::parse("tcp![::]:26600")?.get_transport()?,
292            PyChannelTransport::TcpWithHostname,
293        );
294        assert_eq!(
295            PyChannelAddr::parse("metatls!devgpu001.pci.facebook.com:26600")?.get_transport()?,
296            PyChannelTransport::MetaTlsWithHostname
297        );
298        assert_eq!(
299            PyChannelAddr::parse("metatls!::1:26600")?.get_transport()?,
300            PyChannelTransport::MetaTlsWithIpV6
301        );
302        assert_eq!(
303            // IpV4 will fallback to hostname
304            PyChannelAddr::parse("metatls!127.0.0.1:26600")?.get_transport()?,
305            PyChannelTransport::MetaTlsWithHostname
306        );
307        assert_eq!(
308            PyChannelAddr::parse("local!12345")?.get_transport()?,
309            PyChannelTransport::Local
310        );
311        assert_eq!(
312            PyChannelAddr::parse("unix!@1a2b3c")?.get_transport()?,
313            PyChannelTransport::Unix
314        );
315        Ok(())
316    }
317}