hyperactor_mesh/
testactor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a test actor. It is defined in a separate module
10//! (outside of [`crate::testing`]) to ensure that it is compiled into
11//! the bootstrap binary, which is not built in test mode (and anyway, test mode
12//! does not work across crate boundaries)
13
14#[cfg(test)]
15use std::collections::HashMap;
16use std::collections::VecDeque;
17use std::ops::Deref;
18#[cfg(test)]
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::Bind;
24use hyperactor::Context;
25use hyperactor::Handler;
26use hyperactor::Instance;
27use hyperactor::RefClient;
28use hyperactor::Unbind;
29#[cfg(test)]
30use hyperactor::context;
31use hyperactor::ordering::SEQ_INFO;
32use hyperactor::ordering::SeqInfo;
33use hyperactor::reference as hyperactor_reference;
34use hyperactor::supervision::ActorSupervisionEvent;
35use hyperactor_config::Flattrs;
36use hyperactor_config::global::Source;
37use ndslice::Point;
38#[cfg(test)]
39use ndslice::ViewExt as _;
40use serde::Deserialize;
41use serde::Serialize;
42use typeuri::Named;
43#[cfg(test)]
44use uuid::Uuid;
45
46use crate::ActorMesh;
47#[cfg(test)]
48use crate::ActorMeshRef;
49use crate::Name;
50use crate::ProcMeshRef;
51use crate::comm::multicast::CastInfo;
52use crate::supervision::MeshFailure;
53#[cfg(test)]
54use crate::testing;
55
56/// A simple test actor used by various unit tests.
57#[derive(Default, Debug)]
58#[hyperactor::export(
59    spawn = true,
60    handlers = [
61        () { cast = true },
62        GetActorId { cast = true },
63        GetCastInfo { cast = true },
64        CauseSupervisionEvent { cast = true },
65        Forward,
66        GetConfigAttrs { cast = true },
67        SetConfigAttrs { cast = true },
68    ]
69)]
70pub struct TestActor;
71
72impl Actor for TestActor {}
73
74/// A message that returns the recipient actor's id and cast message's seq info.
75#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
76pub struct GetActorId(
77    #[binding(include)]
78    pub  hyperactor_reference::PortRef<(hyperactor_reference::ActorId, Option<SeqInfo>)>,
79);
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum SupervisionEventType {
83    Panic,
84    SigSEGV,
85    ProcessExit(i32),
86}
87
88/// A message that causes a supervision event. The one argument determines what
89/// kind of supervision event it'll be.
90#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
91pub struct CauseSupervisionEvent {
92    pub kind: SupervisionEventType,
93    pub send_to_children: bool,
94}
95
96impl CauseSupervisionEvent {
97    fn cause_event(&self) -> ! {
98        match self.kind {
99            SupervisionEventType::Panic => {
100                panic!("for testing");
101            }
102            SupervisionEventType::SigSEGV => {
103                tracing::error!("exiting with SIGSEGV");
104                // SAFETY: This is for testing code that explicitly causes a SIGSEGV.
105                unsafe { std::ptr::null_mut::<i32>().write(42) };
106                // While the above should always segfault, we need a hard exit
107                // for the compiler's sake.
108                panic!("should have segfaulted");
109            }
110            SupervisionEventType::ProcessExit(code) => {
111                tracing::error!("exiting process {} with code {}", std::process::id(), code);
112                std::process::exit(code);
113            }
114        }
115    }
116}
117
118#[async_trait]
119impl Handler<()> for TestActor {
120    async fn handle(&mut self, _cx: &Context<Self>, _: ()) -> Result<(), anyhow::Error> {
121        Ok(())
122    }
123}
124
125#[async_trait]
126impl Handler<GetActorId> for TestActor {
127    async fn handle(
128        &mut self,
129        cx: &Context<Self>,
130        GetActorId(reply): GetActorId,
131    ) -> Result<(), anyhow::Error> {
132        let seq_info = cx.headers().get(SEQ_INFO);
133        reply.send(cx, (cx.self_id().clone(), seq_info))?;
134        Ok(())
135    }
136}
137
138#[async_trait]
139impl Handler<CauseSupervisionEvent> for TestActor {
140    async fn handle(
141        &mut self,
142        _cx: &Context<Self>,
143        msg: CauseSupervisionEvent,
144    ) -> Result<(), anyhow::Error> {
145        msg.cause_event();
146    }
147}
148
149/// A test actor that handles supervision events.
150/// It should be the parent of TestActor who can panic or cause a SIGSEGV.
151#[derive(Default, Debug)]
152#[hyperactor::export(
153    spawn = true,
154    handlers = [ActorSupervisionEvent],
155)]
156pub struct TestActorWithSupervisionHandling;
157
158#[async_trait]
159impl Actor for TestActorWithSupervisionHandling {
160    async fn handle_supervision_event(
161        &mut self,
162        _this: &Instance<Self>,
163        event: &ActorSupervisionEvent,
164    ) -> Result<bool, anyhow::Error> {
165        tracing::error!("supervision event: {:?}", event);
166        // Swallow the supervision error to avoid crashing the process.
167        Ok(true)
168    }
169}
170
171#[async_trait]
172impl Handler<ActorSupervisionEvent> for TestActorWithSupervisionHandling {
173    async fn handle(
174        &mut self,
175        _cx: &Context<Self>,
176        _msg: ActorSupervisionEvent,
177    ) -> Result<(), anyhow::Error> {
178        Ok(())
179    }
180}
181
182/// A test actor that sleeps when it receives a Duration message.
183/// Used for testing timeout and abort behavior.
184#[derive(Default, Debug)]
185#[hyperactor::export(
186    spawn = true,
187    handlers = [std::time::Duration],
188)]
189pub struct SleepActor;
190
191impl Actor for SleepActor {}
192
193#[async_trait]
194impl Handler<std::time::Duration> for SleepActor {
195    async fn handle(
196        &mut self,
197        _cx: &Context<Self>,
198        duration: std::time::Duration,
199    ) -> Result<(), anyhow::Error> {
200        tokio::time::sleep(duration).await;
201        Ok(())
202    }
203}
204
205/// A message to forward to a visit list of ports.
206/// Each port removes the next entry, and adds it to the
207/// 'visited' list.
208#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
209pub struct Forward {
210    pub to_visit: VecDeque<hyperactor_reference::PortRef<Forward>>,
211    pub visited: Vec<hyperactor_reference::PortRef<Forward>>,
212}
213
214#[async_trait]
215impl Handler<Forward> for TestActor {
216    async fn handle(
217        &mut self,
218        cx: &Context<Self>,
219        Forward {
220            mut to_visit,
221            mut visited,
222        }: Forward,
223    ) -> Result<(), anyhow::Error> {
224        let Some(this) = to_visit.pop_front() else {
225            anyhow::bail!("unexpected forward chain termination");
226        };
227        visited.push(this);
228        let next = to_visit.front().cloned();
229        anyhow::ensure!(next.is_some(), "unexpected forward chain termination");
230        next.unwrap().send(cx, Forward { to_visit, visited })?;
231        Ok(())
232    }
233}
234
235/// Just return the cast info of the sender.
236#[derive(
237    Debug,
238    Clone,
239    Named,
240    Bind,
241    Unbind,
242    Serialize,
243    Deserialize,
244    Handler,
245    RefClient
246)]
247pub struct GetCastInfo {
248    /// Originating actor, point, sender.
249    #[reply]
250    pub cast_info: hyperactor_reference::PortRef<(
251        Point,
252        hyperactor_reference::ActorRef<TestActor>,
253        hyperactor_reference::ActorId,
254    )>,
255}
256
257#[async_trait]
258impl Handler<GetCastInfo> for TestActor {
259    async fn handle(
260        &mut self,
261        cx: &Context<Self>,
262        GetCastInfo { cast_info }: GetCastInfo,
263    ) -> Result<(), anyhow::Error> {
264        cast_info.send(cx, (cx.cast_point(), cx.bind(), cx.sender().clone()))?;
265        Ok(())
266    }
267}
268
269#[derive(Debug)]
270#[hyperactor::export(spawn = true)]
271pub struct FailingCreateTestActor;
272
273#[async_trait]
274impl Actor for FailingCreateTestActor {}
275
276#[async_trait]
277impl hyperactor::RemoteSpawn for FailingCreateTestActor {
278    type Params = ();
279
280    async fn new(
281        _params: Self::Params,
282        _environment: Flattrs,
283    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
284        Err(anyhow::anyhow!("test failure"))
285    }
286}
287
288#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
289pub struct SetConfigAttrs(pub Vec<u8>);
290
291#[async_trait]
292impl Handler<SetConfigAttrs> for TestActor {
293    async fn handle(
294        &mut self,
295        _cx: &Context<Self>,
296        SetConfigAttrs(attrs): SetConfigAttrs,
297    ) -> Result<(), anyhow::Error> {
298        let attrs = bincode::deserialize(&attrs)?;
299        hyperactor_config::global::set(Source::Runtime, attrs);
300        Ok(())
301    }
302}
303
304#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
305pub struct GetConfigAttrs(pub hyperactor_reference::PortRef<Vec<u8>>);
306
307#[async_trait]
308impl Handler<GetConfigAttrs> for TestActor {
309    async fn handle(
310        &mut self,
311        cx: &Context<Self>,
312        GetConfigAttrs(reply): GetConfigAttrs,
313    ) -> Result<(), anyhow::Error> {
314        let attrs = bincode::serialize(&hyperactor_config::global::attrs())?;
315        reply.send(cx, attrs)?;
316        Ok(())
317    }
318}
319
320/// A message to request the next supervision event delivered to WrapperActor.
321/// Replies with None if no supervision event is encountered within a timeout
322/// (10 seconds).
323#[derive(Clone, Debug, Serialize, Deserialize, Named, Bind, Unbind)]
324pub struct NextSupervisionFailure(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
325
326/// A small wrapper to handle supervision messages so they don't
327/// need to reach the client. This just wraps and forwards all messages to TestActor.
328/// The supervision events are sent back to "supervisor".
329#[derive(Debug)]
330#[hyperactor::export(
331    spawn = true,
332    handlers = [
333        CauseSupervisionEvent { cast = true },
334        MeshFailure { cast = true },
335        NextSupervisionFailure { cast = true },
336    ]
337)]
338pub struct WrapperActor {
339    proc_mesh: ProcMeshRef,
340    // Needs to be a mesh so we own this actor and have a controller for it.
341    mesh: Option<ActorMesh<TestActor>>,
342    supervisor: hyperactor_reference::PortRef<MeshFailure>,
343    test_name: Name,
344}
345
346#[async_trait]
347impl hyperactor::RemoteSpawn for WrapperActor {
348    type Params = (
349        ProcMeshRef,
350        hyperactor_reference::PortRef<MeshFailure>,
351        Name,
352    );
353
354    async fn new(
355        (proc_mesh, supervisor, test_name): Self::Params,
356        _environment: Flattrs,
357    ) -> Result<Self, hyperactor::internal_macro_support::anyhow::Error> {
358        Ok(Self {
359            proc_mesh,
360            mesh: None,
361            supervisor,
362            test_name,
363        })
364    }
365}
366
367#[async_trait]
368impl Actor for WrapperActor {
369    async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
370        self.mesh = Some(
371            self.proc_mesh
372                .spawn_with_name(this, self.test_name.clone(), &(), None, false)
373                .await?,
374        );
375        Ok(())
376    }
377}
378
379#[async_trait]
380impl Handler<CauseSupervisionEvent> for WrapperActor {
381    async fn handle(
382        &mut self,
383        cx: &Context<Self>,
384        msg: CauseSupervisionEvent,
385    ) -> Result<(), anyhow::Error> {
386        // No reply to wait for.
387        if msg.send_to_children {
388            // Send only to children, don't cause the event itself.
389            self.mesh
390                .as_ref()
391                .unwrap()
392                .cast(cx, msg)
393                .map_err(|e| e.into())
394        } else {
395            msg.cause_event()
396        }
397    }
398}
399
400#[async_trait]
401impl Handler<NextSupervisionFailure> for WrapperActor {
402    async fn handle(
403        &mut self,
404        cx: &Context<Self>,
405        msg: NextSupervisionFailure,
406    ) -> Result<(), anyhow::Error> {
407        let mesh = if let Some(mesh) = self.mesh.as_ref() {
408            mesh.deref()
409        } else {
410            msg.0.send(cx, None)?;
411            return Ok(());
412        };
413        let failure = match tokio::time::timeout(
414            tokio::time::Duration::from_secs(20),
415            mesh.next_supervision_event(cx),
416        )
417        .await
418        {
419            Ok(Ok(failure)) => Some(failure),
420            // Any error in next_supervision_event is treated the same.
421            Ok(Err(_)) => None,
422            // If we timeout, send back None.
423            Err(_) => None,
424        };
425        msg.0.send(cx, failure)?;
426        Ok(())
427    }
428}
429
430#[async_trait]
431impl Handler<MeshFailure> for WrapperActor {
432    async fn handle(&mut self, cx: &Context<Self>, msg: MeshFailure) -> Result<(), anyhow::Error> {
433        // All supervision events are considered handled so they don't bubble up
434        // to the client (who isn't listening for MeshFailure).
435        tracing::info!("got supervision event from child: {}", msg);
436        // Send to a port so the client can view the messages.
437        // Ignore the error if there is one.
438        let _ = self.supervisor.send(cx, msg.clone());
439        Ok(())
440    }
441}
442
443#[cfg(test)]
444/// Asserts that the provided actor mesh has the expected shape,
445/// and all actors are assigned the correct ranks. We also test
446/// slicing the mesh.
447pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
448    let instance = testing::instance();
449    // Verify casting to the root actor mesh
450    assert_casting_correctness(&actor_mesh, instance, None).await;
451
452    // Just pick the first dimension. Slice half of it off.
453    // actor_mesh.extent().
454    let label = actor_mesh.extent().labels()[0].clone();
455    let size = actor_mesh.extent().sizes()[0] / 2;
456
457    // Verify casting to the sliced actor mesh
458    let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
459    assert_casting_correctness(&sliced_actor_mesh, instance, None).await;
460}
461
462#[cfg(test)]
463/// Cast to the actor mesh, and verify that all actors are reached, and the
464/// sequence numbers, if provided, are correct.
465pub async fn assert_casting_correctness(
466    actor_mesh: &ActorMeshRef<TestActor>,
467    instance: &impl context::Actor,
468    expected_seqs: Option<(Uuid, Vec<u64>)>,
469) {
470    let (port, mut rx) = instance.mailbox().open_port();
471    actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
472    let expected_actor_ids = actor_mesh
473        .values()
474        .map(|actor_ref| actor_ref.actor_id().clone())
475        .collect::<Vec<_>>();
476    let mut expected: HashMap<&hyperactor_reference::ActorId, Option<SeqInfo>> = match expected_seqs
477    {
478        None => expected_actor_ids
479            .iter()
480            .map(|actor_id| (actor_id, None))
481            .collect(),
482        Some((session_id, seqs)) => expected_actor_ids
483            .iter()
484            .zip(
485                seqs.into_iter()
486                    .map(|seq| Some(SeqInfo::Session { session_id, seq })),
487            )
488            .collect(),
489    };
490
491    while !expected.is_empty() {
492        let (actor_id, rcved) = rx.recv().await.unwrap();
493        let rcv_seq_info = rcved.unwrap();
494        let removed = expected.remove(&actor_id);
495        assert!(
496            removed.is_some(),
497            "got {actor_id}, expect {expected_actor_ids:?}"
498        );
499        if let Some(expected) = removed.unwrap() {
500            assert_eq!(expected, rcv_seq_info, "got different seq for {actor_id}");
501        }
502    }
503
504    // No more messages
505    tokio::time::sleep(Duration::from_secs(1)).await;
506    let result = rx.try_recv();
507    assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
508}