hyperactor_mesh/
alloc.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a proc allocator interface as well as a multi-process
10//! (local) allocator, [`ProcessAllocator`].
11
12pub mod local;
13pub mod process;
14pub mod remoteprocess;
15pub mod sim;
16
17use std::collections::HashMap;
18use std::fmt;
19use std::net::IpAddr;
20use std::net::Ipv4Addr;
21use std::net::Ipv6Addr;
22use std::net::SocketAddr;
23use std::net::TcpListener;
24use std::ops::Range;
25use std::sync::Mutex;
26use std::sync::OnceLock;
27use std::sync::atomic::AtomicUsize;
28use std::sync::atomic::Ordering;
29
30use async_trait::async_trait;
31use enum_as_inner::EnumAsInner;
32use hyperactor::ActorRef;
33use hyperactor::Named;
34use hyperactor::ProcId;
35use hyperactor::RemoteMessage;
36use hyperactor::WorldId;
37use hyperactor::attrs::declare_attrs;
38use hyperactor::channel;
39use hyperactor::channel::ChannelAddr;
40use hyperactor::channel::ChannelRx;
41use hyperactor::channel::ChannelTransport;
42use hyperactor::channel::MetaTlsAddr;
43use hyperactor::config;
44use hyperactor::config::CONFIG;
45use hyperactor::config::ConfigAttr;
46pub use local::LocalAlloc;
47pub use local::LocalAllocator;
48use mockall::predicate::*;
49use mockall::*;
50use ndslice::Shape;
51use ndslice::Slice;
52use ndslice::view::Extent;
53use ndslice::view::Point;
54pub use process::ProcessAlloc;
55pub use process::ProcessAllocator;
56use serde::Deserialize;
57use serde::Serialize;
58use strum::AsRefStr;
59
60use crate::alloc::test_utils::MockAllocWrapper;
61use crate::assign::Ranks;
62use crate::proc_mesh::mesh_agent::ProcMeshAgent;
63use crate::shortuuid::ShortUuid;
64
65declare_attrs! {
66    /// For Tcp channel types, if true, bind the IP address to INADDR_ANY
67    /// (0.0.0.0 or [::]) for frontend ports.
68    ///
69    /// This config is useful in environments where we cannot bind the port to
70    /// the given IP address. For example, in a AWS setting, it might not allow
71    /// us to bind the port to the host's public IP address.
72    @meta(CONFIG = ConfigAttr {
73        env_name: Some("HYPERACTOR_REMOTE_ALLOC_BIND_TO_INADDR_ANY".to_string()),
74        py_name: None,
75    })
76    pub attr REMOTE_ALLOC_BIND_TO_INADDR_ANY: bool = false;
77
78    /// Specify the address alloc uses as its bootstrap address. e.g.:
79    ///
80    /// * "tcp:142.250.81.228:0" means seve at a random port with IP4 address
81    ///   142.250.81.228.
82    /// * "tcp:[2401:db00:eef0:1120:3520:0:7812:4eca]:27001" means serve at port
83    ///   27001 with any IP6 2401:db00:eef0:1120:3520:0:7812:4eca.
84    ///
85    /// These IP address must be the IP address of the host running the alloc.
86    ///
87    /// This config is useful when we want the alloc to use a particular IP
88    /// address. For example, in a AWS setting, we might want to use the host's
89    /// public IP address.
90    // TODO: remove this env var, and make it part of alloc spec instead.
91    @meta(CONFIG = ConfigAttr {
92        env_name: Some("HYPERACTOR_REMOTE_ALLOC_BOOTSTRAP_ADDR".to_string()),
93        py_name: None,
94    })
95    pub attr REMOTE_ALLOC_BOOTSTRAP_ADDR: String;
96
97    /// For Tcp channel types, if set, only uses ports in this range for the
98    /// frontend ports. The input should be in the format "<start>..<end>",
99    /// where <end> is exclusive. e.g.:
100    ///
101    /// * "26601..26611" means only use the 10 ports in the range [26601, 26610],
102    ///   including 26601 and 26610.
103    ///
104    /// This config is useful in environments where only a certain range of
105    /// ports are allowed to be used.
106    @meta(CONFIG = ConfigAttr {
107        env_name: Some("HYPERACTOR_REMOTE_ALLOC_ALLOWED_PORT_RANGE".to_string()),
108        py_name: None,
109    })
110    pub attr REMOTE_ALLOC_ALLOWED_PORT_RANGE: Range<u16>;
111}
112
113/// Errors that occur during allocation operations.
114#[derive(Debug, thiserror::Error)]
115pub enum AllocatorError {
116    #[error("incomplete allocation; expected: {0}")]
117    Incomplete(Extent),
118
119    /// The requested shape is too large for the allocator.
120    #[error("not enough resources; requested: {requested}, available: {available}")]
121    NotEnoughResources { requested: Extent, available: usize },
122
123    /// An uncategorized error from an underlying system.
124    #[error(transparent)]
125    Other(#[from] anyhow::Error),
126}
127
128/// Constraints on the allocation.
129#[derive(Debug, Clone, Serialize, Deserialize, Default)]
130pub struct AllocConstraints {
131    /// Aribitrary name/value pairs that are interpreted by individual
132    /// allocators to control allocation process.
133    pub match_labels: HashMap<String, String>,
134}
135
136/// Specifies how to interpret the extent dimensions for allocation.
137#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
138pub enum ProcAllocationMode {
139    /// Proc-level allocation: splits extent to allocate multiple processes per host.
140    /// Requires at least 2 dimensions (e.g., [hosts: N, gpus: M]).
141    /// Splits by second-to-last dimension, creating N regions with M processes each.
142    /// Used by MastAllocator.
143    ProcLevel,
144    /// Host-level allocation: each point in the extent is a host (no sub-host splitting).
145    /// For extent!(region = 2, host = 4), create 8 regions, each representing 1 host.
146    /// Used by MastHostAllocator.
147    HostLevel,
148}
149
150impl Default for ProcAllocationMode {
151    fn default() -> Self {
152        // Default to ProcLevel for backward compatibility
153        Self::ProcLevel
154    }
155}
156
157/// A specification (desired state) of an alloc.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct AllocSpec {
160    /// The requested extent of the alloc.
161    // We currently assume that this shape is dense.
162    // This should be validated, or even enforced by
163    // way of types.
164    pub extent: Extent,
165
166    /// Constraints on the allocation.
167    pub constraints: AllocConstraints,
168
169    /// If specified, return procs using direct addressing with
170    /// the provided proc name.
171    pub proc_name: Option<String>,
172
173    /// The transport to use for the procs in this alloc.
174    pub transport: ChannelTransport,
175
176    /// Specifies how to interpret the extent dimensions for allocation.
177    /// Defaults to ProcLevel for backward compatibility.
178    #[serde(default = "default_proc_allocation_mode")]
179    pub proc_allocation_mode: ProcAllocationMode,
180}
181
182fn default_proc_allocation_mode() -> ProcAllocationMode {
183    ProcAllocationMode::ProcLevel
184}
185
186/// The core allocator trait, implemented by all allocators.
187#[automock(type Alloc=MockAllocWrapper;)]
188#[async_trait]
189pub trait Allocator {
190    /// The type of [`Alloc`] produced by this allocator.
191    type Alloc: Alloc;
192
193    /// Create a new allocation. The allocation itself is generally
194    /// returned immediately (after validating parameters, etc.);
195    /// the caller is expected to respond to allocation events as
196    /// the underlying procs are incrementally allocated.
197    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError>;
198}
199
200/// A proc's status. A proc can only monotonically move from
201/// `Created` to `Running` to `Stopped`.
202#[derive(
203    Clone,
204    Debug,
205    PartialEq,
206    EnumAsInner,
207    Serialize,
208    Deserialize,
209    AsRefStr,
210    Named
211)]
212pub enum ProcState {
213    /// A proc was added to the alloc.
214    Created {
215        /// A key to uniquely identify a created proc. The key is used again
216        /// to identify the created proc as Running.
217        create_key: ShortUuid,
218        /// Its assigned point (in the alloc's extent).
219        point: Point,
220        /// The system process ID of the created child process.
221        pid: u32,
222    },
223    /// A proc was started.
224    Running {
225        /// The key used to identify the created proc.
226        create_key: ShortUuid,
227        /// The proc's assigned ID.
228        proc_id: ProcId,
229        /// Reference to this proc's mesh agent. In the future, we'll reserve a
230        /// 'well known' PID (0) for this purpose.
231        mesh_agent: ActorRef<ProcMeshAgent>,
232        /// The address of this proc. The endpoint of this address is
233        /// the proc's mailbox, which accepts [`hyperactor::mailbox::MessageEnvelope`]s.
234        addr: ChannelAddr,
235    },
236    /// A proc was stopped.
237    Stopped {
238        create_key: ShortUuid,
239        reason: ProcStopReason,
240    },
241    /// Allocation process encountered an irrecoverable error. Depending on the
242    /// implementation, the allocation process may continue transiently and calls
243    /// to next() may return some events. But eventually the allocation will not
244    /// be complete. Callers can use the `description` to determine the reason for
245    /// the failure.
246    /// Allocation can then be cleaned up by calling `stop()`` on the `Alloc` and
247    /// drain the iterator for clean shutdown.
248    Failed {
249        /// The world ID of the failed alloc.
250        ///
251        /// TODO: this is not meaningful with direct addressing.
252        world_id: WorldId,
253        /// A description of the failure.
254        description: String,
255    },
256}
257
258impl fmt::Display for ProcState {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        match self {
261            ProcState::Created {
262                create_key,
263                point,
264                pid,
265            } => {
266                write!(f, "{}: created at ({}) with PID {}", create_key, point, pid)
267            }
268            ProcState::Running { proc_id, addr, .. } => {
269                write!(f, "{}: running at {}", proc_id, addr)
270            }
271            ProcState::Stopped { create_key, reason } => {
272                write!(f, "{}: stopped: {}", create_key, reason)
273            }
274            ProcState::Failed {
275                description,
276                world_id,
277            } => {
278                write!(f, "{}: failed: {}", world_id, description)
279            }
280        }
281    }
282}
283
284/// The reason a proc stopped.
285#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, EnumAsInner)]
286pub enum ProcStopReason {
287    /// The proc stopped gracefully, e.g., with exit code 0.
288    Stopped,
289    /// The proc exited with the provided error code and stderr
290    Exited(i32, String),
291    /// The proc was killed. The signal number is indicated;
292    /// the flags determines whether there was a core dump.
293    Killed(i32, bool),
294    /// The proc failed to respond to a watchdog request within a timeout.
295    Watchdog,
296    /// The host running the proc failed to respond to a watchdog request
297    /// within a timeout.
298    HostWatchdog,
299    /// The proc failed for an unknown reason.
300    Unknown,
301}
302
303impl fmt::Display for ProcStopReason {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        match self {
306            Self::Stopped => write!(f, "stopped"),
307            Self::Exited(code, stderr) => {
308                if stderr.is_empty() {
309                    write!(f, "exited with code {}", code)
310                } else {
311                    write!(f, "exited with code {}: {}", code, stderr)
312                }
313            }
314            Self::Killed(signal, dumped) => {
315                write!(f, "killed with signal {} (core dumped={})", signal, dumped)
316            }
317            Self::Watchdog => write!(f, "proc watchdog failure"),
318            Self::HostWatchdog => write!(f, "host watchdog failure"),
319            Self::Unknown => write!(f, "unknown"),
320        }
321    }
322}
323
324/// An alloc is a specific allocation, returned by an [`Allocator`].
325#[automock]
326#[async_trait]
327pub trait Alloc {
328    /// Return the next proc event. `None` indicates that there are
329    /// no more events, and that the alloc is stopped.
330    async fn next(&mut self) -> Option<ProcState>;
331
332    /// The spec against which this alloc is executing.
333    fn spec(&self) -> &AllocSpec;
334
335    /// The shape of the alloc.
336    fn extent(&self) -> &Extent;
337
338    /// The shape of the alloc. (Deprecated.)
339    fn shape(&self) -> Shape {
340        let slice = Slice::new_row_major(self.extent().sizes());
341        Shape::new(self.extent().labels().to_vec(), slice).unwrap()
342    }
343
344    /// The world id of this alloc, uniquely identifying the alloc.
345    /// Note: This will be removed in favor of a different naming scheme,
346    /// once we exise "worlds" from hyperactor core.
347    fn world_id(&self) -> &WorldId;
348
349    /// The channel transport used the procs in this alloc.
350    fn transport(&self) -> ChannelTransport {
351        self.spec().transport.clone()
352    }
353
354    /// Stop this alloc, shutting down all of its procs. A clean
355    /// shutdown should result in Stop events from all allocs,
356    /// followed by the end of the event stream.
357    async fn stop(&mut self) -> Result<(), AllocatorError>;
358
359    /// Stop this alloc and wait for all procs to stop. Call will
360    /// block until all ProcState events have been drained.
361    async fn stop_and_wait(&mut self) -> Result<(), AllocatorError> {
362        tracing::error!(
363            name = "AllocStatus",
364            alloc_name = %self.world_id(),
365            status = "StopAndWait",
366        );
367        self.stop().await?;
368        while let Some(event) = self.next().await {
369            tracing::debug!(
370                alloc_name = %self.world_id(),
371                "drained event: {event:?}"
372            );
373        }
374        tracing::error!(
375            name = "AllocStatus",
376            alloc_name = %self.world_id(),
377            status = "Stopped",
378        );
379        Ok(())
380    }
381
382    /// Returns whether the alloc is a local alloc: that is, its procs are
383    /// not independent processes, but just threads in the selfsame process.
384    fn is_local(&self) -> bool {
385        false
386    }
387
388    /// The address that should be used to serve the client's router.
389    fn client_router_addr(&self) -> ChannelAddr {
390        ChannelAddr::any(self.transport())
391    }
392}
393
394#[derive(Debug, Clone, PartialEq, Eq, Hash)]
395pub(crate) struct AllocatedProc {
396    pub create_key: ShortUuid,
397    pub proc_id: ProcId,
398    pub addr: ChannelAddr,
399    pub mesh_agent: ActorRef<ProcMeshAgent>,
400}
401
402impl fmt::Display for AllocatedProc {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        write!(
405            f,
406            "AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, mesh_agent: {} }}",
407            self.create_key, self.proc_id, self.addr, self.mesh_agent
408        )
409    }
410}
411
412#[async_trait]
413pub(crate) trait AllocExt {
414    /// Perform initial allocation, consuming events until the alloc is fully
415    /// running. Returns the ranked procs.
416    async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError>;
417}
418
419#[async_trait]
420impl<A: ?Sized + Send + Alloc> AllocExt for A {
421    async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError> {
422        // We wait for the full allocation to be running before returning the mesh.
423        let shape = self.shape().clone();
424
425        let mut created = Ranks::new(shape.slice().len());
426        let mut running = Ranks::new(shape.slice().len());
427
428        while !running.is_full() {
429            let Some(state) = self.next().await else {
430                // Alloc finished before it was fully allocated.
431                return Err(AllocatorError::Incomplete(self.extent().clone()));
432            };
433
434            let name = tracing::Span::current()
435                .metadata()
436                .map(|m| m.name())
437                .unwrap_or("initialize");
438            let status = format!("ProcState:{}", state.arm().unwrap_or("unknown"));
439
440            match state {
441                ProcState::Created {
442                    create_key, point, ..
443                } => {
444                    let rank = point.rank();
445                    if let Some(old_create_key) = created.insert(rank, create_key.clone()) {
446                        tracing::warn!(
447                            name,
448                            status,
449                            rank,
450                            "rank {rank} reassigned from {old_create_key} to {create_key}"
451                        );
452                    }
453                    tracing::info!(
454                        name,
455                        status,
456                        rank,
457                        "proc with create key {}, rank {}: created",
458                        create_key,
459                        rank
460                    );
461                }
462                ProcState::Running {
463                    create_key,
464                    proc_id,
465                    mesh_agent,
466                    addr,
467                } => {
468                    let Some(rank) = created.rank(&create_key) else {
469                        tracing::warn!(
470                            name,
471                            %proc_id,
472                            status,
473                            "proc id {proc_id} with create key {create_key} \
474                            is running, but was not created"
475                        );
476                        continue;
477                    };
478
479                    let allocated_proc = AllocatedProc {
480                        create_key,
481                        proc_id: proc_id.clone(),
482                        addr: addr.clone(),
483                        mesh_agent: mesh_agent.clone(),
484                    };
485                    if let Some(old_allocated_proc) = running.insert(*rank, allocated_proc.clone())
486                    {
487                        tracing::warn!(
488                            name,
489                            %proc_id,
490                            status,
491                            rank,
492                            "duplicate running notifications for {rank}: \
493                            old:{old_allocated_proc}; \
494                            new:{allocated_proc}"
495                        )
496                    }
497                    tracing::info!(
498                        name,
499                        %proc_id,
500                        status,
501                        "proc {} rank {}: running at addr:{addr} mesh_agent:{mesh_agent}",
502                        proc_id,
503                        rank
504                    );
505                }
506                // TODO: We should push responsibility to the allocator, which
507                // can choose to either provide a new proc or emit a
508                // ProcState::Failed to fail the whole allocation.
509                ProcState::Stopped { create_key, reason } => {
510                    tracing::error!(
511                        name,
512                        status,
513                        "allocation failed for proc with create key {}: {}",
514                        create_key,
515                        reason
516                    );
517                    return Err(AllocatorError::Other(anyhow::Error::msg(reason)));
518                }
519                ProcState::Failed {
520                    world_id,
521                    description,
522                } => {
523                    tracing::error!(
524                        name,
525                        status,
526                        "allocation failed for world {}: {}",
527                        world_id,
528                        description
529                    );
530                    return Err(AllocatorError::Other(anyhow::Error::msg(description)));
531                }
532            }
533        }
534
535        // We collect all the ranks at this point of completion, so that we can
536        // avoid holding Rcs across awaits.
537        Ok(running.into_iter().map(Option::unwrap).collect())
538    }
539}
540
541/// If addr is Tcp or Metatls, use its IP address or hostname to create
542/// a new addr with port unspecified.
543///
544/// for other types of addr, return "any" address.
545pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> ChannelAddr {
546    match addr {
547        ChannelAddr::Tcp(socket) => {
548            let mut new_socket = socket.clone();
549            new_socket.set_port(0);
550            ChannelAddr::Tcp(new_socket)
551        }
552        ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => {
553            let mut new_socket = socket.clone();
554            new_socket.set_port(0);
555            ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket))
556        }
557        ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => {
558            ChannelAddr::MetaTls(MetaTlsAddr::Host {
559                hostname: hostname.clone(),
560                port: 0,
561            })
562        }
563        _ => addr.transport().any(),
564    }
565}
566
567pub(crate) fn serve_with_config<M: RemoteMessage>(
568    mut serve_addr: ChannelAddr,
569) -> anyhow::Result<(ChannelAddr, ChannelRx<M>)> {
570    fn set_as_inaddr_any(original: &mut SocketAddr) {
571        let inaddr_any: IpAddr = match &original {
572            SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(),
573            SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(),
574        };
575        original.set_ip(inaddr_any);
576    }
577
578    let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY);
579    let mut original_ip: Option<IpAddr> = None;
580    match &mut serve_addr {
581        ChannelAddr::Tcp(socket) => {
582            original_ip = Some(socket.ip().clone());
583            if use_inaddr_any {
584                set_as_inaddr_any(socket);
585                tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),);
586            }
587            if socket.port() == 0 {
588                socket.set_port(next_allowed_port(socket.ip().clone())?);
589            }
590        }
591        _ => {
592            if use_inaddr_any {
593                tracing::debug!(
594                    "can only bind to INADDR_ANY for TCP; got transport {}, addr {}",
595                    serve_addr.transport(),
596                    serve_addr
597                );
598            }
599        }
600    };
601
602    let (mut bound, rx) = channel::serve(serve_addr)?;
603
604    // Restore the original IP address if we used INADDR_ANY.
605    match &mut bound {
606        ChannelAddr::Tcp(socket) => {
607            if use_inaddr_any {
608                socket.set_ip(original_ip.unwrap());
609            }
610        }
611        _ => (),
612    }
613
614    Ok((bound, rx))
615}
616
617enum AllowedPorts {
618    Config { range: Vec<u16>, next: AtomicUsize },
619    Any,
620}
621
622impl AllowedPorts {
623    fn next(&self, ip: IpAddr) -> anyhow::Result<u16> {
624        match self {
625            Self::Config { range, next } => {
626                let mut count = 0;
627                loop {
628                    let i = next.fetch_add(1, Ordering::Relaxed);
629                    count += 1;
630                    // Since we do not have a good way to put release ports back to the list,
631                    // we opportunistically hope ports previously took already released. If
632                    // not, we'll just see error when binding to it later. This
633                    // is not much different from raising error here.
634                    let port = range.get(i % range.len()).cloned().unwrap();
635                    let socket = SocketAddr::new(ip, port);
636                    if TcpListener::bind(socket).is_ok() {
637                        tracing::debug!("taking port {port} from the allowed list",);
638                        return Ok(port);
639                    }
640                    if count == range.len() {
641                        anyhow::bail!(
642                            "fail to find a port because all ports in the allowed list are already bound"
643                        );
644                    }
645                }
646            }
647            Self::Any => Ok(0),
648        }
649    }
650}
651
652static ALLOWED_PORTS: OnceLock<Mutex<AllowedPorts>> = OnceLock::new();
653fn next_allowed_port(ip: IpAddr) -> anyhow::Result<u16> {
654    let mutex = ALLOWED_PORTS.get_or_init(|| {
655        let ports = match config::global::try_get_cloned(REMOTE_ALLOC_ALLOWED_PORT_RANGE) {
656            Some(range) => AllowedPorts::Config {
657                range: range.into_iter().collect(),
658                next: AtomicUsize::new(0),
659            },
660            None => AllowedPorts::Any,
661        };
662        Mutex::new(ports)
663    });
664    mutex.lock().unwrap().next(ip)
665}
666
667pub mod test_utils {
668    use std::time::Duration;
669
670    use hyperactor::Actor;
671    use hyperactor::Context;
672    use hyperactor::Handler;
673    use hyperactor::Named;
674    use libc::atexit;
675    use tokio::sync::broadcast::Receiver;
676    use tokio::sync::broadcast::Sender;
677
678    use super::*;
679
680    extern "C" fn exit_handler() {
681        loop {
682            #[allow(clippy::disallowed_methods)]
683            std::thread::sleep(Duration::from_secs(60));
684        }
685    }
686
687    // This can't be defined under a `#[cfg(test)]` because there needs to
688    // be an entry in the spawnable actor registry in the executable
689    // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor
690    // mesh test suite.
691    #[derive(Debug, Default, Actor)]
692    #[hyperactor::export(
693        spawn = true,
694        handlers = [
695            Wait
696        ],
697    )]
698    pub struct TestActor;
699
700    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
701    pub struct Wait;
702
703    #[async_trait]
704    impl Handler<Wait> for TestActor {
705        async fn handle(&mut self, _: &Context<Self>, _: Wait) -> Result<(), anyhow::Error> {
706            // SAFETY:
707            // This is in order to simulate a process in tests that never exits.
708            unsafe {
709                atexit(exit_handler);
710            }
711            Ok(())
712        }
713    }
714
715    /// Test wrapper around MockAlloc to allow us to block next() calls since
716    /// mockall doesn't support returning futures.
717    pub struct MockAllocWrapper {
718        pub alloc: MockAlloc,
719        pub block_next_after: usize,
720        notify_tx: Sender<()>,
721        notify_rx: Receiver<()>,
722        next_unblocked: bool,
723    }
724
725    impl MockAllocWrapper {
726        pub fn new(alloc: MockAlloc) -> Self {
727            Self::new_block_next(alloc, usize::MAX)
728        }
729
730        pub fn new_block_next(alloc: MockAlloc, count: usize) -> Self {
731            let (tx, rx) = tokio::sync::broadcast::channel(1);
732            Self {
733                alloc,
734                block_next_after: count,
735                notify_tx: tx,
736                notify_rx: rx,
737                next_unblocked: false,
738            }
739        }
740
741        pub fn notify_tx(&self) -> Sender<()> {
742            self.notify_tx.clone()
743        }
744    }
745
746    #[async_trait]
747    impl Alloc for MockAllocWrapper {
748        async fn next(&mut self) -> Option<ProcState> {
749            match self.block_next_after {
750                0 => {
751                    if !self.next_unblocked {
752                        self.notify_rx.recv().await.unwrap();
753                        self.next_unblocked = true;
754                    }
755                }
756                1.. => {
757                    self.block_next_after -= 1;
758                }
759            }
760
761            self.alloc.next().await
762        }
763
764        fn spec(&self) -> &AllocSpec {
765            self.alloc.spec()
766        }
767
768        fn extent(&self) -> &Extent {
769            self.alloc.extent()
770        }
771
772        fn world_id(&self) -> &WorldId {
773            self.alloc.world_id()
774        }
775
776        async fn stop(&mut self) -> Result<(), AllocatorError> {
777            self.alloc.stop().await
778        }
779    }
780}
781
782#[cfg(test)]
783pub(crate) mod testing {
784    use core::panic;
785    use std::collections::HashMap;
786    use std::collections::HashSet;
787    use std::time::Duration;
788
789    use hyperactor::Instance;
790    use hyperactor::actor::remote::Remote;
791    use hyperactor::channel;
792    use hyperactor::context;
793    use hyperactor::mailbox;
794    use hyperactor::mailbox::BoxedMailboxSender;
795    use hyperactor::mailbox::DialMailboxRouter;
796    use hyperactor::mailbox::IntoBoxedMailboxSender;
797    use hyperactor::mailbox::MailboxServer;
798    use hyperactor::mailbox::UndeliverableMailboxSender;
799    use hyperactor::proc::Proc;
800    use hyperactor::reference::Reference;
801    use ndslice::extent;
802    use tokio::process::Command;
803
804    use super::*;
805    use crate::alloc::test_utils::TestActor;
806    use crate::alloc::test_utils::Wait;
807    use crate::proc_mesh::default_transport;
808    use crate::proc_mesh::mesh_agent::GspawnResult;
809    use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
810
811    #[macro_export]
812    macro_rules! alloc_test_suite {
813        ($allocator:expr) => {
814            #[tokio::test]
815            async fn test_allocator_basic() {
816                $crate::alloc::testing::test_allocator_basic($allocator).await;
817            }
818        };
819    }
820
821    pub(crate) async fn test_allocator_basic(mut allocator: impl Allocator) {
822        let extent = extent!(replica = 4);
823        let mut alloc = allocator
824            .allocate(AllocSpec {
825                extent: extent.clone(),
826                constraints: Default::default(),
827                proc_name: None,
828                transport: default_transport(),
829                proc_allocation_mode: Default::default(),
830            })
831            .await
832            .unwrap();
833
834        // Get everything up into running state. We require that we get
835        // procs 0..4.
836        let mut procs = HashMap::new();
837        let mut created = HashMap::new();
838        let mut running = HashSet::new();
839        while running.len() != 4 {
840            match alloc.next().await.unwrap() {
841                ProcState::Created {
842                    create_key, point, ..
843                } => {
844                    created.insert(create_key, point);
845                }
846                ProcState::Running {
847                    create_key,
848                    proc_id,
849                    ..
850                } => {
851                    assert!(running.insert(create_key.clone()));
852                    procs.insert(proc_id, created.remove(&create_key).unwrap());
853                }
854                event => panic!("unexpected event: {:?}", event),
855            }
856        }
857
858        // We should have complete coverage of all points.
859        let points: HashSet<_> = procs.values().collect();
860        for x in 0..4 {
861            assert!(points.contains(&extent.point(vec![x]).unwrap()));
862        }
863
864        // Every proc should belong to the same "world" (alloc).
865        let worlds: HashSet<_> = procs.keys().map(|proc_id| proc_id.world_id()).collect();
866        assert_eq!(worlds.len(), 1);
867
868        // Now, stop the alloc and make sure it shuts down cleanly.
869
870        alloc.stop().await.unwrap();
871        let mut stopped = HashSet::new();
872        while let Some(ProcState::Stopped {
873            create_key, reason, ..
874        }) = alloc.next().await
875        {
876            assert_eq!(reason, ProcStopReason::Stopped);
877            stopped.insert(create_key);
878        }
879        assert!(alloc.next().await.is_none());
880        assert_eq!(stopped, running);
881    }
882
883    async fn spawn_proc(
884        transport: ChannelTransport,
885    ) -> (DialMailboxRouter, Instance<()>, Proc, ChannelAddr) {
886        let (router_channel_addr, router_rx) =
887            channel::serve(ChannelAddr::any(transport.clone())).unwrap();
888        let router =
889            DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
890        router.clone().serve(router_rx);
891
892        let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0);
893        let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(transport)).unwrap();
894        let client_proc = Proc::new(
895            client_proc_id.clone(),
896            BoxedMailboxSender::new(router.clone()),
897        );
898        client_proc.clone().serve(client_rx);
899        router.bind(client_proc_id.clone().into(), client_proc_addr);
900        (
901            router,
902            client_proc.instance("test_proc").unwrap().0,
903            client_proc,
904            router_channel_addr,
905        )
906    }
907
908    async fn spawn_test_actor(
909        rank: usize,
910        client_proc: &Proc,
911        cx: &impl context::Actor,
912        router_channel_addr: ChannelAddr,
913        mesh_agent: ActorRef<ProcMeshAgent>,
914    ) -> ActorRef<TestActor> {
915        let (supervisor, _supervisor_handle) = client_proc.instance("supervisor").unwrap();
916        let (supervison_port, _) = supervisor.open_port();
917        let (config_handle, _) = cx.mailbox().open_port();
918        mesh_agent
919            .configure(
920                cx,
921                rank,
922                router_channel_addr,
923                Some(supervison_port.bind()),
924                HashMap::new(),
925                config_handle.bind(),
926                false,
927            )
928            .await
929            .unwrap();
930        let remote = Remote::collect();
931        let actor_type = remote
932            .name_of::<TestActor>()
933            .ok_or(anyhow::anyhow!("actor not registered"))
934            .unwrap()
935            .to_string();
936        let params = &();
937        let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
938        // gspawn actor
939        mesh_agent
940            .gspawn(
941                cx,
942                actor_type,
943                "Stuck".to_string(),
944                bincode::serialize(params).unwrap(),
945                completed_handle.bind(),
946            )
947            .await
948            .unwrap();
949        let result = completed_receiver.recv().await.unwrap();
950        match result {
951            GspawnResult::Success { actor_id, .. } => ActorRef::attest(actor_id),
952            GspawnResult::Error(error_msg) => {
953                panic!("gspawn failed: {}", error_msg);
954            }
955        }
956    }
957
958    /// In order to simulate stuckness, we have to do two things:
959    /// An actor that is blocked forever AND
960    /// a proc that does not time out when it is asked to wait for
961    /// a stuck actor.
962    #[tokio::test]
963    #[cfg(fbcode_build)]
964    async fn test_allocator_stuck_task() {
965        // Override config.
966        // Use temporary config for this test
967        let config = hyperactor::config::global::lock();
968        let _guard = config.override_key(
969            hyperactor::config::PROCESS_EXIT_TIMEOUT,
970            Duration::from_secs(1),
971        );
972
973        let command = Command::new(crate::testresource::get(
974            "monarch/hyperactor_mesh/bootstrap",
975        ));
976        let mut allocator = ProcessAllocator::new(command);
977        let mut alloc = allocator
978            .allocate(AllocSpec {
979                extent: extent! { replica = 1 },
980                constraints: Default::default(),
981                proc_name: None,
982                transport: ChannelTransport::Unix,
983                proc_allocation_mode: Default::default(),
984            })
985            .await
986            .unwrap();
987
988        // Get everything up into running state. We require that we get
989        let mut procs = HashMap::new();
990        let mut running = HashSet::new();
991        let mut actor_ref = None;
992        let (router, client, client_proc, router_addr) = spawn_proc(alloc.transport()).await;
993        while running.is_empty() {
994            match alloc.next().await.unwrap() {
995                ProcState::Created {
996                    create_key, point, ..
997                } => {
998                    procs.insert(create_key, point);
999                }
1000                ProcState::Running {
1001                    create_key,
1002                    proc_id,
1003                    mesh_agent,
1004                    addr,
1005                } => {
1006                    router.bind(Reference::Proc(proc_id.clone()), addr.clone());
1007
1008                    assert!(procs.contains_key(&create_key));
1009                    assert!(!running.contains(&create_key));
1010
1011                    actor_ref = Some(
1012                        spawn_test_actor(0, &client_proc, &client, router_addr, mesh_agent).await,
1013                    );
1014                    running.insert(create_key.clone());
1015                    break;
1016                }
1017                event => panic!("unexpected event: {:?}", event),
1018            }
1019        }
1020        assert!(actor_ref.unwrap().send(&client, Wait).is_ok());
1021
1022        // There is a stuck actor! We should get a watchdog failure.
1023        alloc.stop().await.unwrap();
1024        let mut stopped = HashSet::new();
1025        while let Some(ProcState::Stopped {
1026            create_key, reason, ..
1027        }) = alloc.next().await
1028        {
1029            assert_eq!(reason, ProcStopReason::Watchdog);
1030            stopped.insert(create_key);
1031        }
1032        assert!(alloc.next().await.is_none());
1033        assert_eq!(stopped, running);
1034    }
1035}