1#[cfg(test)]
15use std::collections::HashMap;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Endpoint as _;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::RefClient;
30use hyperactor::Unbind;
31#[cfg(test)]
32use hyperactor::context;
33use hyperactor::ordering::SEQ_INFO;
34use hyperactor::ordering::SeqInfo;
35use hyperactor::supervision::ActorSupervisionEvent;
36use hyperactor_config::Flattrs;
37use hyperactor_config::global::Source;
38use ndslice::Point;
39#[cfg(test)]
40use ndslice::ViewExt as _;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44#[cfg(test)]
45use uuid::Uuid;
46
47use crate::ActorMesh;
48#[cfg(test)]
49use crate::ActorMeshRef;
50use crate::ProcMeshRef;
51use crate::comm::multicast::CastInfo;
52use crate::mesh_id::ActorMeshId;
53use crate::supervision::MeshFailure;
54#[cfg(test)]
55use crate::testing;
56
57#[derive(Default, Debug)]
59#[hyperactor::export(
60 () { cast = true },
61 GetActorId { cast = true },
62 GetCastInfo { cast = true },
63 CauseSupervisionEvent { cast = true },
64 Forward,
65 GetConfigAttrs { cast = true },
66 SetConfigAttrs { cast = true },
67)]
68#[hyperactor::spawnable]
69pub struct TestActor;
70
71impl Actor for TestActor {}
72
73#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
75pub struct GetActorId(
76 #[binding(include)] pub hyperactor::PortRef<(hyperactor::ActorAddr, Option<SeqInfo>)>,
77);
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum SupervisionEventType {
81 Panic,
82 SigSEGV,
83 ProcessExit(i32),
84}
85
86#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
89pub struct CauseSupervisionEvent {
90 pub kind: SupervisionEventType,
91 pub send_to_children: bool,
92}
93
94impl CauseSupervisionEvent {
95 fn cause_event(&self) -> ! {
96 match self.kind {
97 SupervisionEventType::Panic => {
98 panic!("for testing");
99 }
100 SupervisionEventType::SigSEGV => {
101 tracing::error!("exiting with SIGSEGV");
102 unsafe { std::ptr::null_mut::<i32>().write(42) };
104 panic!("should have segfaulted");
107 }
108 SupervisionEventType::ProcessExit(code) => {
109 tracing::error!("exiting process {} with code {}", std::process::id(), code);
110 std::process::exit(code);
111 }
112 }
113 }
114}
115
116#[async_trait]
117impl Handler<()> for TestActor {
118 async fn handle(&mut self, _cx: &Context<Self>, _: ()) -> Result<(), anyhow::Error> {
119 Ok(())
120 }
121}
122
123#[async_trait]
124impl Handler<GetActorId> for TestActor {
125 async fn handle(
126 &mut self,
127 cx: &Context<Self>,
128 GetActorId(reply): GetActorId,
129 ) -> Result<(), anyhow::Error> {
130 let seq_info = cx.headers().get(SEQ_INFO);
131 reply.post(cx, (cx.self_addr().clone(), seq_info));
132 Ok(())
133 }
134}
135
136#[async_trait]
137impl Handler<CauseSupervisionEvent> for TestActor {
138 async fn handle(
139 &mut self,
140 _cx: &Context<Self>,
141 msg: CauseSupervisionEvent,
142 ) -> Result<(), anyhow::Error> {
143 msg.cause_event();
144 }
145}
146
147#[derive(Default, Debug)]
150#[hyperactor::export(ActorSupervisionEvent)]
151#[hyperactor::spawnable]
152pub struct TestActorWithSupervisionHandling;
153
154#[async_trait]
155impl Actor for TestActorWithSupervisionHandling {
156 async fn handle_supervision_event(
157 &mut self,
158 _this: &Instance<Self>,
159 event: &ActorSupervisionEvent,
160 ) -> Result<bool, anyhow::Error> {
161 tracing::error!("supervision event: {:?}", event);
162 Ok(true)
164 }
165}
166
167#[async_trait]
168impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
169 async fn handle(
170 &mut self,
171 _cx: &Context<Self>,
172 _msg: ActorSupervisionEvent,
173 ) -> Result<(), anyhow::Error> {
174 Ok(())
175 }
176}
177
178#[derive(Default, Debug)]
181#[hyperactor::export(std::time::Duration)]
182#[hyperactor::spawnable]
183pub struct SleepActor;
184
185impl Actor for SleepActor {}
186
187#[async_trait]
188impl Handler<std::time::Duration> for SleepActor {
189 async fn handle(
190 &mut self,
191 _cx: &Context<Self>,
192 duration: std::time::Duration,
193 ) -> Result<(), anyhow::Error> {
194 tokio::time::sleep(duration).await;
195 Ok(())
196 }
197}
198
199#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
203pub struct Forward {
204 pub to_visit: VecDeque<hyperactor::PortRef<Forward>>,
205 pub visited: Vec<hyperactor::PortRef<Forward>>,
206}
207
208#[async_trait]
209impl Handler<Forward> for TestActor {
210 async fn handle(
211 &mut self,
212 cx: &Context<Self>,
213 Forward {
214 mut to_visit,
215 mut visited,
216 }: Forward,
217 ) -> Result<(), anyhow::Error> {
218 let Some(this) = to_visit.pop_front() else {
219 anyhow::bail!("unexpected forward chain termination");
220 };
221 visited.push(this);
222 let next = to_visit.front().cloned();
223 anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
224 next.unwrap().post(cx, Forward { to_visit, visited });
225 Ok(())
226 }
227}
228
229#[derive(
231 Debug,
232 Clone,
233 Named,
234 Bind,
235 Unbind,
236 Serialize,
237 Deserialize,
238 Handler,
239 RefClient
240)]
241pub struct GetCastInfo {
242 #[reply]
244 pub cast_info: hyperactor::PortRef<(Point, ActorRef<TestActor>, hyperactor::ActorAddr)>,
245}
246
247#[async_trait]
248impl Handler<GetCastInfo> for TestActor {
249 async fn handle(
250 &mut self,
251 cx: &Context<Self>,
252 GetCastInfo { cast_info }: GetCastInfo,
253 ) -> Result<(), anyhow::Error> {
254 cast_info.post(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()));
255 Ok(())
256 }
257}
258
259#[derive(Debug)]
260#[hyperactor::export]
261#[hyperactor::spawnable]
262pub struct FailingCreateTestActor;
263
264#[async_trait]
265impl Actor for FailingCreateTestActor {}
266
267#[async_trait]
268impl hyperactor::RemoteSpawn for FailingCreateTestActor {
269 type Params = ();
270
271 async fn new(
272 _params: Self::Params,
273 _environment: Flattrs,
274 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
275 Err(anyhow::anyhow!("test failure"))
276 }
277}
278
279#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
280pub struct SetConfigAttrs(pub Vec<u8>);
281
282#[async_trait]
283impl Handler<SetConfigAttrs> for TestActor {
284 async fn handle(
285 &mut self,
286 _cx: &Context<Self>,
287 SetConfigAttrs(attrs): SetConfigAttrs,
288 ) -> Result<(), anyhow::Error> {
289 let attrs =
290 bincode::serde::decode_from_slice(&attrs, bincode::config::legacy()).map(|(v, _)| v)?;
291 hyperactor_config::global::set(Source::Runtime, attrs);
292 Ok(())
293 }
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
297pub struct GetConfigAttrs(pub hyperactor::PortRef<Vec<u8>>);
298
299#[async_trait]
300impl Handler<GetConfigAttrs> for TestActor {
301 async fn handle(
302 &mut self,
303 cx: &Context<Self>,
304 GetConfigAttrs(reply): GetConfigAttrs,
305 ) -> Result<(), anyhow::Error> {
306 let attrs = bincode::serde::encode_to_vec(
307 hyperactor_config::global::attrs(),
308 bincode::config::legacy(),
309 )?;
310 reply.post(cx, attrs);
311 Ok(())
312 }
313}
314
315#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
319pub struct NextSupervisionFailure(pub hyperactor::PortRef<Option<MeshFailure>>);
320
321#[derive(Debug)]
325#[hyperactor::export(
326 CauseSupervisionEvent { cast = true },
327 MeshFailure { cast = true },
328 NextSupervisionFailure { cast = true },
329)]
330#[hyperactor::spawnable]
331pub struct WrapperActor {
332 proc_mesh: ProcMeshRef,
333 mesh: Option<ActorMesh<TestActor>>,
335 supervisor: hyperactor::PortRef<MeshFailure>,
336 test_name: ActorMeshId,
337}
338
339#[async_trait]
340impl hyperactor::RemoteSpawn for WrapperActor {
341 type Params = (ProcMeshRef, hyperactor::PortRef<MeshFailure>, ActorMeshId);
342
343 async fn new(
344 (proc_mesh, supervisor, test_name): Self::Params,
345 _environment: Flattrs,
346 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
347 Ok(Self {
348 proc_mesh,
349 mesh: None,
350 supervisor,
351 test_name,
352 })
353 }
354}
355
356#[async_trait]
357impl Actor for WrapperActor {
358 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
359 self.mesh = Some(
360 self.proc_mesh
361 .spawn_with_name(this, self.test_name.clone(), &(), None, false)
362 .await?,
363 );
364 Ok(())
365 }
366}
367
368#[async_trait]
369impl Handler<CauseSupervisionEvent> for WrapperActor {
370 async fn handle(
371 &mut self,
372 cx: &Context<Self>,
373 msg: CauseSupervisionEvent,
374 ) -> Result<(), anyhow::Error> {
375 if msg.send_to_children {
377 self.mesh
379 .as_ref()
380 .unwrap()
381 .cast(cx, msg)
382 .map_err(|e| e.into())
383 } else {
384 msg.cause_event()
385 }
386 }
387}
388
389#[async_trait]
390impl Handler<NextSupervisionFailure> for WrapperActor {
391 async fn handle(
392 &mut self,
393 cx: &Context<Self>,
394 msg: NextSupervisionFailure,
395 ) -> Result<(), anyhow::Error> {
396 let mesh = if let Some(mesh) = self.mesh.as_ref() {
397 mesh.deref()
398 } else {
399 msg.0.post(cx, None);
400 return Ok(());
401 };
402 let failure = match tokio::time::timeout(
403 tokio::time::Duration::from_secs(20),
404 mesh.next_supervision_event(cx),
405 )
406 .await
407 {
408 Ok(Ok(failure)) => Some(failure),
409 Ok(Err(_)) => None,
411 Err(_) => None,
413 };
414 msg.0.post(cx, failure);
415 Ok(())
416 }
417}
418
419#[async_trait]
420impl Handler<MeshFailure> for WrapperActor {
421 async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
422 tracing::info!("got supervision event from child: {}", msg);
425 let _ = self.supervisor.post(cx, msg.clone());
428 Ok(())
429 }
430}
431
432#[cfg(test)]
433pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
437 let instance = testing::instance();
438 assert_casting_correctness(&actor_mesh, instance, None).await;
440
441 let label = actor_mesh.extent().labels()[0].clone();
444 let size = actor_mesh.extent().sizes()[0] / 2;
445
446 let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
448 assert_casting_correctness(&sliced_actor_mesh, instance, None).await;
449}
450
451#[cfg(test)]
452pub async fn assert_casting_correctness(
455 actor_mesh: &ActorMeshRef<TestActor>,
456 instance: &impl context::Actor,
457 expected_seqs: Option<(Uuid, Vec<u64>)>,
458) {
459 let (port, mut rx) = instance.mailbox().open_port();
460 actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
461 let expected_actor_ids = actor_mesh
462 .values()
463 .map(|actor_ref| actor_ref.actor_addr().clone())
464 .collect::<Vec<_>>();
465 let mut expected: HashMap<&hyperactor::ActorAddr, Option<SeqInfo>> = match expected_seqs {
466 None => expected_actor_ids
467 .iter()
468 .map(|actor_id| (actor_id, None))
469 .collect(),
470 Some((session_id, seqs)) => expected_actor_ids
471 .iter()
472 .zip(
473 seqs.into_iter()
474 .map(|seq| Some(SeqInfo::Session { session_id, seq })),
475 )
476 .collect(),
477 };
478
479 while !expected.is_empty() {
480 let (actor_id, rcved) = rx.recv().await.unwrap();
481 let rcv_seq_info = rcved.unwrap();
482 let removed = expected.remove(&actor_id);
483 assert!(
484 removed.is_some(),
485 "got {actor_id}, expect {expected_actor_ids:?}"
486 );
487 if let Some(expected) = removed.unwrap() {
488 assert_eq!(expected, rcv_seq_info, "got different seq for {actor_id}");
489 }
490 }
491
492 tokio::time::sleep(Duration::from_secs(1)).await;
494 let result = rx.try_recv();
495 assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
496}