monarch_messages/
controller.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::Display;
10use hyperactor::HandleClient;
11use hyperactor::Handler;
12use hyperactor::RefClient;
13use hyperactor::reference;
14use pyo3::FromPyObject;
15use pyo3::IntoPyObject;
16use pyo3::IntoPyObjectExt;
17use pyo3::types::PyAnyMethods;
18use serde::Deserialize;
19use serde::Serialize;
20use typeuri::Named;
21
22use crate::client::ClientActor;
23use crate::debugger::DebuggerAction;
24use crate::worker::Ref;
25
26/// Used to represent a slice of ranks. This is used to send messages to a subset of workers.
27#[derive(Serialize, Deserialize, Debug, Clone)]
28pub enum Ranks {
29    Slice(ndslice::Slice),
30    SliceList(Vec<ndslice::Slice>),
31}
32
33impl Ranks {
34    pub fn iter_slices<'a>(&'a self) -> std::slice::Iter<'a, ndslice::Slice> {
35        match self {
36            Self::Slice(slice) => std::slice::from_ref(slice).iter(),
37            Self::SliceList(slices) => slices.iter(),
38        }
39    }
40}
41
42/// The sequence number of the operation (message sent to a set of workers). Sequence numbers are
43/// generated by the client, and are strictly increasing.
44#[derive(
45    Debug,
46    Serialize,
47    Deserialize,
48    Clone,
49    PartialEq,
50    Eq,
51    PartialOrd,
52    Hash,
53    Ord,
54    Copy,
55    Named
56)]
57pub struct Seq(u64);
58
59impl Seq {
60    /// Returns the next logical sequence number.
61    #[inline]
62    pub fn next(&self) -> Self {
63        Self(self.0 + 1)
64    }
65
66    pub fn iter_between(start: Self, end: Self) -> impl Iterator<Item = Self> {
67        (start.0..end.0).map(Self)
68    }
69}
70
71impl Default for Seq {
72    #[inline]
73    fn default() -> Self {
74        Self(0)
75    }
76}
77
78impl Display for Seq {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "s{}", self.0)
81    }
82}
83
84impl From<u64> for Seq {
85    #[inline]
86    fn from(value: u64) -> Self {
87        Self(value)
88    }
89}
90
91impl From<Seq> for u64 {
92    #[inline]
93    fn from(value: Seq) -> u64 {
94        value.0
95    }
96}
97
98impl From<&Seq> for u64 {
99    #[inline]
100    fn from(value: &Seq) -> u64 {
101        value.0
102    }
103}
104
105impl FromPyObject<'_> for Seq {
106    fn extract_bound(ob: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
107        Ok(Self(ob.extract::<u64>()?))
108    }
109}
110
111impl<'py> IntoPyObject<'py> for Seq {
112    type Target = pyo3::PyAny;
113    type Output = pyo3::Bound<'py, Self::Target>;
114    type Error = pyo3::PyErr;
115
116    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
117        self.0.into_bound_py_any(py)
118    }
119}
120
121/// Worker operation errors.
122// TODO: Make other exceptions like CallFunctionError, etc. serializable and
123// send them back to the client through WorkerError.
124#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
125#[error("worker {worker_actor_id} error: {backtrace}")]
126pub struct WorkerError {
127    /// The message and/or stack trace of the error.
128    pub backtrace: String,
129
130    /// Actor id of the worker that had the error.
131    // TODO: arguably at this level we only care about the rank
132    pub worker_actor_id: reference::ActorId,
133}
134
135/// Device operation failures.
136#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
137#[error("device {actor_id} error: {backtrace}")]
138pub struct DeviceFailure {
139    /// The message and/or stack trace of the error.
140    pub backtrace: String,
141
142    /// Address of the device that had the error.
143    pub address: String,
144
145    /// Actor id of the worker that had the error.
146    // TODO: arguably at this level we only care about the rank
147    pub actor_id: reference::ActorId,
148}
149
150/// Controller messages. These define the contract that the controller has with the client
151/// and workers.
152#[derive(Handler, HandleClient, RefClient, Serialize, Deserialize, Debug, Named)]
153pub enum ControllerMessage {
154    /// Attach a client to the controller. This is used to send messages to the controller
155    /// and allow the controller to send messages back to the client.
156    Attach {
157        /// The client actor that is being attached.
158        client_actor: reference::ActorRef<ClientActor>,
159
160        /// The response to indicate if the client was successfully attached.
161        #[reply]
162        response_port: reference::OncePortRef<()>,
163    },
164
165    /// Notify the controller of the dependencies for a worker operation with the same seq.
166    /// It is the responsibility of the caller to ensure the seq is unique and strictly
167    /// increasing and matches the right message. This will be used by the controller for
168    /// history / data dependency tracking.
169    /// TODO: Support mutates here as well for proper dep management
170    Node {
171        seq: Seq,
172        /// The set of references defined (or re-defined) by the operation.
173        /// These are the operation's outputs.
174        defs: Vec<Ref>,
175        /// The set of references used by the operation. These are the operation's inputs.
176        uses: Vec<Ref>,
177    },
178
179    // Mark references as being dropped by the client: the client will never
180    // use these references again. Doing so results in undefined behavior.
181    DropRefs {
182        refs: Vec<Ref>,
183    },
184
185    /// Send a message to the workers mapping to the ranks provided in the
186    /// given slice. The message is serialized bytes with the underlying datatype being
187    /// [`crate::worker::WorkerMessage`] and serialization has been done in a hyperactor
188    /// compatible way i.e. using [`bincode`]. These bytes will be forwarded to
189    /// the workers as is. This helps provide isolation between the controller and the
190    /// workers and avoids the need to pay the cost to deserialize pytrees in the controller.
191    Send {
192        ranks: Ranks,
193        message: wirevalue::Any,
194    },
195
196    /// Response to a [`crate::worker::WorkerMessage::CallFunction`] message if
197    /// the function errored.
198    RemoteFunctionFailed {
199        seq: Seq,
200        error: WorkerError,
201    },
202
203    /// Response to a [`crate::worker::WorkerMessage::RequestStatus`] message. The payload will
204    /// be set to the seq provided in the original message + 1.
205    // TODO: T212094401 take a ActorRef
206    Status {
207        seq: Seq,
208        worker_actor_id: reference::ActorId,
209        controller: bool,
210    },
211
212    /// Response to a [`crate::worker::WorkerMessage::SendValue`] message, containing the
213    /// requested value. The value is serialized as a `Any` and deserialization
214    /// is the responsibility of the caller. It should be deserialized as
215    /// [`monarch_types::PyTree<RValue>`] using the [`wirevalue::Any::deserialized`] method.
216    FetchResult {
217        seq: Seq,
218        value: Result<wirevalue::Any, WorkerError>,
219    },
220
221    /// This is used in unit tests to get the first incomplete seq for each rank as captured
222    /// by the controller.
223    GetFirstIncompleteSeqsUnitTestsOnly {
224        #[reply]
225        response_port: reference::OncePortRef<Vec<Seq>>,
226    },
227
228    /// The message to schedule next supervision check task on the controller.
229    CheckSupervision {},
230
231    /// Debugger message sent from a debugger to be forwarded back to the client.
232    DebuggerMessage {
233        debugger_actor_id: reference::ActorId,
234        action: DebuggerAction,
235    },
236}
237wirevalue::register_type!(ControllerMessage);
238
239hyperactor::behavior!(ControllerActor, ControllerMessage);