1use std::collections::HashMap;
12use std::mem::replace;
13use std::mem::take;
14use std::sync::Arc;
15use std::sync::Mutex;
16use std::sync::RwLock;
17use std::sync::RwLockWriteGuard;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::Actor;
22use hyperactor::ActorHandle;
23use hyperactor::ActorId;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Data;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::OncePortRef;
32use hyperactor::PortHandle;
33use hyperactor::PortRef;
34use hyperactor::ProcId;
35use hyperactor::RefClient;
36use hyperactor::Unbind;
37use hyperactor::WorldId;
38use hyperactor::actor::ActorStatus;
39use hyperactor::actor::remote::Remote;
40use hyperactor::channel;
41use hyperactor::channel::ChannelAddr;
42use hyperactor::clock::Clock;
43use hyperactor::clock::RealClock;
44use hyperactor::mailbox::BoxedMailboxSender;
45use hyperactor::mailbox::DialMailboxRouter;
46use hyperactor::mailbox::IntoBoxedMailboxSender;
47use hyperactor::mailbox::MailboxClient;
48use hyperactor::mailbox::MailboxSender;
49use hyperactor::mailbox::MessageEnvelope;
50use hyperactor::mailbox::Undeliverable;
51use hyperactor::proc::Proc;
52use hyperactor::supervision::ActorSupervisionEvent;
53use serde::Deserialize;
54use serde::Serialize;
55
56use crate::actor_mesh::CAST_ACTOR_MESH_ID;
57use crate::proc_mesh::SupervisionEventState;
58use crate::reference::ActorMeshId;
59use crate::resource;
60use crate::v1::Name;
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
63pub enum GspawnResult {
64 Success { rank: usize, actor_id: ActorId },
65 Error(String),
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
69pub enum StopActorResult {
70 Success,
71 Timeout,
72 NotFound,
73}
74
75#[derive(
76 Debug,
77 Clone,
78 PartialEq,
79 Serialize,
80 Deserialize,
81 Handler,
82 HandleClient,
83 RefClient,
84 Named
85)]
86pub(crate) enum MeshAgentMessage {
87 Configure {
89 rank: usize,
91 forwarder: ChannelAddr,
93 supervisor: Option<PortRef<ActorSupervisionEvent>>,
95 address_book: HashMap<ProcId, ChannelAddr>,
97 configured: PortRef<usize>,
100 record_supervision_events: bool,
102 },
103
104 Status {
105 status: PortRef<(usize, bool)>,
109 },
110
111 Gspawn {
113 actor_type: String,
115 actor_name: String,
117 params_data: Data,
119 status_port: PortRef<GspawnResult>,
121 },
122
123 StopActor {
125 actor_id: ActorId,
127 timeout_ms: u64,
129 #[reply]
131 stopped: OncePortRef<StopActorResult>,
132 },
133}
134
135#[derive(Debug, EnumAsInner, Default)]
137enum State {
138 UnconfiguredV0 {
139 sender: ReconfigurableMailboxSender,
140 },
141
142 ConfiguredV0 {
143 sender: ReconfigurableMailboxSender,
144 rank: usize,
145 supervisor: Option<PortRef<ActorSupervisionEvent>>,
146 },
147
148 V1,
149
150 #[default]
151 Invalid,
152}
153
154impl State {
155 fn rank(&self) -> Option<usize> {
156 match self {
157 State::ConfiguredV0 { rank, .. } => Some(*rank),
158 _ => None,
159 }
160 }
161
162 fn supervisor(&self) -> Option<PortRef<ActorSupervisionEvent>> {
163 match self {
164 State::ConfiguredV0 { supervisor, .. } => supervisor.clone(),
165 _ => None,
166 }
167 }
168}
169
170#[derive(Debug)]
172struct ActorInstanceState {
173 create_rank: usize,
174 spawn: Result<ActorId, anyhow::Error>,
175 stopped: bool,
178}
179
180pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSupervisionEvent {
183 if let Some(headers) = &event.message_headers {
184 if let Some(actor_mesh_id) = headers.get(CAST_ACTOR_MESH_ID) {
185 match actor_mesh_id {
186 ActorMeshId::V0(proc_mesh_id, actor_name) => {
187 let old_actor = event.actor_id.clone();
188 event.actor_id = ActorId(
189 ProcId::Ranked(WorldId(proc_mesh_id.0.clone()), 0),
190 actor_name.clone(),
191 0,
192 );
193 tracing::debug!(
194 actor_id = %old_actor,
195 "proc supervision: remapped comm-actor id to mesh id from CAST_ACTOR_MESH_ID {}", event.actor_id
196 );
197 }
198 ActorMeshId::V1(_) => {
199 tracing::debug!(
200 "proc supervision: headers present but V1 ActorMeshId; leaving actor_id unchanged"
201 );
202 }
203 }
204 } else {
205 tracing::debug!(
206 "proc supervision: headers present but no CAST_ACTOR_MESH_ID; leaving actor_id unchanged"
207 );
208 }
209 } else {
210 tracing::debug!("proc supervision: no headers attached; leaving actor_id unchanged");
211 }
212 event
213}
214
215#[derive(Debug)]
217#[hyperactor::export(
218 handlers=[
219 MeshAgentMessage,
220 resource::CreateOrUpdate<ActorSpec> { cast = true },
221 resource::Stop { cast = true },
222 resource::StopAll { cast = true },
223 resource::GetState<ActorState> { cast = true },
224 resource::GetRankStatus { cast = true },
225 ]
226)]
227pub struct ProcMeshAgent {
228 proc: Proc,
229 remote: Remote,
230 state: State,
231 actor_states: HashMap<Name, ActorInstanceState>,
233 record_supervision_events: bool,
236 supervision_events: HashMap<ActorId, Vec<ActorSupervisionEvent>>,
239}
240
241impl ProcMeshAgent {
242 #[hyperactor::observe_result("MeshAgent")]
243 pub(crate) async fn bootstrap(
244 proc_id: ProcId,
245 ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
246 let sender = ReconfigurableMailboxSender::new();
247 let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
248
249 super::router::global().bind(proc_id.into(), proc.clone());
252
253 let agent = ProcMeshAgent {
254 proc: proc.clone(),
255 remote: Remote::collect(),
256 state: State::UnconfiguredV0 { sender },
257 actor_states: HashMap::new(),
258 record_supervision_events: false,
259 supervision_events: HashMap::new(),
260 };
261 let handle = proc.spawn::<Self>("mesh", agent).await?;
262 Ok((proc, handle))
263 }
264
265 pub(crate) async fn boot_v1(proc: Proc) -> Result<ActorHandle<Self>, anyhow::Error> {
266 let agent = ProcMeshAgent {
267 proc: proc.clone(),
268 remote: Remote::collect(),
269 state: State::V1,
270 actor_states: HashMap::new(),
271 record_supervision_events: true,
272 supervision_events: HashMap::new(),
273 };
274 proc.spawn::<Self>("agent", agent).await
275 }
276
277 async fn destroy_and_wait_except_current<'a>(
278 &mut self,
279 cx: &Context<'a, Self>,
280 timeout: tokio::time::Duration,
281 ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
282 self.proc
283 .destroy_and_wait_except_current::<Self>(timeout, Some(cx), true)
284 .await
285 }
286}
287
288#[async_trait]
289impl Actor for ProcMeshAgent {
290 type Params = Self;
291
292 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
293 Ok(params)
294 }
295
296 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
297 self.proc.set_supervision_coordinator(this.port())?;
298 Ok(())
299 }
300}
301
302#[async_trait]
303#[hyperactor::forward(MeshAgentMessage)]
304impl MeshAgentMessageHandler for ProcMeshAgent {
305 async fn configure(
306 &mut self,
307 cx: &Context<Self>,
308 rank: usize,
309 forwarder: ChannelAddr,
310 supervisor: Option<PortRef<ActorSupervisionEvent>>,
311 address_book: HashMap<ProcId, ChannelAddr>,
312 configured: PortRef<usize>,
313 record_supervision_events: bool,
314 ) -> Result<(), anyhow::Error> {
315 anyhow::ensure!(
316 self.state.is_unconfigured_v0(),
317 "mesh agent cannot be (re-)configured"
318 );
319 self.record_supervision_events = record_supervision_events;
320
321 let client = MailboxClient::new(channel::dial(forwarder)?);
324
325 let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
329 let default = super::router::global().fallback(client.into_boxed());
330 DialMailboxRouter::new_with_default_direct_addressed_remote_only(default.into_boxed())
331 } else {
332 DialMailboxRouter::new_with_default_direct_addressed_remote_only(client.into_boxed())
333 };
334
335 for (proc_id, addr) in address_book {
336 router.bind(proc_id.into(), addr);
337 }
338
339 let sender = take(&mut self.state).into_unconfigured_v0().unwrap();
340 assert!(sender.configure(router.into_boxed()));
341
342 self.state = State::ConfiguredV0 {
346 sender,
347 rank,
348 supervisor,
349 };
350 configured.send(cx, rank)?;
351
352 Ok(())
353 }
354
355 async fn gspawn(
356 &mut self,
357 cx: &Context<Self>,
358 actor_type: String,
359 actor_name: String,
360 params_data: Data,
361 status_port: PortRef<GspawnResult>,
362 ) -> Result<(), anyhow::Error> {
363 anyhow::ensure!(
364 self.state.is_configured_v0(),
365 "mesh agent is not v0 configured"
366 );
367 let actor_id = match self
368 .remote
369 .gspawn(&self.proc, &actor_type, &actor_name, params_data)
370 .await
371 {
372 Ok(id) => id,
373 Err(err) => {
374 status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
375 return Err(anyhow::anyhow!("gspawn failed"));
376 }
377 };
378 status_port.send(
379 cx,
380 GspawnResult::Success {
381 rank: self.state.rank().unwrap(),
382 actor_id,
383 },
384 )?;
385 Ok(())
386 }
387
388 async fn stop_actor(
389 &mut self,
390 _cx: &Context<Self>,
391 actor_id: ActorId,
392 timeout_ms: u64,
393 ) -> Result<StopActorResult, anyhow::Error> {
394 tracing::info!(
395 name = "StopActor",
396 actor_id = %actor_id,
397 actor_name = actor_id.name(),
398 );
399
400 if let Some(mut status) = self.proc.stop_actor(&actor_id) {
401 match RealClock
402 .timeout(
403 tokio::time::Duration::from_millis(timeout_ms),
404 status.wait_for(|state: &ActorStatus| state.is_terminal()),
405 )
406 .await
407 {
408 Ok(_) => Ok(StopActorResult::Success),
409 Err(_) => Ok(StopActorResult::Timeout),
410 }
411 } else {
412 Ok(StopActorResult::NotFound)
413 }
414 }
415
416 async fn status(
417 &mut self,
418 cx: &Context<Self>,
419 status_port: PortRef<(usize, bool)>,
420 ) -> Result<(), anyhow::Error> {
421 match &self.state {
422 State::ConfiguredV0 { rank, .. } => {
423 status_port.send(cx, (*rank, true))?;
425 Ok(())
426 }
427 State::UnconfiguredV0 { .. } => {
428 Err(anyhow::anyhow!(
430 "status unavailable: v0 agent not configured (waiting for Configure)"
431 ))
432 }
433 State::V1 => {
434 Err(anyhow::anyhow!(
436 "status unsupported in v1/owned path (no rank)"
437 ))
438 }
439 State::Invalid => Err(anyhow::anyhow!(
440 "status unavailable: agent in invalid state"
441 )),
442 }
443 }
444}
445
446#[async_trait]
447impl Handler<ActorSupervisionEvent> for ProcMeshAgent {
448 async fn handle(
449 &mut self,
450 cx: &Context<Self>,
451 event: ActorSupervisionEvent,
452 ) -> anyhow::Result<()> {
453 let event = update_event_actor_id(event);
454 if self.record_supervision_events {
455 if event.is_error() {
456 tracing::warn!(
457 name = "SupervisionEvent",
458 proc_id = %self.proc.proc_id(),
459 %event,
460 "recording supervision error",
461 );
462 } else {
463 tracing::debug!(
464 name = "SupervisionEvent",
465 proc_id = %self.proc.proc_id(),
466 %event,
467 "recording non-error supervision event",
468 );
469 }
470 self.supervision_events
471 .entry(event.actor_id.clone())
472 .or_default()
473 .push(event.clone());
474 }
475 if let Some(supervisor) = self.state.supervisor() {
476 supervisor.send(cx, event)?;
477 } else if !self.record_supervision_events {
478 tracing::error!(
481 name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
482 proc_id = %cx.self_id().proc_id(),
483 %event,
484 "could not propagate supervision event, crashing",
485 );
486
487 std::process::exit(1);
490 }
491 Ok(())
492 }
493}
494
495#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
499pub struct ActorSpec {
500 pub actor_type: String,
502 pub params_data: Data,
504}
505
506#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
508pub struct ActorState {
509 pub actor_id: ActorId,
511 pub create_rank: usize,
513 pub supervision_events: Vec<ActorSupervisionEvent>,
515}
516
517#[async_trait]
518impl Handler<resource::CreateOrUpdate<ActorSpec>> for ProcMeshAgent {
519 async fn handle(
520 &mut self,
521 _cx: &Context<Self>,
522 create_or_update: resource::CreateOrUpdate<ActorSpec>,
523 ) -> anyhow::Result<()> {
524 if self.actor_states.contains_key(&create_or_update.name) {
525 return Ok(());
527 }
528 let create_rank = create_or_update.rank.unwrap();
529 if !self.supervision_events.is_empty() {
533 self.actor_states.insert(
534 create_or_update.name.clone(),
535 ActorInstanceState {
536 spawn: Err(anyhow::anyhow!(
537 "Cannot spawn new actors on mesh with supervision events"
538 )),
539 create_rank,
540 stopped: false,
541 },
542 );
543 return Ok(());
544 }
545
546 let ActorSpec {
547 actor_type,
548 params_data,
549 } = create_or_update.spec;
550 self.actor_states.insert(
551 create_or_update.name.clone(),
552 ActorInstanceState {
553 create_rank,
554 spawn: self
555 .remote
556 .gspawn(
557 &self.proc,
558 &actor_type,
559 &create_or_update.name.to_string(),
560 params_data,
561 )
562 .await,
563 stopped: false,
564 },
565 );
566
567 Ok(())
568 }
569}
570
571#[async_trait]
572impl Handler<resource::Stop> for ProcMeshAgent {
573 async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
574 let actor = self.actor_states.get_mut(&message.name);
577 let actor_id = match actor {
580 Some(actor_state) => {
581 match &actor_state.spawn {
582 Ok(actor_id) => {
583 if actor_state.stopped {
584 None
585 } else {
586 actor_state.stopped = true;
587 Some(actor_id.clone())
588 }
589 }
590 Err(_) => None,
593 }
594 }
595 None => None,
597 };
598 let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
599 if let Some(actor_id) = actor_id {
600 self.stop_actor(cx, actor_id, timeout.as_millis() as u64)
603 .await
604 .expect("stop_actor cannot fail");
605 }
606
607 Ok(())
608 }
609}
610
611#[async_trait]
616impl Handler<resource::StopAll> for ProcMeshAgent {
617 async fn handle(
618 &mut self,
619 cx: &Context<Self>,
620 _message: resource::StopAll,
621 ) -> anyhow::Result<()> {
622 let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
623 let stop_result = self.destroy_and_wait_except_current(cx, timeout).await;
626 match stop_result {
631 Ok((stopped_actors, aborted_actors)) => {
632 tracing::info!(
634 actor = %cx.self_id(),
635 "exiting process after receiving StopAll message on ProcMeshAgent. \
636 stopped actors = {:?}, aborted actors = {:?}",
637 stopped_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
638 aborted_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
639 );
640 std::process::exit(0);
641 }
642 Err(e) => {
643 tracing::error!(actor = %cx.self_id(), "failed to stop all actors on ProcMeshAgent: {:?}", e);
644 std::process::exit(1);
645 }
646 }
647 }
648}
649
650#[async_trait]
651impl Handler<resource::GetRankStatus> for ProcMeshAgent {
652 async fn handle(
653 &mut self,
654 cx: &Context<Self>,
655 get_rank_status: resource::GetRankStatus,
656 ) -> anyhow::Result<()> {
657 use crate::resource::Status;
658 use crate::v1::StatusOverlay;
659
660 let (rank, status) = match self.actor_states.get(&get_rank_status.name) {
661 Some(ActorInstanceState {
662 spawn: Ok(actor_id),
663 create_rank,
664 stopped,
665 }) => {
666 if *stopped {
667 (*create_rank, resource::Status::Stopped)
668 } else {
669 let supervision_events = self
670 .supervision_events
671 .get(actor_id)
672 .map_or_else(Vec::new, |a| a.clone());
673 (
674 *create_rank,
675 if supervision_events.is_empty() {
676 resource::Status::Running
677 } else {
678 resource::Status::Failed(format!(
679 "because of supervision events: {:?}",
680 supervision_events
681 ))
682 },
683 )
684 }
685 }
686 Some(ActorInstanceState {
687 spawn: Err(e),
688 create_rank,
689 ..
690 }) => (*create_rank, Status::Failed(e.to_string())),
691 None => (usize::MAX, Status::NotExist),
693 };
694
695 let overlay = if rank == usize::MAX {
698 StatusOverlay::new()
699 } else {
700 StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
701 .expect("valid single-run overlay")
702 };
703 let result = get_rank_status.reply.send(cx, overlay);
704 if let Err(e) = result {
708 tracing::warn!(
709 actor = %cx.self_id(),
710 "failed to send GetRankStatus reply to {} due to error: {}",
711 get_rank_status.reply.port_id().actor_id(),
712 e
713 );
714 }
715 Ok(())
716 }
717}
718
719#[async_trait]
720impl Handler<resource::GetState<ActorState>> for ProcMeshAgent {
721 async fn handle(
722 &mut self,
723 cx: &Context<Self>,
724 get_state: resource::GetState<ActorState>,
725 ) -> anyhow::Result<()> {
726 let state = match self.actor_states.get(&get_state.name) {
727 Some(ActorInstanceState {
728 create_rank,
729 spawn: Ok(actor_id),
730 stopped,
731 }) => {
732 let supervision_events = self
733 .supervision_events
734 .get(actor_id)
735 .map_or_else(Vec::new, |a| a.clone());
736 let status = if *stopped {
737 resource::Status::Stopped
738 } else if supervision_events.is_empty() {
739 resource::Status::Running
740 } else {
741 resource::Status::Failed(format!(
742 "because of supervision events: {:?}",
743 supervision_events
744 ))
745 };
746 resource::State {
747 name: get_state.name.clone(),
748 status,
749 state: Some(ActorState {
750 actor_id: actor_id.clone(),
751 create_rank: *create_rank,
752 supervision_events,
753 }),
754 }
755 }
756 Some(ActorInstanceState { spawn: Err(e), .. }) => resource::State {
757 name: get_state.name.clone(),
758 status: resource::Status::Failed(e.to_string()),
759 state: None,
760 },
761 None => resource::State {
762 name: get_state.name.clone(),
763 status: resource::Status::NotExist,
764 state: None,
765 },
766 };
767
768 let result = get_state.reply.send(cx, state);
769 if let Err(e) = result {
773 tracing::warn!(
774 actor = %cx.self_id(),
775 "failed to send GetState reply to {} due to error: {}",
776 get_state.reply.port_id().actor_id(),
777 e
778 );
779 }
780 Ok(())
781 }
782}
783
784#[derive(Clone)]
787pub(crate) struct ReconfigurableMailboxSender {
788 state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
789}
790
791impl std::fmt::Debug for ReconfigurableMailboxSender {
792 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
793 f.debug_struct("ReconfigurableMailboxSender").finish()
796 }
797}
798
799pub(crate) struct ReconfigurableMailboxSenderInner<'a> {
800 guard: RwLockWriteGuard<'a, ReconfigurableMailboxSenderState>,
801}
802
803impl<'a> ReconfigurableMailboxSenderInner<'a> {
804 pub(crate) fn as_configured(&self) -> Option<&BoxedMailboxSender> {
805 self.guard.as_configured()
806 }
807}
808
809type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
810
811#[derive(EnumAsInner, Debug)]
812enum ReconfigurableMailboxSenderState {
813 Queueing(Mutex<Vec<Post>>),
814 Configured(BoxedMailboxSender),
815}
816
817impl ReconfigurableMailboxSender {
818 pub(crate) fn new() -> Self {
819 Self {
820 state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
821 Mutex::new(Vec::new()),
822 ))),
823 }
824 }
825
826 pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
830 let mut state = self.state.write().unwrap();
831 if state.is_configured() {
832 return false;
833 }
834
835 let queued = replace(
836 &mut *state,
837 ReconfigurableMailboxSenderState::Configured(sender.clone()),
838 );
839
840 for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
841 sender.post(envelope, return_handle);
842 }
843 *state = ReconfigurableMailboxSenderState::Configured(sender);
844 true
845 }
846
847 pub(crate) fn as_inner<'a>(
848 &'a self,
849 ) -> Result<ReconfigurableMailboxSenderInner<'a>, anyhow::Error> {
850 let state = self.state.write().unwrap();
851 if state.is_configured() {
852 Ok(ReconfigurableMailboxSenderInner { guard: state })
853 } else {
854 Err(anyhow::anyhow!("cannot get inner sender: not configured"))
855 }
856 }
857}
858
859impl MailboxSender for ReconfigurableMailboxSender {
860 fn post(
861 &self,
862 envelope: MessageEnvelope,
863 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
864 ) {
865 match *self.state.read().unwrap() {
866 ReconfigurableMailboxSenderState::Queueing(ref queue) => {
867 queue.lock().unwrap().push((envelope, return_handle));
868 }
869 ReconfigurableMailboxSenderState::Configured(ref sender) => {
870 sender.post(envelope, return_handle);
871 }
872 }
873 }
874
875 fn post_unchecked(
876 &self,
877 envelope: MessageEnvelope,
878 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
879 ) {
880 match *self.state.read().unwrap() {
881 ReconfigurableMailboxSenderState::Queueing(ref queue) => {
882 queue.lock().unwrap().push((envelope, return_handle));
883 }
884 ReconfigurableMailboxSenderState::Configured(ref sender) => {
885 sender.post_unchecked(envelope, return_handle);
886 }
887 }
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use std::sync::Arc;
894 use std::sync::Mutex;
895
896 use hyperactor::attrs::Attrs;
897 use hyperactor::id;
898 use hyperactor::mailbox::BoxedMailboxSender;
899 use hyperactor::mailbox::Mailbox;
900 use hyperactor::mailbox::MailboxSender;
901 use hyperactor::mailbox::MessageEnvelope;
902 use hyperactor::mailbox::PortHandle;
903 use hyperactor::mailbox::Undeliverable;
904
905 use super::*;
906
907 #[derive(Debug, Clone)]
908 struct QueueingMailboxSender {
909 messages: Arc<Mutex<Vec<MessageEnvelope>>>,
910 }
911
912 impl QueueingMailboxSender {
913 fn new() -> Self {
914 Self {
915 messages: Arc::new(Mutex::new(Vec::new())),
916 }
917 }
918
919 fn get_messages(&self) -> Vec<MessageEnvelope> {
920 self.messages.lock().unwrap().clone()
921 }
922 }
923
924 impl MailboxSender for QueueingMailboxSender {
925 fn post_unchecked(
926 &self,
927 envelope: MessageEnvelope,
928 _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
929 ) {
930 self.messages.lock().unwrap().push(envelope);
931 }
932 }
933
934 fn envelope(data: u64) -> MessageEnvelope {
936 MessageEnvelope::serialize(
937 id!(world[0].sender),
938 id!(world[0].receiver[0][1]),
939 &data,
940 Attrs::new(),
941 )
942 .unwrap()
943 }
944
945 fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
946 let mbox = Mailbox::new_detached(id!(test[0].test));
947 let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
948 port
949 }
950
951 #[test]
952 fn test_queueing_before_configure() {
953 let sender = ReconfigurableMailboxSender::new();
954
955 let test_sender = QueueingMailboxSender::new();
956 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
957
958 let return_handle = return_handle();
959 sender.post(envelope(1), return_handle.clone());
960 sender.post(envelope(2), return_handle.clone());
961
962 assert_eq!(test_sender.get_messages().len(), 0);
963
964 sender.configure(boxed_sender);
965
966 let messages = test_sender.get_messages();
967 assert_eq!(messages.len(), 2);
968
969 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
970 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
971 }
972
973 #[test]
974 fn test_direct_delivery_after_configure() {
975 let sender = ReconfigurableMailboxSender::new();
977
978 let test_sender = QueueingMailboxSender::new();
979 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
980 sender.configure(boxed_sender);
981
982 let return_handle = return_handle();
983 sender.post(envelope(3), return_handle.clone());
984 sender.post(envelope(4), return_handle.clone());
985
986 let messages = test_sender.get_messages();
987 assert_eq!(messages.len(), 2);
988
989 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
990 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
991 }
992
993 #[test]
994 fn test_multiple_configurations() {
995 let sender = ReconfigurableMailboxSender::new();
996 let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
997
998 assert!(sender.configure(boxed_sender.clone()));
999 assert!(!sender.configure(boxed_sender));
1000 }
1001
1002 #[test]
1003 fn test_mixed_queueing_and_direct_delivery() {
1004 let sender = ReconfigurableMailboxSender::new();
1005
1006 let test_sender = QueueingMailboxSender::new();
1007 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
1008
1009 let return_handle = return_handle();
1010 sender.post(envelope(5), return_handle.clone());
1011 sender.post(envelope(6), return_handle.clone());
1012
1013 sender.configure(boxed_sender);
1014
1015 sender.post(envelope(7), return_handle.clone());
1016 sender.post(envelope(8), return_handle.clone());
1017
1018 let messages = test_sender.get_messages();
1019 assert_eq!(messages.len(), 4);
1020
1021 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
1022 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
1023 assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
1024 assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
1025 }
1026}