hyperactor_mesh/v1/
testactor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a test actor. It is defined in a separate module
10//! (outside of [`crate::v1::testing`]) to ensure that it is compiled into
11//! the bootstrap binary, which is not built in test mode (and anyway, test mode
12//! does not work across crate boundaries)
13
14#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17#[cfg(test)]
18use std::time::Duration;
19
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::Named;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32#[cfg(test)]
33use hyperactor::clock::Clock as _;
34#[cfg(test)]
35use hyperactor::clock::RealClock;
36#[cfg(test)]
37use hyperactor::mailbox;
38use hyperactor::supervision::ActorSupervisionEvent;
39use ndslice::Point;
40#[cfg(test)]
41use ndslice::ViewExt as _;
42use serde::Deserialize;
43use serde::Serialize;
44
45use crate::comm::multicast::CastInfo;
46#[cfg(test)]
47use crate::v1::ActorMesh;
48#[cfg(test)]
49use crate::v1::ActorMeshRef;
50#[cfg(test)]
51use crate::v1::testing;
52
53/// A simple test actor used by various unit tests.
54#[derive(Actor, Default, Debug)]
55#[hyperactor::export(
56    spawn = true,
57    handlers = [
58        GetActorId { cast = true },
59        GetCastInfo { cast = true },
60        CauseSupervisionEvent { cast = true },
61        Forward,
62    ]
63)]
64pub struct TestActor;
65
66/// A message that returns the recipient actor's id.
67#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
68pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum SupervisionEventType {
72    Panic,
73    SigSEGV,
74    ProcessExit(i32),
75}
76
77/// A message that causes a supervision event. The one argument determines what
78/// kind of supervision event it'll be.
79#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
80pub struct CauseSupervisionEvent(pub SupervisionEventType);
81
82#[async_trait]
83impl Handler<GetActorId> for TestActor {
84    async fn handle(
85        &mut self,
86        cx: &Context<Self>,
87        GetActorId(reply): GetActorId,
88    ) -> Result<(), anyhow::Error> {
89        reply.send(cx, cx.self_id().clone())?;
90        Ok(())
91    }
92}
93
94#[async_trait]
95impl Handler<CauseSupervisionEvent> for TestActor {
96    async fn handle(
97        &mut self,
98        _cx: &Context<Self>,
99        msg: CauseSupervisionEvent,
100    ) -> Result<(), anyhow::Error> {
101        match msg.0 {
102            SupervisionEventType::Panic => {
103                panic!("for testing");
104            }
105            SupervisionEventType::SigSEGV => {
106                // SAFETY: This is for testing code that explicitly causes a SIGSEGV.
107                unsafe { std::ptr::null_mut::<i32>().write(42) };
108            }
109            SupervisionEventType::ProcessExit(code) => {
110                std::process::exit(code);
111            }
112        }
113        Ok(())
114    }
115}
116
117/// A test actor that handles supervision events.
118/// It should be the parent of TestActor who can panic or cause a SIGSEGV.
119#[derive(Default, Debug)]
120#[hyperactor::export(
121    spawn = true,
122    handlers = [ActorSupervisionEvent],
123)]
124pub struct TestActorWithSupervisionHandling;
125
126#[async_trait]
127impl Actor for TestActorWithSupervisionHandling {
128    type Params = ();
129
130    async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
131        Ok(Self {})
132    }
133
134    async fn handle_supervision_event(
135        &mut self,
136        _this: &Instance<Self>,
137        event: &ActorSupervisionEvent,
138    ) -> Result<bool, anyhow::Error> {
139        tracing::error!("supervision event: {:?}", event);
140        // Swallow the supervision error to avoid crashing the process.
141        Ok(true)
142    }
143}
144
145#[async_trait]
146impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
147    async fn handle(
148        &mut self,
149        _cx: &Context<Self>,
150        _msg: ActorSupervisionEvent,
151    ) -> Result<(), anyhow::Error> {
152        Ok(())
153    }
154}
155
156/// A message to forward to a visit list of ports.
157/// Each port removes the next entry, and adds it to the
158/// 'visited' list.
159#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
160pub struct Forward {
161    pub to_visit: VecDeque<PortRef<Forward>>,
162    pub visited: Vec<PortRef<Forward>>,
163}
164
165#[async_trait]
166impl Handler<Forward> for TestActor {
167    async fn handle(
168        &mut self,
169        cx: &Context<Self>,
170        Forward {
171            mut to_visit,
172            mut visited,
173        }: Forward,
174    ) -> Result<(), anyhow::Error> {
175        let Some(this) = to_visit.pop_front() else {
176            anyhow::bail!("unexpected forward chain termination");
177        };
178        visited.push(this);
179        let next = to_visit.front().cloned();
180        anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
181        next.unwrap().send(cx, Forward { to_visit, visited })?;
182        Ok(())
183    }
184}
185
186/// Just return the cast info of the sender.
187#[derive(
188    Debug,
189    Clone,
190    Named,
191    Bind,
192    Unbind,
193    Serialize,
194    Deserialize,
195    Handler,
196    RefClient
197)]
198pub struct GetCastInfo {
199    /// Originating actor, point, sender.
200    #[reply]
201    pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
202}
203
204#[async_trait]
205impl Handler<GetCastInfo> for TestActor {
206    async fn handle(
207        &mut self,
208        cx: &Context<Self>,
209        GetCastInfo { cast_info }: GetCastInfo,
210    ) -> Result<(), anyhow::Error> {
211        cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
212        Ok(())
213    }
214}
215
216#[derive(Default, Debug)]
217#[hyperactor::export(spawn = true)]
218pub struct FailingCreateTestActor;
219
220#[async_trait]
221impl Actor for FailingCreateTestActor {
222    type Params = ();
223
224    async fn new(params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
225        Err(anyhow::anyhow!("test failure"))
226    }
227}
228
229#[cfg(test)]
230/// Asserts that the provided actor mesh has the expected shape,
231/// and all actors are assigned the correct ranks. We also test
232/// slicing the mesh.
233pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
234    let instance = testing::instance().await;
235    // Verify casting to the root actor mesh
236    assert_casting_correctness(&actor_mesh, instance).await;
237
238    // Just pick the first dimension. Slice half of it off.
239    // actor_mesh.extent().
240    let label = actor_mesh.extent().labels()[0].clone();
241    let size = actor_mesh.extent().sizes()[0] / 2;
242
243    // Verify casting to the sliced actor mesh
244    let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
245    assert_casting_correctness(&sliced_actor_mesh, instance).await;
246}
247
248#[cfg(test)]
249/// Cast to the actor mesh, and verify that all actors are reached.
250pub async fn assert_casting_correctness(
251    actor_mesh: &ActorMeshRef<TestActor>,
252    instance: &Instance<()>,
253) {
254    let (port, mut rx) = mailbox::open_port(instance);
255    actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
256
257    let mut expected_actor_ids: HashSet<_> = actor_mesh
258        .values()
259        .map(|actor_ref| actor_ref.actor_id().clone())
260        .collect();
261
262    while !expected_actor_ids.is_empty() {
263        let actor_id = rx.recv().await.unwrap();
264        assert!(
265            expected_actor_ids.remove(&actor_id),
266            "got {actor_id}, expect {expected_actor_ids:?}"
267        );
268    }
269
270    // No more messages
271    RealClock.sleep(Duration::from_secs(1)).await;
272    let result = rx.try_recv();
273    assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
274}