1use std::collections::HashMap;
12use std::mem::take;
13use std::sync::Arc;
14use std::sync::Mutex;
15use std::sync::RwLock;
16use std::sync::RwLockReadGuard;
17
18use async_trait::async_trait;
19use enum_as_inner::EnumAsInner;
20use hyperactor::Actor;
21use hyperactor::ActorHandle;
22use hyperactor::ActorId;
23use hyperactor::Bind;
24use hyperactor::Context;
25use hyperactor::Data;
26use hyperactor::HandleClient;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::OncePortRef;
30use hyperactor::PortHandle;
31use hyperactor::PortRef;
32use hyperactor::ProcId;
33use hyperactor::RefClient;
34use hyperactor::Unbind;
35use hyperactor::WorldId;
36use hyperactor::actor::ActorStatus;
37use hyperactor::actor::remote::Remote;
38use hyperactor::channel;
39use hyperactor::channel::ChannelAddr;
40use hyperactor::clock::Clock;
41use hyperactor::clock::RealClock;
42use hyperactor::mailbox::BoxedMailboxSender;
43use hyperactor::mailbox::DialMailboxRouter;
44use hyperactor::mailbox::IntoBoxedMailboxSender;
45use hyperactor::mailbox::MailboxClient;
46use hyperactor::mailbox::MailboxSender;
47use hyperactor::mailbox::MessageEnvelope;
48use hyperactor::mailbox::Undeliverable;
49use hyperactor::proc::Proc;
50use hyperactor::supervision::ActorSupervisionEvent;
51use serde::Deserialize;
52use serde::Serialize;
53use typeuri::Named;
54
55use crate::actor_mesh::CAST_ACTOR_MESH_ID;
56use crate::proc_mesh::SupervisionEventState;
57use crate::reference::ActorMeshId;
58use crate::resource;
59use crate::v1::Name;
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
62pub enum GspawnResult {
63 Success { rank: usize, actor_id: ActorId },
64 Error(String),
65}
66wirevalue::register_type!(GspawnResult);
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
69pub enum StopActorResult {
70 Success,
71 Timeout,
72 NotFound,
73}
74wirevalue::register_type!(StopActorResult);
75
76#[derive(
77 Debug,
78 Clone,
79 PartialEq,
80 Serialize,
81 Deserialize,
82 Handler,
83 HandleClient,
84 RefClient,
85 Named
86)]
87pub(crate) enum MeshAgentMessage {
88 Configure {
90 rank: usize,
92 forwarder: ChannelAddr,
94 supervisor: Option<PortRef<ActorSupervisionEvent>>,
96 address_book: HashMap<ProcId, ChannelAddr>,
98 configured: PortRef<usize>,
101 record_supervision_events: bool,
103 },
104
105 Status {
106 status: PortRef<(usize, bool)>,
110 },
111
112 Gspawn {
114 actor_type: String,
116 actor_name: String,
118 params_data: Data,
120 status_port: PortRef<GspawnResult>,
122 },
123
124 StopActor {
126 actor_id: ActorId,
128 timeout_ms: u64,
130 #[reply]
132 stopped: OncePortRef<StopActorResult>,
133 },
134}
135
136#[derive(Debug, EnumAsInner, Default)]
138enum State {
139 UnconfiguredV0 {
140 sender: ReconfigurableMailboxSender,
141 },
142
143 ConfiguredV0 {
144 sender: ReconfigurableMailboxSender,
145 rank: usize,
146 supervisor: Option<PortRef<ActorSupervisionEvent>>,
147 },
148
149 V1,
150
151 #[default]
152 Invalid,
153}
154
155impl State {
156 fn rank(&self) -> Option<usize> {
157 match self {
158 State::ConfiguredV0 { rank, .. } => Some(*rank),
159 _ => None,
160 }
161 }
162
163 fn supervisor(&self) -> Option<PortRef<ActorSupervisionEvent>> {
164 match self {
165 State::ConfiguredV0 { supervisor, .. } => supervisor.clone(),
166 _ => None,
167 }
168 }
169}
170
171#[derive(Debug)]
173struct ActorInstanceState {
174 create_rank: usize,
175 spawn: Result<ActorId, anyhow::Error>,
176 stopped: bool,
179}
180
181pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSupervisionEvent {
184 if let Some(headers) = &event.message_headers {
185 if let Some(actor_mesh_id) = headers.get(CAST_ACTOR_MESH_ID) {
186 match actor_mesh_id {
187 ActorMeshId::V0(proc_mesh_id, actor_name) => {
188 let old_actor = event.actor_id.clone();
189 event.actor_id = ActorId(
190 ProcId::Ranked(WorldId(proc_mesh_id.0.clone()), 0),
191 actor_name.clone(),
192 0,
193 );
194 tracing::debug!(
195 actor_id = %old_actor,
196 "proc supervision: remapped comm-actor id to mesh id from CAST_ACTOR_MESH_ID {}", event.actor_id
197 );
198 }
199 ActorMeshId::V1(_) => {
200 tracing::debug!(
201 "proc supervision: headers present but V1 ActorMeshId; leaving actor_id unchanged"
202 );
203 }
204 }
205 } else {
206 tracing::debug!(
207 "proc supervision: headers present but no CAST_ACTOR_MESH_ID; leaving actor_id unchanged"
208 );
209 }
210 } else {
211 tracing::debug!("proc supervision: no headers attached; leaving actor_id unchanged");
212 }
213 event
214}
215
216#[hyperactor::export(
218 handlers=[
219 MeshAgentMessage,
220 resource::CreateOrUpdate<ActorSpec> { cast = true },
221 resource::Stop { cast = true },
222 resource::StopAll { cast = true },
223 resource::GetState<ActorState> { cast = true },
224 resource::GetRankStatus { cast = true },
225 ]
226)]
227pub struct ProcMeshAgent {
228 proc: Proc,
229 remote: Remote,
230 state: State,
231 actor_states: HashMap<Name, ActorInstanceState>,
233 record_supervision_events: bool,
236 supervision_events: HashMap<ActorId, Vec<ActorSupervisionEvent>>,
239}
240
241impl ProcMeshAgent {
242 #[hyperactor::observe_result("MeshAgent")]
243 pub(crate) async fn bootstrap(
244 proc_id: ProcId,
245 ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
246 let sender = ReconfigurableMailboxSender::new();
247 let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
248
249 super::router::global().bind(proc_id.into(), proc.clone());
252
253 let agent = ProcMeshAgent {
254 proc: proc.clone(),
255 remote: Remote::collect(),
256 state: State::UnconfiguredV0 { sender },
257 actor_states: HashMap::new(),
258 record_supervision_events: false,
259 supervision_events: HashMap::new(),
260 };
261 let handle = proc.spawn::<Self>("mesh", agent)?;
262 Ok((proc, handle))
263 }
264
265 pub(crate) fn boot_v1(proc: Proc) -> Result<ActorHandle<Self>, anyhow::Error> {
266 let agent = ProcMeshAgent {
267 proc: proc.clone(),
268 remote: Remote::collect(),
269 state: State::V1,
270 actor_states: HashMap::new(),
271 record_supervision_events: true,
272 supervision_events: HashMap::new(),
273 };
274 proc.spawn::<Self>("agent", agent)
275 }
276
277 async fn destroy_and_wait_except_current<'a>(
278 &mut self,
279 cx: &Context<'a, Self>,
280 timeout: tokio::time::Duration,
281 ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
282 self.proc
283 .destroy_and_wait_except_current::<Self>(timeout, Some(cx), true)
284 .await
285 }
286}
287
288#[async_trait]
289impl Actor for ProcMeshAgent {
290 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
291 self.proc.set_supervision_coordinator(this.port())?;
292 Ok(())
293 }
294}
295
296#[async_trait]
297#[hyperactor::forward(MeshAgentMessage)]
298impl MeshAgentMessageHandler for ProcMeshAgent {
299 async fn configure(
300 &mut self,
301 cx: &Context<Self>,
302 rank: usize,
303 forwarder: ChannelAddr,
304 supervisor: Option<PortRef<ActorSupervisionEvent>>,
305 address_book: HashMap<ProcId, ChannelAddr>,
306 configured: PortRef<usize>,
307 record_supervision_events: bool,
308 ) -> Result<(), anyhow::Error> {
309 anyhow::ensure!(
310 self.state.is_unconfigured_v0(),
311 "mesh agent cannot be (re-)configured"
312 );
313 self.record_supervision_events = record_supervision_events;
314
315 let client = MailboxClient::new(channel::dial(forwarder)?);
318
319 let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
323 let default = super::router::global().fallback(client.into_boxed());
324 DialMailboxRouter::new_with_default_direct_addressed_remote_only(default.into_boxed())
325 } else {
326 DialMailboxRouter::new_with_default_direct_addressed_remote_only(client.into_boxed())
327 };
328
329 for (proc_id, addr) in address_book {
330 router.bind(proc_id.into(), addr);
331 }
332
333 let sender = take(&mut self.state).into_unconfigured_v0().unwrap();
334 assert!(sender.configure(router.into_boxed()));
335
336 self.state = State::ConfiguredV0 {
340 sender,
341 rank,
342 supervisor,
343 };
344 configured.send(cx, rank)?;
345
346 Ok(())
347 }
348
349 async fn gspawn(
350 &mut self,
351 cx: &Context<Self>,
352 actor_type: String,
353 actor_name: String,
354 params_data: Data,
355 status_port: PortRef<GspawnResult>,
356 ) -> Result<(), anyhow::Error> {
357 anyhow::ensure!(
358 self.state.is_configured_v0(),
359 "mesh agent is not v0 configured"
360 );
361 let actor_id = match self
362 .remote
363 .gspawn(&self.proc, &actor_type, &actor_name, params_data)
364 .await
365 {
366 Ok(id) => id,
367 Err(err) => {
368 status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
369 return Err(anyhow::anyhow!("gspawn failed"));
370 }
371 };
372 status_port.send(
373 cx,
374 GspawnResult::Success {
375 rank: self.state.rank().unwrap(),
376 actor_id,
377 },
378 )?;
379 Ok(())
380 }
381
382 async fn stop_actor(
383 &mut self,
384 _cx: &Context<Self>,
385 actor_id: ActorId,
386 timeout_ms: u64,
387 ) -> Result<StopActorResult, anyhow::Error> {
388 tracing::info!(
389 name = "StopActor",
390 actor_id = %actor_id,
391 actor_name = actor_id.name(),
392 );
393
394 if let Some(mut status) = self.proc.stop_actor(&actor_id) {
395 match RealClock
396 .timeout(
397 tokio::time::Duration::from_millis(timeout_ms),
398 status.wait_for(|state: &ActorStatus| state.is_terminal()),
399 )
400 .await
401 {
402 Ok(_) => Ok(StopActorResult::Success),
403 Err(_) => Ok(StopActorResult::Timeout),
404 }
405 } else {
406 Ok(StopActorResult::NotFound)
407 }
408 }
409
410 async fn status(
411 &mut self,
412 cx: &Context<Self>,
413 status_port: PortRef<(usize, bool)>,
414 ) -> Result<(), anyhow::Error> {
415 match &self.state {
416 State::ConfiguredV0 { rank, .. } => {
417 status_port.send(cx, (*rank, true))?;
419 Ok(())
420 }
421 State::UnconfiguredV0 { .. } => {
422 Err(anyhow::anyhow!(
424 "status unavailable: v0 agent not configured (waiting for Configure)"
425 ))
426 }
427 State::V1 => {
428 Err(anyhow::anyhow!(
430 "status unsupported in v1/owned path (no rank)"
431 ))
432 }
433 State::Invalid => Err(anyhow::anyhow!(
434 "status unavailable: agent in invalid state"
435 )),
436 }
437 }
438}
439
440#[async_trait]
441impl Handler<ActorSupervisionEvent> for ProcMeshAgent {
442 async fn handle(
443 &mut self,
444 cx: &Context<Self>,
445 event: ActorSupervisionEvent,
446 ) -> anyhow::Result<()> {
447 let event = update_event_actor_id(event);
448 if self.record_supervision_events {
449 if event.is_error() {
450 tracing::warn!(
451 name = "SupervisionEvent",
452 proc_id = %self.proc.proc_id(),
453 %event,
454 "recording supervision error",
455 );
456 } else {
457 tracing::debug!(
458 name = "SupervisionEvent",
459 proc_id = %self.proc.proc_id(),
460 %event,
461 "recording non-error supervision event",
462 );
463 }
464 self.supervision_events
465 .entry(event.actor_id.clone())
466 .or_default()
467 .push(event.clone());
468 }
469 if let Some(supervisor) = self.state.supervisor() {
470 supervisor.send(cx, event)?;
471 } else if !self.record_supervision_events {
472 tracing::error!(
475 name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
476 proc_id = %cx.self_id().proc_id(),
477 %event,
478 "could not propagate supervision event, crashing",
479 );
480
481 std::process::exit(1);
484 }
485 Ok(())
486 }
487}
488
489#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
493pub struct ActorSpec {
494 pub actor_type: String,
496 pub params_data: Data,
498}
499wirevalue::register_type!(ActorSpec);
500
501#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
503pub struct ActorState {
504 pub actor_id: ActorId,
506 pub create_rank: usize,
508 pub supervision_events: Vec<ActorSupervisionEvent>,
510}
511wirevalue::register_type!(ActorState);
512
513#[async_trait]
514impl Handler<resource::CreateOrUpdate<ActorSpec>> for ProcMeshAgent {
515 async fn handle(
516 &mut self,
517 _cx: &Context<Self>,
518 create_or_update: resource::CreateOrUpdate<ActorSpec>,
519 ) -> anyhow::Result<()> {
520 if self.actor_states.contains_key(&create_or_update.name) {
521 return Ok(());
523 }
524 let create_rank = create_or_update.rank.unwrap();
525 if !self.supervision_events.is_empty() {
529 self.actor_states.insert(
530 create_or_update.name.clone(),
531 ActorInstanceState {
532 spawn: Err(anyhow::anyhow!(
533 "Cannot spawn new actors on mesh with supervision events"
534 )),
535 create_rank,
536 stopped: false,
537 },
538 );
539 return Ok(());
540 }
541
542 let ActorSpec {
543 actor_type,
544 params_data,
545 } = create_or_update.spec;
546 self.actor_states.insert(
547 create_or_update.name.clone(),
548 ActorInstanceState {
549 create_rank,
550 spawn: self
551 .remote
552 .gspawn(
553 &self.proc,
554 &actor_type,
555 &create_or_update.name.to_string(),
556 params_data,
557 )
558 .await,
559 stopped: false,
560 },
561 );
562
563 Ok(())
564 }
565}
566
567#[async_trait]
568impl Handler<resource::Stop> for ProcMeshAgent {
569 async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
570 let actor = self.actor_states.get_mut(&message.name);
573 let actor_id = match actor {
576 Some(actor_state) => {
577 match &actor_state.spawn {
578 Ok(actor_id) => {
579 if actor_state.stopped {
580 None
581 } else {
582 actor_state.stopped = true;
583 Some(actor_id.clone())
584 }
585 }
586 Err(_) => None,
589 }
590 }
591 None => None,
593 };
594 let timeout = hyperactor_config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
595 if let Some(actor_id) = actor_id {
596 self.stop_actor(cx, actor_id, timeout.as_millis() as u64)
599 .await
600 .expect("stop_actor cannot fail");
601 }
602
603 Ok(())
604 }
605}
606
607#[async_trait]
612impl Handler<resource::StopAll> for ProcMeshAgent {
613 async fn handle(
614 &mut self,
615 cx: &Context<Self>,
616 _message: resource::StopAll,
617 ) -> anyhow::Result<()> {
618 let timeout = hyperactor_config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
619 let stop_result = self.destroy_and_wait_except_current(cx, timeout).await;
622 match stop_result {
627 Ok((stopped_actors, aborted_actors)) => {
628 tracing::info!(
630 actor = %cx.self_id(),
631 "exiting process after receiving StopAll message on ProcMeshAgent. \
632 stopped actors = {:?}, aborted actors = {:?}",
633 stopped_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
634 aborted_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
635 );
636 std::process::exit(0);
637 }
638 Err(e) => {
639 tracing::error!(actor = %cx.self_id(), "failed to stop all actors on ProcMeshAgent: {:?}", e);
640 std::process::exit(1);
641 }
642 }
643 }
644}
645
646#[async_trait]
647impl Handler<resource::GetRankStatus> for ProcMeshAgent {
648 async fn handle(
649 &mut self,
650 cx: &Context<Self>,
651 get_rank_status: resource::GetRankStatus,
652 ) -> anyhow::Result<()> {
653 use crate::resource::Status;
654 use crate::v1::StatusOverlay;
655
656 let (rank, status) = match self.actor_states.get(&get_rank_status.name) {
657 Some(ActorInstanceState {
658 spawn: Ok(actor_id),
659 create_rank,
660 stopped,
661 }) => {
662 if *stopped {
663 (*create_rank, resource::Status::Stopped)
664 } else {
665 let supervision_events = self
666 .supervision_events
667 .get(actor_id)
668 .map_or_else(Vec::new, |a| a.clone());
669 (
670 *create_rank,
671 if supervision_events.is_empty() {
672 resource::Status::Running
673 } else {
674 resource::Status::Failed(format!(
675 "because of supervision events: {:?}",
676 supervision_events
677 ))
678 },
679 )
680 }
681 }
682 Some(ActorInstanceState {
683 spawn: Err(e),
684 create_rank,
685 ..
686 }) => (*create_rank, Status::Failed(e.to_string())),
687 None => (usize::MAX, Status::NotExist),
689 };
690
691 let overlay = if rank == usize::MAX {
694 StatusOverlay::new()
695 } else {
696 StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
697 .expect("valid single-run overlay")
698 };
699 let result = get_rank_status.reply.send(cx, overlay);
700 if let Err(e) = result {
704 tracing::warn!(
705 actor = %cx.self_id(),
706 "failed to send GetRankStatus reply to {} due to error: {}",
707 get_rank_status.reply.port_id().actor_id(),
708 e
709 );
710 }
711 Ok(())
712 }
713}
714
715#[async_trait]
716impl Handler<resource::GetState<ActorState>> for ProcMeshAgent {
717 async fn handle(
718 &mut self,
719 cx: &Context<Self>,
720 get_state: resource::GetState<ActorState>,
721 ) -> anyhow::Result<()> {
722 let state = match self.actor_states.get(&get_state.name) {
723 Some(ActorInstanceState {
724 create_rank,
725 spawn: Ok(actor_id),
726 stopped,
727 }) => {
728 let supervision_events = self
729 .supervision_events
730 .get(actor_id)
731 .map_or_else(Vec::new, |a| a.clone());
732 let status = if *stopped {
733 resource::Status::Stopped
734 } else if supervision_events.is_empty() {
735 resource::Status::Running
736 } else {
737 resource::Status::Failed(format!(
738 "because of supervision events: {:?}",
739 supervision_events
740 ))
741 };
742 resource::State {
743 name: get_state.name.clone(),
744 status,
745 state: Some(ActorState {
746 actor_id: actor_id.clone(),
747 create_rank: *create_rank,
748 supervision_events,
749 }),
750 }
751 }
752 Some(ActorInstanceState { spawn: Err(e), .. }) => resource::State {
753 name: get_state.name.clone(),
754 status: resource::Status::Failed(e.to_string()),
755 state: None,
756 },
757 None => resource::State {
758 name: get_state.name.clone(),
759 status: resource::Status::NotExist,
760 state: None,
761 },
762 };
763
764 let result = get_state.reply.send(cx, state);
765 if let Err(e) = result {
769 tracing::warn!(
770 actor = %cx.self_id(),
771 "failed to send GetState reply to {} due to error: {}",
772 get_state.reply.port_id().actor_id(),
773 e
774 );
775 }
776 Ok(())
777 }
778}
779
780#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)]
783pub struct NewClientInstance {
784 #[reply]
785 pub client_instance: PortHandle<Instance<()>>,
786}
787
788#[async_trait]
789impl Handler<NewClientInstance> for ProcMeshAgent {
790 async fn handle(
791 &mut self,
792 _cx: &Context<Self>,
793 NewClientInstance { client_instance }: NewClientInstance,
794 ) -> anyhow::Result<()> {
795 let (instance, _handle) = self.proc.instance("client")?;
796 client_instance.send(instance)?;
797 Ok(())
798 }
799}
800
801#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)]
804pub struct GetProc {
805 #[reply]
806 pub proc: PortHandle<Proc>,
807}
808
809#[async_trait]
810impl Handler<GetProc> for ProcMeshAgent {
811 async fn handle(
812 &mut self,
813 _cx: &Context<Self>,
814 GetProc { proc }: GetProc,
815 ) -> anyhow::Result<()> {
816 proc.send(self.proc.clone())?;
817 Ok(())
818 }
819}
820
821#[derive(Clone)]
824pub(crate) struct ReconfigurableMailboxSender {
825 state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
826}
827
828impl std::fmt::Debug for ReconfigurableMailboxSender {
829 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830 f.debug_struct("ReconfigurableMailboxSender").finish()
833 }
834}
835
836pub(crate) struct ReconfigurableMailboxSenderInner<'a> {
849 guard: RwLockReadGuard<'a, ReconfigurableMailboxSenderState>,
850}
851
852impl<'a> ReconfigurableMailboxSenderInner<'a> {
853 pub(crate) fn as_configured(&self) -> Option<&BoxedMailboxSender> {
854 self.guard.as_configured()
855 }
856}
857
858type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
859
860#[derive(EnumAsInner, Debug)]
861enum ReconfigurableMailboxSenderState {
862 Queueing(Mutex<Vec<Post>>),
863 Configured(BoxedMailboxSender),
864}
865
866impl ReconfigurableMailboxSender {
867 pub(crate) fn new() -> Self {
868 Self {
869 state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
870 Mutex::new(Vec::new()),
871 ))),
872 }
873 }
874
875 pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
879 let mut state = self.state.write().unwrap();
881 if state.is_configured() {
882 return false;
883 }
884
885 let queued = std::mem::replace(
887 &mut *state,
888 ReconfigurableMailboxSenderState::Configured(sender),
889 );
890
891 let configured_sender = state.as_configured().expect("just configured");
894
895 for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
897 configured_sender.post(envelope, return_handle);
898 }
899
900 true
901 }
902
903 pub(crate) fn as_inner<'a>(
904 &'a self,
905 ) -> Result<ReconfigurableMailboxSenderInner<'a>, anyhow::Error> {
906 let state = self.state.read().unwrap();
907 if state.is_configured() {
908 Ok(ReconfigurableMailboxSenderInner { guard: state })
909 } else {
910 Err(anyhow::anyhow!("cannot get inner sender: not configured"))
911 }
912 }
913}
914
915impl MailboxSender for ReconfigurableMailboxSender {
916 fn post(
917 &self,
918 envelope: MessageEnvelope,
919 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
920 ) {
921 match &*self.state.read().unwrap() {
922 ReconfigurableMailboxSenderState::Queueing(queue) => {
923 queue.lock().unwrap().push((envelope, return_handle));
924 }
925 ReconfigurableMailboxSenderState::Configured(sender) => {
926 sender.post(envelope, return_handle);
927 }
928 }
929 }
930
931 fn post_unchecked(
932 &self,
933 envelope: MessageEnvelope,
934 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
935 ) {
936 match &*self.state.read().unwrap() {
937 ReconfigurableMailboxSenderState::Queueing(queue) => {
938 queue.lock().unwrap().push((envelope, return_handle));
939 }
940 ReconfigurableMailboxSenderState::Configured(sender) => {
941 sender.post_unchecked(envelope, return_handle);
942 }
943 }
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use std::sync::Arc;
950 use std::sync::Mutex;
951
952 use hyperactor::id;
953 use hyperactor::mailbox::BoxedMailboxSender;
954 use hyperactor::mailbox::Mailbox;
955 use hyperactor::mailbox::MailboxSender;
956 use hyperactor::mailbox::MessageEnvelope;
957 use hyperactor::mailbox::PortHandle;
958 use hyperactor::mailbox::Undeliverable;
959 use hyperactor_config::attrs::Attrs;
960
961 use super::*;
962
963 #[derive(Debug, Clone)]
964 struct QueueingMailboxSender {
965 messages: Arc<Mutex<Vec<MessageEnvelope>>>,
966 }
967
968 impl QueueingMailboxSender {
969 fn new() -> Self {
970 Self {
971 messages: Arc::new(Mutex::new(Vec::new())),
972 }
973 }
974
975 fn get_messages(&self) -> Vec<MessageEnvelope> {
976 self.messages.lock().unwrap().clone()
977 }
978 }
979
980 impl MailboxSender for QueueingMailboxSender {
981 fn post_unchecked(
982 &self,
983 envelope: MessageEnvelope,
984 _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
985 ) {
986 self.messages.lock().unwrap().push(envelope);
987 }
988 }
989
990 fn envelope(data: u64) -> MessageEnvelope {
992 MessageEnvelope::serialize(
993 id!(world[0].sender),
994 id!(world[0].receiver[0][1]),
995 &data,
996 Attrs::new(),
997 )
998 .unwrap()
999 }
1000
1001 fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
1002 let mbox = Mailbox::new_detached(id!(test[0].test));
1003 let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
1004 port
1005 }
1006
1007 #[test]
1008 fn test_queueing_before_configure() {
1009 let sender = ReconfigurableMailboxSender::new();
1010
1011 let test_sender = QueueingMailboxSender::new();
1012 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1013
1014 let return_handle = return_handle();
1015 sender.post(envelope(1), return_handle.clone());
1016 sender.post(envelope(2), return_handle.clone());
1017
1018 assert_eq!(test_sender.get_messages().len(), 0);
1019
1020 sender.configure(boxed_sender);
1021
1022 let messages = test_sender.get_messages();
1023 assert_eq!(messages.len(), 2);
1024
1025 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
1026 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
1027 }
1028
1029 #[test]
1030 fn test_direct_delivery_after_configure() {
1031 let sender = ReconfigurableMailboxSender::new();
1033
1034 let test_sender = QueueingMailboxSender::new();
1035 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1036 sender.configure(boxed_sender);
1037
1038 let return_handle = return_handle();
1039 sender.post(envelope(3), return_handle.clone());
1040 sender.post(envelope(4), return_handle.clone());
1041
1042 let messages = test_sender.get_messages();
1043 assert_eq!(messages.len(), 2);
1044
1045 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
1046 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
1047 }
1048
1049 #[test]
1050 fn test_multiple_configurations() {
1051 let sender = ReconfigurableMailboxSender::new();
1052 let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
1053
1054 assert!(sender.configure(boxed_sender.clone()));
1055 assert!(!sender.configure(boxed_sender));
1056 }
1057
1058 #[test]
1059 fn test_mixed_queueing_and_direct_delivery() {
1060 let sender = ReconfigurableMailboxSender::new();
1061
1062 let test_sender = QueueingMailboxSender::new();
1063 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1064
1065 let return_handle = return_handle();
1066 sender.post(envelope(5), return_handle.clone());
1067 sender.post(envelope(6), return_handle.clone());
1068
1069 sender.configure(boxed_sender);
1070
1071 sender.post(envelope(7), return_handle.clone());
1072 sender.post(envelope(8), return_handle.clone());
1073
1074 let messages = test_sender.get_messages();
1075 assert_eq!(messages.len(), 4);
1076
1077 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
1078 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
1079 assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
1080 assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
1081 }
1082}