1#[cfg(test)]
15use std::collections::HashMap;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::Bind;
24use hyperactor::Context;
25use hyperactor::Handler;
26use hyperactor::Instance;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29#[cfg(test)]
30use hyperactor::context;
31use hyperactor::ordering::SEQ_INFO;
32use hyperactor::ordering::SeqInfo;
33use hyperactor::reference as hyperactor_reference;
34use hyperactor::supervision::ActorSupervisionEvent;
35use hyperactor_config::Flattrs;
36use hyperactor_config::global::Source;
37use ndslice::Point;
38#[cfg(test)]
39use ndslice::ViewExt as _;
40use serde::Deserialize;
41use serde::Serialize;
42use typeuri::Named;
43#[cfg(test)]
44use uuid::Uuid;
45
46use crate::ActorMesh;
47#[cfg(test)]
48use crate::ActorMeshRef;
49use crate::Name;
50use crate::ProcMeshRef;
51use crate::comm::multicast::CastInfo;
52use crate::supervision::MeshFailure;
53#[cfg(test)]
54use crate::testing;
55
56#[derive(Default, Debug)]
58#[hyperactor::export(
59 spawn = true,
60 handlers = [
61 () { cast = true },
62 GetActorId { cast = true },
63 GetCastInfo { cast = true },
64 CauseSupervisionEvent { cast = true },
65 Forward,
66 GetConfigAttrs { cast = true },
67 SetConfigAttrs { cast = true },
68 ]
69)]
70pub struct TestActor;
71
72impl Actor for TestActor {}
73
74#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
76pub struct GetActorId(
77 #[binding(include)]
78 pub hyperactor_reference::PortRef<(hyperactor_reference::ActorId, Option<SeqInfo>)>,
79);
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum SupervisionEventType {
83 Panic,
84 SigSEGV,
85 ProcessExit(i32),
86}
87
88#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
91pub struct CauseSupervisionEvent {
92 pub kind: SupervisionEventType,
93 pub send_to_children: bool,
94}
95
96impl CauseSupervisionEvent {
97 fn cause_event(&self) -> ! {
98 match self.kind {
99 SupervisionEventType::Panic => {
100 panic!("for testing");
101 }
102 SupervisionEventType::SigSEGV => {
103 tracing::error!("exiting with SIGSEGV");
104 unsafe { std::ptr::null_mut::<i32>().write(42) };
106 panic!("should have segfaulted");
109 }
110 SupervisionEventType::ProcessExit(code) => {
111 tracing::error!("exiting process {} with code {}", std::process::id(), code);
112 std::process::exit(code);
113 }
114 }
115 }
116}
117
118#[async_trait]
119impl Handler<()> for TestActor {
120 async fn handle(&mut self, _cx: &Context<Self>, _: ()) -> Result<(), anyhow::Error> {
121 Ok(())
122 }
123}
124
125#[async_trait]
126impl Handler<GetActorId> for TestActor {
127 async fn handle(
128 &mut self,
129 cx: &Context<Self>,
130 GetActorId(reply): GetActorId,
131 ) -> Result<(), anyhow::Error> {
132 let seq_info = cx.headers().get(SEQ_INFO);
133 reply.send(cx, (cx.self_id().clone(), seq_info))?;
134 Ok(())
135 }
136}
137
138#[async_trait]
139impl Handler<CauseSupervisionEvent> for TestActor {
140 async fn handle(
141 &mut self,
142 _cx: &Context<Self>,
143 msg: CauseSupervisionEvent,
144 ) -> Result<(), anyhow::Error> {
145 msg.cause_event();
146 }
147}
148
149#[derive(Default, Debug)]
152#[hyperactor::export(
153 spawn = true,
154 handlers = [ActorSupervisionEvent],
155)]
156pub struct TestActorWithSupervisionHandling;
157
158#[async_trait]
159impl Actor for TestActorWithSupervisionHandling {
160 async fn handle_supervision_event(
161 &mut self,
162 _this: &Instance<Self>,
163 event: &ActorSupervisionEvent,
164 ) -> Result<bool, anyhow::Error> {
165 tracing::error!("supervision event: {:?}", event);
166 Ok(true)
168 }
169}
170
171#[async_trait]
172impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
173 async fn handle(
174 &mut self,
175 _cx: &Context<Self>,
176 _msg: ActorSupervisionEvent,
177 ) -> Result<(), anyhow::Error> {
178 Ok(())
179 }
180}
181
182#[derive(Default, Debug)]
185#[hyperactor::export(
186 spawn = true,
187 handlers = [std::time::Duration],
188)]
189pub struct SleepActor;
190
191impl Actor for SleepActor {}
192
193#[async_trait]
194impl Handler<std::time::Duration> for SleepActor {
195 async fn handle(
196 &mut self,
197 _cx: &Context<Self>,
198 duration: std::time::Duration,
199 ) -> Result<(), anyhow::Error> {
200 tokio::time::sleep(duration).await;
201 Ok(())
202 }
203}
204
205#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
209pub struct Forward {
210 pub to_visit: VecDeque<hyperactor_reference::PortRef<Forward>>,
211 pub visited: Vec<hyperactor_reference::PortRef<Forward>>,
212}
213
214#[async_trait]
215impl Handler<Forward> for TestActor {
216 async fn handle(
217 &mut self,
218 cx: &Context<Self>,
219 Forward {
220 mut to_visit,
221 mut visited,
222 }: Forward,
223 ) -> Result<(), anyhow::Error> {
224 let Some(this) = to_visit.pop_front() else {
225 anyhow::bail!("unexpected forward chain termination");
226 };
227 visited.push(this);
228 let next = to_visit.front().cloned();
229 anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
230 next.unwrap().send(cx, Forward { to_visit, visited })?;
231 Ok(())
232 }
233}
234
235#[derive(
237 Debug,
238 Clone,
239 Named,
240 Bind,
241 Unbind,
242 Serialize,
243 Deserialize,
244 Handler,
245 RefClient
246)]
247pub struct GetCastInfo {
248 #[reply]
250 pub cast_info: hyperactor_reference::PortRef<(
251 Point,
252 hyperactor_reference::ActorRef<TestActor>,
253 hyperactor_reference::ActorId,
254 )>,
255}
256
257#[async_trait]
258impl Handler<GetCastInfo> for TestActor {
259 async fn handle(
260 &mut self,
261 cx: &Context<Self>,
262 GetCastInfo { cast_info }: GetCastInfo,
263 ) -> Result<(), anyhow::Error> {
264 cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
265 Ok(())
266 }
267}
268
269#[derive(Debug)]
270#[hyperactor::export(spawn = true)]
271pub struct FailingCreateTestActor;
272
273#[async_trait]
274impl Actor for FailingCreateTestActor {}
275
276#[async_trait]
277impl hyperactor::RemoteSpawn for FailingCreateTestActor {
278 type Params = ();
279
280 async fn new(
281 _params: Self::Params,
282 _environment: Flattrs,
283 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
284 Err(anyhow::anyhow!("test failure"))
285 }
286}
287
288#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
289pub struct SetConfigAttrs(pub Vec<u8>);
290
291#[async_trait]
292impl Handler<SetConfigAttrs> for TestActor {
293 async fn handle(
294 &mut self,
295 _cx: &Context<Self>,
296 SetConfigAttrs(attrs): SetConfigAttrs,
297 ) -> Result<(), anyhow::Error> {
298 let attrs = bincode::deserialize(&attrs)?;
299 hyperactor_config::global::set(Source::Runtime, attrs);
300 Ok(())
301 }
302}
303
304#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
305pub struct GetConfigAttrs(pub hyperactor_reference::PortRef<Vec<u8>>);
306
307#[async_trait]
308impl Handler<GetConfigAttrs> for TestActor {
309 async fn handle(
310 &mut self,
311 cx: &Context<Self>,
312 GetConfigAttrs(reply): GetConfigAttrs,
313 ) -> Result<(), anyhow::Error> {
314 let attrs = bincode::serialize(&hyperactor_config::global::attrs())?;
315 reply.send(cx, attrs)?;
316 Ok(())
317 }
318}
319
320#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
324pub struct NextSupervisionFailure(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
325
326#[derive(Debug)]
330#[hyperactor::export(
331 spawn = true,
332 handlers = [
333 CauseSupervisionEvent { cast = true },
334 MeshFailure { cast = true },
335 NextSupervisionFailure { cast = true },
336 ]
337)]
338pub struct WrapperActor {
339 proc_mesh: ProcMeshRef,
340 mesh: Option<ActorMesh<TestActor>>,
342 supervisor: hyperactor_reference::PortRef<MeshFailure>,
343 test_name: Name,
344}
345
346#[async_trait]
347impl hyperactor::RemoteSpawn for WrapperActor {
348 type Params = (
349 ProcMeshRef,
350 hyperactor_reference::PortRef<MeshFailure>,
351 Name,
352 );
353
354 async fn new(
355 (proc_mesh, supervisor, test_name): Self::Params,
356 _environment: Flattrs,
357 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
358 Ok(Self {
359 proc_mesh,
360 mesh: None,
361 supervisor,
362 test_name,
363 })
364 }
365}
366
367#[async_trait]
368impl Actor for WrapperActor {
369 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
370 self.mesh = Some(
371 self.proc_mesh
372 .spawn_with_name(this, self.test_name.clone(), &(), None, false)
373 .await?,
374 );
375 Ok(())
376 }
377}
378
379#[async_trait]
380impl Handler<CauseSupervisionEvent> for WrapperActor {
381 async fn handle(
382 &mut self,
383 cx: &Context<Self>,
384 msg: CauseSupervisionEvent,
385 ) -> Result<(), anyhow::Error> {
386 if msg.send_to_children {
388 self.mesh
390 .as_ref()
391 .unwrap()
392 .cast(cx, msg)
393 .map_err(|e| e.into())
394 } else {
395 msg.cause_event()
396 }
397 }
398}
399
400#[async_trait]
401impl Handler<NextSupervisionFailure> for WrapperActor {
402 async fn handle(
403 &mut self,
404 cx: &Context<Self>,
405 msg: NextSupervisionFailure,
406 ) -> Result<(), anyhow::Error> {
407 let mesh = if let Some(mesh) = self.mesh.as_ref() {
408 mesh.deref()
409 } else {
410 msg.0.send(cx, None)?;
411 return Ok(());
412 };
413 let failure = match tokio::time::timeout(
414 tokio::time::Duration::from_secs(20),
415 mesh.next_supervision_event(cx),
416 )
417 .await
418 {
419 Ok(Ok(failure)) => Some(failure),
420 Ok(Err(_)) => None,
422 Err(_) => None,
424 };
425 msg.0.send(cx, failure)?;
426 Ok(())
427 }
428}
429
430#[async_trait]
431impl Handler<MeshFailure> for WrapperActor {
432 async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
433 tracing::info!("got supervision event from child: {}", msg);
436 let _ = self.supervisor.send(cx, msg.clone());
439 Ok(())
440 }
441}
442
443#[cfg(test)]
444pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
448 let instance = testing::instance();
449 assert_casting_correctness(&actor_mesh, instance, None).await;
451
452 let label = actor_mesh.extent().labels()[0].clone();
455 let size = actor_mesh.extent().sizes()[0] / 2;
456
457 let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
459 assert_casting_correctness(&sliced_actor_mesh, instance, None).await;
460}
461
462#[cfg(test)]
463pub async fn assert_casting_correctness(
466 actor_mesh: &ActorMeshRef<TestActor>,
467 instance: &impl context::Actor,
468 expected_seqs: Option<(Uuid, Vec<u64>)>,
469) {
470 let (port, mut rx) = instance.mailbox().open_port();
471 actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
472 let expected_actor_ids = actor_mesh
473 .values()
474 .map(|actor_ref| actor_ref.actor_id().clone())
475 .collect::<Vec<_>>();
476 let mut expected: HashMap<&hyperactor_reference::ActorId, Option<SeqInfo>> = match expected_seqs
477 {
478 None => expected_actor_ids
479 .iter()
480 .map(|actor_id| (actor_id, None))
481 .collect(),
482 Some((session_id, seqs)) => expected_actor_ids
483 .iter()
484 .zip(
485 seqs.into_iter()
486 .map(|seq| Some(SeqInfo::Session { session_id, seq })),
487 )
488 .collect(),
489 };
490
491 while !expected.is_empty() {
492 let (actor_id, rcved) = rx.recv().await.unwrap();
493 let rcv_seq_info = rcved.unwrap();
494 let removed = expected.remove(&actor_id);
495 assert!(
496 removed.is_some(),
497 "got {actor_id}, expect {expected_actor_ids:?}"
498 );
499 if let Some(expected) = removed.unwrap() {
500 assert_eq!(expected, rcv_seq_info, "got different seq for {actor_id}");
501 }
502 }
503
504 tokio::time::sleep(Duration::from_secs(1)).await;
506 let result = rx.try_recv();
507 assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
508}