torch_sys/
backend.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::OnceLock;
10
11use anyhow::Result;
12use async_trait::async_trait;
13use cxx::CxxVector;
14use pyo3::prelude::*;
15use pyo3::types::PyDict;
16use tokio::runtime::Handle;
17
18use crate::Tensor;
19use crate::bridge::ffi;
20pub use crate::bridge::ffi::AllToAllOptions;
21pub use crate::bridge::ffi::AllreduceOptions;
22pub use crate::bridge::ffi::BarrierOptions;
23pub use crate::bridge::ffi::BroadcastOptions;
24pub use crate::bridge::ffi::GatherOptions;
25pub use crate::bridge::ffi::ReduceOp;
26pub use crate::bridge::ffi::ReduceOptions;
27pub use crate::bridge::ffi::ReduceScatterOptions;
28pub use crate::bridge::ffi::ScatterOptions;
29
30static REGISTER: OnceLock<()> = OnceLock::new();
31static INIT: OnceLock<()> = OnceLock::new();
32
33#[async_trait]
34pub trait Work: Sync + Send + 'static {
35    type Error;
36    async fn wait(&self) -> Result<(), Self::Error>;
37    async fn is_completed(&self) -> Result<bool, Self::Error>;
38}
39
40/// A wrapper around the `Backend` trait to make it usable from the C++ bridge.
41pub(crate) struct BoxedWork(pub Box<dyn Work<Error = anyhow::Error>>);
42
43// TODO(agallagher): Support the `Work` return type for async -- currently we
44// do everything sync and bridge into async code.
45impl BoxedWork {
46    pub fn wait(&self) -> Result<(), anyhow::Error> {
47        // Re-enter the parents runtime to run async code.
48        Handle::current().block_on(self.0.wait())
49    }
50    pub fn is_completed(&self) -> Result<bool, anyhow::Error> {
51        // Re-enter the parents runtime to run async code.
52        Handle::current().block_on(self.0.is_completed())
53    }
54}
55
56// A Rust translation of the `Backend` base class in torch, used to create custom
57// process group backends:
58// https://github.com/pytorch/pytorch/blob/6178be822dc3fb307e950c337876f05dd63582b2/torch/csrc/distributed/c10d/Backend.hpp#L20
59#[async_trait]
60pub trait Backend: Sync + Send + 'static {
61    type Error;
62    async fn allreduce(
63        &self,
64        tensors: &CxxVector<Tensor>,
65        opts: AllreduceOptions,
66    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
67    async fn allgather(
68        &self,
69        output: &CxxVector<Tensor>,
70        input: &Tensor,
71    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
72    async fn _allgather_base(
73        &self,
74        output: &Tensor,
75        input: &Tensor,
76    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
77    async fn barrier(
78        &self,
79        opts: BarrierOptions,
80    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
81    async fn reduce(
82        &self,
83        input: &Tensor,
84        opts: ReduceOptions,
85    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
86    async fn _reduce_scatter_base(
87        &self,
88        output: &Tensor,
89        input: &Tensor,
90        opts: ReduceScatterOptions,
91    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
92    async fn send(
93        &self,
94        tensors: &CxxVector<Tensor>,
95        dst_rank: i32,
96        tag: i32,
97    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
98    async fn recv(
99        &self,
100        tensors: &CxxVector<Tensor>,
101        src_rank: i32,
102        tag: i32,
103    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
104    async fn gather(
105        &self,
106        outputs: &CxxVector<Tensor>,
107        input: &Tensor,
108        opts: GatherOptions,
109    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
110    async fn scatter(
111        &self,
112        output: &Tensor,
113        inputs: &CxxVector<Tensor>,
114        opts: ScatterOptions,
115    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
116    async fn broadcast(
117        &self,
118        tensors: &CxxVector<Tensor>,
119        opts: BroadcastOptions,
120    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
121    async fn alltoall_base(
122        &self,
123        output_buffer: &Tensor,
124        input_buffer: &Tensor,
125        opts: AllToAllOptions,
126    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
127    async fn alltoall(
128        &self,
129        output_tensors: &CxxVector<Tensor>,
130        input_tensors: &CxxVector<Tensor>,
131        opts: AllToAllOptions,
132    ) -> Result<Box<dyn Work<Error = Self::Error>>, Self::Error>;
133}
134
135/// A wrapper around the `Backend` trait to make it usable from the C++ bridge.
136pub(crate) struct BoxedBackend(pub Box<dyn Backend<Error = anyhow::Error>>);
137
138impl BoxedBackend {
139    pub fn allreduce(
140        &self,
141        tensors: &CxxVector<Tensor>,
142        opts: AllreduceOptions,
143    ) -> Result<Box<BoxedWork>, anyhow::Error> {
144        // Re-enter the parents runtime to run async code.
145        Ok(Box::new(BoxedWork(
146            Handle::current().block_on(self.0.allreduce(tensors, opts))?,
147        )))
148    }
149
150    pub fn allgather(
151        &self,
152        output: &CxxVector<Tensor>,
153        input: &Tensor,
154    ) -> Result<Box<BoxedWork>, anyhow::Error> {
155        // Re-enter the parents runtime to run async code.
156        Ok(Box::new(BoxedWork(
157            Handle::current().block_on(self.0.allgather(output, input))?,
158        )))
159    }
160
161    pub fn _allgather_base(
162        &self,
163        output: &Tensor,
164        input: &Tensor,
165    ) -> Result<Box<BoxedWork>, anyhow::Error> {
166        // Re-enter the parents runtime to run async code.
167        Ok(Box::new(BoxedWork(
168            Handle::current().block_on(self.0._allgather_base(output, input))?,
169        )))
170    }
171
172    pub fn barrier(&self, opts: BarrierOptions) -> Result<Box<BoxedWork>, anyhow::Error> {
173        // Re-enter the parents runtime to run async code.
174        Ok(Box::new(BoxedWork(
175            Handle::current().block_on(self.0.barrier(opts))?,
176        )))
177    }
178
179    pub fn reduce(
180        &self,
181        input: &Tensor,
182        opts: ReduceOptions,
183    ) -> Result<Box<BoxedWork>, anyhow::Error> {
184        // Re-enter the parents runtime to run async code.
185        Ok(Box::new(BoxedWork(
186            Handle::current().block_on(self.0.reduce(input, opts))?,
187        )))
188    }
189
190    pub fn _reduce_scatter_base(
191        &self,
192        output: &Tensor,
193        input: &Tensor,
194        opts: ReduceScatterOptions,
195    ) -> Result<Box<BoxedWork>, anyhow::Error> {
196        // Re-enter the parents runtime to run async code.
197        Ok(Box::new(BoxedWork(Handle::current().block_on(
198            self.0._reduce_scatter_base(output, input, opts),
199        )?)))
200    }
201
202    pub fn send(
203        &self,
204        tensors: &CxxVector<Tensor>,
205        dst_rank: i32,
206        tag: i32,
207    ) -> Result<Box<BoxedWork>, anyhow::Error> {
208        // Re-enter the parents runtime to run async code.
209        Ok(Box::new(BoxedWork(
210            Handle::current().block_on(self.0.send(tensors, dst_rank, tag))?,
211        )))
212    }
213
214    pub fn recv(
215        &self,
216        tensors: &CxxVector<Tensor>,
217        src_rank: i32,
218        tag: i32,
219    ) -> Result<Box<BoxedWork>, anyhow::Error> {
220        // Re-enter the parents runtime to run async code.
221        Ok(Box::new(BoxedWork(
222            Handle::current().block_on(self.0.recv(tensors, src_rank, tag))?,
223        )))
224    }
225
226    pub fn gather(
227        &self,
228        outputs: &CxxVector<Tensor>,
229        input: &Tensor,
230        opts: GatherOptions,
231    ) -> Result<Box<BoxedWork>, anyhow::Error> {
232        // Re-enter the parents runtime to run async code.
233        Ok(Box::new(BoxedWork(
234            Handle::current().block_on(self.0.gather(outputs, input, opts))?,
235        )))
236    }
237
238    pub fn scatter(
239        &self,
240        output: &Tensor,
241        inputs: &CxxVector<Tensor>,
242        opts: ScatterOptions,
243    ) -> Result<Box<BoxedWork>, anyhow::Error> {
244        Ok(Box::new(BoxedWork(
245            Handle::current().block_on(self.0.scatter(output, inputs, opts))?,
246        )))
247    }
248
249    pub fn broadcast(
250        &self,
251        tensors: &CxxVector<Tensor>,
252        opts: BroadcastOptions,
253    ) -> Result<Box<BoxedWork>, anyhow::Error> {
254        // Re-enter the parents runtime to run async code.
255        Ok(Box::new(BoxedWork(
256            Handle::current().block_on(self.0.broadcast(tensors, opts))?,
257        )))
258    }
259
260    pub fn alltoall_base(
261        &self,
262        output_buffer: &Tensor,
263        input_buffer: &Tensor,
264        opts: AllToAllOptions,
265    ) -> Result<Box<BoxedWork>, anyhow::Error> {
266        // Re-enter the parents runtime to run async code.
267        Ok(Box::new(BoxedWork(Handle::current().block_on(
268            self.0.alltoall_base(output_buffer, input_buffer, opts),
269        )?)))
270    }
271
272    pub fn alltoall(
273        &self,
274        output_tensors: &CxxVector<Tensor>,
275        input_tensors: &CxxVector<Tensor>,
276        opts: AllToAllOptions,
277    ) -> Result<Box<BoxedWork>, anyhow::Error> {
278        // Re-enter the parents runtime to run async code.
279        Ok(Box::new(BoxedWork(Handle::current().block_on(
280            self.0.alltoall(output_tensors, input_tensors, opts),
281        )?)))
282    }
283}
284
285fn register(py: Python<'_>) -> PyResult<()> {
286    // Import torch.distributed module
287    let module = py.import("torch.distributed")?;
288
289    // Get the register_backend attribute from Backend
290    let backend = module.getattr("Backend")?;
291    let register_backend = backend.getattr("register_backend")?;
292
293    // Create a Python callable from our Rust function
294    let create_backend = ffi::create_monarch_backend().into_pyobject(py)?;
295
296    // We use the extended API so that callers can pass in the inner, pre-
297    // initialized backend via `pg_options`.
298    let kwargs = PyDict::new(py);
299    kwargs.set_item("devices", vec!["cuda"])?;
300    kwargs.set_item("extended_api", true)?;
301
302    // Register the backend
303    register_backend
304        .call(("monarch", create_backend), Some(&kwargs))
305        .inspect_err(|e| tracing::error!("failed init backend: {}", e))?;
306
307    Ok(())
308}
309
310fn init_process_group(py: Python<'_>, world_size: usize, rank: usize) -> PyResult<()> {
311    let torchd = py.import("torch.distributed")?;
312
313    // Get the register_backend attribute from Backend
314    let backend = torchd.getattr("Backend")?;
315    let register_backend = backend.getattr("register_backend")?;
316
317    // Create a Python callable from our Rust function
318    let create_backend = ffi::create_null_backend().into_pyobject(py)?;
319
320    // We use the extended API so that callers can pass in the inner, pre-
321    // initialized backend via `pg_options`.
322    let kwargs = PyDict::new(py);
323    kwargs.set_item("extended_api", true)?;
324
325    // Register the backend
326    register_backend
327        .call(("null", create_backend), Some(&kwargs))
328        .inspect_err(|e| tracing::error!("failed init backend: {}", e))?;
329
330    // Init the process group.
331    let kwargs = PyDict::new(py);
332    // Use a special noop backend that errors out if it's actually used.
333    kwargs.set_item("backend", "null")?;
334    kwargs.set_item("rank", rank)?;
335    // Since the communicator we give it is pre-initialized, we don't acually
336    // end up using the store, but the `init_process_group` requires that one
337    // is passed in.
338    kwargs.set_item("store", torchd.call_method1("FileStore", ("/dev/null",))?)?;
339    kwargs.set_item("world_size", world_size)?;
340
341    torchd.call_method("init_process_group", (), Some(&kwargs))?;
342
343    Ok(())
344}
345
346pub fn ensure_init_process_group(py: Python<'_>, world_size: usize, rank: usize) -> PyResult<()> {
347    py.allow_threads(move || {
348        INIT.get_or_try_init(move || {
349            Python::with_gil(|py| init_process_group(py, world_size, rank))
350        })
351    })?;
352    Ok(())
353}
354
355pub fn new_group<'py, B: Backend<Error = anyhow::Error>>(
356    py: Python<'py>,
357    ranks: Vec<usize>,
358    backend: B,
359) -> PyResult<Bound<'py, PyAny>> {
360    // Make sure we've registered the monarch backend.
361    py.allow_threads(|| REGISTER.get_or_try_init(|| Python::with_gil(register)))?;
362
363    let kwargs = PyDict::new(py);
364    kwargs.set_item("backend", "monarch")?;
365    kwargs.set_item("ranks", ranks)?;
366    kwargs.set_item(
367        "pg_options",
368        Box::into_raw(Box::new(BoxedBackend(Box::new(backend)))) as u64,
369    )?;
370
371    let torchd = py.import("torch.distributed")?;
372
373    torchd.call_method("new_group", (), Some(&kwargs))
374}