1use std::collections::HashMap;
10use std::future::IntoFuture;
11
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use hyperactor::Instance;
15use hyperactor::actor::ActorError;
16use hyperactor::actor::ActorHandle;
17use hyperactor::channel;
18use hyperactor::channel::ChannelAddr;
19use hyperactor::clock::Clock;
20use hyperactor::clock::ClockKind;
21use hyperactor::context::Mailbox as _;
22use hyperactor::id;
23use hyperactor::mailbox::BoxedMailboxSender;
24use hyperactor::mailbox::MailboxClient;
25use hyperactor::mailbox::MailboxSender;
26use hyperactor::mailbox::MailboxServer;
27use hyperactor::mailbox::MailboxServerHandle;
28use hyperactor::proc::Proc;
29use system_actor::SystemActor;
30use system_actor::SystemActorParams;
31use system_actor::SystemMessageClient;
32use tokio::join;
33
34use crate::proc_actor::ProcMessage;
35use crate::system_actor;
36use crate::system_actor::ProcLifecycleMode;
37
38#[derive(Debug)]
40pub struct System {
41 addr: ChannelAddr,
42}
43
44impl System {
45 pub async fn serve(
49 addr: ChannelAddr,
50 supervision_update_timeout: tokio::time::Duration,
51 world_eviction_timeout: tokio::time::Duration,
52 ) -> Result<ServerHandle, anyhow::Error> {
53 let clock = ClockKind::for_channel_addr(&addr);
54 let params = SystemActorParams::new(supervision_update_timeout, world_eviction_timeout);
55 let (actor_handle, system_proc) = SystemActor::bootstrap_with_clock(params, clock).await?;
56 actor_handle.bind::<SystemActor>();
57
58 let (local_addr, rx) = channel::serve(addr)?;
59 let mailbox_handle = system_proc.clone().serve(rx);
60
61 Ok(ServerHandle {
62 actor_handle,
63 mailbox_handle,
64 local_addr,
65 })
66 }
67
68 pub fn new(addr: ChannelAddr) -> Self {
70 Self { addr }
71 }
72
73 async fn sender(&self) -> Result<impl MailboxSender + use<>, anyhow::Error> {
75 let tx = channel::dial(self.addr.clone())?;
76 Ok(MailboxClient::new(tx))
77 }
78
79 pub async fn attach(&mut self) -> Result<Instance<()>, anyhow::Error> {
85 let world_id = id!(user);
88 let proc = Proc::new(
89 world_id.random_user_proc(),
90 BoxedMailboxSender::new(self.sender().await?),
91 );
92
93 let (proc_addr, proc_rx) = channel::serve(ChannelAddr::any(self.addr.transport())).unwrap();
94
95 let _proc_serve_handle: MailboxServerHandle = proc.clone().serve(proc_rx);
96
97 let (instance, _handle) = proc.instance("proc")?;
99 let (proc_tx, mut proc_rx) = instance.mailbox().open_port();
100
101 system_actor::SYSTEM_ACTOR_REF
102 .join(
103 &instance,
104 world_id,
105 proc.proc_id().clone(),
106 proc_tx.bind(),
107 proc_addr,
108 HashMap::new(),
109 ProcLifecycleMode::Detached,
110 )
111 .await
112 .unwrap();
113 let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT);
114 loop {
115 let result = proc.clock().timeout(timeout, proc_rx.recv()).await?;
116 match result? {
117 ProcMessage::Joined() => break,
118 message => tracing::info!("proc message while joining: {:?}", message),
119 }
120 }
121
122 proc.instance("user").map(|(instance, _)| instance)
123 }
124}
125
126#[derive(Debug)]
128pub struct ServerHandle {
129 actor_handle: ActorHandle<SystemActor>,
130 mailbox_handle: MailboxServerHandle,
131 local_addr: ChannelAddr,
132}
133
134impl ServerHandle {
135 pub async fn stop(&self) -> Result<(), ActorError> {
137 self.actor_handle.drain_and_stop()?;
139 self.mailbox_handle.stop("system server stopped");
140 Ok(())
141 }
142
143 pub fn local_addr(&self) -> &ChannelAddr {
145 &self.local_addr
146 }
147
148 pub fn system_actor_handle(&self) -> &ActorHandle<SystemActor> {
150 &self.actor_handle
151 }
152}
153
154impl IntoFuture for ServerHandle {
157 type Output = ();
158 type IntoFuture = BoxFuture<'static, Self::Output>;
159
160 fn into_future(self) -> Self::IntoFuture {
161 let future = async move {
162 let _ = join!(self.actor_handle.into_future(), self.mailbox_handle);
163 };
164 future.boxed()
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use std::collections::HashMap;
171 use std::collections::HashSet;
172 use std::time::Duration;
173
174 use hyperactor::ActorRef;
175 use hyperactor::ProcId;
176 use hyperactor::WorldId;
177 use hyperactor::channel::ChannelAddr;
178 use hyperactor::channel::ChannelTransport;
179 use hyperactor::clock::Clock;
180 use hyperactor::clock::RealClock;
181 use hyperactor_telemetry::env::execution_id;
182 use maplit::hashset;
183 use timed_test::async_timed_test;
184
185 use super::*;
186 use crate::System;
187 use crate::proc_actor::Environment;
188 use crate::proc_actor::ProcActor;
189 use crate::supervision::ProcSupervisor;
190 use crate::system_actor::ProcLifecycleMode;
191 use crate::system_actor::SYSTEM_ACTOR_REF;
192 use crate::system_actor::Shape;
193 use crate::system_actor::SystemMessageClient;
194 use crate::system_actor::SystemSnapshot;
195 use crate::system_actor::SystemSnapshotFilter;
196 use crate::system_actor::WorldSnapshot;
197 use crate::system_actor::WorldSnapshotProcInfo;
198 use crate::system_actor::WorldStatus;
199
200 #[tokio::test]
201 async fn test_join() {
202 for transport in ChannelTransport::all() {
203 #[cfg(not(target_os = "linux"))]
205 if matches!(transport, ChannelTransport::Unix) {
206 continue;
207 }
208
209 let system_handle = System::serve(
210 ChannelAddr::any(transport),
211 Duration::from_secs(10),
212 Duration::from_secs(10),
213 )
214 .await
215 .unwrap();
216
217 let mut system = System::new(system_handle.local_addr().clone());
218 let client1 = system.attach().await.unwrap();
219 let client2 = system.attach().await.unwrap();
220
221 let (port, mut port_rx) = client2.open_port();
222
223 port.bind().send(&client1, 123u64).unwrap();
224 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
225
226 system_handle.stop().await.unwrap();
227 system_handle.await;
228 }
229 }
230
231 #[tokio::test]
232 async fn test_system_snapshot() {
233 let system_handle = System::serve(
234 ChannelAddr::any(ChannelTransport::Local),
235 Duration::from_secs(10),
236 Duration::from_secs(10),
237 )
238 .await
239 .unwrap();
240
241 let mut system = System::new(system_handle.local_addr().clone());
242 let client = system.attach().await.unwrap();
243
244 let sys_actor_handle = system_handle.system_actor_handle();
245 {
247 let snapshot = sys_actor_handle
248 .snapshot(&client, SystemSnapshotFilter::all())
249 .await
250 .unwrap();
251 assert_eq!(
252 snapshot,
253 SystemSnapshot {
254 worlds: HashMap::new(),
255 execution_id: execution_id(),
256 }
257 );
258 }
259
260 let foo_world = {
262 let foo_world_id = WorldId("foo_world".to_string());
263 sys_actor_handle
264 .upsert_world(
265 &client,
266 foo_world_id.clone(),
267 Shape::Definite(vec![2]),
268 5,
269 Environment::Local,
270 HashMap::new(),
271 )
272 .await
273 .unwrap();
274 {
275 let snapshot = sys_actor_handle
276 .snapshot(&client, SystemSnapshotFilter::all())
277 .await
278 .unwrap();
279 let time = snapshot
280 .worlds
281 .get(&foo_world_id)
282 .unwrap()
283 .status
284 .as_unhealthy()
285 .unwrap()
286 .clone();
287 assert_eq!(
288 snapshot,
289 SystemSnapshot {
290 worlds: HashMap::from([(
291 foo_world_id.clone(),
292 WorldSnapshot {
293 host_procs: HashSet::new(),
294 procs: HashMap::new(),
295 status: WorldStatus::Unhealthy(time),
296 labels: HashMap::new(),
297 }
298 ),]),
299 execution_id: execution_id(),
300 }
301 );
302 }
303
304 {
306 let test_labels =
307 HashMap::from([("test_name".to_string(), "test_value".to_string())]);
308 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
309 let proc_id = ProcId::Ranked(foo_world_id.clone(), 1);
310 ProcActor::try_bootstrap(
311 proc_id.clone(),
312 foo_world_id.clone(),
313 listen_addr,
314 system_handle.local_addr().clone(),
315 ActorRef::attest(proc_id.actor_id("supervision", 0)),
316 Duration::from_secs(30),
317 test_labels.clone(),
318 ProcLifecycleMode::ManagedBySystem,
319 )
320 .await
321 .unwrap();
322
323 let snapshot = sys_actor_handle
324 .snapshot(&client, SystemSnapshotFilter::all())
325 .await
326 .unwrap();
327 let time = snapshot
328 .worlds
329 .get(&foo_world_id)
330 .unwrap()
331 .status
332 .as_unhealthy()
333 .unwrap()
334 .clone();
335 let foo_world = (
336 foo_world_id.clone(),
337 WorldSnapshot {
338 host_procs: HashSet::new(),
339 procs: HashMap::from([(
340 proc_id.clone(),
341 WorldSnapshotProcInfo {
342 labels: test_labels.clone(),
343 },
344 )]),
345 status: WorldStatus::Unhealthy(time),
346 labels: HashMap::new(),
347 },
348 );
349
350 assert_eq!(
351 snapshot,
352 SystemSnapshot {
353 worlds: HashMap::from([foo_world.clone(),]),
354 execution_id: execution_id(),
355 },
356 );
357
358 let snapshot = sys_actor_handle
360 .snapshot(
361 &client,
362 SystemSnapshotFilter {
363 worlds: vec![WorldId("none".to_string())],
364 world_labels: HashMap::new(),
365 proc_labels: HashMap::new(),
366 },
367 )
368 .await
369 .unwrap();
370 assert!(snapshot.worlds.is_empty());
371 let snapshot = sys_actor_handle
373 .snapshot(
374 &client,
375 SystemSnapshotFilter {
376 worlds: vec![],
377 world_labels: HashMap::new(),
378 proc_labels: test_labels.clone(),
379 },
380 )
381 .await
382 .unwrap();
383 assert_eq!(snapshot.worlds.get(&foo_world_id).unwrap(), &foo_world.1);
384 foo_world
385 }
386 };
387
388 {
390 let worker_world_id = WorldId("worker_world".to_string());
391 let host_world_id = WorldId(("hostworker_world").to_string());
392 let listen_addr: ChannelAddr = ChannelAddr::any(ChannelTransport::Local);
393 let host_proc_id_1 = ProcId::Ranked(host_world_id.clone(), 1);
395 ProcActor::try_bootstrap(
396 host_proc_id_1.clone(),
397 host_world_id.clone(),
398 listen_addr.clone(),
399 system_handle.local_addr().clone(),
400 ActorRef::attest(host_proc_id_1.actor_id("supervision", 0)),
401 Duration::from_secs(30),
402 HashMap::new(),
403 ProcLifecycleMode::ManagedBySystem,
404 )
405 .await
406 .unwrap();
407 {
408 let snapshot = sys_actor_handle
409 .snapshot(&client, SystemSnapshotFilter::all())
410 .await
411 .unwrap();
412 assert_eq!(
413 snapshot,
414 SystemSnapshot {
415 worlds: HashMap::from([
416 foo_world.clone(),
417 (
418 worker_world_id.clone(),
419 WorldSnapshot {
420 host_procs: HashSet::from([host_proc_id_1.clone()]),
421 procs: HashMap::new(),
422 status: WorldStatus::AwaitingCreation,
423 labels: HashMap::new(),
424 }
425 ),
426 ]),
427 execution_id: execution_id(),
428 },
429 );
430 }
431
432 sys_actor_handle
434 .upsert_world(
435 &client,
436 worker_world_id.clone(),
437 Shape::Definite(vec![3, 4]),
440 8,
441 Environment::Local,
442 HashMap::new(),
443 )
444 .await
445 .unwrap();
446 RealClock.sleep(Duration::from_secs(2)).await;
448 {
449 let snapshot = sys_actor_handle
450 .snapshot(&client, SystemSnapshotFilter::all())
451 .await
452 .unwrap();
453 let time = snapshot
454 .worlds
455 .get(&worker_world_id)
456 .unwrap()
457 .status
458 .as_unhealthy()
459 .unwrap()
460 .clone();
461 assert_eq!(
462 snapshot,
463 SystemSnapshot {
464 worlds: HashMap::from([
465 foo_world.clone(),
466 (
467 worker_world_id.clone(),
468 WorldSnapshot {
469 host_procs: HashSet::from([host_proc_id_1.clone()]),
470 procs: (8..12)
471 .map(|i| (
472 ProcId::Ranked(worker_world_id.clone(), i),
473 WorldSnapshotProcInfo {
474 labels: HashMap::new()
475 }
476 ))
477 .collect(),
478 status: WorldStatus::Unhealthy(time),
479 labels: HashMap::new(),
480 }
481 ),
482 ]),
483 execution_id: execution_id(),
484 },
485 );
486 }
487
488 let host_proc_id_0 = ProcId::Ranked(host_world_id.clone(), 0);
489 ProcActor::try_bootstrap(
490 host_proc_id_0.clone(),
491 host_world_id.clone(),
492 listen_addr,
493 system_handle.local_addr().clone(),
494 ActorRef::attest(host_proc_id_0.actor_id("supervision", 0)),
495 Duration::from_secs(30),
496 HashMap::new(),
497 ProcLifecycleMode::ManagedBySystem,
498 )
499 .await
500 .unwrap();
501
502 RealClock.sleep(Duration::from_secs(2)).await;
504 {
505 let snapshot = sys_actor_handle
506 .snapshot(&client, SystemSnapshotFilter::all())
507 .await
508 .unwrap();
509 assert_eq!(
510 snapshot,
511 SystemSnapshot {
512 worlds: HashMap::from([
513 foo_world,
514 (
515 worker_world_id.clone(),
516 WorldSnapshot {
517 host_procs: HashSet::from([host_proc_id_0, host_proc_id_1]),
518 procs: HashMap::from_iter((0..12).map(|i| (
519 ProcId::Ranked(worker_world_id.clone(), i),
520 WorldSnapshotProcInfo {
521 labels: HashMap::new()
522 }
523 ))),
524 status: WorldStatus::Live,
526 labels: HashMap::new(),
527 }
528 ),
529 ]),
530 execution_id: execution_id(),
531 }
532 );
533 }
534 }
535 }
536
537 #[tracing_test::traced_test]
543 #[async_timed_test(timeout_secs = 60)]
544 async fn test_system_shutdown() {
545 let system_handle = System::serve(
546 ChannelAddr::any(ChannelTransport::Local),
547 Duration::from_secs(10),
548 Duration::from_secs(10),
549 )
550 .await
551 .unwrap();
552 let system_supervision_ref: ActorRef<ProcSupervisor> =
553 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
554
555 let mut system = System::new(system_handle.local_addr().clone());
556 let client = system.attach().await.unwrap();
557
558 let sys_actor_handle = system_handle.system_actor_handle();
559
560 let worker_world_id = WorldId("worker_world".to_string());
562 let shape = vec![2, 2, 4];
563 let host_proc_actors = {
564 let host_world_id = WorldId(("hostworker_world").to_string());
565 sys_actor_handle
567 .upsert_world(
568 &client,
569 worker_world_id.clone(),
570 Shape::Definite(shape.clone()),
572 8,
573 Environment::Local,
574 HashMap::new(),
575 )
576 .await
577 .unwrap();
578
579 let futs = (0..2).map(|i| {
581 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
582 ProcActor::try_bootstrap(
583 host_proc_id.clone(),
584 host_world_id.clone(),
585 ChannelAddr::any(ChannelTransport::Local),
586 system_handle.local_addr().clone(),
587 system_supervision_ref.clone(),
588 Duration::from_secs(30),
589 HashMap::new(),
590 ProcLifecycleMode::ManagedBySystem,
591 )
592 });
593 futures::future::try_join_all(futs).await.unwrap()
594 };
595 RealClock.sleep(Duration::from_secs(2)).await;
597
598 let foo_proc_actors = {
600 let foo_world_id = WorldId("foo_world".to_string());
601 sys_actor_handle
602 .upsert_world(
603 &client,
604 foo_world_id.clone(),
605 Shape::Definite(vec![2]),
606 2,
607 Environment::Local,
608 HashMap::new(),
609 )
610 .await
611 .unwrap();
612 let foo_futs = (0..2).map(|i| {
614 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
615 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
616 ProcActor::try_bootstrap(
617 proc_id.clone(),
618 foo_world_id.clone(),
619 listen_addr,
620 system_handle.local_addr().clone(),
621 system_supervision_ref.clone(),
622 Duration::from_secs(30),
623 HashMap::new(),
624 ProcLifecycleMode::ManagedBySystem,
625 )
626 });
627 futures::future::try_join_all(foo_futs).await.unwrap()
628 };
629
630 let (port, receiver) = client.open_once_port::<()>();
631 sys_actor_handle
633 .stop(&client, None, Duration::from_secs(5), port.bind())
634 .await
635 .unwrap();
636 receiver.recv().await.unwrap();
637 RealClock.sleep(Duration::from_secs(5)).await;
638
639 for bootstrap in host_proc_actors {
641 bootstrap.proc_actor.into_future().await;
642 }
643
644 for bootstrap in foo_proc_actors {
646 bootstrap.proc_actor.into_future().await;
647 }
648 system_handle.actor_handle.into_future().await;
650
651 for m in 0..(shape.iter().product()) {
654 let proc_id = worker_world_id.proc_id(m);
655 assert!(tracing_test::internal::logs_with_scope_contain(
656 "hyperactor::proc",
657 format!("{proc_id}: proc stopped").as_str()
658 ));
659 }
660 }
661
662 #[async_timed_test(timeout_secs = 60)]
663 async fn test_single_world_shutdown() {
664 let system_handle = System::serve(
665 ChannelAddr::any(ChannelTransport::Local),
666 Duration::from_secs(10),
667 Duration::from_secs(10),
668 )
669 .await
670 .unwrap();
671 let system_supervision_ref: ActorRef<ProcSupervisor> =
672 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
673
674 let mut system = System::new(system_handle.local_addr().clone());
675 let client = system.attach().await.unwrap();
676
677 let sys_actor_handle = system_handle.system_actor_handle();
678
679 let host_world_id = WorldId(("host_world").to_string());
680 let worker_world_id = WorldId("worker_world".to_string());
681 let foo_world_id = WorldId("foo_world".to_string());
682
683 let shape = vec![2, 2, 4];
685 let host_proc_actors = {
686 sys_actor_handle
688 .upsert_world(
689 &client,
690 worker_world_id.clone(),
691 Shape::Definite(shape.clone()),
693 8,
694 Environment::Local,
695 HashMap::new(),
696 )
697 .await
698 .unwrap();
699
700 let futs = (0..2).map(|i| {
702 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
703 ProcActor::try_bootstrap(
704 host_proc_id.clone(),
705 host_world_id.clone(),
706 ChannelAddr::any(ChannelTransport::Local),
707 system_handle.local_addr().clone(),
708 system_supervision_ref.clone(),
709 Duration::from_secs(30),
710 HashMap::new(),
711 ProcLifecycleMode::ManagedBySystem,
712 )
713 });
714 futures::future::try_join_all(futs).await.unwrap()
715 };
716 RealClock.sleep(Duration::from_secs(2)).await;
718
719 let foo_proc_actors = {
721 sys_actor_handle
722 .upsert_world(
723 &client,
724 foo_world_id.clone(),
725 Shape::Definite(vec![2]),
726 2,
727 Environment::Local,
728 HashMap::new(),
729 )
730 .await
731 .unwrap();
732 let foo_futs = (0..2).map(|i| {
734 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
735 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
736 ProcActor::try_bootstrap(
737 proc_id.clone(),
738 foo_world_id.clone(),
739 listen_addr,
740 system_handle.local_addr().clone(),
741 system_supervision_ref.clone(),
742 Duration::from_secs(30),
743 HashMap::new(),
744 ProcLifecycleMode::ManagedBySystem,
745 )
746 });
747 futures::future::try_join_all(foo_futs).await.unwrap()
748 };
749
750 {
751 let snapshot = sys_actor_handle
752 .snapshot(&client, SystemSnapshotFilter::all())
753 .await
754 .unwrap();
755 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
756 assert_eq!(
757 snapshot_world_ids,
758 hashset! {worker_world_id.clone(), foo_world_id.clone(), WorldId("_world".to_string())}
759 );
760 }
761
762 let (port, receiver) = client.open_once_port::<()>();
763 sys_actor_handle
765 .stop(
766 &client,
767 Some(vec![WorldId("foo_world".into())]),
768 Duration::from_secs(5),
769 port.bind(),
770 )
771 .await
772 .unwrap();
773 receiver.recv().await.unwrap();
774 RealClock.sleep(Duration::from_secs(5)).await;
775
776 for bootstrap in foo_proc_actors {
778 bootstrap.proc_actor.into_future().await;
779 }
780
781 for bootstrap in host_proc_actors {
783 match RealClock
784 .timeout(Duration::from_secs(5), bootstrap.proc_actor.into_future())
785 .await
786 {
787 Ok(_) => {
788 panic!("foo actor shouldn't be stopped");
789 }
790 Err(_) => {}
791 }
792 }
793
794 match RealClock
796 .timeout(
797 Duration::from_secs(3),
798 system_handle.actor_handle.clone().into_future(),
799 )
800 .await
801 {
802 Ok(_) => {
803 panic!("system actor shouldn't be stopped");
804 }
805 Err(_) => {}
806 }
807
808 {
809 let snapshot = sys_actor_handle
810 .snapshot(&client, SystemSnapshotFilter::all())
811 .await
812 .unwrap();
813 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
814 assert_eq!(
816 snapshot_world_ids,
817 hashset! {worker_world_id, WorldId("_world".to_string())}
818 );
819 }
820 }
821
822 #[tracing_test::traced_test]
825 #[tokio::test]
826 async fn test_channel_dial_count() {
827 let system_handle = System::serve(
828 ChannelAddr::any(ChannelTransport::Tcp),
829 Duration::from_secs(10),
830 Duration::from_secs(10),
831 )
832 .await
833 .unwrap();
834
835 let system_addr = system_handle.local_addr();
836 let mut system = System::new(system_addr.clone());
837 let client1 = system.attach().await.unwrap();
850
851 let client2 = system.attach().await.unwrap();
864
865 let (port, mut port_rx) = client2.open_port();
873 port.bind().send(&client1, 123u64).unwrap();
874 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
875
876 logs_assert(|logs| {
878 let dial_count = logs
879 .iter()
880 .filter(|log| log.contains("dialing channel tcp"))
881 .count();
882 if dial_count == 4 {
883 Ok(())
884 } else {
885 Err(format!("unexpected tcp channel dial count: {}", dial_count))
886 }
887 });
888
889 system_handle.stop().await.unwrap();
890 system_handle.await;
891 }
892}