monarch_messages/
controller.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::Display;
10use hyperactor::ActorRef;
11use hyperactor::HandleClient;
12use hyperactor::Handler;
13use hyperactor::Named;
14use hyperactor::RefClient;
15use hyperactor::data::Serialized;
16use hyperactor::reference::ActorId;
17use pyo3::FromPyObject;
18use pyo3::IntoPyObject;
19use pyo3::IntoPyObjectExt;
20use pyo3::types::PyAnyMethods;
21use serde::Deserialize;
22use serde::Serialize;
23
24use crate::client::ClientActor;
25use crate::debugger::DebuggerAction;
26use crate::worker::Ref;
27
28/// Used to represent a slice of ranks. This is used to send messages to a subset of workers.
29#[derive(Serialize, Deserialize, Debug, Clone)]
30pub enum Ranks {
31    Slice(ndslice::Slice),
32    SliceList(Vec<ndslice::Slice>),
33}
34
35impl Ranks {
36    pub fn iter_slices<'a>(&'a self) -> std::slice::Iter<'a, ndslice::Slice> {
37        match self {
38            Self::Slice(slice) => std::slice::from_ref(slice).iter(),
39            Self::SliceList(slices) => slices.iter(),
40        }
41    }
42}
43
44/// The sequence number of the operation (message sent to a set of workers). Sequence numbers are
45/// generated by the client, and are strictly increasing.
46#[derive(
47    Debug,
48    Serialize,
49    Deserialize,
50    Clone,
51    PartialEq,
52    Eq,
53    PartialOrd,
54    Hash,
55    Ord,
56    Copy,
57    Named
58)]
59pub struct Seq(u64);
60
61impl Seq {
62    /// Returns the next logical sequence number.
63    #[inline]
64    pub fn next(&self) -> Self {
65        Self(self.0 + 1)
66    }
67
68    pub fn iter_between(start: Self, end: Self) -> impl Iterator<Item = Self> {
69        (start.0..end.0).map(Self)
70    }
71}
72
73impl Default for Seq {
74    #[inline]
75    fn default() -> Self {
76        Self(0)
77    }
78}
79
80impl Display for Seq {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "s{}", self.0)
83    }
84}
85
86impl From<u64> for Seq {
87    #[inline]
88    fn from(value: u64) -> Self {
89        Self(value)
90    }
91}
92
93impl From<Seq> for u64 {
94    #[inline]
95    fn from(value: Seq) -> u64 {
96        value.0
97    }
98}
99
100impl From<&Seq> for u64 {
101    #[inline]
102    fn from(value: &Seq) -> u64 {
103        value.0
104    }
105}
106
107impl FromPyObject<'_> for Seq {
108    fn extract_bound(ob: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
109        Ok(Self(ob.extract::<u64>()?))
110    }
111}
112
113impl<'py> IntoPyObject<'py> for Seq {
114    type Target = pyo3::PyAny;
115    type Output = pyo3::Bound<'py, Self::Target>;
116    type Error = pyo3::PyErr;
117
118    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
119        self.0.into_bound_py_any(py)
120    }
121}
122
123/// Worker operation errors.
124// TODO: Make other exceptions like CallFunctionError, etc. serializable and
125// send them back to the client through WorkerError.
126#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
127#[error("worker {worker_actor_id} error: {backtrace}")]
128pub struct WorkerError {
129    /// The message and/or stack trace of the error.
130    pub backtrace: String,
131
132    /// Actor id of the worker that had the error.
133    // TODO: arguably at this level we only care about the rank
134    pub worker_actor_id: ActorId,
135}
136
137/// Device operation failures.
138#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
139#[error("device {actor_id} error: {backtrace}")]
140pub struct DeviceFailure {
141    /// The message and/or stack trace of the error.
142    pub backtrace: String,
143
144    /// Address of the device that had the error.
145    pub address: String,
146
147    /// Actor id of the worker that had the error.
148    // TODO: arguably at this level we only care about the rank
149    pub actor_id: ActorId,
150}
151
152/// Controller messages. These define the contract that the controller has with the client
153/// and workers.
154#[derive(Handler, HandleClient, RefClient, Serialize, Deserialize, Debug, Named)]
155pub enum ControllerMessage {
156    /// Attach a client to the controller. This is used to send messages to the controller
157    /// and allow the controller to send messages back to the client.
158    Attach {
159        /// The client actor that is being attached.
160        client_actor: ActorRef<ClientActor>,
161
162        /// The response to indicate if the client was successfully attached.
163        #[reply]
164        response_port: hyperactor::OncePortRef<()>,
165    },
166
167    /// Notify the controller of the dependencies for a worker operation with the same seq.
168    /// It is the responsibility of the caller to ensure the seq is unique and strictly
169    /// increasing and matches the right message. This will be used by the controller for
170    /// history / data dependency tracking.
171    /// TODO: Support mutates here as well for proper dep management
172    Node {
173        seq: Seq,
174        /// The set of references defined (or re-defined) by the operation.
175        /// These are the operation's outputs.
176        defs: Vec<Ref>,
177        /// The set of references used by the operation. These are the operation's inputs.
178        uses: Vec<Ref>,
179    },
180
181    // Mark references as being dropped by the client: the client will never
182    // use these references again. Doing so results in undefined behavior.
183    DropRefs {
184        refs: Vec<Ref>,
185    },
186
187    /// Send a message to the workers mapping to the ranks provided in the
188    /// given slice. The message is serialized bytes with the underlying datatype being
189    /// [`crate::worker::WorkerMessage`] and serialization has been done in a hyperactor
190    /// compatible way i.e. using [`bincode`]. These bytes will be forwarded to
191    /// the workers as is. This helps provide isolation between the controller and the
192    /// workers and avoids the need to pay the cost to deserialize pytrees in the controller.
193    Send {
194        ranks: Ranks,
195        message: Serialized,
196    },
197
198    /// Response to a [`crate::worker::WorkerMessage::CallFunction`] message if
199    /// the function errored.
200    RemoteFunctionFailed {
201        seq: Seq,
202        error: WorkerError,
203    },
204
205    /// Response to a [`crate::worker::WorkerMessage::RequestStatus`] message. The payload will
206    /// be set to the seq provided in the original message + 1.
207    // TODO: T212094401 take a ActorRef
208    Status {
209        seq: Seq,
210        worker_actor_id: ActorId,
211        controller: bool,
212    },
213
214    /// Response to a [`crate::worker::WorkerMessage::SendValue`] message, containing the
215    /// requested value. The value is serialized as a `Serialized` and deserialization
216    /// is the responsibility of the caller. It should be deserialized as
217    /// [`monarch_types::PyTree<RValue>`] using the [`Serialized::deserialized`] method.
218    FetchResult {
219        seq: Seq,
220        value: Result<Serialized, WorkerError>,
221    },
222
223    /// This is used in unit tests to get the first incomplete seq for each rank as captured
224    /// by the controller.
225    GetFirstIncompleteSeqsUnitTestsOnly {
226        #[reply]
227        response_port: hyperactor::OncePortRef<Vec<Seq>>,
228    },
229
230    /// The message to schedule next supervision check task on the controller.
231    CheckSupervision {},
232
233    /// Debugger message sent from a debugger to be forwarded back to the client.
234    DebuggerMessage {
235        debugger_actor_id: ActorId,
236        action: DebuggerAction,
237    },
238}
239
240hyperactor::alias!(ControllerActor, ControllerMessage);