1use std::str::FromStr;
10
11use hyperactor::channel::BindSpec;
12use hyperactor::channel::ChannelAddr;
13use hyperactor::channel::ChannelTransport;
14use hyperactor::channel::TcpMode;
15use hyperactor::channel::TlsAddr;
16use hyperactor::channel::TlsMode;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyRuntimeError;
19use pyo3::exceptions::PyTypeError;
20use pyo3::exceptions::PyValueError;
21use pyo3::prelude::*;
22
23#[pyclass(
28 name = "ChannelTransport",
29 module = "monarch._rust_bindings.monarch_hyperactor.channel",
30 eq
31)]
32#[derive(PartialEq, Clone, Copy, Debug)]
33pub enum PyChannelTransport {
34 TcpWithLocalhost,
35 TcpWithHostname,
36 MetaTlsWithHostname,
37 MetaTlsWithIpV6,
38 Tls,
39 Local,
40 Unix,
41 }
43
44#[pymethods]
45impl PyChannelTransport {
46 fn get(&self) -> Self {
47 self.clone()
48 }
49
50 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
51 let getattr_fn = py.import("builtins")?.getattr("getattr")?;
52 let variant_name = match self {
53 PyChannelTransport::TcpWithLocalhost => "TcpWithLocalhost",
54 PyChannelTransport::TcpWithHostname => "TcpWithHostname",
55 PyChannelTransport::MetaTlsWithHostname => "MetaTlsWithHostname",
56 PyChannelTransport::MetaTlsWithIpV6 => "MetaTlsWithIpV6",
57 PyChannelTransport::Tls => "Tls",
58 PyChannelTransport::Local => "Local",
59 PyChannelTransport::Unix => "Unix",
60 };
61 let cls = py
62 .import("monarch._rust_bindings.monarch_hyperactor.channel")?
63 .getattr("ChannelTransport")?;
64 let args = (cls, variant_name).into_bound_py_any(py)?;
65 Ok((getattr_fn, args))
66 }
67}
68
69impl TryFrom<ChannelTransport> for PyChannelTransport {
70 type Error = PyErr;
71
72 fn try_from(transport: ChannelTransport) -> PyResult<Self> {
73 match transport {
74 ChannelTransport::Tcp(TcpMode::Localhost) => Ok(PyChannelTransport::TcpWithLocalhost),
75 ChannelTransport::Tcp(TcpMode::Hostname) => Ok(PyChannelTransport::TcpWithHostname),
76 ChannelTransport::MetaTls(TlsMode::Hostname) => {
77 Ok(PyChannelTransport::MetaTlsWithHostname)
78 }
79 ChannelTransport::MetaTls(TlsMode::IpV6) => Ok(PyChannelTransport::MetaTlsWithIpV6),
80 ChannelTransport::Tls => Ok(PyChannelTransport::Tls),
81 ChannelTransport::Local => Ok(PyChannelTransport::Local),
82 ChannelTransport::Unix => Ok(PyChannelTransport::Unix),
83 }
84 }
85}
86
87#[pyclass(
89 name = "BindSpec",
90 module = "monarch._rust_bindings.monarch_hyperactor.channel"
91)]
92#[derive(Clone, Debug, PartialEq)]
93pub struct PyBindSpec {
94 inner: BindSpec,
95}
96
97#[pymethods]
98impl PyBindSpec {
99 #[new]
107 pub fn new(spec: &Bound<'_, PyAny>) -> PyResult<Self> {
108 if let Ok(bind_spec) = spec.extract::<PyBindSpec>() {
110 return Ok(bind_spec);
111 }
112
113 if let Ok(py_transport) = spec.extract::<PyChannelTransport>() {
115 let transport: ChannelTransport = py_transport.into();
116 return Ok(PyBindSpec {
117 inner: BindSpec::Any(transport),
118 });
119 }
120
121 if let Ok(spec_str) = spec.extract::<String>() {
123 let bind_spec = BindSpec::from_str(&spec_str).map_err(|e| {
124 PyValueError::new_err(format!("invalid str for BindSpec '{}': {}", spec_str, e))
125 })?;
126 return Ok(PyBindSpec { inner: bind_spec });
127 }
128
129 Err(PyTypeError::new_err(
130 "expected ChannelTransport enum, BindSpec, or str",
131 ))
132 }
133
134 fn __str__(&self) -> String {
135 self.inner.to_string()
136 }
137
138 fn __repr__(&self) -> String {
139 format!("PyBindSpec({:?})", self.inner)
140 }
141
142 fn __eq__(&self, other: &Self) -> bool {
143 self.inner == other.inner
144 }
145}
146
147impl From<PyBindSpec> for BindSpec {
148 fn from(spec: PyBindSpec) -> Self {
149 spec.inner
150 }
151}
152
153impl From<BindSpec> for PyBindSpec {
154 fn from(spec: BindSpec) -> Self {
155 PyBindSpec { inner: spec }
156 }
157}
158
159#[pyclass(
160 name = "ChannelAddr",
161 module = "monarch._rust_bindings.monarch_hyperactor.channel"
162)]
163pub struct PyChannelAddr {
164 inner: ChannelAddr,
165}
166
167impl FromStr for PyChannelAddr {
168 type Err = anyhow::Error;
169 fn from_str(addr: &str) -> anyhow::Result<Self> {
170 let inner = ChannelAddr::from_str(addr)?;
171 Ok(Self { inner })
172 }
173}
174
175#[pymethods]
176impl PyChannelAddr {
177 #[staticmethod]
182 pub fn any(transport: PyChannelTransport) -> PyResult<String> {
183 Ok(ChannelAddr::any(transport.into()).to_string())
184 }
185
186 #[staticmethod]
187 pub fn parse(addr: &str) -> PyResult<Self> {
188 Ok(PyChannelAddr::from_str(addr)?)
189 }
190
191 pub fn get_port(&self) -> PyResult<u16> {
194 match &self.inner {
195 ChannelAddr::Tcp(socket_addr) => Ok(socket_addr.port()),
196 ChannelAddr::MetaTls(TlsAddr { port, .. }) | ChannelAddr::Tls(TlsAddr { port, .. }) => {
197 Ok(*port)
198 }
199 ChannelAddr::Unix(_) | ChannelAddr::Local(_) => Ok(0),
200 _ => Err(PyRuntimeError::new_err(format!(
201 "unsupported transport: `{:?}` for channel address: `{}`",
202 self.inner.transport(),
203 self.inner
204 ))),
205 }
206 }
207
208 pub fn get_transport(&self) -> PyResult<PyChannelTransport> {
210 let transport = self.inner.transport();
211 match transport {
212 ChannelTransport::Tcp(mode) => match mode {
213 TcpMode::Localhost => Ok(PyChannelTransport::TcpWithLocalhost),
214 TcpMode::Hostname => Ok(PyChannelTransport::TcpWithHostname),
215 },
216 ChannelTransport::MetaTls(mode) => match mode {
217 TlsMode::Hostname => Ok(PyChannelTransport::MetaTlsWithHostname),
218 TlsMode::IpV6 => Ok(PyChannelTransport::MetaTlsWithIpV6),
219 },
220 ChannelTransport::Tls => Ok(PyChannelTransport::Tls),
221 ChannelTransport::Local => Ok(PyChannelTransport::Local),
222 ChannelTransport::Unix => Ok(PyChannelTransport::Unix),
223 }
224 }
225}
226
227impl From<PyChannelTransport> for ChannelTransport {
228 fn from(val: PyChannelTransport) -> Self {
229 match val {
230 PyChannelTransport::TcpWithLocalhost => ChannelTransport::Tcp(TcpMode::Localhost),
231 PyChannelTransport::TcpWithHostname => ChannelTransport::Tcp(TcpMode::Hostname),
232 PyChannelTransport::MetaTlsWithHostname => ChannelTransport::MetaTls(TlsMode::Hostname),
233 PyChannelTransport::MetaTlsWithIpV6 => ChannelTransport::MetaTls(TlsMode::IpV6),
234 PyChannelTransport::Tls => ChannelTransport::Tls,
235 PyChannelTransport::Local => ChannelTransport::Local,
236 PyChannelTransport::Unix => ChannelTransport::Unix,
237 }
238 }
239}
240
241pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
242 hyperactor_mod.add_class::<PyChannelTransport>()?;
243 hyperactor_mod.add_class::<PyBindSpec>()?;
244 hyperactor_mod.add_class::<PyChannelAddr>()?;
245 Ok(())
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 #[cfg_attr(not(fbcode_build), ignore)]
255 fn test_channel_any_and_parse() -> PyResult<()> {
256 for transport in [
258 PyChannelTransport::TcpWithLocalhost,
259 PyChannelTransport::TcpWithHostname,
260 PyChannelTransport::Unix,
261 PyChannelTransport::MetaTlsWithHostname,
262 PyChannelTransport::MetaTlsWithIpV6,
263 PyChannelTransport::Tls,
264 PyChannelTransport::Local,
265 ] {
266 let address = PyChannelAddr::any(transport)?;
267 let _ = PyChannelAddr::parse(&address)?;
268 }
269 Ok(())
270 }
271
272 #[test]
273 fn test_channel_addr_get_port() -> PyResult<()> {
274 assert_eq!(PyChannelAddr::parse("tcp![::]:26600")?.get_port()?, 26600);
275 assert_eq!(
276 PyChannelAddr::parse("metatls!devgpu1.pci.facebook.com:26600")?.get_port()?,
277 26600
278 );
279 assert_eq!(PyChannelAddr::parse("local!12345")?.get_port()?, 0);
280 assert_eq!(PyChannelAddr::parse("unix!@1a2b3c")?.get_port()?, 0);
281 Ok(())
282 }
283
284 #[test]
285 fn test_channel_addr_get_transport() -> PyResult<()> {
286 assert_eq!(
287 PyChannelAddr::parse("tcp![::1]:26600")?.get_transport()?,
288 PyChannelTransport::TcpWithLocalhost,
289 );
290 assert_eq!(
291 PyChannelAddr::parse("tcp![::]:26600")?.get_transport()?,
292 PyChannelTransport::TcpWithHostname,
293 );
294 assert_eq!(
295 PyChannelAddr::parse("metatls!devgpu001.pci.facebook.com:26600")?.get_transport()?,
296 PyChannelTransport::MetaTlsWithHostname
297 );
298 assert_eq!(
299 PyChannelAddr::parse("metatls!::1:26600")?.get_transport()?,
300 PyChannelTransport::MetaTlsWithIpV6
301 );
302 assert_eq!(
303 PyChannelAddr::parse("metatls!127.0.0.1:26600")?.get_transport()?,
305 PyChannelTransport::MetaTlsWithHostname
306 );
307 assert_eq!(
308 PyChannelAddr::parse("local!12345")?.get_transport()?,
309 PyChannelTransport::Local
310 );
311 assert_eq!(
312 PyChannelAddr::parse("unix!@1a2b3c")?.get_transport()?,
313 PyChannelTransport::Unix
314 );
315 Ok(())
316 }
317}