monarch_messages/
controller.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::Display;
10use hyperactor::ActorRef;
11use hyperactor::HandleClient;
12use hyperactor::Handler;
13use hyperactor::RefClient;
14use hyperactor::reference::ActorId;
15use pyo3::FromPyObject;
16use pyo3::IntoPyObject;
17use pyo3::IntoPyObjectExt;
18use pyo3::types::PyAnyMethods;
19use serde::Deserialize;
20use serde::Serialize;
21use typeuri::Named;
22
23use crate::client::ClientActor;
24use crate::debugger::DebuggerAction;
25use crate::worker::Ref;
26
27/// Used to represent a slice of ranks. This is used to send messages to a subset of workers.
28#[derive(Serialize, Deserialize, Debug, Clone)]
29pub enum Ranks {
30    Slice(ndslice::Slice),
31    SliceList(Vec<ndslice::Slice>),
32}
33
34impl Ranks {
35    pub fn iter_slices<'a>(&'a self) -> std::slice::Iter<'a, ndslice::Slice> {
36        match self {
37            Self::Slice(slice) => std::slice::from_ref(slice).iter(),
38            Self::SliceList(slices) => slices.iter(),
39        }
40    }
41}
42
43/// The sequence number of the operation (message sent to a set of workers). Sequence numbers are
44/// generated by the client, and are strictly increasing.
45#[derive(
46    Debug,
47    Serialize,
48    Deserialize,
49    Clone,
50    PartialEq,
51    Eq,
52    PartialOrd,
53    Hash,
54    Ord,
55    Copy,
56    Named
57)]
58pub struct Seq(u64);
59
60impl Seq {
61    /// Returns the next logical sequence number.
62    #[inline]
63    pub fn next(&self) -> Self {
64        Self(self.0 + 1)
65    }
66
67    pub fn iter_between(start: Self, end: Self) -> impl Iterator<Item = Self> {
68        (start.0..end.0).map(Self)
69    }
70}
71
72impl Default for Seq {
73    #[inline]
74    fn default() -> Self {
75        Self(0)
76    }
77}
78
79impl Display for Seq {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "s{}", self.0)
82    }
83}
84
85impl From<u64> for Seq {
86    #[inline]
87    fn from(value: u64) -> Self {
88        Self(value)
89    }
90}
91
92impl From<Seq> for u64 {
93    #[inline]
94    fn from(value: Seq) -> u64 {
95        value.0
96    }
97}
98
99impl From<&Seq> for u64 {
100    #[inline]
101    fn from(value: &Seq) -> u64 {
102        value.0
103    }
104}
105
106impl FromPyObject<'_> for Seq {
107    fn extract_bound(ob: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
108        Ok(Self(ob.extract::<u64>()?))
109    }
110}
111
112impl<'py> IntoPyObject<'py> for Seq {
113    type Target = pyo3::PyAny;
114    type Output = pyo3::Bound<'py, Self::Target>;
115    type Error = pyo3::PyErr;
116
117    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
118        self.0.into_bound_py_any(py)
119    }
120}
121
122/// Worker operation errors.
123// TODO: Make other exceptions like CallFunctionError, etc. serializable and
124// send them back to the client through WorkerError.
125#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
126#[error("worker {worker_actor_id} error: {backtrace}")]
127pub struct WorkerError {
128    /// The message and/or stack trace of the error.
129    pub backtrace: String,
130
131    /// Actor id of the worker that had the error.
132    // TODO: arguably at this level we only care about the rank
133    pub worker_actor_id: ActorId,
134}
135
136/// Device operation failures.
137#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, thiserror::Error)]
138#[error("device {actor_id} error: {backtrace}")]
139pub struct DeviceFailure {
140    /// The message and/or stack trace of the error.
141    pub backtrace: String,
142
143    /// Address of the device that had the error.
144    pub address: String,
145
146    /// Actor id of the worker that had the error.
147    // TODO: arguably at this level we only care about the rank
148    pub actor_id: ActorId,
149}
150
151/// Controller messages. These define the contract that the controller has with the client
152/// and workers.
153#[derive(Handler, HandleClient, RefClient, Serialize, Deserialize, Debug, Named)]
154pub enum ControllerMessage {
155    /// Attach a client to the controller. This is used to send messages to the controller
156    /// and allow the controller to send messages back to the client.
157    Attach {
158        /// The client actor that is being attached.
159        client_actor: ActorRef<ClientActor>,
160
161        /// The response to indicate if the client was successfully attached.
162        #[reply]
163        response_port: hyperactor::OncePortRef<()>,
164    },
165
166    /// Notify the controller of the dependencies for a worker operation with the same seq.
167    /// It is the responsibility of the caller to ensure the seq is unique and strictly
168    /// increasing and matches the right message. This will be used by the controller for
169    /// history / data dependency tracking.
170    /// TODO: Support mutates here as well for proper dep management
171    Node {
172        seq: Seq,
173        /// The set of references defined (or re-defined) by the operation.
174        /// These are the operation's outputs.
175        defs: Vec<Ref>,
176        /// The set of references used by the operation. These are the operation's inputs.
177        uses: Vec<Ref>,
178    },
179
180    // Mark references as being dropped by the client: the client will never
181    // use these references again. Doing so results in undefined behavior.
182    DropRefs {
183        refs: Vec<Ref>,
184    },
185
186    /// Send a message to the workers mapping to the ranks provided in the
187    /// given slice. The message is serialized bytes with the underlying datatype being
188    /// [`crate::worker::WorkerMessage`] and serialization has been done in a hyperactor
189    /// compatible way i.e. using [`bincode`]. These bytes will be forwarded to
190    /// the workers as is. This helps provide isolation between the controller and the
191    /// workers and avoids the need to pay the cost to deserialize pytrees in the controller.
192    Send {
193        ranks: Ranks,
194        message: wirevalue::Any,
195    },
196
197    /// Response to a [`crate::worker::WorkerMessage::CallFunction`] message if
198    /// the function errored.
199    RemoteFunctionFailed {
200        seq: Seq,
201        error: WorkerError,
202    },
203
204    /// Response to a [`crate::worker::WorkerMessage::RequestStatus`] message. The payload will
205    /// be set to the seq provided in the original message + 1.
206    // TODO: T212094401 take a ActorRef
207    Status {
208        seq: Seq,
209        worker_actor_id: ActorId,
210        controller: bool,
211    },
212
213    /// Response to a [`crate::worker::WorkerMessage::SendValue`] message, containing the
214    /// requested value. The value is serialized as a `Any` and deserialization
215    /// is the responsibility of the caller. It should be deserialized as
216    /// [`monarch_types::PyTree<RValue>`] using the [`wirevalue::Any::deserialized`] method.
217    FetchResult {
218        seq: Seq,
219        value: Result<wirevalue::Any, WorkerError>,
220    },
221
222    /// This is used in unit tests to get the first incomplete seq for each rank as captured
223    /// by the controller.
224    GetFirstIncompleteSeqsUnitTestsOnly {
225        #[reply]
226        response_port: hyperactor::OncePortRef<Vec<Seq>>,
227    },
228
229    /// The message to schedule next supervision check task on the controller.
230    CheckSupervision {},
231
232    /// Debugger message sent from a debugger to be forwarded back to the client.
233    DebuggerMessage {
234        debugger_actor_id: ActorId,
235        action: DebuggerAction,
236    },
237}
238wirevalue::register_type!(ControllerMessage);
239
240hyperactor::behavior!(ControllerActor, ControllerMessage);