hyperactor_mesh/v1/
testactor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a test actor. It is defined in a separate module
10//! (outside of [`crate::v1::testing`]) to ensure that it is compiled into
11//! the bootstrap binary, which is not built in test mode (and anyway, test mode
12//! does not work across crate boundaries)
13
14#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Bind;
26use hyperactor::Context;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32use hyperactor::clock::Clock as _;
33use hyperactor::clock::RealClock;
34#[cfg(test)]
35use hyperactor::context;
36#[cfg(test)]
37use hyperactor::mailbox;
38use hyperactor::supervision::ActorSupervisionEvent;
39use hyperactor_config::global::Source;
40use ndslice::Point;
41#[cfg(test)]
42use ndslice::ViewExt as _;
43use serde::Deserialize;
44use serde::Serialize;
45use typeuri::Named;
46
47use crate::comm::multicast::CastInfo;
48use crate::supervision::MeshFailure;
49use crate::v1::ActorMesh;
50#[cfg(test)]
51use crate::v1::ActorMeshRef;
52use crate::v1::Name;
53use crate::v1::ProcMeshRef;
54#[cfg(test)]
55use crate::v1::testing;
56
57/// A simple test actor used by various unit tests.
58#[derive(Default, Debug)]
59#[hyperactor::export(
60    spawn = true,
61    handlers = [
62        GetActorId { cast = true },
63        GetCastInfo { cast = true },
64        CauseSupervisionEvent { cast = true },
65        Forward,
66        GetConfigAttrs { cast = true },
67        SetConfigAttrs { cast = true },
68    ]
69)]
70pub struct TestActor;
71
72impl Actor for TestActor {}
73
74/// A message that returns the recipient actor's id.
75#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
76pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum SupervisionEventType {
80    Panic,
81    SigSEGV,
82    ProcessExit(i32),
83}
84
85/// A message that causes a supervision event. The one argument determines what
86/// kind of supervision event it'll be.
87#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
88pub struct CauseSupervisionEvent {
89    pub kind: SupervisionEventType,
90    pub send_to_children: bool,
91}
92
93impl CauseSupervisionEvent {
94    fn cause_event(&self) -> ! {
95        match self.kind {
96            SupervisionEventType::Panic => {
97                panic!("for testing");
98            }
99            SupervisionEventType::SigSEGV => {
100                tracing::error!("exiting with SIGSEGV");
101                // SAFETY: This is for testing code that explicitly causes a SIGSEGV.
102                unsafe { std::ptr::null_mut::<i32>().write(42) };
103                // While the above should always segfault, we need a hard exit
104                // for the compiler's sake.
105                panic!("should have segfaulted");
106            }
107            SupervisionEventType::ProcessExit(code) => {
108                tracing::error!("exiting process {} with code {}", std::process::id(), code);
109                std::process::exit(code);
110            }
111        }
112    }
113}
114
115#[async_trait]
116impl Handler<GetActorId> for TestActor {
117    async fn handle(
118        &mut self,
119        cx: &Context<Self>,
120        GetActorId(reply): GetActorId,
121    ) -> Result<(), anyhow::Error> {
122        reply.send(cx, cx.self_id().clone())?;
123        Ok(())
124    }
125}
126
127#[async_trait]
128impl Handler<CauseSupervisionEvent> for TestActor {
129    async fn handle(
130        &mut self,
131        _cx: &Context<Self>,
132        msg: CauseSupervisionEvent,
133    ) -> Result<(), anyhow::Error> {
134        msg.cause_event();
135    }
136}
137
138/// A test actor that handles supervision events.
139/// It should be the parent of TestActor who can panic or cause a SIGSEGV.
140#[derive(Default, Debug)]
141#[hyperactor::export(
142    spawn = true,
143    handlers = [ActorSupervisionEvent],
144)]
145pub struct TestActorWithSupervisionHandling;
146
147#[async_trait]
148impl Actor for TestActorWithSupervisionHandling {
149    async fn handle_supervision_event(
150        &mut self,
151        _this: &Instance<Self>,
152        event: &ActorSupervisionEvent,
153    ) -> Result<bool, anyhow::Error> {
154        tracing::error!("supervision event: {:?}", event);
155        // Swallow the supervision error to avoid crashing the process.
156        Ok(true)
157    }
158}
159
160#[async_trait]
161impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
162    async fn handle(
163        &mut self,
164        _cx: &Context<Self>,
165        _msg: ActorSupervisionEvent,
166    ) -> Result<(), anyhow::Error> {
167        Ok(())
168    }
169}
170
171/// A test actor that sleeps when it receives a Duration message.
172/// Used for testing timeout and abort behavior.
173#[derive(Default, Debug)]
174#[hyperactor::export(
175    spawn = true,
176    handlers = [std::time::Duration],
177)]
178pub struct SleepActor;
179
180impl Actor for SleepActor {}
181
182#[async_trait]
183impl Handler<std::time::Duration> for SleepActor {
184    async fn handle(
185        &mut self,
186        _cx: &Context<Self>,
187        duration: std::time::Duration,
188    ) -> Result<(), anyhow::Error> {
189        RealClock.sleep(duration).await;
190        Ok(())
191    }
192}
193
194/// A message to forward to a visit list of ports.
195/// Each port removes the next entry, and adds it to the
196/// 'visited' list.
197#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
198pub struct Forward {
199    pub to_visit: VecDeque<PortRef<Forward>>,
200    pub visited: Vec<PortRef<Forward>>,
201}
202
203#[async_trait]
204impl Handler<Forward> for TestActor {
205    async fn handle(
206        &mut self,
207        cx: &Context<Self>,
208        Forward {
209            mut to_visit,
210            mut visited,
211        }: Forward,
212    ) -> Result<(), anyhow::Error> {
213        let Some(this) = to_visit.pop_front() else {
214            anyhow::bail!("unexpected forward chain termination");
215        };
216        visited.push(this);
217        let next = to_visit.front().cloned();
218        anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
219        next.unwrap().send(cx, Forward { to_visit, visited })?;
220        Ok(())
221    }
222}
223
224/// Just return the cast info of the sender.
225#[derive(
226    Debug,
227    Clone,
228    Named,
229    Bind,
230    Unbind,
231    Serialize,
232    Deserialize,
233    Handler,
234    RefClient
235)]
236pub struct GetCastInfo {
237    /// Originating actor, point, sender.
238    #[reply]
239    pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
240}
241
242#[async_trait]
243impl Handler<GetCastInfo> for TestActor {
244    async fn handle(
245        &mut self,
246        cx: &Context<Self>,
247        GetCastInfo { cast_info }: GetCastInfo,
248    ) -> Result<(), anyhow::Error> {
249        cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
250        Ok(())
251    }
252}
253
254#[derive(Debug)]
255#[hyperactor::export(spawn = true)]
256pub struct FailingCreateTestActor;
257
258#[async_trait]
259impl Actor for FailingCreateTestActor {}
260
261#[async_trait]
262impl hyperactor::RemoteSpawn for FailingCreateTestActor {
263    type Params = ();
264
265    async fn new(
266        _params: Self::Params,
267    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
268        Err(anyhow::anyhow!("test failure"))
269    }
270}
271
272#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
273pub struct SetConfigAttrs(pub Vec<u8>);
274
275#[async_trait]
276impl Handler<SetConfigAttrs> for TestActor {
277    async fn handle(
278        &mut self,
279        _cx: &Context<Self>,
280        SetConfigAttrs(attrs): SetConfigAttrs,
281    ) -> Result<(), anyhow::Error> {
282        let attrs = bincode::deserialize(&attrs)?;
283        hyperactor_config::global::set(Source::Runtime, attrs);
284        Ok(())
285    }
286}
287
288#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
289pub struct GetConfigAttrs(pub PortRef<Vec<u8>>);
290
291#[async_trait]
292impl Handler<GetConfigAttrs> for TestActor {
293    async fn handle(
294        &mut self,
295        cx: &Context<Self>,
296        GetConfigAttrs(reply): GetConfigAttrs,
297    ) -> Result<(), anyhow::Error> {
298        let attrs = bincode::serialize(&hyperactor_config::global::attrs())?;
299        reply.send(cx, attrs)?;
300        Ok(())
301    }
302}
303
304/// A message to request the next supervision event delivered to WrapperActor.
305/// Replies with None if no supervision event is encountered within a timeout
306/// (10 seconds).
307#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
308pub struct NextSupervisionFailure(pub PortRef<Option<MeshFailure>>);
309
310/// A small wrapper to handle supervision messages so they don't
311/// need to reach the client. This just wraps and forwards all messages to TestActor.
312/// The supervision events are sent back to "supervisor".
313#[derive(Debug)]
314#[hyperactor::export(
315    spawn = true,
316    handlers = [
317        CauseSupervisionEvent { cast = true },
318        MeshFailure { cast = true },
319        NextSupervisionFailure { cast = true },
320    ]
321)]
322pub struct WrapperActor {
323    proc_mesh: ProcMeshRef,
324    // Needs to be a mesh so we own this actor and have a controller for it.
325    mesh: Option<ActorMesh<TestActor>>,
326    supervisor: PortRef<MeshFailure>,
327    test_name: Name,
328}
329
330#[async_trait]
331impl hyperactor::RemoteSpawn for WrapperActor {
332    type Params = (ProcMeshRef, PortRef<MeshFailure>, Name);
333
334    async fn new(
335        (proc_mesh, supervisor, test_name): Self::Params,
336    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
337        Ok(Self {
338            proc_mesh,
339            mesh: None,
340            supervisor,
341            test_name,
342        })
343    }
344}
345
346#[async_trait]
347impl Actor for WrapperActor {
348    async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
349        self.mesh = Some(
350            self.proc_mesh
351                .spawn_with_name(this, self.test_name.clone(), &(), None, false)
352                .await?,
353        );
354        Ok(())
355    }
356}
357
358#[async_trait]
359impl Handler<CauseSupervisionEvent> for WrapperActor {
360    async fn handle(
361        &mut self,
362        cx: &Context<Self>,
363        msg: CauseSupervisionEvent,
364    ) -> Result<(), anyhow::Error> {
365        // No reply to wait for.
366        if msg.send_to_children {
367            // Send only to children, don't cause the event itself.
368            self.mesh
369                .as_ref()
370                .unwrap()
371                .cast(cx, msg)
372                .map_err(|e| e.into())
373        } else {
374            msg.cause_event()
375        }
376    }
377}
378
379#[async_trait]
380impl Handler<NextSupervisionFailure> for WrapperActor {
381    async fn handle(
382        &mut self,
383        cx: &Context<Self>,
384        msg: NextSupervisionFailure,
385    ) -> Result<(), anyhow::Error> {
386        let mesh = if let Some(mesh) = self.mesh.as_ref() {
387            mesh.deref()
388        } else {
389            msg.0.send(cx, None)?;
390            return Ok(());
391        };
392        let failure = match RealClock
393            .timeout(
394                tokio::time::Duration::from_secs(10),
395                mesh.next_supervision_event(cx),
396            )
397            .await
398        {
399            Ok(Ok(failure)) => Some(failure),
400            // Any error in next_supervision_event is treated the same.
401            Ok(Err(_)) => None,
402            // If we timeout, send back None.
403            Err(_) => None,
404        };
405        msg.0.send(cx, failure)?;
406        Ok(())
407    }
408}
409
410#[async_trait]
411impl Handler<MeshFailure> for WrapperActor {
412    async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
413        // All supervision events are considered handled so they don't bubble up
414        // to the client (who isn't listening for MeshFailure).
415        tracing::info!("got supervision event from child: {}", msg);
416        // Send to a port so the client can view the messages.
417        // Ignore the error if there is one.
418        let _ = self.supervisor.send(cx, msg.clone());
419        Ok(())
420    }
421}
422
423#[cfg(test)]
424/// Asserts that the provided actor mesh has the expected shape,
425/// and all actors are assigned the correct ranks. We also test
426/// slicing the mesh.
427pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
428    let instance = testing::instance();
429    // Verify casting to the root actor mesh
430    assert_casting_correctness(&actor_mesh, instance).await;
431
432    // Just pick the first dimension. Slice half of it off.
433    // actor_mesh.extent().
434    let label = actor_mesh.extent().labels()[0].clone();
435    let size = actor_mesh.extent().sizes()[0] / 2;
436
437    // Verify casting to the sliced actor mesh
438    let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
439    assert_casting_correctness(&sliced_actor_mesh, instance).await;
440}
441
442#[cfg(test)]
443/// Cast to the actor mesh, and verify that all actors are reached.
444pub async fn assert_casting_correctness(
445    actor_mesh: &ActorMeshRef<TestActor>,
446    instance: &impl context::Actor,
447) {
448    let (port, mut rx) = mailbox::open_port(instance);
449    actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
450
451    let mut expected_actor_ids: HashSet<_> = actor_mesh
452        .values()
453        .map(|actor_ref| actor_ref.actor_id().clone())
454        .collect();
455
456    while !expected_actor_ids.is_empty() {
457        let actor_id = rx.recv().await.unwrap();
458        assert!(
459            expected_actor_ids.remove(&actor_id),
460            "got {actor_id}, expect {expected_actor_ids:?}"
461        );
462    }
463
464    // No more messages
465    RealClock.sleep(Duration::from_secs(1)).await;
466    let result = rx.try_recv();
467    assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
468}