1#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17#[cfg(test)]
18use std::time::Duration;
19
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::Named;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32#[cfg(test)]
33use hyperactor::clock::Clock as _;
34#[cfg(test)]
35use hyperactor::clock::RealClock;
36#[cfg(test)]
37use hyperactor::mailbox;
38use hyperactor::supervision::ActorSupervisionEvent;
39use ndslice::Point;
40#[cfg(test)]
41use ndslice::ViewExt as _;
42use serde::Deserialize;
43use serde::Serialize;
44
45use crate::comm::multicast::CastInfo;
46#[cfg(test)]
47use crate::v1::ActorMesh;
48#[cfg(test)]
49use crate::v1::ActorMeshRef;
50#[cfg(test)]
51use crate::v1::testing;
52
53#[derive(Actor, Default, Debug)]
55#[hyperactor::export(
56 spawn = true,
57 handlers = [
58 GetActorId { cast = true },
59 GetCastInfo { cast = true },
60 CauseSupervisionEvent { cast = true },
61 Forward,
62 ]
63)]
64pub struct TestActor;
65
66#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
68pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum SupervisionEventType {
72 Panic,
73 SigSEGV,
74 ProcessExit(i32),
75}
76
77#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
80pub struct CauseSupervisionEvent(pub SupervisionEventType);
81
82#[async_trait]
83impl Handler<GetActorId> for TestActor {
84 async fn handle(
85 &mut self,
86 cx: &Context<Self>,
87 GetActorId(reply): GetActorId,
88 ) -> Result<(), anyhow::Error> {
89 reply.send(cx, cx.self_id().clone())?;
90 Ok(())
91 }
92}
93
94#[async_trait]
95impl Handler<CauseSupervisionEvent> for TestActor {
96 async fn handle(
97 &mut self,
98 _cx: &Context<Self>,
99 msg: CauseSupervisionEvent,
100 ) -> Result<(), anyhow::Error> {
101 match msg.0 {
102 SupervisionEventType::Panic => {
103 panic!("for testing");
104 }
105 SupervisionEventType::SigSEGV => {
106 unsafe { std::ptr::null_mut::<i32>().write(42) };
108 }
109 SupervisionEventType::ProcessExit(code) => {
110 std::process::exit(code);
111 }
112 }
113 Ok(())
114 }
115}
116
117#[derive(Default, Debug)]
120#[hyperactor::export(
121 spawn = true,
122 handlers = [ActorSupervisionEvent],
123)]
124pub struct TestActorWithSupervisionHandling;
125
126#[async_trait]
127impl Actor for TestActorWithSupervisionHandling {
128 type Params = ();
129
130 async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
131 Ok(Self {})
132 }
133
134 async fn handle_supervision_event(
135 &mut self,
136 _this: &Instance<Self>,
137 event: &ActorSupervisionEvent,
138 ) -> Result<bool, anyhow::Error> {
139 tracing::error!("supervision event: {:?}", event);
140 Ok(true)
142 }
143}
144
145#[async_trait]
146impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
147 async fn handle(
148 &mut self,
149 _cx: &Context<Self>,
150 _msg: ActorSupervisionEvent,
151 ) -> Result<(), anyhow::Error> {
152 Ok(())
153 }
154}
155
156#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
160pub struct Forward {
161 pub to_visit: VecDeque<PortRef<Forward>>,
162 pub visited: Vec<PortRef<Forward>>,
163}
164
165#[async_trait]
166impl Handler<Forward> for TestActor {
167 async fn handle(
168 &mut self,
169 cx: &Context<Self>,
170 Forward {
171 mut to_visit,
172 mut visited,
173 }: Forward,
174 ) -> Result<(), anyhow::Error> {
175 let Some(this) = to_visit.pop_front() else {
176 anyhow::bail!("unexpected forward chain termination");
177 };
178 visited.push(this);
179 let next = to_visit.front().cloned();
180 anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
181 next.unwrap().send(cx, Forward { to_visit, visited })?;
182 Ok(())
183 }
184}
185
186#[derive(
188 Debug,
189 Clone,
190 Named,
191 Bind,
192 Unbind,
193 Serialize,
194 Deserialize,
195 Handler,
196 RefClient
197)]
198pub struct GetCastInfo {
199 #[reply]
201 pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
202}
203
204#[async_trait]
205impl Handler<GetCastInfo> for TestActor {
206 async fn handle(
207 &mut self,
208 cx: &Context<Self>,
209 GetCastInfo { cast_info }: GetCastInfo,
210 ) -> Result<(), anyhow::Error> {
211 cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
212 Ok(())
213 }
214}
215
216#[derive(Default, Debug)]
217#[hyperactor::export(spawn = true)]
218pub struct FailingCreateTestActor;
219
220#[async_trait]
221impl Actor for FailingCreateTestActor {
222 type Params = ();
223
224 async fn new(params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
225 Err(anyhow::anyhow!("test failure"))
226 }
227}
228
229#[cfg(test)]
230pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
234 let instance = testing::instance().await;
235 assert_casting_correctness(&actor_mesh, instance).await;
237
238 let label = actor_mesh.extent().labels()[0].clone();
241 let size = actor_mesh.extent().sizes()[0] / 2;
242
243 let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
245 assert_casting_correctness(&sliced_actor_mesh, instance).await;
246}
247
248#[cfg(test)]
249pub async fn assert_casting_correctness(
251 actor_mesh: &ActorMeshRef<TestActor>,
252 instance: &Instance<()>,
253) {
254 let (port, mut rx) = mailbox::open_port(instance);
255 actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
256
257 let mut expected_actor_ids: HashSet<_> = actor_mesh
258 .values()
259 .map(|actor_ref| actor_ref.actor_id().clone())
260 .collect();
261
262 while !expected_actor_ids.is_empty() {
263 let actor_id = rx.recv().await.unwrap();
264 assert!(
265 expected_actor_ids.remove(&actor_id),
266 "got {actor_id}, expect {expected_actor_ids:?}"
267 );
268 }
269
270 RealClock.sleep(Duration::from_secs(1)).await;
272 let result = rx.try_recv();
273 assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
274}