1use std::sync::OnceLock;
10
11use anyhow::Result;
12use async_trait::async_trait;
13use cxx::CxxVector;
14use pyo3::prelude::*;
15use pyo3::types::PyDict;
16use tokio::runtime::Handle;
17
18use crate::Tensor;
19use crate::bridge::ffi;
20pub use crate::bridge::ffi::AllToAllOptions;
21pub use crate::bridge::ffi::AllreduceOptions;
22pub use crate::bridge::ffi::BarrierOptions;
23pub use crate::bridge::ffi::BroadcastOptions;
24pub use crate::bridge::ffi::GatherOptions;
25pub use crate::bridge::ffi::ReduceOp;
26pub use crate::bridge::ffi::ReduceOptions;
27pub use crate::bridge::ffi::ReduceScatterOptions;
28pub use crate::bridge::ffi::ScatterOptions;
29
30static REGISTER: OnceLock<()> = OnceLock::new();
31static INIT: OnceLock<()> = OnceLock::new();
32
33#[async_trait]
34pub trait Work: Sync + Send + 'static {
35 type Error;
36 async fn wait(&self) -> Result<(), Self::Error>;
37 async fn is_completed(&self) -> Result<bool, Self::Error>;
38}
39
40pub(crate) struct BoxedWork(pub Box<dyn Work<Error = anyhow::Error>>);
42
43impl BoxedWork {
46 pub fn wait(&self) -> Result<(), anyhow::Error> {
47 Handle::current().block_on(self.0.wait())
49 }
50 pub fn is_completed(&self) -> Result<bool, anyhow::Error> {
51 Handle::current().block_on(self.0.is_completed())
53 }
54}
55
56#[async_trait]
60pub trait Backend: Sync + Send + 'static {
61 type Error;
62 async fn allreduce(
63 &self,
64 tensors: &CxxVector<Tensor>,
65 opts: AllreduceOptions,
66 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
67 async fn allgather(
68 &self,
69 output: &CxxVector<Tensor>,
70 input: &Tensor,
71 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
72 async fn _allgather_base(
73 &self,
74 output: &Tensor,
75 input: &Tensor,
76 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
77 async fn barrier(
78 &self,
79 opts: BarrierOptions,
80 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
81 async fn reduce(
82 &self,
83 input: &Tensor,
84 opts: ReduceOptions,
85 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
86 async fn _reduce_scatter_base(
87 &self,
88 output: &Tensor,
89 input: &Tensor,
90 opts: ReduceScatterOptions,
91 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
92 async fn send(
93 &self,
94 tensors: &CxxVector<Tensor>,
95 dst_rank: i32,
96 tag: i32,
97 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
98 async fn recv(
99 &self,
100 tensors: &CxxVector<Tensor>,
101 src_rank: i32,
102 tag: i32,
103 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
104 async fn gather(
105 &self,
106 outputs: &CxxVector<Tensor>,
107 input: &Tensor,
108 opts: GatherOptions,
109 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
110 async fn scatter(
111 &self,
112 output: &Tensor,
113 inputs: &CxxVector<Tensor>,
114 opts: ScatterOptions,
115 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
116 async fn broadcast(
117 &self,
118 tensors: &CxxVector<Tensor>,
119 opts: BroadcastOptions,
120 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
121 async fn alltoall_base(
122 &self,
123 output_buffer: &Tensor,
124 input_buffer: &Tensor,
125 opts: AllToAllOptions,
126 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
127 async fn alltoall(
128 &self,
129 output_tensors: &CxxVector<Tensor>,
130 input_tensors: &CxxVector<Tensor>,
131 opts: AllToAllOptions,
132 ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
133}
134
135pub(crate) struct BoxedBackend(pub Box<dyn Backend<Error = anyhow::Error>>);
137
138impl BoxedBackend {
139 pub fn allreduce(
140 &self,
141 tensors: &CxxVector<Tensor>,
142 opts: AllreduceOptions,
143 ) -> Result<Box<BoxedWork>, anyhow::Error> {
144 Ok(Box::new(BoxedWork(
146 Handle::current().block_on(self.0.allreduce(tensors, opts))?,
147 )))
148 }
149
150 pub fn allgather(
151 &self,
152 output: &CxxVector<Tensor>,
153 input: &Tensor,
154 ) -> Result<Box<BoxedWork>, anyhow::Error> {
155 Ok(Box::new(BoxedWork(
157 Handle::current().block_on(self.0.allgather(output, input))?,
158 )))
159 }
160
161 pub fn _allgather_base(
162 &self,
163 output: &Tensor,
164 input: &Tensor,
165 ) -> Result<Box<BoxedWork>, anyhow::Error> {
166 Ok(Box::new(BoxedWork(
168 Handle::current().block_on(self.0._allgather_base(output, input))?,
169 )))
170 }
171
172 pub fn barrier(&self, opts: BarrierOptions) -> Result<Box<BoxedWork>, anyhow::Error> {
173 Ok(Box::new(BoxedWork(
175 Handle::current().block_on(self.0.barrier(opts))?,
176 )))
177 }
178
179 pub fn reduce(
180 &self,
181 input: &Tensor,
182 opts: ReduceOptions,
183 ) -> Result<Box<BoxedWork>, anyhow::Error> {
184 Ok(Box::new(BoxedWork(
186 Handle::current().block_on(self.0.reduce(input, opts))?,
187 )))
188 }
189
190 pub fn _reduce_scatter_base(
191 &self,
192 output: &Tensor,
193 input: &Tensor,
194 opts: ReduceScatterOptions,
195 ) -> Result<Box<BoxedWork>, anyhow::Error> {
196 Ok(Box::new(BoxedWork(Handle::current().block_on(
198 self.0._reduce_scatter_base(output, input, opts),
199 )?)))
200 }
201
202 pub fn send(
203 &self,
204 tensors: &CxxVector<Tensor>,
205 dst_rank: i32,
206 tag: i32,
207 ) -> Result<Box<BoxedWork>, anyhow::Error> {
208 Ok(Box::new(BoxedWork(
210 Handle::current().block_on(self.0.send(tensors, dst_rank, tag))?,
211 )))
212 }
213
214 pub fn recv(
215 &self,
216 tensors: &CxxVector<Tensor>,
217 src_rank: i32,
218 tag: i32,
219 ) -> Result<Box<BoxedWork>, anyhow::Error> {
220 Ok(Box::new(BoxedWork(
222 Handle::current().block_on(self.0.recv(tensors, src_rank, tag))?,
223 )))
224 }
225
226 pub fn gather(
227 &self,
228 outputs: &CxxVector<Tensor>,
229 input: &Tensor,
230 opts: GatherOptions,
231 ) -> Result<Box<BoxedWork>, anyhow::Error> {
232 Ok(Box::new(BoxedWork(
234 Handle::current().block_on(self.0.gather(outputs, input, opts))?,
235 )))
236 }
237
238 pub fn scatter(
239 &self,
240 output: &Tensor,
241 inputs: &CxxVector<Tensor>,
242 opts: ScatterOptions,
243 ) -> Result<Box<BoxedWork>, anyhow::Error> {
244 Ok(Box::new(BoxedWork(
245 Handle::current().block_on(self.0.scatter(output, inputs, opts))?,
246 )))
247 }
248
249 pub fn broadcast(
250 &self,
251 tensors: &CxxVector<Tensor>,
252 opts: BroadcastOptions,
253 ) -> Result<Box<BoxedWork>, anyhow::Error> {
254 Ok(Box::new(BoxedWork(
256 Handle::current().block_on(self.0.broadcast(tensors, opts))?,
257 )))
258 }
259
260 pub fn alltoall_base(
261 &self,
262 output_buffer: &Tensor,
263 input_buffer: &Tensor,
264 opts: AllToAllOptions,
265 ) -> Result<Box<BoxedWork>, anyhow::Error> {
266 Ok(Box::new(BoxedWork(Handle::current().block_on(
268 self.0.alltoall_base(output_buffer, input_buffer, opts),
269 )?)))
270 }
271
272 pub fn alltoall(
273 &self,
274 output_tensors: &CxxVector<Tensor>,
275 input_tensors: &CxxVector<Tensor>,
276 opts: AllToAllOptions,
277 ) -> Result<Box<BoxedWork>, anyhow::Error> {
278 Ok(Box::new(BoxedWork(Handle::current().block_on(
280 self.0.alltoall(output_tensors, input_tensors, opts),
281 )?)))
282 }
283}
284
285fn register(py: Python<'_>) -> PyResult<()> {
286 let module = py.import("torch.distributed")?;
288
289 let backend = module.getattr("Backend")?;
291 let register_backend = backend.getattr("register_backend")?;
292
293 let create_backend = ffi::create_monarch_backend().into_pyobject(py)?;
295
296 let kwargs = PyDict::new(py);
299 kwargs.set_item("devices", vec!["cuda"])?;
300 kwargs.set_item("extended_api", true)?;
301
302 register_backend
304 .call(("monarch", create_backend), Some(&kwargs))
305 .inspect_err(|e| tracing::error!("failed init backend: {}", e))?;
306
307 Ok(())
308}
309
310fn init_process_group(py: Python<'_>, world_size: usize, rank: usize) -> PyResult<()> {
311 let torchd = py.import("torch.distributed")?;
312
313 let backend = torchd.getattr("Backend")?;
315 let register_backend = backend.getattr("register_backend")?;
316
317 let create_backend = ffi::create_null_backend().into_pyobject(py)?;
319
320 let kwargs = PyDict::new(py);
323 kwargs.set_item("extended_api", true)?;
324
325 register_backend
327 .call(("null", create_backend), Some(&kwargs))
328 .inspect_err(|e| tracing::error!("failed init backend: {}", e))?;
329
330 let kwargs = PyDict::new(py);
332 kwargs.set_item("backend", "null")?;
334 kwargs.set_item("rank", rank)?;
335 kwargs.set_item("store", torchd.call_method1("FileStore", ("/dev/null",))?)?;
339 kwargs.set_item("world_size", world_size)?;
340
341 torchd.call_method("init_process_group", (), Some(&kwargs))?;
342
343 Ok(())
344}
345
346pub fn ensure_init_process_group(py: Python<'_>, world_size: usize, rank: usize) -> PyResult<()> {
347 py.allow_threads(move || {
348 INIT.get_or_try_init(move || {
349 Python::with_gil(|py| init_process_group(py, world_size, rank))
350 })
351 })?;
352 Ok(())
353}
354
355pub fn new_group<'py, B: Backend<Error = anyhow::Error>>(
356 py: Python<'py>,
357 ranks: Vec<usize>,
358 backend: B,
359) -> PyResult<Bound<'py, PyAny>> {
360 py.allow_threads(|| REGISTER.get_or_try_init(|| Python::with_gil(register)))?;
362
363 let kwargs = PyDict::new(py);
364 kwargs.set_item("backend", "monarch")?;
365 kwargs.set_item("ranks", ranks)?;
366 kwargs.set_item(
367 "pg_options",
368 Box::into_raw(Box::new(BoxedBackend(Box::new(backend)))) as u64,
369 )?;
370
371 let torchd = py.import("torch.distributed")?;
372
373 torchd.call_method("new_group", (), Some(&kwargs))
374}