1use std::collections::HashMap;
10use std::future::IntoFuture;
11
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use hyperactor::actor::ActorError;
15use hyperactor::actor::ActorHandle;
16use hyperactor::channel;
17use hyperactor::channel::ChannelAddr;
18use hyperactor::clock::Clock;
19use hyperactor::clock::ClockKind;
20use hyperactor::id;
21use hyperactor::mailbox::BoxedMailboxSender;
22use hyperactor::mailbox::Mailbox;
23use hyperactor::mailbox::MailboxClient;
24use hyperactor::mailbox::MailboxSender;
25use hyperactor::mailbox::MailboxServer;
26use hyperactor::mailbox::MailboxServerHandle;
27use hyperactor::proc::Proc;
28use system_actor::SystemActor;
29use system_actor::SystemActorParams;
30use system_actor::SystemMessageClient;
31use tokio::join;
32
33use crate::proc_actor::ProcMessage;
34use crate::system_actor;
35use crate::system_actor::ProcLifecycleMode;
36
37#[derive(Debug)]
39pub struct System {
40 addr: ChannelAddr,
41}
42
43impl System {
44 pub async fn serve(
48 addr: ChannelAddr,
49 supervision_update_timeout: tokio::time::Duration,
50 world_eviction_timeout: tokio::time::Duration,
51 ) -> Result<ServerHandle, anyhow::Error> {
52 let clock = ClockKind::for_channel_addr(&addr);
53 let params = SystemActorParams::new(supervision_update_timeout, world_eviction_timeout);
54 let (actor_handle, system_proc) = SystemActor::bootstrap_with_clock(params, clock).await?;
55 actor_handle.bind::<SystemActor>();
56
57 let (local_addr, rx) = channel::serve(addr).await?;
58 let mailbox_handle = system_proc.clone().serve(rx);
59
60 Ok(ServerHandle {
61 actor_handle,
62 mailbox_handle,
63 local_addr,
64 })
65 }
66
67 pub fn new(addr: ChannelAddr) -> Self {
69 Self { addr }
70 }
71
72 async fn sender(&self) -> Result<impl MailboxSender + use<>, anyhow::Error> {
74 let tx = channel::dial(self.addr.clone())?;
75 Ok(MailboxClient::new(tx))
76 }
77
78 pub async fn attach(&mut self) -> Result<Mailbox, anyhow::Error> {
84 let world_id = id!(user);
87 let proc = Proc::new(
88 world_id.random_user_proc(),
89 BoxedMailboxSender::new(self.sender().await?),
90 );
91
92 let (proc_addr, proc_rx) = channel::serve(ChannelAddr::any(self.addr.transport()))
93 .await
94 .unwrap();
95
96 let _proc_serve_handle: MailboxServerHandle = proc.clone().serve(proc_rx);
97
98 let proc_inst = proc.attach("proc")?;
100 let (proc_tx, mut proc_rx) = proc_inst.open_port();
101
102 system_actor::SYSTEM_ACTOR_REF
103 .join(
104 &proc_inst,
105 world_id,
106 proc.proc_id().clone(),
107 proc_tx.bind(),
108 proc_addr,
109 HashMap::new(),
110 ProcLifecycleMode::Detached,
111 )
112 .await
113 .unwrap();
114 let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT);
115 loop {
116 let result = proc.clock().timeout(timeout, proc_rx.recv()).await?;
117 match result? {
118 ProcMessage::Joined() => break,
119 message => tracing::info!("proc message while joining: {:?}", message),
120 }
121 }
122
123 proc.attach("user")
124 }
125}
126
127#[derive(Debug)]
129pub struct ServerHandle {
130 actor_handle: ActorHandle<SystemActor>,
131 mailbox_handle: MailboxServerHandle,
132 local_addr: ChannelAddr,
133}
134
135impl ServerHandle {
136 pub async fn stop(&self) -> Result<(), ActorError> {
138 self.actor_handle.drain_and_stop()?;
140 self.mailbox_handle.stop("system server stopped");
141 Ok(())
142 }
143
144 pub fn local_addr(&self) -> &ChannelAddr {
146 &self.local_addr
147 }
148
149 pub fn system_actor_handle(&self) -> &ActorHandle<SystemActor> {
151 &self.actor_handle
152 }
153}
154
155impl IntoFuture for ServerHandle {
158 type Output = ();
159 type IntoFuture = BoxFuture<'static, Self::Output>;
160
161 fn into_future(self) -> Self::IntoFuture {
162 let future = async move {
163 let _ = join!(self.actor_handle.into_future(), self.mailbox_handle);
164 };
165 future.boxed()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::collections::HashMap;
172 use std::collections::HashSet;
173 use std::time::Duration;
174
175 use hyperactor::ActorRef;
176 use hyperactor::ProcId;
177 use hyperactor::WorldId;
178 use hyperactor::channel::ChannelAddr;
179 use hyperactor::channel::ChannelTransport;
180 use hyperactor::clock::Clock;
181 use hyperactor::clock::RealClock;
182 use hyperactor_telemetry::env::execution_id;
183 use maplit::hashset;
184 use timed_test::async_timed_test;
185
186 use super::*;
187 use crate::System;
188 use crate::proc_actor::Environment;
189 use crate::proc_actor::ProcActor;
190 use crate::supervision::ProcSupervisor;
191 use crate::system_actor::ProcLifecycleMode;
192 use crate::system_actor::SYSTEM_ACTOR_REF;
193 use crate::system_actor::Shape;
194 use crate::system_actor::SystemMessageClient;
195 use crate::system_actor::SystemSnapshot;
196 use crate::system_actor::SystemSnapshotFilter;
197 use crate::system_actor::WorldSnapshot;
198 use crate::system_actor::WorldSnapshotProcInfo;
199 use crate::system_actor::WorldStatus;
200
201 #[tokio::test]
202 async fn test_join() {
203 for transport in ChannelTransport::all() {
204 #[cfg(not(target_os = "linux"))]
206 if matches!(transport, ChannelTransport::Unix) {
207 continue;
208 }
209
210 let system_handle = System::serve(
211 ChannelAddr::any(transport),
212 Duration::from_secs(10),
213 Duration::from_secs(10),
214 )
215 .await
216 .unwrap();
217
218 let mut system = System::new(system_handle.local_addr().clone());
219 let client1 = system.attach().await.unwrap();
220 let client2 = system.attach().await.unwrap();
221
222 let (port, mut port_rx) = client2.open_port();
223
224 port.bind().send(&client1, 123u64).unwrap();
225 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
226
227 system_handle.stop().await.unwrap();
228 system_handle.await;
229 }
230 }
231
232 #[tokio::test]
233 async fn test_system_snapshot() {
234 let system_handle = System::serve(
235 ChannelAddr::any(ChannelTransport::Local),
236 Duration::from_secs(10),
237 Duration::from_secs(10),
238 )
239 .await
240 .unwrap();
241
242 let mut system = System::new(system_handle.local_addr().clone());
243 let client = system.attach().await.unwrap();
244
245 let sys_actor_handle = system_handle.system_actor_handle();
246 {
248 let snapshot = sys_actor_handle
249 .snapshot(&client, SystemSnapshotFilter::all())
250 .await
251 .unwrap();
252 assert_eq!(
253 snapshot,
254 SystemSnapshot {
255 worlds: HashMap::new(),
256 execution_id: execution_id(),
257 }
258 );
259 }
260
261 let foo_world = {
263 let foo_world_id = WorldId("foo_world".to_string());
264 sys_actor_handle
265 .upsert_world(
266 &client,
267 foo_world_id.clone(),
268 Shape::Definite(vec![2]),
269 5,
270 Environment::Local,
271 HashMap::new(),
272 )
273 .await
274 .unwrap();
275 {
276 let snapshot = sys_actor_handle
277 .snapshot(&client, SystemSnapshotFilter::all())
278 .await
279 .unwrap();
280 let time = snapshot
281 .worlds
282 .get(&foo_world_id)
283 .unwrap()
284 .status
285 .as_unhealthy()
286 .unwrap()
287 .clone();
288 assert_eq!(
289 snapshot,
290 SystemSnapshot {
291 worlds: HashMap::from([(
292 foo_world_id.clone(),
293 WorldSnapshot {
294 host_procs: HashSet::new(),
295 procs: HashMap::new(),
296 status: WorldStatus::Unhealthy(time),
297 labels: HashMap::new(),
298 }
299 ),]),
300 execution_id: execution_id(),
301 }
302 );
303 }
304
305 {
307 let test_labels =
308 HashMap::from([("test_name".to_string(), "test_value".to_string())]);
309 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
310 let proc_id = ProcId::Ranked(foo_world_id.clone(), 1);
311 ProcActor::try_bootstrap(
312 proc_id.clone(),
313 foo_world_id.clone(),
314 listen_addr,
315 system_handle.local_addr().clone(),
316 ActorRef::attest(proc_id.actor_id("supervision", 0)),
317 Duration::from_secs(30),
318 test_labels.clone(),
319 ProcLifecycleMode::ManagedBySystem,
320 )
321 .await
322 .unwrap();
323
324 let snapshot = sys_actor_handle
325 .snapshot(&client, SystemSnapshotFilter::all())
326 .await
327 .unwrap();
328 let time = snapshot
329 .worlds
330 .get(&foo_world_id)
331 .unwrap()
332 .status
333 .as_unhealthy()
334 .unwrap()
335 .clone();
336 let foo_world = (
337 foo_world_id.clone(),
338 WorldSnapshot {
339 host_procs: HashSet::new(),
340 procs: HashMap::from([(
341 proc_id.clone(),
342 WorldSnapshotProcInfo {
343 labels: test_labels.clone(),
344 },
345 )]),
346 status: WorldStatus::Unhealthy(time),
347 labels: HashMap::new(),
348 },
349 );
350
351 assert_eq!(
352 snapshot,
353 SystemSnapshot {
354 worlds: HashMap::from([foo_world.clone(),]),
355 execution_id: execution_id(),
356 },
357 );
358
359 let snapshot = sys_actor_handle
361 .snapshot(
362 &client,
363 SystemSnapshotFilter {
364 worlds: vec![WorldId("none".to_string())],
365 world_labels: HashMap::new(),
366 proc_labels: HashMap::new(),
367 },
368 )
369 .await
370 .unwrap();
371 assert!(snapshot.worlds.is_empty());
372 let snapshot = sys_actor_handle
374 .snapshot(
375 &client,
376 SystemSnapshotFilter {
377 worlds: vec![],
378 world_labels: HashMap::new(),
379 proc_labels: test_labels.clone(),
380 },
381 )
382 .await
383 .unwrap();
384 assert_eq!(snapshot.worlds.get(&foo_world_id).unwrap(), &foo_world.1);
385 foo_world
386 }
387 };
388
389 {
391 let worker_world_id = WorldId("worker_world".to_string());
392 let host_world_id = WorldId(("hostworker_world").to_string());
393 let listen_addr: ChannelAddr = ChannelAddr::any(ChannelTransport::Local);
394 let host_proc_id_1 = ProcId::Ranked(host_world_id.clone(), 1);
396 ProcActor::try_bootstrap(
397 host_proc_id_1.clone(),
398 host_world_id.clone(),
399 listen_addr.clone(),
400 system_handle.local_addr().clone(),
401 ActorRef::attest(host_proc_id_1.actor_id("supervision", 0)),
402 Duration::from_secs(30),
403 HashMap::new(),
404 ProcLifecycleMode::ManagedBySystem,
405 )
406 .await
407 .unwrap();
408 {
409 let snapshot = sys_actor_handle
410 .snapshot(&client, SystemSnapshotFilter::all())
411 .await
412 .unwrap();
413 assert_eq!(
414 snapshot,
415 SystemSnapshot {
416 worlds: HashMap::from([
417 foo_world.clone(),
418 (
419 worker_world_id.clone(),
420 WorldSnapshot {
421 host_procs: HashSet::from([host_proc_id_1.clone()]),
422 procs: HashMap::new(),
423 status: WorldStatus::AwaitingCreation,
424 labels: HashMap::new(),
425 }
426 ),
427 ]),
428 execution_id: execution_id(),
429 },
430 );
431 }
432
433 sys_actor_handle
435 .upsert_world(
436 &client,
437 worker_world_id.clone(),
438 Shape::Definite(vec![3, 4]),
441 8,
442 Environment::Local,
443 HashMap::new(),
444 )
445 .await
446 .unwrap();
447 RealClock.sleep(Duration::from_secs(2)).await;
449 {
450 let snapshot = sys_actor_handle
451 .snapshot(&client, SystemSnapshotFilter::all())
452 .await
453 .unwrap();
454 let time = snapshot
455 .worlds
456 .get(&worker_world_id)
457 .unwrap()
458 .status
459 .as_unhealthy()
460 .unwrap()
461 .clone();
462 assert_eq!(
463 snapshot,
464 SystemSnapshot {
465 worlds: HashMap::from([
466 foo_world.clone(),
467 (
468 worker_world_id.clone(),
469 WorldSnapshot {
470 host_procs: HashSet::from([host_proc_id_1.clone()]),
471 procs: (8..12)
472 .map(|i| (
473 ProcId::Ranked(worker_world_id.clone(), i),
474 WorldSnapshotProcInfo {
475 labels: HashMap::new()
476 }
477 ))
478 .collect(),
479 status: WorldStatus::Unhealthy(time),
480 labels: HashMap::new(),
481 }
482 ),
483 ]),
484 execution_id: execution_id(),
485 },
486 );
487 }
488
489 let host_proc_id_0 = ProcId::Ranked(host_world_id.clone(), 0);
490 ProcActor::try_bootstrap(
491 host_proc_id_0.clone(),
492 host_world_id.clone(),
493 listen_addr,
494 system_handle.local_addr().clone(),
495 ActorRef::attest(host_proc_id_0.actor_id("supervision", 0)),
496 Duration::from_secs(30),
497 HashMap::new(),
498 ProcLifecycleMode::ManagedBySystem,
499 )
500 .await
501 .unwrap();
502
503 RealClock.sleep(Duration::from_secs(2)).await;
505 {
506 let snapshot = sys_actor_handle
507 .snapshot(&client, SystemSnapshotFilter::all())
508 .await
509 .unwrap();
510 assert_eq!(
511 snapshot,
512 SystemSnapshot {
513 worlds: HashMap::from([
514 foo_world,
515 (
516 worker_world_id.clone(),
517 WorldSnapshot {
518 host_procs: HashSet::from([host_proc_id_0, host_proc_id_1]),
519 procs: HashMap::from_iter((0..12).map(|i| (
520 ProcId::Ranked(worker_world_id.clone(), i),
521 WorldSnapshotProcInfo {
522 labels: HashMap::new()
523 }
524 ))),
525 status: WorldStatus::Live,
527 labels: HashMap::new(),
528 }
529 ),
530 ]),
531 execution_id: execution_id(),
532 }
533 );
534 }
535 }
536 }
537
538 #[tracing_test::traced_test]
544 #[async_timed_test(timeout_secs = 60)]
545 async fn test_system_shutdown() {
546 let system_handle = System::serve(
547 ChannelAddr::any(ChannelTransport::Local),
548 Duration::from_secs(10),
549 Duration::from_secs(10),
550 )
551 .await
552 .unwrap();
553 let system_supervision_ref: ActorRef<ProcSupervisor> =
554 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
555
556 let mut system = System::new(system_handle.local_addr().clone());
557 let client = system.attach().await.unwrap();
558
559 let sys_actor_handle = system_handle.system_actor_handle();
560
561 let worker_world_id = WorldId("worker_world".to_string());
563 let shape = vec![2, 2, 4];
564 let host_proc_actors = {
565 let host_world_id = WorldId(("hostworker_world").to_string());
566 sys_actor_handle
568 .upsert_world(
569 &client,
570 worker_world_id.clone(),
571 Shape::Definite(shape.clone()),
573 8,
574 Environment::Local,
575 HashMap::new(),
576 )
577 .await
578 .unwrap();
579
580 let futs = (0..2).map(|i| {
582 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
583 ProcActor::try_bootstrap(
584 host_proc_id.clone(),
585 host_world_id.clone(),
586 ChannelAddr::any(ChannelTransport::Local),
587 system_handle.local_addr().clone(),
588 system_supervision_ref.clone(),
589 Duration::from_secs(30),
590 HashMap::new(),
591 ProcLifecycleMode::ManagedBySystem,
592 )
593 });
594 futures::future::try_join_all(futs).await.unwrap()
595 };
596 RealClock.sleep(Duration::from_secs(2)).await;
598
599 let foo_proc_actors = {
601 let foo_world_id = WorldId("foo_world".to_string());
602 sys_actor_handle
603 .upsert_world(
604 &client,
605 foo_world_id.clone(),
606 Shape::Definite(vec![2]),
607 2,
608 Environment::Local,
609 HashMap::new(),
610 )
611 .await
612 .unwrap();
613 let foo_futs = (0..2).map(|i| {
615 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
616 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
617 ProcActor::try_bootstrap(
618 proc_id.clone(),
619 foo_world_id.clone(),
620 listen_addr,
621 system_handle.local_addr().clone(),
622 system_supervision_ref.clone(),
623 Duration::from_secs(30),
624 HashMap::new(),
625 ProcLifecycleMode::ManagedBySystem,
626 )
627 });
628 futures::future::try_join_all(foo_futs).await.unwrap()
629 };
630
631 let (port, receiver) = client.open_once_port::<()>();
632 sys_actor_handle
634 .stop(&client, None, Duration::from_secs(5), port.bind())
635 .await
636 .unwrap();
637 receiver.recv().await.unwrap();
638 RealClock.sleep(Duration::from_secs(5)).await;
639
640 for bootstrap in host_proc_actors {
642 bootstrap.proc_actor.into_future().await;
643 }
644
645 for bootstrap in foo_proc_actors {
647 bootstrap.proc_actor.into_future().await;
648 }
649 system_handle.actor_handle.into_future().await;
651
652 for m in 0..(shape.iter().product()) {
655 let proc_id = worker_world_id.proc_id(m);
656 assert!(tracing_test::internal::logs_with_scope_contain(
657 "hyperactor::proc",
658 format!("{proc_id}: proc stopped").as_str()
659 ));
660 }
661 }
662
663 #[async_timed_test(timeout_secs = 60)]
664 async fn test_single_world_shutdown() {
665 let system_handle = System::serve(
666 ChannelAddr::any(ChannelTransport::Local),
667 Duration::from_secs(10),
668 Duration::from_secs(10),
669 )
670 .await
671 .unwrap();
672 let system_supervision_ref: ActorRef<ProcSupervisor> =
673 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
674
675 let mut system = System::new(system_handle.local_addr().clone());
676 let client = system.attach().await.unwrap();
677
678 let sys_actor_handle = system_handle.system_actor_handle();
679
680 let host_world_id = WorldId(("host_world").to_string());
681 let worker_world_id = WorldId("worker_world".to_string());
682 let foo_world_id = WorldId("foo_world".to_string());
683
684 let shape = vec![2, 2, 4];
686 let host_proc_actors = {
687 sys_actor_handle
689 .upsert_world(
690 &client,
691 worker_world_id.clone(),
692 Shape::Definite(shape.clone()),
694 8,
695 Environment::Local,
696 HashMap::new(),
697 )
698 .await
699 .unwrap();
700
701 let futs = (0..2).map(|i| {
703 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
704 ProcActor::try_bootstrap(
705 host_proc_id.clone(),
706 host_world_id.clone(),
707 ChannelAddr::any(ChannelTransport::Local),
708 system_handle.local_addr().clone(),
709 system_supervision_ref.clone(),
710 Duration::from_secs(30),
711 HashMap::new(),
712 ProcLifecycleMode::ManagedBySystem,
713 )
714 });
715 futures::future::try_join_all(futs).await.unwrap()
716 };
717 RealClock.sleep(Duration::from_secs(2)).await;
719
720 let foo_proc_actors = {
722 sys_actor_handle
723 .upsert_world(
724 &client,
725 foo_world_id.clone(),
726 Shape::Definite(vec![2]),
727 2,
728 Environment::Local,
729 HashMap::new(),
730 )
731 .await
732 .unwrap();
733 let foo_futs = (0..2).map(|i| {
735 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
736 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
737 ProcActor::try_bootstrap(
738 proc_id.clone(),
739 foo_world_id.clone(),
740 listen_addr,
741 system_handle.local_addr().clone(),
742 system_supervision_ref.clone(),
743 Duration::from_secs(30),
744 HashMap::new(),
745 ProcLifecycleMode::ManagedBySystem,
746 )
747 });
748 futures::future::try_join_all(foo_futs).await.unwrap()
749 };
750
751 {
752 let snapshot = sys_actor_handle
753 .snapshot(&client, SystemSnapshotFilter::all())
754 .await
755 .unwrap();
756 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
757 assert_eq!(
758 snapshot_world_ids,
759 hashset! {worker_world_id.clone(), foo_world_id.clone(), WorldId("_world".to_string())}
760 );
761 }
762
763 let (port, receiver) = client.open_once_port::<()>();
764 sys_actor_handle
766 .stop(
767 &client,
768 Some(vec![WorldId("foo_world".into())]),
769 Duration::from_secs(5),
770 port.bind(),
771 )
772 .await
773 .unwrap();
774 receiver.recv().await.unwrap();
775 RealClock.sleep(Duration::from_secs(5)).await;
776
777 for bootstrap in foo_proc_actors {
779 bootstrap.proc_actor.into_future().await;
780 }
781
782 for bootstrap in host_proc_actors {
784 match RealClock
785 .timeout(Duration::from_secs(5), bootstrap.proc_actor.into_future())
786 .await
787 {
788 Ok(_) => {
789 panic!("foo actor shouldn't be stopped");
790 }
791 Err(_) => {}
792 }
793 }
794
795 match RealClock
797 .timeout(
798 Duration::from_secs(3),
799 system_handle.actor_handle.clone().into_future(),
800 )
801 .await
802 {
803 Ok(_) => {
804 panic!("system actor shouldn't be stopped");
805 }
806 Err(_) => {}
807 }
808
809 {
810 let snapshot = sys_actor_handle
811 .snapshot(&client, SystemSnapshotFilter::all())
812 .await
813 .unwrap();
814 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
815 assert_eq!(
817 snapshot_world_ids,
818 hashset! {worker_world_id, WorldId("_world".to_string())}
819 );
820 }
821 }
822
823 #[tracing_test::traced_test]
826 #[tokio::test]
827 async fn test_channel_dial_count() {
828 let system_handle = System::serve(
829 ChannelAddr::any(ChannelTransport::Tcp),
830 Duration::from_secs(10),
831 Duration::from_secs(10),
832 )
833 .await
834 .unwrap();
835
836 let system_addr = system_handle.local_addr();
837 let mut system = System::new(system_addr.clone());
838 let client1 = system.attach().await.unwrap();
851
852 let client2 = system.attach().await.unwrap();
865
866 let (port, mut port_rx) = client2.open_port();
874 port.bind().send(&client1, 123u64).unwrap();
875 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
876
877 logs_assert(|logs| {
879 let dial_count = logs
880 .iter()
881 .filter(|log| log.contains("dialing channel tcp"))
882 .count();
883 if dial_count == 4 {
884 Ok(())
885 } else {
886 Err(format!("unexpected tcp channel dial count: {}", dial_count))
887 }
888 });
889
890 system_handle.stop().await.unwrap();
891 system_handle.await;
892 }
893}