Skip to main content

monarch_types/
nccl_types.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! GPU-independent types for NCCL-based communication.
10//!
11//! These types are used in the message protocol between tensor workers and must
12//! be available even in CPU-only builds where `nccl-sys` is not compiled.
13
14use std::fmt;
15use std::fmt::Write;
16
17use serde::Deserialize;
18use serde::Deserializer;
19use serde::Serialize;
20use serde::Serializer;
21use serde::ser::SerializeSeq;
22
23/// Rust version of `ncclRedOp_t`.
24#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
25pub enum ReduceOp {
26    Sum = 0,
27    Prod = 1,
28    Max = 2,
29    Min = 3,
30    Avg = 4,
31}
32
33/// Wire-compatible representation of `ncclUniqueId`.
34///
35/// This is a 128-byte opaque identifier used to bootstrap NCCL communicators.
36/// The struct layout and serialization format match `ncclUniqueId` from `nccl-sys`
37/// exactly, so that messages are wire-compatible regardless of whether the sender
38/// or receiver was built with GPU support.
39#[repr(C)]
40#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
41pub struct NcclUniqueId {
42    #[serde(
43        serialize_with = "serialize_array",
44        deserialize_with = "deserialize_array"
45    )]
46    pub internal: [::std::os::raw::c_char; 128usize],
47}
48
49fn deserialize_array<'de, D>(deserializer: D) -> Result<[::std::os::raw::c_char; 128], D::Error>
50where
51    D: Deserializer<'de>,
52{
53    let vec: Vec<::std::os::raw::c_char> = Deserialize::deserialize(deserializer)?;
54    vec.try_into().map_err(|v: Vec<::std::os::raw::c_char>| {
55        serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
56    })
57}
58
59fn serialize_array<S>(
60    array: &[::std::os::raw::c_char; 128],
61    serializer: S,
62) -> Result<S::Ok, S::Error>
63where
64    S: Serializer,
65{
66    let mut seq = serializer.serialize_seq(Some(128))?;
67    for element in array {
68        seq.serialize_element(element)?;
69    }
70    seq.end()
71}
72
73/// Binding for `ncclUniqueId`.
74///
75/// Wraps the raw 128-byte NCCL unique identifier. On GPU builds, this can be
76/// created via `nccl-sys`; on CPU builds, it can only be deserialized from a
77/// message sent by a GPU-capable peer.
78#[derive(Clone, Serialize, Deserialize)]
79pub struct UniqueId {
80    inner: NcclUniqueId,
81}
82
83impl fmt::Debug for UniqueId {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("UniqueId")
86            .field(
87                "inner",
88                &format_args!(
89                    "{}",
90                    self.inner
91                        .internal
92                        .iter()
93                        .fold(String::new(), |mut output, b| {
94                            let _ = write!(output, "{:02x}", b);
95                            output
96                        })
97                ),
98            )
99            .finish()
100    }
101}
102
103impl UniqueId {
104    /// Create a `UniqueId` from raw bytes.
105    pub fn from_internal(internal: [::std::os::raw::c_char; 128]) -> Self {
106        Self {
107            inner: NcclUniqueId { internal },
108        }
109    }
110
111    /// Access the raw bytes.
112    pub fn internal(&self) -> &[::std::os::raw::c_char; 128] {
113        &self.inner.internal
114    }
115
116    /// Access the inner `NcclUniqueId`.
117    pub fn as_nccl_unique_id(&self) -> &NcclUniqueId {
118        &self.inner
119    }
120
121    /// Consume and return the inner `NcclUniqueId`.
122    pub fn into_nccl_unique_id(self) -> NcclUniqueId {
123        self.inner
124    }
125}