1#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17#[cfg(test)]
18use std::time::Duration;
19
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::Named;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32#[cfg(test)]
33use hyperactor::clock::Clock as _;
34#[cfg(test)]
35use hyperactor::clock::RealClock;
36use hyperactor::config;
37use hyperactor::config::global::Source;
38#[cfg(test)]
39use hyperactor::mailbox;
40use hyperactor::supervision::ActorSupervisionEvent;
41use ndslice::Point;
42#[cfg(test)]
43use ndslice::ViewExt as _;
44use serde::Deserialize;
45use serde::Serialize;
46
47use crate::comm::multicast::CastInfo;
48#[cfg(test)]
49use crate::v1::ActorMesh;
50#[cfg(test)]
51use crate::v1::ActorMeshRef;
52#[cfg(test)]
53use crate::v1::testing;
54
55#[derive(Actor, Default, Debug)]
57#[hyperactor::export(
58 spawn = true,
59 handlers = [
60 GetActorId { cast = true },
61 GetCastInfo { cast = true },
62 CauseSupervisionEvent { cast = true },
63 Forward,
64 GetConfigAttrs { cast = true },
65 SetConfigAttrs { cast = true },
66 ]
67)]
68pub struct TestActor;
69
70#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
72pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum SupervisionEventType {
76 Panic,
77 SigSEGV,
78 ProcessExit(i32),
79}
80
81#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
84pub struct CauseSupervisionEvent(pub SupervisionEventType);
85
86#[async_trait]
87impl Handler<GetActorId> for TestActor {
88 async fn handle(
89 &mut self,
90 cx: &Context<Self>,
91 GetActorId(reply): GetActorId,
92 ) -> Result<(), anyhow::Error> {
93 reply.send(cx, cx.self_id().clone())?;
94 Ok(())
95 }
96}
97
98#[async_trait]
99impl Handler<CauseSupervisionEvent> for TestActor {
100 async fn handle(
101 &mut self,
102 _cx: &Context<Self>,
103 msg: CauseSupervisionEvent,
104 ) -> Result<(), anyhow::Error> {
105 match msg.0 {
106 SupervisionEventType::Panic => {
107 panic!("for testing");
108 }
109 SupervisionEventType::SigSEGV => {
110 tracing::error!("exiting with SIGSEGV");
111 unsafe { std::ptr::null_mut::<i32>().write(42) };
113 }
114 SupervisionEventType::ProcessExit(code) => {
115 tracing::error!("exiting process {} with code {}", std::process::id(), code);
116 std::process::exit(code);
117 }
118 }
119 Ok(())
120 }
121}
122
123#[derive(Default, Debug)]
126#[hyperactor::export(
127 spawn = true,
128 handlers = [ActorSupervisionEvent],
129)]
130pub struct TestActorWithSupervisionHandling;
131
132#[async_trait]
133impl Actor for TestActorWithSupervisionHandling {
134 type Params = ();
135
136 async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
137 Ok(Self {})
138 }
139
140 async fn handle_supervision_event(
141 &mut self,
142 _this: &Instance<Self>,
143 event: &ActorSupervisionEvent,
144 ) -> Result<bool, anyhow::Error> {
145 tracing::error!("supervision event: {:?}", event);
146 Ok(true)
148 }
149}
150
151#[async_trait]
152impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
153 async fn handle(
154 &mut self,
155 _cx: &Context<Self>,
156 _msg: ActorSupervisionEvent,
157 ) -> Result<(), anyhow::Error> {
158 Ok(())
159 }
160}
161
162#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
166pub struct Forward {
167 pub to_visit: VecDeque<PortRef<Forward>>,
168 pub visited: Vec<PortRef<Forward>>,
169}
170
171#[async_trait]
172impl Handler<Forward> for TestActor {
173 async fn handle(
174 &mut self,
175 cx: &Context<Self>,
176 Forward {
177 mut to_visit,
178 mut visited,
179 }: Forward,
180 ) -> Result<(), anyhow::Error> {
181 let Some(this) = to_visit.pop_front() else {
182 anyhow::bail!("unexpected forward chain termination");
183 };
184 visited.push(this);
185 let next = to_visit.front().cloned();
186 anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
187 next.unwrap().send(cx, Forward { to_visit, visited })?;
188 Ok(())
189 }
190}
191
192#[derive(
194 Debug,
195 Clone,
196 Named,
197 Bind,
198 Unbind,
199 Serialize,
200 Deserialize,
201 Handler,
202 RefClient
203)]
204pub struct GetCastInfo {
205 #[reply]
207 pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
208}
209
210#[async_trait]
211impl Handler<GetCastInfo> for TestActor {
212 async fn handle(
213 &mut self,
214 cx: &Context<Self>,
215 GetCastInfo { cast_info }: GetCastInfo,
216 ) -> Result<(), anyhow::Error> {
217 cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
218 Ok(())
219 }
220}
221
222#[derive(Default, Debug)]
223#[hyperactor::export(spawn = true)]
224pub struct FailingCreateTestActor;
225
226#[async_trait]
227impl Actor for FailingCreateTestActor {
228 type Params = ();
229
230 async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
231 Err(anyhow::anyhow!("test failure"))
232 }
233}
234
235#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
236pub struct SetConfigAttrs(pub Vec<u8>);
237
238#[async_trait]
239impl Handler<SetConfigAttrs> for TestActor {
240 async fn handle(
241 &mut self,
242 _cx: &Context<Self>,
243 SetConfigAttrs(attrs): SetConfigAttrs,
244 ) -> Result<(), anyhow::Error> {
245 let attrs = bincode::deserialize(&attrs)?;
246 config::global::set(Source::Runtime, attrs);
247 Ok(())
248 }
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
252pub struct GetConfigAttrs(pub PortRef<Vec<u8>>);
253
254#[async_trait]
255impl Handler<GetConfigAttrs> for TestActor {
256 async fn handle(
257 &mut self,
258 cx: &Context<Self>,
259 GetConfigAttrs(reply): GetConfigAttrs,
260 ) -> Result<(), anyhow::Error> {
261 let attrs = bincode::serialize(&config::global::attrs())?;
262 reply.send(cx, attrs)?;
263 Ok(())
264 }
265}
266
267#[cfg(test)]
268pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
272 let instance = testing::instance().await;
273 assert_casting_correctness(&actor_mesh, instance).await;
275
276 let label = actor_mesh.extent().labels()[0].clone();
279 let size = actor_mesh.extent().sizes()[0] / 2;
280
281 let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
283 assert_casting_correctness(&sliced_actor_mesh, instance).await;
284}
285
286#[cfg(test)]
287pub async fn assert_casting_correctness(
289 actor_mesh: &ActorMeshRef<TestActor>,
290 instance: &Instance<()>,
291) {
292 let (port, mut rx) = mailbox::open_port(instance);
293 actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
294
295 let mut expected_actor_ids: HashSet<_> = actor_mesh
296 .values()
297 .map(|actor_ref| actor_ref.actor_id().clone())
298 .collect();
299
300 while !expected_actor_ids.is_empty() {
301 let actor_id = rx.recv().await.unwrap();
302 assert!(
303 expected_actor_ids.remove(&actor_id),
304 "got {actor_id}, expect {expected_actor_ids:?}"
305 );
306 }
307
308 RealClock.sleep(Duration::from_secs(1)).await;
310 let result = rx.try_recv();
311 assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
312}