hyperactor_mesh/v1/
testactor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a test actor. It is defined in a separate module
10//! (outside of [`crate::v1::testing`]) to ensure that it is compiled into
11//! the bootstrap binary, which is not built in test mode (and anyway, test mode
12//! does not work across crate boundaries)
13
14#[cfg(test)]
15use std::collections::HashSet;
16use std::collections::VecDeque;
17#[cfg(test)]
18use std::time::Duration;
19
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::Named;
29use hyperactor::PortRef;
30use hyperactor::RefClient;
31use hyperactor::Unbind;
32#[cfg(test)]
33use hyperactor::clock::Clock as _;
34#[cfg(test)]
35use hyperactor::clock::RealClock;
36use hyperactor::config;
37use hyperactor::config::global::Source;
38#[cfg(test)]
39use hyperactor::mailbox;
40use hyperactor::supervision::ActorSupervisionEvent;
41use ndslice::Point;
42#[cfg(test)]
43use ndslice::ViewExt as _;
44use serde::Deserialize;
45use serde::Serialize;
46
47use crate::comm::multicast::CastInfo;
48#[cfg(test)]
49use crate::v1::ActorMesh;
50#[cfg(test)]
51use crate::v1::ActorMeshRef;
52#[cfg(test)]
53use crate::v1::testing;
54
55/// A simple test actor used by various unit tests.
56#[derive(Actor, Default, Debug)]
57#[hyperactor::export(
58    spawn = true,
59    handlers = [
60        GetActorId { cast = true },
61        GetCastInfo { cast = true },
62        CauseSupervisionEvent { cast = true },
63        Forward,
64        GetConfigAttrs { cast = true },
65        SetConfigAttrs { cast = true },
66    ]
67)]
68pub struct TestActor;
69
70/// A message that returns the recipient actor's id.
71#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
72pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum SupervisionEventType {
76    Panic,
77    SigSEGV,
78    ProcessExit(i32),
79}
80
81/// A message that causes a supervision event. The one argument determines what
82/// kind of supervision event it'll be.
83#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
84pub struct CauseSupervisionEvent(pub SupervisionEventType);
85
86#[async_trait]
87impl Handler<GetActorId> for TestActor {
88    async fn handle(
89        &mut self,
90        cx: &Context<Self>,
91        GetActorId(reply): GetActorId,
92    ) -> Result<(), anyhow::Error> {
93        reply.send(cx, cx.self_id().clone())?;
94        Ok(())
95    }
96}
97
98#[async_trait]
99impl Handler<CauseSupervisionEvent> for TestActor {
100    async fn handle(
101        &mut self,
102        _cx: &Context<Self>,
103        msg: CauseSupervisionEvent,
104    ) -> Result<(), anyhow::Error> {
105        match msg.0 {
106            SupervisionEventType::Panic => {
107                panic!("for testing");
108            }
109            SupervisionEventType::SigSEGV => {
110                tracing::error!("exiting with SIGSEGV");
111                // SAFETY: This is for testing code that explicitly causes a SIGSEGV.
112                unsafe { std::ptr::null_mut::<i32>().write(42) };
113            }
114            SupervisionEventType::ProcessExit(code) => {
115                tracing::error!("exiting process {} with code {}", std::process::id(), code);
116                std::process::exit(code);
117            }
118        }
119        Ok(())
120    }
121}
122
123/// A test actor that handles supervision events.
124/// It should be the parent of TestActor who can panic or cause a SIGSEGV.
125#[derive(Default, Debug)]
126#[hyperactor::export(
127    spawn = true,
128    handlers = [ActorSupervisionEvent],
129)]
130pub struct TestActorWithSupervisionHandling;
131
132#[async_trait]
133impl Actor for TestActorWithSupervisionHandling {
134    type Params = ();
135
136    async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
137        Ok(Self {})
138    }
139
140    async fn handle_supervision_event(
141        &mut self,
142        _this: &Instance<Self>,
143        event: &ActorSupervisionEvent,
144    ) -> Result<bool, anyhow::Error> {
145        tracing::error!("supervision event: {:?}", event);
146        // Swallow the supervision error to avoid crashing the process.
147        Ok(true)
148    }
149}
150
151#[async_trait]
152impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
153    async fn handle(
154        &mut self,
155        _cx: &Context<Self>,
156        _msg: ActorSupervisionEvent,
157    ) -> Result<(), anyhow::Error> {
158        Ok(())
159    }
160}
161
162/// A message to forward to a visit list of ports.
163/// Each port removes the next entry, and adds it to the
164/// 'visited' list.
165#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
166pub struct Forward {
167    pub to_visit: VecDeque<PortRef<Forward>>,
168    pub visited: Vec<PortRef<Forward>>,
169}
170
171#[async_trait]
172impl Handler<Forward> for TestActor {
173    async fn handle(
174        &mut self,
175        cx: &Context<Self>,
176        Forward {
177            mut to_visit,
178            mut visited,
179        }: Forward,
180    ) -> Result<(), anyhow::Error> {
181        let Some(this) = to_visit.pop_front() else {
182            anyhow::bail!("unexpected forward chain termination");
183        };
184        visited.push(this);
185        let next = to_visit.front().cloned();
186        anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
187        next.unwrap().send(cx, Forward { to_visit, visited })?;
188        Ok(())
189    }
190}
191
192/// Just return the cast info of the sender.
193#[derive(
194    Debug,
195    Clone,
196    Named,
197    Bind,
198    Unbind,
199    Serialize,
200    Deserialize,
201    Handler,
202    RefClient
203)]
204pub struct GetCastInfo {
205    /// Originating actor, point, sender.
206    #[reply]
207    pub cast_info: PortRef<(Point, ActorRef<TestActor>, ActorId)>,
208}
209
210#[async_trait]
211impl Handler<GetCastInfo> for TestActor {
212    async fn handle(
213        &mut self,
214        cx: &Context<Self>,
215        GetCastInfo { cast_info }: GetCastInfo,
216    ) -> Result<(), anyhow::Error> {
217        cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
218        Ok(())
219    }
220}
221
222#[derive(Default, Debug)]
223#[hyperactor::export(spawn = true)]
224pub struct FailingCreateTestActor;
225
226#[async_trait]
227impl Actor for FailingCreateTestActor {
228    type Params = ();
229
230    async fn new(_params: Self::Params) -> Result<Self, hyperactor::anyhow::Error> {
231        Err(anyhow::anyhow!("test failure"))
232    }
233}
234
235#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
236pub struct SetConfigAttrs(pub Vec<u8>);
237
238#[async_trait]
239impl Handler<SetConfigAttrs> for TestActor {
240    async fn handle(
241        &mut self,
242        _cx: &Context<Self>,
243        SetConfigAttrs(attrs): SetConfigAttrs,
244    ) -> Result<(), anyhow::Error> {
245        let attrs = bincode::deserialize(&attrs)?;
246        config::global::set(Source::Runtime, attrs);
247        Ok(())
248    }
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
252pub struct GetConfigAttrs(pub PortRef<Vec<u8>>);
253
254#[async_trait]
255impl Handler<GetConfigAttrs> for TestActor {
256    async fn handle(
257        &mut self,
258        cx: &Context<Self>,
259        GetConfigAttrs(reply): GetConfigAttrs,
260    ) -> Result<(), anyhow::Error> {
261        let attrs = bincode::serialize(&config::global::attrs())?;
262        reply.send(cx, attrs)?;
263        Ok(())
264    }
265}
266
267#[cfg(test)]
268/// Asserts that the provided actor mesh has the expected shape,
269/// and all actors are assigned the correct ranks. We also test
270/// slicing the mesh.
271pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
272    let instance = testing::instance().await;
273    // Verify casting to the root actor mesh
274    assert_casting_correctness(&actor_mesh, instance).await;
275
276    // Just pick the first dimension. Slice half of it off.
277    // actor_mesh.extent().
278    let label = actor_mesh.extent().labels()[0].clone();
279    let size = actor_mesh.extent().sizes()[0] / 2;
280
281    // Verify casting to the sliced actor mesh
282    let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
283    assert_casting_correctness(&sliced_actor_mesh, instance).await;
284}
285
286#[cfg(test)]
287/// Cast to the actor mesh, and verify that all actors are reached.
288pub async fn assert_casting_correctness(
289    actor_mesh: &ActorMeshRef<TestActor>,
290    instance: &Instance<()>,
291) {
292    let (port, mut rx) = mailbox::open_port(instance);
293    actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
294
295    let mut expected_actor_ids: HashSet<_> = actor_mesh
296        .values()
297        .map(|actor_ref| actor_ref.actor_id().clone())
298        .collect();
299
300    while !expected_actor_ids.is_empty() {
301        let actor_id = rx.recv().await.unwrap();
302        assert!(
303            expected_actor_ids.remove(&actor_id),
304            "got {actor_id}, expect {expected_actor_ids:?}"
305        );
306    }
307
308    // No more messages
309    RealClock.sleep(Duration::from_secs(1)).await;
310    let result = rx.try_recv();
311    assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
312}