1#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Bind;
26use hyperactor::Context;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32use hyperactor::clock::Clock as _;
33use hyperactor::clock::RealClock;
34#[cfg(test)]
35use hyperactor::context;
36#[cfg(test)]
37use hyperactor::mailbox;
38use hyperactor::supervision::ActorSupervisionEvent;
39use hyperactor_config::global::Source;
40use ndslice::Point;
41#[cfg(test)]
42use ndslice::ViewExt as _;
43use serde::Deserialize;
44use serde::Serialize;
45use typeuri::Named;
46
47use crate::comm::multicast::CastInfo;
48use crate::supervision::MeshFailure;
49use crate::v1::ActorMesh;
50#[cfg(test)]
51use crate::v1::ActorMeshRef;
52use crate::v1::Name;
53use crate::v1::ProcMeshRef;
54#[cfg(test)]
55use crate::v1::testing;
56
57#[derive(Default, Debug)]
59#[hyperactor::export(
60 spawn = true,
61 handlers = [
62 GetActorId { cast = true },
63 GetCastInfo { cast = true },
64 CauseSupervisionEvent { cast = true },
65 Forward,
66 GetConfigAttrs { cast = true },
67 SetConfigAttrs { cast = true },
68 ]
69)]
70pub struct TestActor;
71
72impl Actor for TestActor {}
73
74#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
76pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum SupervisionEventType {
80 Panic,
81 SigSEGV,
82 ProcessExit(i32),
83}
84
85#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
88pub struct CauseSupervisionEvent {
89 pub kind: SupervisionEventType,
90 pub send_to_children: bool,
91}
92
93impl CauseSupervisionEvent {
94 fn cause_event(&self) -> ! {
95 match self.kind {
96 SupervisionEventType::Panic => {
97 panic!("for testing");
98 }
99 SupervisionEventType::SigSEGV => {
100 tracing::error!("exiting with SIGSEGV");
101 unsafe { std::ptr::null_mut::<i32>().write(42) };
103 panic!("should have segfaulted");
106 }
107 SupervisionEventType::ProcessExit(code) => {
108 tracing::error!("exiting process {} with code {}", std::process::id(), code);
109 std::process::exit(code);
110 }
111 }
112 }
113}
114
115#[async_trait]
116impl Handler<GetActorId> for TestActor {
117 async fn handle(
118 &mut self,
119 cx: &Context<Self>,
120 GetActorId(reply): GetActorId,
121 ) -> Result<(), anyhow::Error> {
122 reply.send(cx, cx.self_id().clone())?;
123 Ok(())
124 }
125}
126
127#[async_trait]
128impl Handler<CauseSupervisionEvent> for TestActor {
129 async fn handle(
130 &mut self,
131 _cx: &Context<Self>,
132 msg: CauseSupervisionEvent,
133 ) -> Result<(), anyhow::Error> {
134 msg.cause_event();
135 }
136}
137
138#[derive(Default, Debug)]
141#[hyperactor::export(
142 spawn = true,
143 handlers = [ActorSupervisionEvent],
144)]
145pub struct TestActorWithSupervisionHandling;
146
147#[async_trait]
148impl Actor for TestActorWithSupervisionHandling {
149 async fn handle_supervision_event(
150 &mut self,
151 _this: &Instance<Self>,
152 event: &ActorSupervisionEvent,
153 ) -> Result<bool, anyhow::Error> {
154 tracing::error!("supervision event: {:?}", event);
155 Ok(true)
157 }
158}
159
160#[async_trait]
161impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
162 async fn handle(
163 &mut self,
164 _cx: &Context<Self>,
165 _msg: ActorSupervisionEvent,
166 ) -> Result<(), anyhow::Error> {
167 Ok(())
168 }
169}
170
171#[derive(Default, Debug)]
174#[hyperactor::export(
175 spawn = true,
176 handlers = [std::time::Duration],
177)]
178pub struct SleepActor;
179
180impl Actor for SleepActor {}
181
182#[async_trait]
183impl Handler<std::time::Duration> for SleepActor {
184 async fn handle(
185 &mut self,
186 _cx: &Context<Self>,
187 duration: std::time::Duration,
188 ) -> Result<(), anyhow::Error> {
189 RealClock.sleep(duration).await;
190 Ok(())
191 }
192}
193
194#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
198pub struct Forward {
199 pub to_visit: VecDeque<PortRef<Forward>>,
200 pub visited: Vec<PortRef<Forward>>,
201}
202
203#[async_trait]
204impl Handler<Forward> for TestActor {
205 async fn handle(
206 &mut self,
207 cx: &Context<Self>,
208 Forward {
209 mut to_visit,
210 mut visited,
211 }: Forward,
212 ) -> Result<(), anyhow::Error> {
213 let Some(this) = to_visit.pop_front() else {
214 anyhow::bail!("unexpected forward chain termination");
215 };
216 visited.push(this);
217 let next = to_visit.front().cloned();
218 anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
219 next.unwrap().send(cx, Forward { to_visit, visited })?;
220 Ok(())
221 }
222}
223
224#[derive(
226 Debug,
227 Clone,
228 Named,
229 Bind,
230 Unbind,
231 Serialize,
232 Deserialize,
233 Handler,
234 RefClient
235)]
236pub struct GetCastInfo {
237 #[reply]
239 pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
240}
241
242#[async_trait]
243impl Handler<GetCastInfo> for TestActor {
244 async fn handle(
245 &mut self,
246 cx: &Context<Self>,
247 GetCastInfo { cast_info }: GetCastInfo,
248 ) -> Result<(), anyhow::Error> {
249 cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
250 Ok(())
251 }
252}
253
254#[derive(Debug)]
255#[hyperactor::export(spawn = true)]
256pub struct FailingCreateTestActor;
257
258#[async_trait]
259impl Actor for FailingCreateTestActor {}
260
261#[async_trait]
262impl hyperactor::RemoteSpawn for FailingCreateTestActor {
263 type Params = ();
264
265 async fn new(
266 _params: Self::Params,
267 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
268 Err(anyhow::anyhow!("test failure"))
269 }
270}
271
272#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
273pub struct SetConfigAttrs(pub Vec<u8>);
274
275#[async_trait]
276impl Handler<SetConfigAttrs> for TestActor {
277 async fn handle(
278 &mut self,
279 _cx: &Context<Self>,
280 SetConfigAttrs(attrs): SetConfigAttrs,
281 ) -> Result<(), anyhow::Error> {
282 let attrs = bincode::deserialize(&attrs)?;
283 hyperactor_config::global::set(Source::Runtime, attrs);
284 Ok(())
285 }
286}
287
288#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
289pub struct GetConfigAttrs(pub PortRef<Vec<u8>>);
290
291#[async_trait]
292impl Handler<GetConfigAttrs> for TestActor {
293 async fn handle(
294 &mut self,
295 cx: &Context<Self>,
296 GetConfigAttrs(reply): GetConfigAttrs,
297 ) -> Result<(), anyhow::Error> {
298 let attrs = bincode::serialize(&hyperactor_config::global::attrs())?;
299 reply.send(cx, attrs)?;
300 Ok(())
301 }
302}
303
304#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
308pub struct NextSupervisionFailure(pub PortRef<Option<MeshFailure>>);
309
310#[derive(Debug)]
314#[hyperactor::export(
315 spawn = true,
316 handlers = [
317 CauseSupervisionEvent { cast = true },
318 MeshFailure { cast = true },
319 NextSupervisionFailure { cast = true },
320 ]
321)]
322pub struct WrapperActor {
323 proc_mesh: ProcMeshRef,
324 mesh: Option<ActorMesh<TestActor>>,
326 supervisor: PortRef<MeshFailure>,
327 test_name: Name,
328}
329
330#[async_trait]
331impl hyperactor::RemoteSpawn for WrapperActor {
332 type Params = (ProcMeshRef, PortRef<MeshFailure>, Name);
333
334 async fn new(
335 (proc_mesh, supervisor, test_name): Self::Params,
336 ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
337 Ok(Self {
338 proc_mesh,
339 mesh: None,
340 supervisor,
341 test_name,
342 })
343 }
344}
345
346#[async_trait]
347impl Actor for WrapperActor {
348 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
349 self.mesh = Some(
350 self.proc_mesh
351 .spawn_with_name(this, self.test_name.clone(), &(), None, false)
352 .await?,
353 );
354 Ok(())
355 }
356}
357
358#[async_trait]
359impl Handler<CauseSupervisionEvent> for WrapperActor {
360 async fn handle(
361 &mut self,
362 cx: &Context<Self>,
363 msg: CauseSupervisionEvent,
364 ) -> Result<(), anyhow::Error> {
365 if msg.send_to_children {
367 self.mesh
369 .as_ref()
370 .unwrap()
371 .cast(cx, msg)
372 .map_err(|e| e.into())
373 } else {
374 msg.cause_event()
375 }
376 }
377}
378
379#[async_trait]
380impl Handler<NextSupervisionFailure> for WrapperActor {
381 async fn handle(
382 &mut self,
383 cx: &Context<Self>,
384 msg: NextSupervisionFailure,
385 ) -> Result<(), anyhow::Error> {
386 let mesh = if let Some(mesh) = self.mesh.as_ref() {
387 mesh.deref()
388 } else {
389 msg.0.send(cx, None)?;
390 return Ok(());
391 };
392 let failure = match RealClock
393 .timeout(
394 tokio::time::Duration::from_secs(10),
395 mesh.next_supervision_event(cx),
396 )
397 .await
398 {
399 Ok(Ok(failure)) => Some(failure),
400 Ok(Err(_)) => None,
402 Err(_) => None,
404 };
405 msg.0.send(cx, failure)?;
406 Ok(())
407 }
408}
409
410#[async_trait]
411impl Handler<MeshFailure> for WrapperActor {
412 async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
413 tracing::info!("got supervision event from child: {}", msg);
416 let _ = self.supervisor.send(cx, msg.clone());
419 Ok(())
420 }
421}
422
423#[cfg(test)]
424pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
428 let instance = testing::instance();
429 assert_casting_correctness(&actor_mesh, instance).await;
431
432 let label = actor_mesh.extent().labels()[0].clone();
435 let size = actor_mesh.extent().sizes()[0] / 2;
436
437 let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
439 assert_casting_correctness(&sliced_actor_mesh, instance).await;
440}
441
442#[cfg(test)]
443pub async fn assert_casting_correctness(
445 actor_mesh: &ActorMeshRef<TestActor>,
446 instance: &impl context::Actor,
447) {
448 let (port, mut rx) = mailbox::open_port(instance);
449 actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
450
451 let mut expected_actor_ids: HashSet<_> = actor_mesh
452 .values()
453 .map(|actor_ref| actor_ref.actor_id().clone())
454 .collect();
455
456 while !expected_actor_ids.is_empty() {
457 let actor_id = rx.recv().await.unwrap();
458 assert!(
459 expected_actor_ids.remove(&actor_id),
460 "got {actor_id}, expect {expected_actor_ids:?}"
461 );
462 }
463
464 RealClock.sleep(Duration::from_secs(1)).await;
466 let result = rx.try_recv();
467 assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
468}