hyperactor_mesh/proc_mesh/
mesh_agent.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The mesh agent actor manages procs in ProcMeshes.
10
11use std::collections::HashMap;
12use std::mem::replace;
13use std::sync::Arc;
14use std::sync::Mutex;
15use std::sync::RwLock;
16
17use async_trait::async_trait;
18use enum_as_inner::EnumAsInner;
19use hyperactor::Actor;
20use hyperactor::ActorHandle;
21use hyperactor::ActorId;
22use hyperactor::Context;
23use hyperactor::Data;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Instance;
27use hyperactor::Named;
28use hyperactor::OncePortRef;
29use hyperactor::PortHandle;
30use hyperactor::PortRef;
31use hyperactor::ProcId;
32use hyperactor::RefClient;
33use hyperactor::actor::ActorStatus;
34use hyperactor::actor::remote::Remote;
35use hyperactor::channel;
36use hyperactor::channel::ChannelAddr;
37use hyperactor::clock::Clock;
38use hyperactor::clock::RealClock;
39use hyperactor::mailbox::BoxedMailboxSender;
40use hyperactor::mailbox::DialMailboxRouter;
41use hyperactor::mailbox::IntoBoxedMailboxSender;
42use hyperactor::mailbox::MailboxClient;
43use hyperactor::mailbox::MailboxSender;
44use hyperactor::mailbox::MessageEnvelope;
45use hyperactor::mailbox::Undeliverable;
46use hyperactor::proc::Proc;
47use hyperactor::supervision::ActorSupervisionEvent;
48use serde::Deserialize;
49use serde::Serialize;
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
52pub enum GspawnResult {
53    Success { rank: usize, actor_id: ActorId },
54    Error(String),
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
58pub enum StopActorResult {
59    Success,
60    Timeout,
61    NotFound,
62}
63
64#[derive(
65    Debug,
66    Clone,
67    PartialEq,
68    Serialize,
69    Deserialize,
70    Handler,
71    HandleClient,
72    RefClient,
73    Named
74)]
75pub(crate) enum MeshAgentMessage {
76    /// Configure the proc in the mesh.
77    Configure {
78        /// The rank of this proc in the mesh.
79        rank: usize,
80        /// The forwarder to send messages to unknown destinations.
81        forwarder: ChannelAddr,
82        /// The supervisor port to which the agent should report supervision events.
83        supervisor: PortRef<ActorSupervisionEvent>,
84        /// An address book to use for direct dialing.
85        address_book: HashMap<ProcId, ChannelAddr>,
86        /// The agent should write its rank to this port when it successfully
87        /// configured.
88        configured: PortRef<usize>,
89    },
90
91    /// Spawn an actor on the proc to the provided name.
92    Gspawn {
93        /// registered actor type
94        actor_type: String,
95        /// spawned actor name
96        actor_name: String,
97        /// serialized parameters
98        params_data: Data,
99        /// reply port; the proc should send its rank to indicated a spawned actor
100        status_port: PortRef<GspawnResult>,
101    },
102
103    /// Stop actors of a specific mesh name
104    StopActor {
105        /// The actor to stop
106        actor_id: ActorId,
107        /// The timeout for waiting for the actor to stop
108        timeout_ms: u64,
109        /// The result when trying to stop the actor
110        #[reply]
111        stopped: OncePortRef<StopActorResult>,
112    },
113}
114
115/// A mesh agent is responsible for managing procs in a [`ProcMesh`].
116#[derive(Debug)]
117#[hyperactor::export(handlers=[MeshAgentMessage])]
118pub struct MeshAgent {
119    proc: Proc,
120    remote: Remote,
121    sender: ReconfigurableMailboxSender,
122    rank: Option<usize>,
123    supervisor: Option<PortRef<ActorSupervisionEvent>>,
124}
125
126impl MeshAgent {
127    #[hyperactor::instrument]
128    pub(crate) async fn bootstrap(
129        proc_id: ProcId,
130    ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
131        let sender = ReconfigurableMailboxSender::new();
132        let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
133
134        // Wire up this proc to the global router so that any meshes managed by
135        // this process can reach actors in this proc.
136        super::global_router().bind(proc_id.into(), proc.clone());
137
138        let agent = MeshAgent {
139            proc: proc.clone(),
140            remote: Remote::collect(),
141            sender,
142            rank: None,       // not yet assigned
143            supervisor: None, // not yet assigned
144        };
145        let handle = proc.spawn::<Self>("mesh", agent).await?;
146        tracing::info!("bootstrap_end");
147        Ok((proc, handle))
148    }
149}
150
151#[async_trait]
152impl Actor for MeshAgent {
153    type Params = Self;
154
155    async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
156        Ok(params)
157    }
158
159    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
160        self.proc.set_supervision_coordinator(this.port())?;
161        Ok(())
162    }
163}
164
165#[async_trait]
166#[hyperactor::forward(MeshAgentMessage)]
167impl MeshAgentMessageHandler for MeshAgent {
168    async fn configure(
169        &mut self,
170        cx: &Context<Self>,
171        rank: usize,
172        forwarder: ChannelAddr,
173        supervisor: PortRef<ActorSupervisionEvent>,
174        address_book: HashMap<ProcId, ChannelAddr>,
175        configured: PortRef<usize>,
176    ) -> Result<(), anyhow::Error> {
177        // Set the supervisor first so that we can handle supervison events that might
178        // occur from configuration failures. Though we should instead report these directly
179        // for better ergonomics in the allocator.
180        self.supervisor = Some(supervisor);
181
182        // Wire up the local proc to the global (process) router. This ensures that child
183        // meshes are reachable from any actor created by this mesh.
184        let client = MailboxClient::new(channel::dial(forwarder)?);
185
186        // `HYPERACTOR_MESH_ROUTER_CONFIG_NO_GLOBAL_FALLBACK` may be
187        // set as a means of failure injection in the testing of
188        // supervision codepaths.
189        let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
190            let default = super::global_router().fallback(client.into_boxed());
191            DialMailboxRouter::new_with_default(default.into_boxed())
192        } else {
193            DialMailboxRouter::new_with_default(client.into_boxed())
194        };
195
196        for (proc_id, addr) in address_book {
197            router.bind(proc_id.into(), addr);
198        }
199
200        if self.sender.configure(router.into_boxed()) {
201            self.rank = Some(rank);
202            configured.send(cx, rank)?;
203        } else {
204            tracing::error!("tried to reconfigure mesh agent");
205        }
206        Ok(())
207    }
208
209    async fn gspawn(
210        &mut self,
211        cx: &Context<Self>,
212        actor_type: String,
213        actor_name: String,
214        params_data: Data,
215        status_port: PortRef<GspawnResult>,
216    ) -> Result<(), anyhow::Error> {
217        let actor_id = match self
218            .remote
219            .gspawn(&self.proc, &actor_type, &actor_name, params_data)
220            .await
221        {
222            Ok(id) => id,
223            Err(err) => {
224                status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
225                return Err(anyhow::anyhow!("gspawn failed"));
226            }
227        };
228        let rank = match self.rank {
229            Some(rank) => rank,
230            None => {
231                let err = "tried to spawn on unconfigured proc";
232                status_port.send(cx, GspawnResult::Error(err.to_string()))?;
233                return Err(anyhow::anyhow!(err));
234            }
235        };
236        status_port.send(cx, GspawnResult::Success { rank, actor_id })?;
237        Ok(())
238    }
239
240    async fn stop_actor(
241        &mut self,
242        _cx: &Context<Self>,
243        actor_id: ActorId,
244        timeout_ms: u64,
245    ) -> Result<StopActorResult, anyhow::Error> {
246        tracing::info!("Stopping actor: {}", actor_id);
247
248        if let Some(mut status) = self.proc.stop_actor(&actor_id) {
249            match RealClock
250                .timeout(
251                    tokio::time::Duration::from_millis(timeout_ms),
252                    status.wait_for(|state: &ActorStatus| matches!(*state, ActorStatus::Stopped)),
253                )
254                .await
255            {
256                Ok(_) => Ok(StopActorResult::Success),
257                Err(_) => Ok(StopActorResult::Timeout),
258            }
259        } else {
260            Ok(StopActorResult::NotFound)
261        }
262    }
263}
264
265#[async_trait]
266impl Handler<ActorSupervisionEvent> for MeshAgent {
267    async fn handle(
268        &mut self,
269        cx: &Context<Self>,
270        event: ActorSupervisionEvent,
271    ) -> anyhow::Result<()> {
272        if let Some(supervisor) = &self.supervisor {
273            supervisor.send(cx, event)?;
274        } else {
275            tracing::error!(
276                "proc {}: could not propagate supervision event {:?}: crashing",
277                cx.self_id().proc_id(),
278                event
279            );
280
281            // We should have a custom "crash" function here, so that this works
282            // in testing of the LocalAllocator, etc.
283            std::process::exit(1);
284        }
285        Ok(())
286    }
287}
288
289/// A mailbox sender that initially queues messages, and then relays them to
290/// an underlying sender once configured.
291#[derive(Clone)]
292pub(crate) struct ReconfigurableMailboxSender {
293    state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
294}
295
296impl std::fmt::Debug for ReconfigurableMailboxSender {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        // Not super helpful, but we definitely don't wan to acquire any locks
299        // in a Debug formatter.
300        f.debug_struct("ReconfigurableMailboxSender").finish()
301    }
302}
303
304type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
305
306#[derive(EnumAsInner, Debug)]
307enum ReconfigurableMailboxSenderState {
308    Queueing(Mutex<Vec<Post>>),
309    Configured(BoxedMailboxSender),
310}
311
312impl ReconfigurableMailboxSender {
313    pub(crate) fn new() -> Self {
314        Self {
315            state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
316                Mutex::new(Vec::new()),
317            ))),
318        }
319    }
320
321    /// Configure this mailbox with the provided sender. This will first
322    /// enqueue any pending messages onto the sender; future messages are
323    /// posted directly to the configured sender.
324    pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
325        let mut state = self.state.write().unwrap();
326        if state.is_configured() {
327            return false;
328        }
329
330        let queued = replace(
331            &mut *state,
332            ReconfigurableMailboxSenderState::Configured(sender.clone()),
333        );
334
335        for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
336            sender.post(envelope, return_handle);
337        }
338        *state = ReconfigurableMailboxSenderState::Configured(sender);
339        true
340    }
341}
342
343impl MailboxSender for ReconfigurableMailboxSender {
344    fn post(
345        &self,
346        envelope: MessageEnvelope,
347        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
348    ) {
349        match *self.state.read().unwrap() {
350            ReconfigurableMailboxSenderState::Queueing(ref queue) => {
351                queue.lock().unwrap().push((envelope, return_handle));
352            }
353            ReconfigurableMailboxSenderState::Configured(ref sender) => {
354                sender.post(envelope, return_handle);
355            }
356        }
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use std::sync::Arc;
363    use std::sync::Mutex;
364
365    use hyperactor::attrs::Attrs;
366    use hyperactor::id;
367    use hyperactor::mailbox::BoxedMailboxSender;
368    use hyperactor::mailbox::Mailbox;
369    use hyperactor::mailbox::MailboxSender;
370    use hyperactor::mailbox::MessageEnvelope;
371    use hyperactor::mailbox::PortHandle;
372    use hyperactor::mailbox::Undeliverable;
373
374    use super::*;
375
376    #[derive(Debug, Clone)]
377    struct QueueingMailboxSender {
378        messages: Arc<Mutex<Vec<MessageEnvelope>>>,
379    }
380
381    impl QueueingMailboxSender {
382        fn new() -> Self {
383            Self {
384                messages: Arc::new(Mutex::new(Vec::new())),
385            }
386        }
387
388        fn get_messages(&self) -> Vec<MessageEnvelope> {
389            self.messages.lock().unwrap().clone()
390        }
391    }
392
393    impl MailboxSender for QueueingMailboxSender {
394        fn post(
395            &self,
396            envelope: MessageEnvelope,
397            _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
398        ) {
399            self.messages.lock().unwrap().push(envelope);
400        }
401    }
402
403    // Helper function to create a test message envelope
404    fn envelope(data: u64) -> MessageEnvelope {
405        MessageEnvelope::serialize(
406            id!(world[0].sender),
407            id!(world[0].receiver[0][1]),
408            &data,
409            Attrs::new(),
410        )
411        .unwrap()
412    }
413
414    fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
415        let mbox = Mailbox::new_detached(id!(test[0].test));
416        let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
417        port
418    }
419
420    #[test]
421    fn test_queueing_before_configure() {
422        let sender = ReconfigurableMailboxSender::new();
423
424        let test_sender = QueueingMailboxSender::new();
425        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
426
427        let return_handle = return_handle();
428        sender.post(envelope(1), return_handle.clone());
429        sender.post(envelope(2), return_handle.clone());
430
431        assert_eq!(test_sender.get_messages().len(), 0);
432
433        sender.configure(boxed_sender);
434
435        let messages = test_sender.get_messages();
436        assert_eq!(messages.len(), 2);
437
438        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
439        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
440    }
441
442    #[test]
443    fn test_direct_delivery_after_configure() {
444        // Create a ReconfigurableMailboxSender
445        let sender = ReconfigurableMailboxSender::new();
446
447        let test_sender = QueueingMailboxSender::new();
448        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
449        sender.configure(boxed_sender);
450
451        let return_handle = return_handle();
452        sender.post(envelope(3), return_handle.clone());
453        sender.post(envelope(4), return_handle.clone());
454
455        let messages = test_sender.get_messages();
456        assert_eq!(messages.len(), 2);
457
458        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
459        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
460    }
461
462    #[test]
463    fn test_multiple_configurations() {
464        let sender = ReconfigurableMailboxSender::new();
465        let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
466
467        assert!(sender.configure(boxed_sender.clone()));
468        assert!(!sender.configure(boxed_sender));
469    }
470
471    #[test]
472    fn test_mixed_queueing_and_direct_delivery() {
473        let sender = ReconfigurableMailboxSender::new();
474
475        let test_sender = QueueingMailboxSender::new();
476        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
477
478        let return_handle = return_handle();
479        sender.post(envelope(5), return_handle.clone());
480        sender.post(envelope(6), return_handle.clone());
481
482        sender.configure(boxed_sender);
483
484        sender.post(envelope(7), return_handle.clone());
485        sender.post(envelope(8), return_handle.clone());
486
487        let messages = test_sender.get_messages();
488        assert_eq!(messages.len(), 4);
489
490        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
491        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
492        assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
493        assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
494    }
495}