1use std::collections::HashMap;
10use std::future::IntoFuture;
11
12use futures::FutureExt;
13use futures::future::BoxFuture;
14use hyperactor::Instance;
15use hyperactor::actor::ActorError;
16use hyperactor::actor::ActorHandle;
17use hyperactor::channel;
18use hyperactor::channel::ChannelAddr;
19use hyperactor::clock::Clock;
20use hyperactor::clock::ClockKind;
21use hyperactor::context::Mailbox as _;
22use hyperactor::id;
23use hyperactor::mailbox::BoxedMailboxSender;
24use hyperactor::mailbox::MailboxClient;
25use hyperactor::mailbox::MailboxSender;
26use hyperactor::mailbox::MailboxServer;
27use hyperactor::mailbox::MailboxServerHandle;
28use hyperactor::proc::Proc;
29use system_actor::SystemActor;
30use system_actor::SystemActorParams;
31use system_actor::SystemMessageClient;
32use tokio::join;
33
34use crate::proc_actor::ProcMessage;
35use crate::system_actor;
36use crate::system_actor::ProcLifecycleMode;
37
38#[derive(Debug)]
40pub struct System {
41 addr: ChannelAddr,
42}
43
44impl System {
45 #[tracing::instrument]
49 pub async fn serve(
50 addr: ChannelAddr,
51 supervision_update_timeout: tokio::time::Duration,
52 world_eviction_timeout: tokio::time::Duration,
53 ) -> Result<ServerHandle, anyhow::Error> {
54 let clock = ClockKind::for_channel_addr(&addr);
55 let params = SystemActorParams::new(supervision_update_timeout, world_eviction_timeout);
56 let (actor_handle, system_proc) = SystemActor::bootstrap_with_clock(params, clock).await?;
57 actor_handle.bind::<SystemActor>();
58
59 let (local_addr, rx) = channel::serve(addr)?;
60 let mailbox_handle = system_proc.clone().serve(rx);
61
62 Ok(ServerHandle {
63 actor_handle,
64 mailbox_handle,
65 local_addr,
66 })
67 }
68
69 pub fn new(addr: ChannelAddr) -> Self {
71 Self { addr }
72 }
73
74 async fn sender(&self) -> Result<impl MailboxSender + use<>, anyhow::Error> {
76 let tx = channel::dial(self.addr.clone())?;
77 Ok(MailboxClient::new(tx))
78 }
79
80 pub async fn attach(&mut self) -> Result<Instance<()>, anyhow::Error> {
86 let world_id = id!(user);
89 let proc = Proc::new(
90 world_id.random_user_proc(),
91 BoxedMailboxSender::new(self.sender().await?),
92 );
93
94 let (proc_addr, proc_rx) = channel::serve(ChannelAddr::any(self.addr.transport())).unwrap();
95
96 let _proc_serve_handle: MailboxServerHandle = proc.clone().serve(proc_rx);
97
98 let (instance, _handle) = proc.instance("proc")?;
100 let (proc_tx, mut proc_rx) = instance.mailbox().open_port();
101
102 system_actor::SYSTEM_ACTOR_REF
103 .join(
104 &instance,
105 world_id,
106 proc.proc_id().clone(),
107 proc_tx.bind(),
108 proc_addr,
109 HashMap::new(),
110 ProcLifecycleMode::Detached,
111 )
112 .await
113 .unwrap();
114 let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT);
115 loop {
116 let result = proc.clock().timeout(timeout, proc_rx.recv()).await?;
117 match result? {
118 ProcMessage::Joined() => break,
119 message => tracing::info!("proc message while joining: {:?}", message),
120 }
121 }
122
123 proc.instance("user").map(|(instance, _)| instance)
124 }
125}
126
127#[derive(Debug)]
129pub struct ServerHandle {
130 actor_handle: ActorHandle<SystemActor>,
131 mailbox_handle: MailboxServerHandle,
132 local_addr: ChannelAddr,
133}
134
135impl ServerHandle {
136 pub async fn stop(&self) -> Result<(), ActorError> {
138 self.actor_handle.drain_and_stop()?;
140 self.mailbox_handle.stop("system server stopped");
141 Ok(())
142 }
143
144 pub fn local_addr(&self) -> &ChannelAddr {
146 &self.local_addr
147 }
148
149 pub fn system_actor_handle(&self) -> &ActorHandle<SystemActor> {
151 &self.actor_handle
152 }
153}
154
155impl IntoFuture for ServerHandle {
158 type Output = ();
159 type IntoFuture = BoxFuture<'static, Self::Output>;
160
161 fn into_future(self) -> Self::IntoFuture {
162 let future = async move {
163 let _ = join!(self.actor_handle.into_future(), self.mailbox_handle);
164 };
165 future.boxed()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::collections::HashMap;
172 use std::collections::HashSet;
173 use std::time::Duration;
174
175 use hyperactor::ActorRef;
176 use hyperactor::ProcId;
177 use hyperactor::WorldId;
178 use hyperactor::channel::ChannelAddr;
179 use hyperactor::channel::ChannelTransport;
180 use hyperactor::channel::TcpMode;
181 use hyperactor::clock::Clock;
182 use hyperactor::clock::RealClock;
183 use hyperactor_telemetry::env::execution_id;
184 use maplit::hashset;
185 use timed_test::async_timed_test;
186
187 use super::*;
188 use crate::System;
189 use crate::proc_actor::Environment;
190 use crate::proc_actor::ProcActor;
191 use crate::supervision::ProcSupervisor;
192 use crate::system_actor::ProcLifecycleMode;
193 use crate::system_actor::SYSTEM_ACTOR_REF;
194 use crate::system_actor::Shape;
195 use crate::system_actor::SystemMessageClient;
196 use crate::system_actor::SystemSnapshot;
197 use crate::system_actor::SystemSnapshotFilter;
198 use crate::system_actor::WorldSnapshot;
199 use crate::system_actor::WorldSnapshotProcInfo;
200 use crate::system_actor::WorldStatus;
201
202 #[tokio::test]
203 async fn test_join() {
204 for transport in ChannelTransport::all() {
205 #[cfg(not(target_os = "linux"))]
207 if matches!(transport, ChannelTransport::Unix) {
208 continue;
209 }
210
211 let system_handle = System::serve(
212 ChannelAddr::any(transport),
213 Duration::from_secs(10),
214 Duration::from_secs(10),
215 )
216 .await
217 .unwrap();
218
219 let mut system = System::new(system_handle.local_addr().clone());
220 let client1 = system.attach().await.unwrap();
221 let client2 = system.attach().await.unwrap();
222
223 let (port, mut port_rx) = client2.open_port();
224
225 port.bind().send(&client1, 123u64).unwrap();
226 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
227
228 system_handle.stop().await.unwrap();
229 system_handle.await;
230 }
231 }
232
233 #[tokio::test]
234 async fn test_system_snapshot() {
235 let system_handle = System::serve(
236 ChannelAddr::any(ChannelTransport::Local),
237 Duration::from_secs(10),
238 Duration::from_secs(10),
239 )
240 .await
241 .unwrap();
242
243 let mut system = System::new(system_handle.local_addr().clone());
244 let client = system.attach().await.unwrap();
245
246 let sys_actor_handle = system_handle.system_actor_handle();
247 {
249 let snapshot = sys_actor_handle
250 .snapshot(&client, SystemSnapshotFilter::all())
251 .await
252 .unwrap();
253 assert_eq!(
254 snapshot,
255 SystemSnapshot {
256 worlds: HashMap::new(),
257 execution_id: execution_id(),
258 }
259 );
260 }
261
262 let foo_world = {
264 let foo_world_id = WorldId("foo_world".to_string());
265 sys_actor_handle
266 .upsert_world(
267 &client,
268 foo_world_id.clone(),
269 Shape::Definite(vec![2]),
270 5,
271 Environment::Local,
272 HashMap::new(),
273 )
274 .await
275 .unwrap();
276 {
277 let snapshot = sys_actor_handle
278 .snapshot(&client, SystemSnapshotFilter::all())
279 .await
280 .unwrap();
281 let time = snapshot
282 .worlds
283 .get(&foo_world_id)
284 .unwrap()
285 .status
286 .as_unhealthy()
287 .unwrap()
288 .clone();
289 assert_eq!(
290 snapshot,
291 SystemSnapshot {
292 worlds: HashMap::from([(
293 foo_world_id.clone(),
294 WorldSnapshot {
295 host_procs: HashSet::new(),
296 procs: HashMap::new(),
297 status: WorldStatus::Unhealthy(time),
298 labels: HashMap::new(),
299 }
300 ),]),
301 execution_id: execution_id(),
302 }
303 );
304 }
305
306 {
308 let test_labels =
309 HashMap::from([("test_name".to_string(), "test_value".to_string())]);
310 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
311 let proc_id = ProcId::Ranked(foo_world_id.clone(), 1);
312 ProcActor::try_bootstrap(
313 proc_id.clone(),
314 foo_world_id.clone(),
315 listen_addr,
316 system_handle.local_addr().clone(),
317 ActorRef::attest(proc_id.actor_id("supervision", 0)),
318 Duration::from_secs(30),
319 test_labels.clone(),
320 ProcLifecycleMode::ManagedBySystem,
321 )
322 .await
323 .unwrap();
324
325 let snapshot = sys_actor_handle
326 .snapshot(&client, SystemSnapshotFilter::all())
327 .await
328 .unwrap();
329 let time = snapshot
330 .worlds
331 .get(&foo_world_id)
332 .unwrap()
333 .status
334 .as_unhealthy()
335 .unwrap()
336 .clone();
337 let foo_world = (
338 foo_world_id.clone(),
339 WorldSnapshot {
340 host_procs: HashSet::new(),
341 procs: HashMap::from([(
342 proc_id.clone(),
343 WorldSnapshotProcInfo {
344 labels: test_labels.clone(),
345 },
346 )]),
347 status: WorldStatus::Unhealthy(time),
348 labels: HashMap::new(),
349 },
350 );
351
352 assert_eq!(
353 snapshot,
354 SystemSnapshot {
355 worlds: HashMap::from([foo_world.clone(),]),
356 execution_id: execution_id(),
357 },
358 );
359
360 let snapshot = sys_actor_handle
362 .snapshot(
363 &client,
364 SystemSnapshotFilter {
365 worlds: vec![WorldId("none".to_string())],
366 world_labels: HashMap::new(),
367 proc_labels: HashMap::new(),
368 },
369 )
370 .await
371 .unwrap();
372 assert!(snapshot.worlds.is_empty());
373 let snapshot = sys_actor_handle
375 .snapshot(
376 &client,
377 SystemSnapshotFilter {
378 worlds: vec![],
379 world_labels: HashMap::new(),
380 proc_labels: test_labels.clone(),
381 },
382 )
383 .await
384 .unwrap();
385 assert_eq!(snapshot.worlds.get(&foo_world_id).unwrap(), &foo_world.1);
386 foo_world
387 }
388 };
389
390 {
392 let worker_world_id = WorldId("worker_world".to_string());
393 let host_world_id = WorldId(("hostworker_world").to_string());
394 let listen_addr: ChannelAddr = ChannelAddr::any(ChannelTransport::Local);
395 let host_proc_id_1 = ProcId::Ranked(host_world_id.clone(), 1);
397 ProcActor::try_bootstrap(
398 host_proc_id_1.clone(),
399 host_world_id.clone(),
400 listen_addr.clone(),
401 system_handle.local_addr().clone(),
402 ActorRef::attest(host_proc_id_1.actor_id("supervision", 0)),
403 Duration::from_secs(30),
404 HashMap::new(),
405 ProcLifecycleMode::ManagedBySystem,
406 )
407 .await
408 .unwrap();
409 {
410 let snapshot = sys_actor_handle
411 .snapshot(&client, SystemSnapshotFilter::all())
412 .await
413 .unwrap();
414 assert_eq!(
415 snapshot,
416 SystemSnapshot {
417 worlds: HashMap::from([
418 foo_world.clone(),
419 (
420 worker_world_id.clone(),
421 WorldSnapshot {
422 host_procs: HashSet::from([host_proc_id_1.clone()]),
423 procs: HashMap::new(),
424 status: WorldStatus::AwaitingCreation,
425 labels: HashMap::new(),
426 }
427 ),
428 ]),
429 execution_id: execution_id(),
430 },
431 );
432 }
433
434 sys_actor_handle
436 .upsert_world(
437 &client,
438 worker_world_id.clone(),
439 Shape::Definite(vec![3, 4]),
442 8,
443 Environment::Local,
444 HashMap::new(),
445 )
446 .await
447 .unwrap();
448 RealClock.sleep(Duration::from_secs(2)).await;
450 {
451 let snapshot = sys_actor_handle
452 .snapshot(&client, SystemSnapshotFilter::all())
453 .await
454 .unwrap();
455 let time = snapshot
456 .worlds
457 .get(&worker_world_id)
458 .unwrap()
459 .status
460 .as_unhealthy()
461 .unwrap()
462 .clone();
463 assert_eq!(
464 snapshot,
465 SystemSnapshot {
466 worlds: HashMap::from([
467 foo_world.clone(),
468 (
469 worker_world_id.clone(),
470 WorldSnapshot {
471 host_procs: HashSet::from([host_proc_id_1.clone()]),
472 procs: (8..12)
473 .map(|i| (
474 ProcId::Ranked(worker_world_id.clone(), i),
475 WorldSnapshotProcInfo {
476 labels: HashMap::new()
477 }
478 ))
479 .collect(),
480 status: WorldStatus::Unhealthy(time),
481 labels: HashMap::new(),
482 }
483 ),
484 ]),
485 execution_id: execution_id(),
486 },
487 );
488 }
489
490 let host_proc_id_0 = ProcId::Ranked(host_world_id.clone(), 0);
491 ProcActor::try_bootstrap(
492 host_proc_id_0.clone(),
493 host_world_id.clone(),
494 listen_addr,
495 system_handle.local_addr().clone(),
496 ActorRef::attest(host_proc_id_0.actor_id("supervision", 0)),
497 Duration::from_secs(30),
498 HashMap::new(),
499 ProcLifecycleMode::ManagedBySystem,
500 )
501 .await
502 .unwrap();
503
504 RealClock.sleep(Duration::from_secs(2)).await;
506 {
507 let snapshot = sys_actor_handle
508 .snapshot(&client, SystemSnapshotFilter::all())
509 .await
510 .unwrap();
511 assert_eq!(
512 snapshot,
513 SystemSnapshot {
514 worlds: HashMap::from([
515 foo_world,
516 (
517 worker_world_id.clone(),
518 WorldSnapshot {
519 host_procs: HashSet::from([host_proc_id_0, host_proc_id_1]),
520 procs: HashMap::from_iter((0..12).map(|i| (
521 ProcId::Ranked(worker_world_id.clone(), i),
522 WorldSnapshotProcInfo {
523 labels: HashMap::new()
524 }
525 ))),
526 status: WorldStatus::Live,
528 labels: HashMap::new(),
529 }
530 ),
531 ]),
532 execution_id: execution_id(),
533 }
534 );
535 }
536 }
537 }
538
539 #[tracing_test::traced_test]
545 #[async_timed_test(timeout_secs = 60)]
546 async fn test_system_shutdown() {
547 let system_handle = System::serve(
548 ChannelAddr::any(ChannelTransport::Local),
549 Duration::from_secs(10),
550 Duration::from_secs(10),
551 )
552 .await
553 .unwrap();
554 let system_supervision_ref: ActorRef<ProcSupervisor> =
555 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
556
557 let mut system = System::new(system_handle.local_addr().clone());
558 let client = system.attach().await.unwrap();
559
560 let sys_actor_handle = system_handle.system_actor_handle();
561
562 let worker_world_id = WorldId("worker_world".to_string());
564 let shape = vec![2, 2, 4];
565 let host_proc_actors = {
566 let host_world_id = WorldId(("hostworker_world").to_string());
567 sys_actor_handle
569 .upsert_world(
570 &client,
571 worker_world_id.clone(),
572 Shape::Definite(shape.clone()),
574 8,
575 Environment::Local,
576 HashMap::new(),
577 )
578 .await
579 .unwrap();
580
581 let futs = (0..2).map(|i| {
583 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
584 ProcActor::try_bootstrap(
585 host_proc_id.clone(),
586 host_world_id.clone(),
587 ChannelAddr::any(ChannelTransport::Local),
588 system_handle.local_addr().clone(),
589 system_supervision_ref.clone(),
590 Duration::from_secs(30),
591 HashMap::new(),
592 ProcLifecycleMode::ManagedBySystem,
593 )
594 });
595 futures::future::try_join_all(futs).await.unwrap()
596 };
597 RealClock.sleep(Duration::from_secs(2)).await;
599
600 let foo_proc_actors = {
602 let foo_world_id = WorldId("foo_world".to_string());
603 sys_actor_handle
604 .upsert_world(
605 &client,
606 foo_world_id.clone(),
607 Shape::Definite(vec![2]),
608 2,
609 Environment::Local,
610 HashMap::new(),
611 )
612 .await
613 .unwrap();
614 let foo_futs = (0..2).map(|i| {
616 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
617 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
618 ProcActor::try_bootstrap(
619 proc_id.clone(),
620 foo_world_id.clone(),
621 listen_addr,
622 system_handle.local_addr().clone(),
623 system_supervision_ref.clone(),
624 Duration::from_secs(30),
625 HashMap::new(),
626 ProcLifecycleMode::ManagedBySystem,
627 )
628 });
629 futures::future::try_join_all(foo_futs).await.unwrap()
630 };
631
632 let (port, receiver) = client.open_once_port::<()>();
633 sys_actor_handle
635 .stop(&client, None, Duration::from_secs(5), port.bind())
636 .await
637 .unwrap();
638 receiver.recv().await.unwrap();
639 RealClock.sleep(Duration::from_secs(5)).await;
640
641 for bootstrap in host_proc_actors {
643 bootstrap.proc_actor.into_future().await;
644 }
645
646 for bootstrap in foo_proc_actors {
648 bootstrap.proc_actor.into_future().await;
649 }
650 system_handle.actor_handle.into_future().await;
652
653 for m in 0..(shape.iter().product()) {
656 let proc_id = worker_world_id.proc_id(m);
657 assert!(tracing_test::internal::logs_with_scope_contain(
658 "hyperactor::proc",
659 format!("{proc_id}: proc stopped").as_str()
660 ));
661 }
662 }
663
664 #[async_timed_test(timeout_secs = 60)]
665 async fn test_single_world_shutdown() {
666 let system_handle = System::serve(
667 ChannelAddr::any(ChannelTransport::Local),
668 Duration::from_secs(10),
669 Duration::from_secs(10),
670 )
671 .await
672 .unwrap();
673 let system_supervision_ref: ActorRef<ProcSupervisor> =
674 ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone());
675
676 let mut system = System::new(system_handle.local_addr().clone());
677 let client = system.attach().await.unwrap();
678
679 let sys_actor_handle = system_handle.system_actor_handle();
680
681 let host_world_id = WorldId(("host_world").to_string());
682 let worker_world_id = WorldId("worker_world".to_string());
683 let foo_world_id = WorldId("foo_world".to_string());
684
685 let shape = vec![2, 2, 4];
687 let host_proc_actors = {
688 sys_actor_handle
690 .upsert_world(
691 &client,
692 worker_world_id.clone(),
693 Shape::Definite(shape.clone()),
695 8,
696 Environment::Local,
697 HashMap::new(),
698 )
699 .await
700 .unwrap();
701
702 let futs = (0..2).map(|i| {
704 let host_proc_id = ProcId::Ranked(host_world_id.clone(), i);
705 ProcActor::try_bootstrap(
706 host_proc_id.clone(),
707 host_world_id.clone(),
708 ChannelAddr::any(ChannelTransport::Local),
709 system_handle.local_addr().clone(),
710 system_supervision_ref.clone(),
711 Duration::from_secs(30),
712 HashMap::new(),
713 ProcLifecycleMode::ManagedBySystem,
714 )
715 });
716 futures::future::try_join_all(futs).await.unwrap()
717 };
718 RealClock.sleep(Duration::from_secs(2)).await;
720
721 let foo_proc_actors = {
723 sys_actor_handle
724 .upsert_world(
725 &client,
726 foo_world_id.clone(),
727 Shape::Definite(vec![2]),
728 2,
729 Environment::Local,
730 HashMap::new(),
731 )
732 .await
733 .unwrap();
734 let foo_futs = (0..2).map(|i| {
736 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
737 let proc_id = ProcId::Ranked(foo_world_id.clone(), i);
738 ProcActor::try_bootstrap(
739 proc_id.clone(),
740 foo_world_id.clone(),
741 listen_addr,
742 system_handle.local_addr().clone(),
743 system_supervision_ref.clone(),
744 Duration::from_secs(30),
745 HashMap::new(),
746 ProcLifecycleMode::ManagedBySystem,
747 )
748 });
749 futures::future::try_join_all(foo_futs).await.unwrap()
750 };
751
752 {
753 let snapshot = sys_actor_handle
754 .snapshot(&client, SystemSnapshotFilter::all())
755 .await
756 .unwrap();
757 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
758 assert_eq!(
759 snapshot_world_ids,
760 hashset! {worker_world_id.clone(), foo_world_id.clone(), WorldId("_world".to_string())}
761 );
762 }
763
764 let (port, receiver) = client.open_once_port::<()>();
765 sys_actor_handle
767 .stop(
768 &client,
769 Some(vec![WorldId("foo_world".into())]),
770 Duration::from_secs(5),
771 port.bind(),
772 )
773 .await
774 .unwrap();
775 receiver.recv().await.unwrap();
776 RealClock.sleep(Duration::from_secs(5)).await;
777
778 for bootstrap in foo_proc_actors {
780 bootstrap.proc_actor.into_future().await;
781 }
782
783 for bootstrap in host_proc_actors {
785 match RealClock
786 .timeout(Duration::from_secs(5), bootstrap.proc_actor.into_future())
787 .await
788 {
789 Ok(_) => {
790 panic!("foo actor shouldn't be stopped");
791 }
792 Err(_) => {}
793 }
794 }
795
796 match RealClock
798 .timeout(
799 Duration::from_secs(3),
800 system_handle.actor_handle.clone().into_future(),
801 )
802 .await
803 {
804 Ok(_) => {
805 panic!("system actor shouldn't be stopped");
806 }
807 Err(_) => {}
808 }
809
810 {
811 let snapshot = sys_actor_handle
812 .snapshot(&client, SystemSnapshotFilter::all())
813 .await
814 .unwrap();
815 let snapshot_world_ids: HashSet<WorldId> = snapshot.worlds.keys().cloned().collect();
816 assert_eq!(
818 snapshot_world_ids,
819 hashset! {worker_world_id, WorldId("_world".to_string())}
820 );
821 }
822 }
823
824 #[tracing_test::traced_test]
827 #[tokio::test]
828 async fn test_channel_dial_count() {
829 let system_handle = System::serve(
830 ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
831 Duration::from_secs(10),
832 Duration::from_secs(10),
833 )
834 .await
835 .unwrap();
836
837 let system_addr = system_handle.local_addr();
838 let mut system = System::new(system_addr.clone());
839 let client1 = system.attach().await.unwrap();
852
853 let client2 = system.attach().await.unwrap();
866
867 let (port, mut port_rx) = client2.open_port();
875 port.bind().send(&client1, 123u64).unwrap();
876 assert_eq!(port_rx.recv().await.unwrap(), 123u64);
877
878 logs_assert(|logs| {
880 let dial_count = logs
881 .iter()
882 .filter(|log| log.contains("dialing channel tcp"))
883 .count();
884 if dial_count == 4 {
885 Ok(())
886 } else {
887 Err(format!("unexpected tcp channel dial count: {}", dial_count))
888 }
889 });
890
891 system_handle.stop().await.unwrap();
892 system_handle.await;
893 }
894}