hyperactor_mesh/proc_mesh/
mesh_agent.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The mesh agent actor manages procs in ProcMeshes.
10
11use std::collections::HashMap;
12use std::mem::take;
13use std::sync::Arc;
14use std::sync::Mutex;
15use std::sync::RwLock;
16use std::sync::RwLockReadGuard;
17
18use async_trait::async_trait;
19use enum_as_inner::EnumAsInner;
20use hyperactor::Actor;
21use hyperactor::ActorHandle;
22use hyperactor::ActorId;
23use hyperactor::Bind;
24use hyperactor::Context;
25use hyperactor::Data;
26use hyperactor::HandleClient;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::OncePortRef;
30use hyperactor::PortHandle;
31use hyperactor::PortRef;
32use hyperactor::ProcId;
33use hyperactor::RefClient;
34use hyperactor::Unbind;
35use hyperactor::WorldId;
36use hyperactor::actor::ActorStatus;
37use hyperactor::actor::remote::Remote;
38use hyperactor::channel;
39use hyperactor::channel::ChannelAddr;
40use hyperactor::clock::Clock;
41use hyperactor::clock::RealClock;
42use hyperactor::mailbox::BoxedMailboxSender;
43use hyperactor::mailbox::DialMailboxRouter;
44use hyperactor::mailbox::IntoBoxedMailboxSender;
45use hyperactor::mailbox::MailboxClient;
46use hyperactor::mailbox::MailboxSender;
47use hyperactor::mailbox::MessageEnvelope;
48use hyperactor::mailbox::Undeliverable;
49use hyperactor::proc::Proc;
50use hyperactor::supervision::ActorSupervisionEvent;
51use serde::Deserialize;
52use serde::Serialize;
53use typeuri::Named;
54
55use crate::actor_mesh::CAST_ACTOR_MESH_ID;
56use crate::proc_mesh::SupervisionEventState;
57use crate::reference::ActorMeshId;
58use crate::resource;
59use crate::v1::Name;
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
62pub enum GspawnResult {
63    Success { rank: usize, actor_id: ActorId },
64    Error(String),
65}
66wirevalue::register_type!(GspawnResult);
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
69pub enum StopActorResult {
70    Success,
71    Timeout,
72    NotFound,
73}
74wirevalue::register_type!(StopActorResult);
75
76#[derive(
77    Debug,
78    Clone,
79    PartialEq,
80    Serialize,
81    Deserialize,
82    Handler,
83    HandleClient,
84    RefClient,
85    Named
86)]
87pub(crate) enum MeshAgentMessage {
88    /// Configure the proc in the mesh.
89    Configure {
90        /// The rank of this proc in the mesh.
91        rank: usize,
92        /// The forwarder to send messages to unknown destinations.
93        forwarder: ChannelAddr,
94        /// The supervisor port to which the agent should report supervision events.
95        supervisor: Option<PortRef<ActorSupervisionEvent>>,
96        /// An address book to use for direct dialing.
97        address_book: HashMap<ProcId, ChannelAddr>,
98        /// The agent should write its rank to this port when it successfully
99        /// configured.
100        configured: PortRef<usize>,
101        /// If true, and supervisor is None, record supervision events to be reported
102        record_supervision_events: bool,
103    },
104
105    Status {
106        /// The status of the proc.
107        /// To be replaced with fine-grained lifecycle status,
108        /// and to use aggregation.
109        status: PortRef<(usize, bool)>,
110    },
111
112    /// Spawn an actor on the proc to the provided name.
113    Gspawn {
114        /// registered actor type
115        actor_type: String,
116        /// spawned actor name
117        actor_name: String,
118        /// serialized parameters
119        params_data: Data,
120        /// reply port; the proc should send its rank to indicated a spawned actor
121        status_port: PortRef<GspawnResult>,
122    },
123
124    /// Stop actors of a specific mesh name
125    StopActor {
126        /// The actor to stop
127        actor_id: ActorId,
128        /// The timeout for waiting for the actor to stop
129        timeout_ms: u64,
130        /// The result when trying to stop the actor
131        #[reply]
132        stopped: OncePortRef<StopActorResult>,
133    },
134}
135
136/// Internal configuration state of the mesh agent.
137#[derive(Debug, EnumAsInner, Default)]
138enum State {
139    UnconfiguredV0 {
140        sender: ReconfigurableMailboxSender,
141    },
142
143    ConfiguredV0 {
144        sender: ReconfigurableMailboxSender,
145        rank: usize,
146        supervisor: Option<PortRef<ActorSupervisionEvent>>,
147    },
148
149    V1,
150
151    #[default]
152    Invalid,
153}
154
155impl State {
156    fn rank(&self) -> Option<usize> {
157        match self {
158            State::ConfiguredV0 { rank, .. } => Some(*rank),
159            _ => None,
160        }
161    }
162
163    fn supervisor(&self) -> Option<PortRef<ActorSupervisionEvent>> {
164        match self {
165            State::ConfiguredV0 { supervisor, .. } => supervisor.clone(),
166            _ => None,
167        }
168    }
169}
170
171/// Actor state used for v1 API.
172#[derive(Debug)]
173struct ActorInstanceState {
174    create_rank: usize,
175    spawn: Result<ActorId, anyhow::Error>,
176    /// If true, the actor has been stopped. There is no way to restart it, a new
177    /// actor must be spawned.
178    stopped: bool,
179}
180
181/// Normalize events that came via the comm tree. Updates their actor id based on
182/// the message headers for the event.
183pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSupervisionEvent {
184    if let Some(headers) = &event.message_headers {
185        if let Some(actor_mesh_id) = headers.get(CAST_ACTOR_MESH_ID) {
186            match actor_mesh_id {
187                ActorMeshId::V0(proc_mesh_id, actor_name) => {
188                    let old_actor = event.actor_id.clone();
189                    event.actor_id = ActorId(
190                        ProcId::Ranked(WorldId(proc_mesh_id.0.clone()), 0),
191                        actor_name.clone(),
192                        0,
193                    );
194                    tracing::debug!(
195                        actor_id = %old_actor,
196                        "proc supervision: remapped comm-actor id to mesh id from CAST_ACTOR_MESH_ID {}", event.actor_id
197                    );
198                }
199                ActorMeshId::V1(_) => {
200                    tracing::debug!(
201                        "proc supervision: headers present but V1 ActorMeshId; leaving actor_id unchanged"
202                    );
203                }
204            }
205        } else {
206            tracing::debug!(
207                "proc supervision: headers present but no CAST_ACTOR_MESH_ID; leaving actor_id unchanged"
208            );
209        }
210    } else {
211        tracing::debug!("proc supervision: no headers attached; leaving actor_id unchanged");
212    }
213    event
214}
215
216/// A mesh agent is responsible for managing procs in a [`ProcMesh`].
217#[hyperactor::export(
218    handlers=[
219        MeshAgentMessage,
220        resource::CreateOrUpdate<ActorSpec> { cast = true },
221        resource::Stop { cast = true },
222        resource::StopAll { cast = true },
223        resource::GetState<ActorState> { cast = true },
224        resource::GetRankStatus { cast = true },
225    ]
226)]
227pub struct ProcMeshAgent {
228    proc: Proc,
229    remote: Remote,
230    state: State,
231    /// Actors created and tracked through the resource behavior.
232    actor_states: HashMap<Name, ActorInstanceState>,
233    /// If true, and supervisor is None, record supervision events to be reported
234    /// to owning actors later.
235    record_supervision_events: bool,
236    /// If record_supervision_events is true, then this will contain the list
237    /// of all events that were received.
238    supervision_events: HashMap<ActorId, Vec<ActorSupervisionEvent>>,
239}
240
241impl ProcMeshAgent {
242    #[hyperactor::observe_result("MeshAgent")]
243    pub(crate) async fn bootstrap(
244        proc_id: ProcId,
245    ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
246        let sender = ReconfigurableMailboxSender::new();
247        let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
248
249        // Wire up this proc to the global router so that any meshes managed by
250        // this process can reach actors in this proc.
251        super::router::global().bind(proc_id.into(), proc.clone());
252
253        let agent = ProcMeshAgent {
254            proc: proc.clone(),
255            remote: Remote::collect(),
256            state: State::UnconfiguredV0 { sender },
257            actor_states: HashMap::new(),
258            record_supervision_events: false,
259            supervision_events: HashMap::new(),
260        };
261        let handle = proc.spawn::<Self>("mesh", agent)?;
262        Ok((proc, handle))
263    }
264
265    pub(crate) fn boot_v1(proc: Proc) -> Result<ActorHandle<Self>, anyhow::Error> {
266        let agent = ProcMeshAgent {
267            proc: proc.clone(),
268            remote: Remote::collect(),
269            state: State::V1,
270            actor_states: HashMap::new(),
271            record_supervision_events: true,
272            supervision_events: HashMap::new(),
273        };
274        proc.spawn::<Self>("agent", agent)
275    }
276
277    async fn destroy_and_wait_except_current<'a>(
278        &mut self,
279        cx: &Context<'a, Self>,
280        timeout: tokio::time::Duration,
281    ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
282        self.proc
283            .destroy_and_wait_except_current::<Self>(timeout, Some(cx), true)
284            .await
285    }
286}
287
288#[async_trait]
289impl Actor for ProcMeshAgent {
290    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
291        self.proc.set_supervision_coordinator(this.port())?;
292        Ok(())
293    }
294}
295
296#[async_trait]
297#[hyperactor::forward(MeshAgentMessage)]
298impl MeshAgentMessageHandler for ProcMeshAgent {
299    async fn configure(
300        &mut self,
301        cx: &Context<Self>,
302        rank: usize,
303        forwarder: ChannelAddr,
304        supervisor: Option<PortRef<ActorSupervisionEvent>>,
305        address_book: HashMap<ProcId, ChannelAddr>,
306        configured: PortRef<usize>,
307        record_supervision_events: bool,
308    ) -> Result<(), anyhow::Error> {
309        anyhow::ensure!(
310            self.state.is_unconfigured_v0(),
311            "mesh agent cannot be (re-)configured"
312        );
313        self.record_supervision_events = record_supervision_events;
314
315        // Wire up the local proc to the global (process) router. This ensures that child
316        // meshes are reachable from any actor created by this mesh.
317        let client = MailboxClient::new(channel::dial(forwarder)?);
318
319        // `HYPERACTOR_MESH_ROUTER_CONFIG_NO_GLOBAL_FALLBACK` may be
320        // set as a means of failure injection in the testing of
321        // supervision codepaths.
322        let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
323            let default = super::router::global().fallback(client.into_boxed());
324            DialMailboxRouter::new_with_default_direct_addressed_remote_only(default.into_boxed())
325        } else {
326            DialMailboxRouter::new_with_default_direct_addressed_remote_only(client.into_boxed())
327        };
328
329        for (proc_id, addr) in address_book {
330            router.bind(proc_id.into(), addr);
331        }
332
333        let sender = take(&mut self.state).into_unconfigured_v0().unwrap();
334        assert!(sender.configure(router.into_boxed()));
335
336        // This is a bit suboptimal: ideally we'd set the supervisor first, to correctly report
337        // any errors that occur during configuration. However, these should anyway be correctly
338        // caught on process exit.
339        self.state = State::ConfiguredV0 {
340            sender,
341            rank,
342            supervisor,
343        };
344        configured.send(cx, rank)?;
345
346        Ok(())
347    }
348
349    async fn gspawn(
350        &mut self,
351        cx: &Context<Self>,
352        actor_type: String,
353        actor_name: String,
354        params_data: Data,
355        status_port: PortRef<GspawnResult>,
356    ) -> Result<(), anyhow::Error> {
357        anyhow::ensure!(
358            self.state.is_configured_v0(),
359            "mesh agent is not v0 configured"
360        );
361        let actor_id = match self
362            .remote
363            .gspawn(&self.proc, &actor_type, &actor_name, params_data)
364            .await
365        {
366            Ok(id) => id,
367            Err(err) => {
368                status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
369                return Err(anyhow::anyhow!("gspawn failed"));
370            }
371        };
372        status_port.send(
373            cx,
374            GspawnResult::Success {
375                rank: self.state.rank().unwrap(),
376                actor_id,
377            },
378        )?;
379        Ok(())
380    }
381
382    async fn stop_actor(
383        &mut self,
384        _cx: &Context<Self>,
385        actor_id: ActorId,
386        timeout_ms: u64,
387    ) -> Result<StopActorResult, anyhow::Error> {
388        tracing::info!(
389            name = "StopActor",
390            actor_id = %actor_id,
391            actor_name = actor_id.name(),
392        );
393
394        if let Some(mut status) = self.proc.stop_actor(&actor_id) {
395            match RealClock
396                .timeout(
397                    tokio::time::Duration::from_millis(timeout_ms),
398                    status.wait_for(|state: &ActorStatus| state.is_terminal()),
399                )
400                .await
401            {
402                Ok(_) => Ok(StopActorResult::Success),
403                Err(_) => Ok(StopActorResult::Timeout),
404            }
405        } else {
406            Ok(StopActorResult::NotFound)
407        }
408    }
409
410    async fn status(
411        &mut self,
412        cx: &Context<Self>,
413        status_port: PortRef<(usize, bool)>,
414    ) -> Result<(), anyhow::Error> {
415        match &self.state {
416            State::ConfiguredV0 { rank, .. } => {
417                // v0 path: configured with a concrete rank
418                status_port.send(cx, (*rank, true))?;
419                Ok(())
420            }
421            State::UnconfiguredV0 { .. } => {
422                // v0 path but not configured yet
423                Err(anyhow::anyhow!(
424                    "status unavailable: v0 agent not configured (waiting for Configure)"
425                ))
426            }
427            State::V1 => {
428                // v1/owned path does not support status (no rank semantics)
429                Err(anyhow::anyhow!(
430                    "status unsupported in v1/owned path (no rank)"
431                ))
432            }
433            State::Invalid => Err(anyhow::anyhow!(
434                "status unavailable: agent in invalid state"
435            )),
436        }
437    }
438}
439
440#[async_trait]
441impl Handler<ActorSupervisionEvent> for ProcMeshAgent {
442    async fn handle(
443        &mut self,
444        cx: &Context<Self>,
445        event: ActorSupervisionEvent,
446    ) -> anyhow::Result<()> {
447        let event = update_event_actor_id(event);
448        if self.record_supervision_events {
449            if event.is_error() {
450                tracing::warn!(
451                    name = "SupervisionEvent",
452                    proc_id = %self.proc.proc_id(),
453                    %event,
454                    "recording supervision error",
455                );
456            } else {
457                tracing::debug!(
458                    name = "SupervisionEvent",
459                    proc_id = %self.proc.proc_id(),
460                    %event,
461                    "recording non-error supervision event",
462                );
463            }
464            self.supervision_events
465                .entry(event.actor_id.clone())
466                .or_default()
467                .push(event.clone());
468        }
469        if let Some(supervisor) = self.state.supervisor() {
470            supervisor.send(cx, event)?;
471        } else if !self.record_supervision_events {
472            // If there is no supervisor, and nothing is recording these, crash
473            // the whole process.
474            tracing::error!(
475                name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
476                proc_id = %cx.self_id().proc_id(),
477                %event,
478                "could not propagate supervision event, crashing",
479            );
480
481            // We should have a custom "crash" function here, so that this works
482            // in testing of the LocalAllocator, etc.
483            std::process::exit(1);
484        }
485        Ok(())
486    }
487}
488
489// Implement the resource behavior for managing actors:
490
491/// Actor spec.
492#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
493pub struct ActorSpec {
494    /// registered actor type
495    pub actor_type: String,
496    /// serialized parameters
497    pub params_data: Data,
498}
499wirevalue::register_type!(ActorSpec);
500
501/// Actor state.
502#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
503pub struct ActorState {
504    /// The actor's ID.
505    pub actor_id: ActorId,
506    /// The rank of the proc that created the actor. This is before any slicing.
507    pub create_rank: usize,
508    // TODO status: ActorStatus,
509    pub supervision_events: Vec<ActorSupervisionEvent>,
510}
511wirevalue::register_type!(ActorState);
512
513#[async_trait]
514impl Handler<resource::CreateOrUpdate<ActorSpec>> for ProcMeshAgent {
515    async fn handle(
516        &mut self,
517        _cx: &Context<Self>,
518        create_or_update: resource::CreateOrUpdate<ActorSpec>,
519    ) -> anyhow::Result<()> {
520        if self.actor_states.contains_key(&create_or_update.name) {
521            // There is no update.
522            return Ok(());
523        }
524        let create_rank = create_or_update.rank.unwrap();
525        // If there have been supervision events for any actors on this proc,
526        // we disallow spawning new actors on it, as this proc may be in an
527        // invalid state.
528        if !self.supervision_events.is_empty() {
529            self.actor_states.insert(
530                create_or_update.name.clone(),
531                ActorInstanceState {
532                    spawn: Err(anyhow::anyhow!(
533                        "Cannot spawn new actors on mesh with supervision events"
534                    )),
535                    create_rank,
536                    stopped: false,
537                },
538            );
539            return Ok(());
540        }
541
542        let ActorSpec {
543            actor_type,
544            params_data,
545        } = create_or_update.spec;
546        self.actor_states.insert(
547            create_or_update.name.clone(),
548            ActorInstanceState {
549                create_rank,
550                spawn: self
551                    .remote
552                    .gspawn(
553                        &self.proc,
554                        &actor_type,
555                        &create_or_update.name.to_string(),
556                        params_data,
557                    )
558                    .await,
559                stopped: false,
560            },
561        );
562
563        Ok(())
564    }
565}
566
567#[async_trait]
568impl Handler<resource::Stop> for ProcMeshAgent {
569    async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
570        // We don't remove the actor from the state map, instead we just store
571        // its state as Stopped.
572        let actor = self.actor_states.get_mut(&message.name);
573        // Have to separate stop_actor from setting "stopped" because it borrows
574        // as mutable and cannot have self borrowed mutably twice.
575        let actor_id = match actor {
576            Some(actor_state) => {
577                match &actor_state.spawn {
578                    Ok(actor_id) => {
579                        if actor_state.stopped {
580                            None
581                        } else {
582                            actor_state.stopped = true;
583                            Some(actor_id.clone())
584                        }
585                    }
586                    // If the original spawn had failed, the actor is still considered
587                    // successfully stopped.
588                    Err(_) => None,
589                }
590            }
591            // TODO: represent unknown rank
592            None => None,
593        };
594        let timeout = hyperactor_config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
595        if let Some(actor_id) = actor_id {
596            // While this function returns a Result, it never returns an Err
597            // value so we can simply expect without any failure handling.
598            self.stop_actor(cx, actor_id, timeout.as_millis() as u64)
599                .await
600                .expect("stop_actor cannot fail");
601        }
602
603        Ok(())
604    }
605}
606
607/// Handles `StopAll` by coordinating an orderly stop of child actors and then
608/// exiting the process. This handler never returns to the caller: it calls
609/// `std::process::exit(0/1)` after shutdown. Any sender must *not* expect a
610/// reply or send any further message, and should watch `ProcStatus` instead.
611#[async_trait]
612impl Handler<resource::StopAll> for ProcMeshAgent {
613    async fn handle(
614        &mut self,
615        cx: &Context<Self>,
616        _message: resource::StopAll,
617    ) -> anyhow::Result<()> {
618        let timeout = hyperactor_config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
619        // By passing in the self context, destroy_and_wait will stop this agent
620        // last, after all others are stopped.
621        let stop_result = self.destroy_and_wait_except_current(cx, timeout).await;
622        // Exit here to cleanup all remaining resources held by the process.
623        // This means ProcMeshAgent will never run cleanup or any other code
624        // from exiting its root actor. Senders of this message should never
625        // send any further messages or expect a reply.
626        match stop_result {
627            Ok((stopped_actors, aborted_actors)) => {
628                // No need to clean up any state, the process is exiting.
629                tracing::info!(
630                    actor = %cx.self_id(),
631                    "exiting process after receiving StopAll message on ProcMeshAgent. \
632                    stopped actors = {:?}, aborted actors = {:?}",
633                    stopped_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
634                    aborted_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
635                );
636                std::process::exit(0);
637            }
638            Err(e) => {
639                tracing::error!(actor = %cx.self_id(), "failed to stop all actors on ProcMeshAgent: {:?}", e);
640                std::process::exit(1);
641            }
642        }
643    }
644}
645
646#[async_trait]
647impl Handler<resource::GetRankStatus> for ProcMeshAgent {
648    async fn handle(
649        &mut self,
650        cx: &Context<Self>,
651        get_rank_status: resource::GetRankStatus,
652    ) -> anyhow::Result<()> {
653        use crate::resource::Status;
654        use crate::v1::StatusOverlay;
655
656        let (rank, status) = match self.actor_states.get(&get_rank_status.name) {
657            Some(ActorInstanceState {
658                spawn: Ok(actor_id),
659                create_rank,
660                stopped,
661            }) => {
662                if *stopped {
663                    (*create_rank, resource::Status::Stopped)
664                } else {
665                    let supervision_events = self
666                        .supervision_events
667                        .get(actor_id)
668                        .map_or_else(Vec::new, |a| a.clone());
669                    (
670                        *create_rank,
671                        if supervision_events.is_empty() {
672                            resource::Status::Running
673                        } else {
674                            resource::Status::Failed(format!(
675                                "because of supervision events: {:?}",
676                                supervision_events
677                            ))
678                        },
679                    )
680                }
681            }
682            Some(ActorInstanceState {
683                spawn: Err(e),
684                create_rank,
685                ..
686            }) => (*create_rank, Status::Failed(e.to_string())),
687            // TODO: represent unknown rank
688            None => (usize::MAX, Status::NotExist),
689        };
690
691        // Send a sparse overlay update. If rank is unknown, emit an
692        // empty overlay.
693        let overlay = if rank == usize::MAX {
694            StatusOverlay::new()
695        } else {
696            StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
697                .expect("valid single-run overlay")
698        };
699        let result = get_rank_status.reply.send(cx, overlay);
700        // Ignore errors, because returning Err from here would cause the ProcMeshAgent
701        // to be stopped, which would prevent querying and spawning other actors.
702        // This only means some actor that requested the state of an actor failed to receive it.
703        if let Err(e) = result {
704            tracing::warn!(
705                actor = %cx.self_id(),
706                "failed to send GetRankStatus reply to {} due to error: {}",
707                get_rank_status.reply.port_id().actor_id(),
708                e
709            );
710        }
711        Ok(())
712    }
713}
714
715#[async_trait]
716impl Handler<resource::GetState<ActorState>> for ProcMeshAgent {
717    async fn handle(
718        &mut self,
719        cx: &Context<Self>,
720        get_state: resource::GetState<ActorState>,
721    ) -> anyhow::Result<()> {
722        let state = match self.actor_states.get(&get_state.name) {
723            Some(ActorInstanceState {
724                create_rank,
725                spawn: Ok(actor_id),
726                stopped,
727            }) => {
728                let supervision_events = self
729                    .supervision_events
730                    .get(actor_id)
731                    .map_or_else(Vec::new, |a| a.clone());
732                let status = if *stopped {
733                    resource::Status::Stopped
734                } else if supervision_events.is_empty() {
735                    resource::Status::Running
736                } else {
737                    resource::Status::Failed(format!(
738                        "because of supervision events: {:?}",
739                        supervision_events
740                    ))
741                };
742                resource::State {
743                    name: get_state.name.clone(),
744                    status,
745                    state: Some(ActorState {
746                        actor_id: actor_id.clone(),
747                        create_rank: *create_rank,
748                        supervision_events,
749                    }),
750                }
751            }
752            Some(ActorInstanceState { spawn: Err(e), .. }) => resource::State {
753                name: get_state.name.clone(),
754                status: resource::Status::Failed(e.to_string()),
755                state: None,
756            },
757            None => resource::State {
758                name: get_state.name.clone(),
759                status: resource::Status::NotExist,
760                state: None,
761            },
762        };
763
764        let result = get_state.reply.send(cx, state);
765        // Ignore errors, because returning Err from here would cause the ProcMeshAgent
766        // to be stopped, which would prevent querying and spawning other actors.
767        // This only means some actor that requested the state of an actor failed to receive it.
768        if let Err(e) = result {
769            tracing::warn!(
770                actor = %cx.self_id(),
771                "failed to send GetState reply to {} due to error: {}",
772                get_state.reply.port_id().actor_id(),
773                e
774            );
775        }
776        Ok(())
777    }
778}
779
780/// A local handler to get a new client instance on the proc.
781/// This is used to create root client instances.
782#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)]
783pub struct NewClientInstance {
784    #[reply]
785    pub client_instance: PortHandle<Instance<()>>,
786}
787
788#[async_trait]
789impl Handler<NewClientInstance> for ProcMeshAgent {
790    async fn handle(
791        &mut self,
792        _cx: &Context<Self>,
793        NewClientInstance { client_instance }: NewClientInstance,
794    ) -> anyhow::Result<()> {
795        let (instance, _handle) = self.proc.instance("client")?;
796        client_instance.send(instance)?;
797        Ok(())
798    }
799}
800
801/// A handler to get a clone of the proc managed by this agent.
802/// This is used to obtain the local proc from a host mesh.
803#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)]
804pub struct GetProc {
805    #[reply]
806    pub proc: PortHandle<Proc>,
807}
808
809#[async_trait]
810impl Handler<GetProc> for ProcMeshAgent {
811    async fn handle(
812        &mut self,
813        _cx: &Context<Self>,
814        GetProc { proc }: GetProc,
815    ) -> anyhow::Result<()> {
816        proc.send(self.proc.clone())?;
817        Ok(())
818    }
819}
820
821/// A mailbox sender that initially queues messages, and then relays them to
822/// an underlying sender once configured.
823#[derive(Clone)]
824pub(crate) struct ReconfigurableMailboxSender {
825    state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
826}
827
828impl std::fmt::Debug for ReconfigurableMailboxSender {
829    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830        // Not super helpful, but we definitely don't wan to acquire any locks
831        // in a Debug formatter.
832        f.debug_struct("ReconfigurableMailboxSender").finish()
833    }
834}
835
836/// A capability wrapper granting access to the configured mailbox
837/// sender.
838///
839/// This type exists to tie the lifetime of any `&BoxedMailboxSender`
840/// reference to a lock guard, so the underlying state cannot be
841/// reconfigured while the reference is in use.
842///
843/// A **read** guard is sufficient because we only need to *observe*
844/// and borrow the configured sender, not mutate state. While a
845/// `RwLockReadGuard` is held, `configure()` cannot acquire the write
846/// lock, so the state cannot transition from `Configured(..)` to any
847/// other variant during the guard’s lifetime.
848pub(crate) struct ReconfigurableMailboxSenderInner<'a> {
849    guard: RwLockReadGuard<'a, ReconfigurableMailboxSenderState>,
850}
851
852impl<'a> ReconfigurableMailboxSenderInner<'a> {
853    pub(crate) fn as_configured(&self) -> Option<&BoxedMailboxSender> {
854        self.guard.as_configured()
855    }
856}
857
858type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
859
860#[derive(EnumAsInner, Debug)]
861enum ReconfigurableMailboxSenderState {
862    Queueing(Mutex<Vec<Post>>),
863    Configured(BoxedMailboxSender),
864}
865
866impl ReconfigurableMailboxSender {
867    pub(crate) fn new() -> Self {
868        Self {
869            state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
870                Mutex::new(Vec::new()),
871            ))),
872        }
873    }
874
875    /// Configure this mailbox with the provided sender. This will first
876    /// enqueue any pending messages onto the sender; future messages are
877    /// posted directly to the configured sender.
878    pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
879        // Hold the write lock until all queued messages are flushed.
880        let mut state = self.state.write().unwrap();
881        if state.is_configured() {
882            return false;
883        }
884
885        // Install the configured sender exactly once.
886        let queued = std::mem::replace(
887            &mut *state,
888            ReconfigurableMailboxSenderState::Configured(sender),
889        );
890
891        // Borrow the configured sender from the state (stable while
892        // we hold the lock).
893        let configured_sender = state.as_configured().expect("just configured");
894
895        // Flush the old queue while still holding the write lock.
896        for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
897            configured_sender.post(envelope, return_handle);
898        }
899
900        true
901    }
902
903    pub(crate) fn as_inner<'a>(
904        &'a self,
905    ) -> Result<ReconfigurableMailboxSenderInner<'a>, anyhow::Error> {
906        let state = self.state.read().unwrap();
907        if state.is_configured() {
908            Ok(ReconfigurableMailboxSenderInner { guard: state })
909        } else {
910            Err(anyhow::anyhow!("cannot get inner sender: not configured"))
911        }
912    }
913}
914
915impl MailboxSender for ReconfigurableMailboxSender {
916    fn post(
917        &self,
918        envelope: MessageEnvelope,
919        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
920    ) {
921        match &*self.state.read().unwrap() {
922            ReconfigurableMailboxSenderState::Queueing(queue) => {
923                queue.lock().unwrap().push((envelope, return_handle));
924            }
925            ReconfigurableMailboxSenderState::Configured(sender) => {
926                sender.post(envelope, return_handle);
927            }
928        }
929    }
930
931    fn post_unchecked(
932        &self,
933        envelope: MessageEnvelope,
934        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
935    ) {
936        match &*self.state.read().unwrap() {
937            ReconfigurableMailboxSenderState::Queueing(queue) => {
938                queue.lock().unwrap().push((envelope, return_handle));
939            }
940            ReconfigurableMailboxSenderState::Configured(sender) => {
941                sender.post_unchecked(envelope, return_handle);
942            }
943        }
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use std::sync::Arc;
950    use std::sync::Mutex;
951
952    use hyperactor::id;
953    use hyperactor::mailbox::BoxedMailboxSender;
954    use hyperactor::mailbox::Mailbox;
955    use hyperactor::mailbox::MailboxSender;
956    use hyperactor::mailbox::MessageEnvelope;
957    use hyperactor::mailbox::PortHandle;
958    use hyperactor::mailbox::Undeliverable;
959    use hyperactor_config::attrs::Attrs;
960
961    use super::*;
962
963    #[derive(Debug, Clone)]
964    struct QueueingMailboxSender {
965        messages: Arc<Mutex<Vec<MessageEnvelope>>>,
966    }
967
968    impl QueueingMailboxSender {
969        fn new() -> Self {
970            Self {
971                messages: Arc::new(Mutex::new(Vec::new())),
972            }
973        }
974
975        fn get_messages(&self) -> Vec<MessageEnvelope> {
976            self.messages.lock().unwrap().clone()
977        }
978    }
979
980    impl MailboxSender for QueueingMailboxSender {
981        fn post_unchecked(
982            &self,
983            envelope: MessageEnvelope,
984            _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
985        ) {
986            self.messages.lock().unwrap().push(envelope);
987        }
988    }
989
990    // Helper function to create a test message envelope
991    fn envelope(data: u64) -> MessageEnvelope {
992        MessageEnvelope::serialize(
993            id!(world[0].sender),
994            id!(world[0].receiver[0][1]),
995            &data,
996            Attrs::new(),
997        )
998        .unwrap()
999    }
1000
1001    fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
1002        let mbox = Mailbox::new_detached(id!(test[0].test));
1003        let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
1004        port
1005    }
1006
1007    #[test]
1008    fn test_queueing_before_configure() {
1009        let sender = ReconfigurableMailboxSender::new();
1010
1011        let test_sender = QueueingMailboxSender::new();
1012        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1013
1014        let return_handle = return_handle();
1015        sender.post(envelope(1), return_handle.clone());
1016        sender.post(envelope(2), return_handle.clone());
1017
1018        assert_eq!(test_sender.get_messages().len(), 0);
1019
1020        sender.configure(boxed_sender);
1021
1022        let messages = test_sender.get_messages();
1023        assert_eq!(messages.len(), 2);
1024
1025        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
1026        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
1027    }
1028
1029    #[test]
1030    fn test_direct_delivery_after_configure() {
1031        // Create a ReconfigurableMailboxSender
1032        let sender = ReconfigurableMailboxSender::new();
1033
1034        let test_sender = QueueingMailboxSender::new();
1035        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1036        sender.configure(boxed_sender);
1037
1038        let return_handle = return_handle();
1039        sender.post(envelope(3), return_handle.clone());
1040        sender.post(envelope(4), return_handle.clone());
1041
1042        let messages = test_sender.get_messages();
1043        assert_eq!(messages.len(), 2);
1044
1045        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
1046        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
1047    }
1048
1049    #[test]
1050    fn test_multiple_configurations() {
1051        let sender = ReconfigurableMailboxSender::new();
1052        let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
1053
1054        assert!(sender.configure(boxed_sender.clone()));
1055        assert!(!sender.configure(boxed_sender));
1056    }
1057
1058    #[test]
1059    fn test_mixed_queueing_and_direct_delivery() {
1060        let sender = ReconfigurableMailboxSender::new();
1061
1062        let test_sender = QueueingMailboxSender::new();
1063        let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1064
1065        let return_handle = return_handle();
1066        sender.post(envelope(5), return_handle.clone());
1067        sender.post(envelope(6), return_handle.clone());
1068
1069        sender.configure(boxed_sender);
1070
1071        sender.post(envelope(7), return_handle.clone());
1072        sender.post(envelope(8), return_handle.clone());
1073
1074        let messages = test_sender.get_messages();
1075        assert_eq!(messages.len(), 4);
1076
1077        assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
1078        assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
1079        assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
1080        assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
1081    }
1082}