hyperactor_mesh/
alloc.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a proc allocator interface as well as a multi-process
10//! (local) allocator, [`ProcessAllocator`].
11
12pub mod local;
13pub(crate) mod logtailer;
14pub mod process;
15pub mod remoteprocess;
16pub mod sim;
17
18use std::collections::HashMap;
19use std::fmt;
20
21use async_trait::async_trait;
22use enum_as_inner::EnumAsInner;
23use hyperactor::ActorRef;
24use hyperactor::ProcId;
25use hyperactor::WorldId;
26use hyperactor::channel::ChannelAddr;
27use hyperactor::channel::ChannelTransport;
28pub use local::LocalAlloc;
29pub use local::LocalAllocator;
30use mockall::predicate::*;
31use mockall::*;
32use ndslice::Shape;
33use ndslice::Slice;
34use ndslice::view::Extent;
35use ndslice::view::Point;
36pub use process::ProcessAlloc;
37pub use process::ProcessAllocator;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::alloc::test_utils::MockAllocWrapper;
42use crate::proc_mesh::mesh_agent::MeshAgent;
43
44/// Errors that occur during allocation operations.
45#[derive(Debug, thiserror::Error)]
46pub enum AllocatorError {
47    #[error("incomplete allocation; expected: {0}")]
48    Incomplete(Extent),
49
50    /// The requested shape is too large for the allocator.
51    #[error("not enough resources; requested: {requested}, available: {available}")]
52    NotEnoughResources { requested: Extent, available: usize },
53
54    /// An uncategorized error from an underlying system.
55    #[error(transparent)]
56    Other(#[from] anyhow::Error),
57}
58
59/// Constraints on the allocation.
60#[derive(Debug, Clone, Serialize, Deserialize, Default)]
61pub struct AllocConstraints {
62    /// Aribitrary name/value pairs that are interpreted by individual
63    /// allocators to control allocation process.
64    pub match_labels: HashMap<String, String>,
65}
66
67/// A specification (desired state) of an alloc.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct AllocSpec {
70    /// The requested extent of the alloc.
71    // We currently assume that this shape is dense.
72    // This should be validated, or even enforced by
73    // way of types.
74    pub extent: Extent,
75    /// Constraints on the allocation.
76    pub constraints: AllocConstraints,
77}
78
79/// The core allocator trait, implemented by all allocators.
80#[automock(type Alloc=MockAllocWrapper;)]
81#[async_trait]
82pub trait Allocator {
83    /// The type of [`Alloc`] produced by this allocator.
84    type Alloc: Alloc;
85
86    /// Create a new allocation. The allocation itself is generally
87    /// returned immediately (after validating parameters, etc.);
88    /// the caller is expected to respond to allocation events as
89    /// the underlying procs are incrementally allocated.
90    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError>;
91}
92
93/// A proc's status. A proc can only monotonically move from
94/// `Created` to `Running` to `Stopped`.
95#[derive(Clone, Debug, PartialEq, EnumAsInner, Serialize, Deserialize)]
96pub enum ProcState {
97    /// A proc was added to the alloc.
98    Created {
99        /// The proc's id.
100        proc_id: ProcId,
101        /// Its assigned point (in the alloc's extent).
102        point: Point,
103        /// The system process ID of the created child process.
104        pid: u32,
105    },
106    /// A proc was started.
107    Running {
108        proc_id: ProcId,
109        /// Reference to this proc's mesh agent. In the future, we'll reserve a
110        /// 'well known' PID (0) for this purpose.
111        mesh_agent: ActorRef<MeshAgent>,
112        /// The address of this proc. The endpoint of this address is
113        /// the proc's mailbox, which accepts [`hyperactor::mailbox::MessageEnvelope`]s.
114        addr: ChannelAddr,
115    },
116    /// A proc was stopped.
117    Stopped {
118        proc_id: ProcId,
119        reason: ProcStopReason,
120    },
121    /// Allocation process encountered an irrecoverable error. Depending on the
122    /// implementation, the allocation process may continue transiently and calls
123    /// to next() may return some events. But eventually the allocation will not
124    /// be complete. Callers can use the `description` to determine the reason for
125    /// the failure.
126    /// Allocation can then be cleaned up by calling `stop()`` on the `Alloc` and
127    /// drain the iterator for clean shutdown.
128    Failed {
129        /// The world ID of the failed alloc.
130        world_id: WorldId,
131        /// A description of the failure.
132        description: String,
133    },
134}
135
136impl fmt::Display for ProcState {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        match self {
139            ProcState::Created {
140                proc_id,
141                point,
142                pid,
143            } => {
144                write!(f, "{}: created at ({}) with PID {}", proc_id, point, pid)
145            }
146            ProcState::Running { proc_id, addr, .. } => {
147                write!(f, "{}: running at {}", proc_id, addr)
148            }
149            ProcState::Stopped { proc_id, reason } => {
150                write!(f, "{}: stopped: {}", proc_id, reason)
151            }
152            ProcState::Failed {
153                description,
154                world_id,
155            } => {
156                write!(f, "{}: failed: {}", world_id, description)
157            }
158        }
159    }
160}
161
162/// The reason a proc stopped.
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, EnumAsInner)]
164pub enum ProcStopReason {
165    /// The proc stopped gracefully, e.g., with exit code 0.
166    Stopped,
167    /// The proc exited with the provided error code and stderr
168    Exited(i32, String),
169    /// The proc was killed. The signal number is indicated;
170    /// the flags determines whether there was a core dump.
171    Killed(i32, bool),
172    /// The proc failed to respond to a watchdog request within a timeout.
173    Watchdog,
174    /// The host running the proc failed to respond to a watchdog request
175    /// within a timeout.
176    HostWatchdog,
177    /// The proc failed for an unknown reason.
178    Unknown,
179}
180
181impl fmt::Display for ProcStopReason {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        match self {
184            Self::Stopped => write!(f, "stopped"),
185            Self::Exited(code, stderr) => {
186                if stderr.is_empty() {
187                    write!(f, "exited with code {}", code)
188                } else {
189                    write!(f, "exited with code {}: {}", code, stderr)
190                }
191            }
192            Self::Killed(signal, dumped) => {
193                write!(f, "killed with signal {} (core dumped={})", signal, dumped)
194            }
195            Self::Watchdog => write!(f, "proc watchdog failure"),
196            Self::HostWatchdog => write!(f, "host watchdog failure"),
197            Self::Unknown => write!(f, "unknown"),
198        }
199    }
200}
201
202/// An alloc is a specific allocation, returned by an [`Allocator`].
203#[automock]
204#[async_trait]
205pub trait Alloc {
206    /// Return the next proc event. `None` indicates that there are
207    /// no more events, and that the alloc is stopped.
208    async fn next(&mut self) -> Option<ProcState>;
209
210    /// The shape of the alloc.
211    fn extent(&self) -> &Extent;
212
213    /// The shape of the alloc. (Deprecated.)
214    fn shape(&self) -> Shape {
215        let slice = Slice::new_row_major(self.extent().sizes());
216        Shape::new(self.extent().labels().to_vec(), slice).unwrap()
217    }
218
219    /// The world id of this alloc, uniquely identifying the alloc.
220    /// Note: This will be removed in favor of a different naming scheme,
221    /// once we exise "worlds" from hyperactor core.
222    fn world_id(&self) -> &WorldId;
223
224    /// The channel transport used the procs in this alloc.
225    fn transport(&self) -> ChannelTransport;
226
227    /// Stop this alloc, shutting down all of its procs. A clean
228    /// shutdown should result in Stop events from all allocs,
229    /// followed by the end of the event stream.
230    async fn stop(&mut self) -> Result<(), AllocatorError>;
231
232    /// Stop this alloc and wait for all procs to stop. Call will
233    /// block until all ProcState events have been drained.
234    async fn stop_and_wait(&mut self) -> Result<(), AllocatorError> {
235        self.stop().await?;
236        while let Some(event) = self.next().await {
237            tracing::debug!("drained event: {:?}", event);
238        }
239        Ok(())
240    }
241}
242
243pub mod test_utils {
244    use std::time::Duration;
245
246    use hyperactor::Actor;
247    use hyperactor::Context;
248    use hyperactor::Handler;
249    use hyperactor::Named;
250    use libc::atexit;
251    use tokio::sync::broadcast::Receiver;
252    use tokio::sync::broadcast::Sender;
253
254    use super::*;
255
256    extern "C" fn exit_handler() {
257        loop {
258            #[allow(clippy::disallowed_methods)]
259            std::thread::sleep(Duration::from_secs(60));
260        }
261    }
262
263    // This can't be defined under a `#[cfg(test)]` because there needs to
264    // be an entry in the spawnable actor registry in the executable
265    // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor
266    // mesh test suite.
267    #[derive(Debug, Default, Actor)]
268    #[hyperactor::export(
269        spawn = true,
270        handlers = [
271            Wait
272        ],
273    )]
274    pub struct TestActor;
275
276    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
277    pub struct Wait;
278
279    #[async_trait]
280    impl Handler<Wait> for TestActor {
281        async fn handle(&mut self, _: &Context<Self>, _: Wait) -> Result<(), anyhow::Error> {
282            // SAFETY:
283            // This is in order to simulate a process in tests that never exits.
284            unsafe {
285                atexit(exit_handler);
286            }
287            Ok(())
288        }
289    }
290
291    /// Test wrapper around MockAlloc to allow us to block next() calls since
292    /// mockall doesn't support returning futures.
293    pub struct MockAllocWrapper {
294        pub alloc: MockAlloc,
295        pub block_next_after: usize,
296        notify_tx: Sender<()>,
297        notify_rx: Receiver<()>,
298        next_unblocked: bool,
299    }
300
301    impl MockAllocWrapper {
302        pub fn new(alloc: MockAlloc) -> Self {
303            Self::new_block_next(alloc, usize::MAX)
304        }
305
306        pub fn new_block_next(alloc: MockAlloc, count: usize) -> Self {
307            let (tx, rx) = tokio::sync::broadcast::channel(1);
308            Self {
309                alloc,
310                block_next_after: count,
311                notify_tx: tx,
312                notify_rx: rx,
313                next_unblocked: false,
314            }
315        }
316
317        pub fn notify_tx(&self) -> Sender<()> {
318            self.notify_tx.clone()
319        }
320    }
321
322    #[async_trait]
323    impl Alloc for MockAllocWrapper {
324        async fn next(&mut self) -> Option<ProcState> {
325            match self.block_next_after {
326                0 => {
327                    if !self.next_unblocked {
328                        self.notify_rx.recv().await.unwrap();
329                        self.next_unblocked = true;
330                    }
331                }
332                1.. => {
333                    self.block_next_after -= 1;
334                }
335            }
336
337            self.alloc.next().await
338        }
339
340        fn extent(&self) -> &Extent {
341            self.alloc.extent()
342        }
343
344        fn world_id(&self) -> &WorldId {
345            self.alloc.world_id()
346        }
347
348        fn transport(&self) -> ChannelTransport {
349            self.alloc.transport()
350        }
351
352        async fn stop(&mut self) -> Result<(), AllocatorError> {
353            self.alloc.stop().await
354        }
355    }
356}
357
358#[cfg(test)]
359pub(crate) mod testing {
360    use core::panic;
361    use std::collections::HashMap;
362    use std::collections::HashSet;
363    use std::time::Duration;
364
365    use hyperactor::Mailbox;
366    use hyperactor::actor::remote::Remote;
367    use hyperactor::channel;
368    use hyperactor::mailbox;
369    use hyperactor::mailbox::BoxedMailboxSender;
370    use hyperactor::mailbox::DialMailboxRouter;
371    use hyperactor::mailbox::IntoBoxedMailboxSender;
372    use hyperactor::mailbox::MailboxServer;
373    use hyperactor::mailbox::UndeliverableMailboxSender;
374    use hyperactor::proc::Proc;
375    use hyperactor::reference::Reference;
376    use ndslice::extent;
377    use tokio::process::Command;
378
379    use super::*;
380    use crate::alloc::test_utils::TestActor;
381    use crate::alloc::test_utils::Wait;
382    use crate::proc_mesh::mesh_agent::GspawnResult;
383    use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
384
385    #[macro_export]
386    macro_rules! alloc_test_suite {
387        ($allocator:expr_2021) => {
388            #[tokio::test]
389            async fn test_allocator_basic() {
390                $crate::alloc::testing::test_allocator_basic($allocator).await;
391            }
392        };
393    }
394
395    pub(crate) async fn test_allocator_basic(mut allocator: impl Allocator) {
396        let extent = extent!(replica = 4);
397        let mut alloc = allocator
398            .allocate(AllocSpec {
399                extent: extent.clone(),
400                constraints: Default::default(),
401            })
402            .await
403            .unwrap();
404
405        // Get everything up into running state. We require that we get
406        // procs 0..4.
407        let mut procs = HashMap::new();
408        let mut running = HashSet::new();
409        while running.len() != 4 {
410            match alloc.next().await.unwrap() {
411                ProcState::Created { proc_id, point, .. } => {
412                    procs.insert(proc_id, point);
413                }
414                ProcState::Running { proc_id, .. } => {
415                    assert!(procs.contains_key(&proc_id));
416                    assert!(!running.contains(&proc_id));
417                    running.insert(proc_id);
418                }
419                event => panic!("unexpected event: {:?}", event),
420            }
421        }
422
423        // We should have complete coverage of all points.
424        let points: HashSet<_> = procs.values().collect();
425        for x in 0..4 {
426            assert!(points.contains(&extent.point(vec![x]).unwrap()));
427        }
428
429        // Every proc should belong to the same "world" (alloc).
430        let worlds: HashSet<_> = procs.keys().map(|proc_id| proc_id.world_id()).collect();
431        assert_eq!(worlds.len(), 1);
432
433        // Now, stop the alloc and make sure it shuts down cleanly.
434
435        alloc.stop().await.unwrap();
436        let mut stopped = HashSet::new();
437        while let Some(ProcState::Stopped { proc_id, reason }) = alloc.next().await {
438            assert_eq!(reason, ProcStopReason::Stopped);
439            stopped.insert(proc_id);
440        }
441        assert!(alloc.next().await.is_none());
442        assert_eq!(stopped, running);
443    }
444
445    async fn spawn_proc(
446        transport: ChannelTransport,
447    ) -> (DialMailboxRouter, Mailbox, Proc, ChannelAddr) {
448        let (router_channel_addr, router_rx) = channel::serve(ChannelAddr::any(transport.clone()))
449            .await
450            .unwrap();
451        let router =
452            DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
453        router.clone().serve(router_rx);
454
455        let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0);
456        let (client_proc_addr, client_rx) =
457            channel::serve(ChannelAddr::any(transport)).await.unwrap();
458        let client_proc = Proc::new(
459            client_proc_id.clone(),
460            BoxedMailboxSender::new(router.clone()),
461        );
462        client_proc.clone().serve(client_rx);
463        router.bind(client_proc_id.clone().into(), client_proc_addr);
464        (
465            router,
466            client_proc.attach("test_proc").unwrap(),
467            client_proc,
468            router_channel_addr,
469        )
470    }
471
472    async fn spawn_test_actor(
473        rank: usize,
474        client_proc: &Proc,
475        client: &Mailbox,
476        router_channel_addr: ChannelAddr,
477        mesh_agent: ActorRef<MeshAgent>,
478    ) -> ActorRef<TestActor> {
479        let supervisor = client_proc.attach("supervisor").unwrap();
480        let (supervison_port, _) = supervisor.open_port();
481        let (config_handle, _) = client.open_port();
482        mesh_agent
483            .configure(
484                client,
485                rank,
486                router_channel_addr,
487                supervison_port.bind(),
488                HashMap::new(),
489                config_handle.bind(),
490            )
491            .await
492            .unwrap();
493        let remote = Remote::collect();
494        let actor_type = remote
495            .name_of::<TestActor>()
496            .ok_or(anyhow::anyhow!("actor not registered"))
497            .unwrap()
498            .to_string();
499        let params = &();
500        let (completed_handle, mut completed_receiver) = mailbox::open_port(client);
501        // gspawn actor
502        mesh_agent
503            .gspawn(
504                client,
505                actor_type,
506                "Stuck".to_string(),
507                bincode::serialize(params).unwrap(),
508                completed_handle.bind(),
509            )
510            .await
511            .unwrap();
512        let result = completed_receiver.recv().await.unwrap();
513        match result {
514            GspawnResult::Success { actor_id, .. } => ActorRef::attest(actor_id),
515            GspawnResult::Error(error_msg) => {
516                panic!("gspawn failed: {}", error_msg);
517            }
518        }
519    }
520
521    /// In order to simulate stuckness, we have to do two things:
522    /// An actor that is blocked forever AND
523    /// a proc that does not time out when it is asked to wait for
524    /// a stuck actor.
525    #[tokio::test]
526    async fn test_allocator_stuck_task() {
527        // Override config.
528        // Use temporary config for this test
529        let config = hyperactor::config::global::lock();
530        let _guard = config.override_key(
531            hyperactor::config::PROCESS_EXIT_TIMEOUT,
532            Duration::from_secs(1),
533        );
534
535        let command =
536            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
537        let mut allocator = ProcessAllocator::new(command);
538        let mut alloc = allocator
539            .allocate(AllocSpec {
540                extent: extent! { replica = 1 },
541                constraints: Default::default(),
542            })
543            .await
544            .unwrap();
545
546        // Get everything up into running state. We require that we get
547        let mut procs = HashMap::new();
548        let mut running = HashSet::new();
549        let mut actor_ref = None;
550        let (router, client, client_proc, router_addr) = spawn_proc(alloc.transport()).await;
551        while running.is_empty() {
552            match alloc.next().await.unwrap() {
553                ProcState::Created { proc_id, point, .. } => {
554                    procs.insert(proc_id, point);
555                }
556                ProcState::Running {
557                    proc_id,
558                    mesh_agent,
559                    addr,
560                } => {
561                    router.bind(Reference::Proc(proc_id.clone()), addr.clone());
562
563                    assert!(procs.contains_key(&proc_id));
564                    assert!(!running.contains(&proc_id));
565
566                    actor_ref = Some(
567                        spawn_test_actor(0, &client_proc, &client, router_addr, mesh_agent).await,
568                    );
569                    running.insert(proc_id);
570                    break;
571                }
572                event => panic!("unexpected event: {:?}", event),
573            }
574        }
575        assert!(actor_ref.unwrap().send(&client, Wait).is_ok());
576
577        // There is a stuck actor! We should get a watchdog failure.
578        alloc.stop().await.unwrap();
579        let mut stopped = HashSet::new();
580        while let Some(ProcState::Stopped { proc_id, reason }) = alloc.next().await {
581            assert_eq!(reason, ProcStopReason::Watchdog);
582            stopped.insert(proc_id);
583        }
584        assert!(alloc.next().await.is_none());
585        assert_eq!(stopped, running);
586    }
587}