Skip to main content

hyperactor_mesh/
testactor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a test actor. It is defined in a separate module
10//! (outside of [`crate::testing`]) to ensure that it is compiled into
11//! the bootstrap binary, which is not built in test mode (and anyway, test mode
12//! does not work across crate boundaries)
13
14#[cfg(test)]
15use std::collections::HashMap;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorRef;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Endpoint as _;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::RefClient;
30use hyperactor::Unbind;
31#[cfg(test)]
32use hyperactor::context;
33use hyperactor::ordering::SEQ_INFO;
34use hyperactor::ordering::SeqInfo;
35use hyperactor::supervision::ActorSupervisionEvent;
36use hyperactor_config::Flattrs;
37use hyperactor_config::global::Source;
38use ndslice::Point;
39#[cfg(test)]
40use ndslice::ViewExt as _;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44#[cfg(test)]
45use uuid::Uuid;
46
47use crate::ActorMesh;
48#[cfg(test)]
49use crate::ActorMeshRef;
50use crate::ProcMeshRef;
51use crate::comm::multicast::CastInfo;
52use crate::mesh_id::ActorMeshId;
53use crate::supervision::MeshFailure;
54#[cfg(test)]
55use crate::testing;
56
57/// A simple test actor used by various unit tests.
58#[derive(Default, Debug)]
59#[hyperactor::export(
60    () { cast = true },
61    GetActorId { cast = true },
62    GetCastInfo { cast = true },
63    CauseSupervisionEvent { cast = true },
64    Forward,
65    GetConfigAttrs { cast = true },
66    SetConfigAttrs { cast = true },
67)]
68#[hyperactor::spawnable]
69pub struct TestActor;
70
71impl Actor for TestActor {}
72
73/// A message that returns the recipient actor's id and cast message's seq info.
74#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
75pub struct GetActorId(
76    #[binding(include)] pub hyperactor::PortRef<(hyperactor::ActorAddr, Option<SeqInfo>)>,
77);
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum SupervisionEventType {
81    Panic,
82    SigSEGV,
83    ProcessExit(i32),
84}
85
86/// A message that causes a supervision event. The one argument determines what
87/// kind of supervision event it'll be.
88#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
89pub struct CauseSupervisionEvent {
90    pub kind: SupervisionEventType,
91    pub send_to_children: bool,
92}
93
94impl CauseSupervisionEvent {
95    fn cause_event(&self) -> ! {
96        match self.kind {
97            SupervisionEventType::Panic => {
98                panic!("for testing");
99            }
100            SupervisionEventType::SigSEGV => {
101                tracing::error!("exiting with SIGSEGV");
102                // SAFETY: This is for testing code that explicitly causes a SIGSEGV.
103                unsafe { std::ptr::null_mut::<i32>().write(42) };
104                // While the above should always segfault, we need a hard exit
105                // for the compiler's sake.
106                panic!("should have segfaulted");
107            }
108            SupervisionEventType::ProcessExit(code) => {
109                tracing::error!("exiting process {} with code {}", std::process::id(), code);
110                std::process::exit(code);
111            }
112        }
113    }
114}
115
116#[async_trait]
117impl Handler<()> for TestActor {
118    async fn handle(&mut self, _cx: &Context<Self>, _: ()) -> Result<(), anyhow::Error> {
119        Ok(())
120    }
121}
122
123#[async_trait]
124impl Handler<GetActorId> for TestActor {
125    async fn handle(
126        &mut self,
127        cx: &Context<Self>,
128        GetActorId(reply): GetActorId,
129    ) -> Result<(), anyhow::Error> {
130        let seq_info = cx.headers().get(SEQ_INFO);
131        reply.post(cx, (cx.self_addr().clone(), seq_info));
132        Ok(())
133    }
134}
135
136#[async_trait]
137impl Handler<CauseSupervisionEvent> for TestActor {
138    async fn handle(
139        &mut self,
140        _cx: &Context<Self>,
141        msg: CauseSupervisionEvent,
142    ) -> Result<(), anyhow::Error> {
143        msg.cause_event();
144    }
145}
146
147/// A test actor that handles supervision events.
148/// It should be the parent of TestActor who can panic or cause a SIGSEGV.
149#[derive(Default, Debug)]
150#[hyperactor::export(ActorSupervisionEvent)]
151#[hyperactor::spawnable]
152pub struct TestActorWithSupervisionHandling;
153
154#[async_trait]
155impl Actor for TestActorWithSupervisionHandling {
156    async fn handle_supervision_event(
157        &mut self,
158        _this: &Instance<Self>,
159        event: &ActorSupervisionEvent,
160    ) -> Result<bool, anyhow::Error> {
161        tracing::error!("supervision event: {:?}", event);
162        // Swallow the supervision error to avoid crashing the process.
163        Ok(true)
164    }
165}
166
167#[async_trait]
168impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
169    async fn handle(
170        &mut self,
171        _cx: &Context<Self>,
172        _msg: ActorSupervisionEvent,
173    ) -> Result<(), anyhow::Error> {
174        Ok(())
175    }
176}
177
178/// A test actor that sleeps when it receives a Duration message.
179/// Used for testing timeout and abort behavior.
180#[derive(Default, Debug)]
181#[hyperactor::export(std::time::Duration)]
182#[hyperactor::spawnable]
183pub struct SleepActor;
184
185impl Actor for SleepActor {}
186
187#[async_trait]
188impl Handler<std::time::Duration> for SleepActor {
189    async fn handle(
190        &mut self,
191        _cx: &Context<Self>,
192        duration: std::time::Duration,
193    ) -> Result<(), anyhow::Error> {
194        tokio::time::sleep(duration).await;
195        Ok(())
196    }
197}
198
199/// A message to forward to a visit list of ports.
200/// Each port removes the next entry, and adds it to the
201/// 'visited' list.
202#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
203pub struct Forward {
204    pub to_visit: VecDeque<hyperactor::PortRef<Forward>>,
205    pub visited: Vec<hyperactor::PortRef<Forward>>,
206}
207
208#[async_trait]
209impl Handler<Forward> for TestActor {
210    async fn handle(
211        &mut self,
212        cx: &Context<Self>,
213        Forward {
214            mut to_visit,
215            mut visited,
216        }: Forward,
217    ) -> Result<(), anyhow::Error> {
218        let Some(this) = to_visit.pop_front() else {
219            anyhow::bail!("unexpected forward chain termination");
220        };
221        visited.push(this);
222        let next = to_visit.front().cloned();
223        anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
224        next.unwrap().post(cx, Forward { to_visit, visited });
225        Ok(())
226    }
227}
228
229/// Just return the cast info of the sender.
230#[derive(
231    Debug,
232    Clone,
233    Named,
234    Bind,
235    Unbind,
236    Serialize,
237    Deserialize,
238    Handler,
239    RefClient
240)]
241pub struct GetCastInfo {
242    /// Originating actor, point, sender.
243    #[reply]
244    pub cast_info: hyperactor::PortRef<(Point, ActorRef<TestActor>, hyperactor::ActorAddr)>,
245}
246
247#[async_trait]
248impl Handler<GetCastInfo> for TestActor {
249    async fn handle(
250        &mut self,
251        cx: &Context<Self>,
252        GetCastInfo { cast_info }: GetCastInfo,
253    ) -> Result<(), anyhow::Error> {
254        cast_info.post(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()));
255        Ok(())
256    }
257}
258
259#[derive(Debug)]
260#[hyperactor::export]
261#[hyperactor::spawnable]
262pub struct FailingCreateTestActor;
263
264#[async_trait]
265impl Actor for FailingCreateTestActor {}
266
267#[async_trait]
268impl hyperactor::RemoteSpawn for FailingCreateTestActor {
269    type Params = ();
270
271    async fn new(
272        _params: Self::Params,
273        _environment: Flattrs,
274    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
275        Err(anyhow::anyhow!("test failure"))
276    }
277}
278
279#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
280pub struct SetConfigAttrs(pub Vec<u8>);
281
282#[async_trait]
283impl Handler<SetConfigAttrs> for TestActor {
284    async fn handle(
285        &mut self,
286        _cx: &Context<Self>,
287        SetConfigAttrs(attrs): SetConfigAttrs,
288    ) -> Result<(), anyhow::Error> {
289        let attrs =
290            bincode::serde::decode_from_slice(&attrs, bincode::config::legacy()).map(|(v, _)| v)?;
291        hyperactor_config::global::set(Source::Runtime, attrs);
292        Ok(())
293    }
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
297pub struct GetConfigAttrs(pub hyperactor::PortRef<Vec<u8>>);
298
299#[async_trait]
300impl Handler<GetConfigAttrs> for TestActor {
301    async fn handle(
302        &mut self,
303        cx: &Context<Self>,
304        GetConfigAttrs(reply): GetConfigAttrs,
305    ) -> Result<(), anyhow::Error> {
306        let attrs = bincode::serde::encode_to_vec(
307            hyperactor_config::global::attrs(),
308            bincode::config::legacy(),
309        )?;
310        reply.post(cx, attrs);
311        Ok(())
312    }
313}
314
315/// A message to request the next supervision event delivered to WrapperActor.
316/// Replies with None if no supervision event is encountered within a timeout
317/// (10 seconds).
318#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
319pub struct NextSupervisionFailure(pub hyperactor::PortRef<Option<MeshFailure>>);
320
321/// A small wrapper to handle supervision messages so they don't
322/// need to reach the client. This just wraps and forwards all messages to TestActor.
323/// The supervision events are sent back to "supervisor".
324#[derive(Debug)]
325#[hyperactor::export(
326    CauseSupervisionEvent { cast = true },
327    MeshFailure { cast = true },
328    NextSupervisionFailure { cast = true },
329)]
330#[hyperactor::spawnable]
331pub struct WrapperActor {
332    proc_mesh: ProcMeshRef,
333    // Needs to be a mesh so we own this actor and have a controller for it.
334    mesh: Option<ActorMesh<TestActor>>,
335    supervisor: hyperactor::PortRef<MeshFailure>,
336    test_name: ActorMeshId,
337}
338
339#[async_trait]
340impl hyperactor::RemoteSpawn for WrapperActor {
341    type Params = (ProcMeshRef, hyperactor::PortRef<MeshFailure>, ActorMeshId);
342
343    async fn new(
344        (proc_mesh, supervisor, test_name): Self::Params,
345        _environment: Flattrs,
346    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
347        Ok(Self {
348            proc_mesh,
349            mesh: None,
350            supervisor,
351            test_name,
352        })
353    }
354}
355
356#[async_trait]
357impl Actor for WrapperActor {
358    async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
359        self.mesh = Some(
360            self.proc_mesh
361                .spawn_with_name(this, self.test_name.clone(), &(), None, false)
362                .await?,
363        );
364        Ok(())
365    }
366}
367
368#[async_trait]
369impl Handler<CauseSupervisionEvent> for WrapperActor {
370    async fn handle(
371        &mut self,
372        cx: &Context<Self>,
373        msg: CauseSupervisionEvent,
374    ) -> Result<(), anyhow::Error> {
375        // No reply to wait for.
376        if msg.send_to_children {
377            // Send only to children, don't cause the event itself.
378            self.mesh
379                .as_ref()
380                .unwrap()
381                .cast(cx, msg)
382                .map_err(|e| e.into())
383        } else {
384            msg.cause_event()
385        }
386    }
387}
388
389#[async_trait]
390impl Handler<NextSupervisionFailure> for WrapperActor {
391    async fn handle(
392        &mut self,
393        cx: &Context<Self>,
394        msg: NextSupervisionFailure,
395    ) -> Result<(), anyhow::Error> {
396        let mesh = if let Some(mesh) = self.mesh.as_ref() {
397            mesh.deref()
398        } else {
399            msg.0.post(cx, None);
400            return Ok(());
401        };
402        let failure = match tokio::time::timeout(
403            tokio::time::Duration::from_secs(20),
404            mesh.next_supervision_event(cx),
405        )
406        .await
407        {
408            Ok(Ok(failure)) => Some(failure),
409            // Any error in next_supervision_event is treated the same.
410            Ok(Err(_)) => None,
411            // If we timeout, send back None.
412            Err(_) => None,
413        };
414        msg.0.post(cx, failure);
415        Ok(())
416    }
417}
418
419#[async_trait]
420impl Handler<MeshFailure> for WrapperActor {
421    async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
422        // All supervision events are considered handled so they don't bubble up
423        // to the client (who isn't listening for MeshFailure).
424        tracing::info!("got supervision event from child: {}", msg);
425        // Send to a port so the client can view the messages.
426        // Ignore the error if there is one.
427        let _ = self.supervisor.post(cx, msg.clone());
428        Ok(())
429    }
430}
431
432#[cfg(test)]
433/// Asserts that the provided actor mesh has the expected shape,
434/// and all actors are assigned the correct ranks. We also test
435/// slicing the mesh.
436pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
437    let instance = testing::instance();
438    // Verify casting to the root actor mesh
439    assert_casting_correctness(&actor_mesh, instance, None).await;
440
441    // Just pick the first dimension. Slice half of it off.
442    // actor_mesh.extent().
443    let label = actor_mesh.extent().labels()[0].clone();
444    let size = actor_mesh.extent().sizes()[0] / 2;
445
446    // Verify casting to the sliced actor mesh
447    let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
448    assert_casting_correctness(&sliced_actor_mesh, instance, None).await;
449}
450
451#[cfg(test)]
452/// Cast to the actor mesh, and verify that all actors are reached, and the
453/// sequence numbers, if provided, are correct.
454pub async fn assert_casting_correctness(
455    actor_mesh: &ActorMeshRef<TestActor>,
456    instance: &impl context::Actor,
457    expected_seqs: Option<(Uuid, Vec<u64>)>,
458) {
459    let (port, mut rx) = instance.mailbox().open_port();
460    actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
461    let expected_actor_ids = actor_mesh
462        .values()
463        .map(|actor_ref| actor_ref.actor_addr().clone())
464        .collect::<Vec<_>>();
465    let mut expected: HashMap<&hyperactor::ActorAddr, Option<SeqInfo>> = match expected_seqs {
466        None => expected_actor_ids
467            .iter()
468            .map(|actor_id| (actor_id, None))
469            .collect(),
470        Some((session_id, seqs)) => expected_actor_ids
471            .iter()
472            .zip(
473                seqs.into_iter()
474                    .map(|seq| Some(SeqInfo::Session { session_id, seq })),
475            )
476            .collect(),
477    };
478
479    while !expected.is_empty() {
480        let (actor_id, rcved) = rx.recv().await.unwrap();
481        let rcv_seq_info = rcved.unwrap();
482        let removed = expected.remove(&actor_id);
483        assert!(
484            removed.is_some(),
485            "got {actor_id}, expect {expected_actor_ids:?}"
486        );
487        if let Some(expected) = removed.unwrap() {
488            assert_eq!(expected, rcv_seq_info, "got different seq for {actor_id}");
489        }
490    }
491
492    // No more messages
493    tokio::time::sleep(Duration::from_secs(1)).await;
494    let result = rx.try_recv();
495    assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
496}