hyperactor/
simnet.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)]
10
11//! A simulator capable of simulating Hyperactor's network channels (see: [`channel`]).
12//! The simulator can simulate message delivery delays and failures, and is used for
13//! testing and development of message distribution techniques.
14
15use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::sync::oneshot;
42use tokio::task::JoinError;
43use tokio::task::JoinHandle;
44use tokio::time::interval;
45
46use crate::ActorId;
47use crate::ProcId;
48use crate::channel::ChannelAddr;
49use crate::clock::Clock;
50use crate::clock::RealClock;
51use crate::clock::SimClock;
52
53static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
54
55/// A handle for SimNet through which you can send and schedule events in the
56/// network.
57///
58/// Return the \[`NotStarted`\] error when called before `simnet::start()` has been called
59#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
60pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
61    match HANDLE.get() {
62        Some(handle) => Ok(handle),
63        None => Err(SimNetError::Closed("SimNet not started".to_string())),
64    }
65}
66
67const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
68
69/// This is used to define an Address-type for the network.
70/// Addresses are bound to nodes in the network.
71pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
72impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
73
74/// Minimum time unit for the simulator.
75pub type SimulatorTimeInstant = tokio::time::Instant;
76
77/// The unit of execution for the simulator.
78/// Using handle(), simnet can schedule executions in the network.
79/// If you want to send a message for example, you would want to implement
80/// a MessageDeliveryEvent much on the lines expressed in simnet tests.
81/// You can also do other more advanced concepts such as node churn,
82/// or even simulate process spawns in a distributed system. For example,
83/// one can implement a SystemActorSimEvent in order to spawn a system
84/// actor.
85#[async_trait]
86pub trait Event: Send + Sync + Debug {
87    /// This is the method that will be called when the simulator fires the event
88    /// at a particular time instant. Examples:
89    /// For messages, it will be delivering the message to the dst's receiver queue.
90    /// For a proc spawn, it will be creating the proc object and instantiating it.
91    /// For any event that manipulates the network (like adding/removing nodes etc.)
92    /// implement handle_network().
93    async fn handle(&self) -> Result<(), SimNetError>;
94
95    /// This is the method that will be called when the simulator fires the event
96    /// Unless you need to make changes to the network, you do not have to implement this.
97    /// Only implement handle() method for all non-simnet requirements.
98    async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
99        self.handle().await
100    }
101
102    /// The latency of the event. This could be network latency, induced latency (sleep), or
103    /// GPU work latency.
104    fn duration(&self) -> tokio::time::Duration;
105
106    /// A user-friendly summary of the event
107    fn summary(&self) -> String;
108}
109
110/// This is a simple event that is used to join a node to the network.
111/// It is used to bind a node to a channel address.
112#[derive(Debug)]
113struct NodeJoinEvent {
114    channel_addr: ChannelAddr,
115}
116
117#[async_trait]
118impl Event for NodeJoinEvent {
119    async fn handle(&self) -> Result<(), SimNetError> {
120        Ok(())
121    }
122
123    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
124        self.handle().await
125    }
126
127    fn duration(&self) -> tokio::time::Duration {
128        tokio::time::Duration::ZERO
129    }
130
131    fn summary(&self) -> String {
132        "Node join".into()
133    }
134}
135
136#[derive(Debug)]
137/// A pytorch operation
138pub struct TorchOpEvent {
139    op: String,
140    done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
141    args_string: String,
142    kwargs_string: String,
143    worker_actor_id: ActorId,
144}
145
146#[async_trait]
147impl Event for TorchOpEvent {
148    async fn handle(&self) -> Result<(), SimNetError> {
149        Ok(())
150    }
151
152    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
153        let mut guard = self.done_tx.lock().await;
154        match guard.take() {
155            Some(done_tx) => {
156                let _ = done_tx
157                    .send(())
158                    .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
159            }
160            None => {
161                return Err(SimNetError::Closed("already sent once".to_string()));
162            }
163        }
164        Ok(())
165    }
166
167    fn duration(&self) -> tokio::time::Duration {
168        tokio::time::Duration::from_millis(100)
169    }
170
171    fn summary(&self) -> String {
172        let kwargs_string = if self.kwargs_string.is_empty() {
173            "".to_string()
174        } else {
175            format!(", {}", self.kwargs_string)
176        };
177        format!(
178            "[{}] Torch Op: {}({}{})",
179            self.worker_actor_id, self.op, self.args_string, kwargs_string
180        )
181    }
182}
183
184impl TorchOpEvent {
185    /// Creates a new TorchOpEvent.
186    pub fn new(
187        op: String,
188        done_tx: oneshot::Sender<()>,
189        args_string: String,
190        kwargs_string: String,
191        worker_actor_id: ActorId,
192    ) -> Box<Self> {
193        Box::new(Self {
194            op,
195            done_tx: Arc::new(Mutex::new(Some(done_tx))),
196            args_string,
197            kwargs_string,
198            worker_actor_id,
199        })
200    }
201}
202
203#[derive(Debug)]
204pub(crate) struct SleepEvent {
205    done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
206    duration: tokio::time::Duration,
207}
208
209impl SleepEvent {
210    pub(crate) fn new(done_tx: oneshot::Sender<()>, duration: tokio::time::Duration) -> Box<Self> {
211        Box::new(Self {
212            done_tx: Arc::new(Mutex::new(Some(done_tx))),
213            duration,
214        })
215    }
216}
217
218#[async_trait]
219impl Event for SleepEvent {
220    async fn handle(&self) -> Result<(), SimNetError> {
221        Ok(())
222    }
223
224    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
225        let mut guard = self.done_tx.lock().await;
226        match guard.take() {
227            Some(done_tx) => {
228                let _ = done_tx
229                    .send(())
230                    .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
231            }
232            None => {
233                return Err(SimNetError::Closed("already sent once".to_string()));
234            }
235        }
236        Ok(())
237    }
238
239    fn duration(&self) -> tokio::time::Duration {
240        self.duration
241    }
242
243    fn summary(&self) -> String {
244        format!("Sleeping for {} ms", self.duration.as_millis())
245    }
246}
247
248/// Each message is timestamped with the delivery time
249/// of the message to the sender.
250/// The timestamp is used to determine the order in which
251/// messages are delivered to senders.
252#[derive(Debug)]
253pub(crate) struct ScheduledEvent {
254    pub(crate) time: SimulatorTimeInstant,
255    pub(crate) event: Box<dyn Event>,
256}
257
258/// Dispatcher is a trait that defines the send operation.
259/// The send operation takes a target address and a data buffer.
260/// This method is called when the simulator is ready for the message to be received
261/// by the target address.
262#[async_trait]
263pub trait Dispatcher<A> {
264    /// Send a raw data blob to the given target.
265    async fn send(&self, target: A, data: wirevalue::Any) -> Result<(), SimNetError>;
266}
267
268/// SimNetError is used to indicate errors that occur during
269/// network simulation.
270#[derive(thiserror::Error, Debug)]
271#[non_exhaustive]
272pub enum SimNetError {
273    /// An invalid address was encountered.
274    #[error("invalid address: {0}")]
275    InvalidAddress(String),
276
277    /// An invalid node was encountered.
278    #[error("invalid node: {0}")]
279    InvalidNode(String, #[source] anyhow::Error),
280
281    /// An invalid parameter was encountered.
282    #[error("invalid arg: {0}")]
283    InvalidArg(String),
284
285    /// The simulator has been closed.
286    #[error("closed: {0}")]
287    Closed(String),
288
289    /// Timeout when waiting for something.
290    #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
291    Timeout(Duration, String),
292
293    /// Cannot deliver the message because destination address is missing.
294    #[error("missing destination address")]
295    MissingDestinationAddress,
296
297    /// SimnetHandle being accessed without starting simnet
298    #[error("simnet not started")]
299    NotStarted,
300}
301
302struct State {
303    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
304    scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
305    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
306    // only if the earliest event in `scheduled_events` occurs after the earliest event in this queue
307    // or some debounce period has passed where there are only events in this queue.
308    unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
309}
310
311/// The state of the python training script.
312#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
313pub enum TrainingScriptState {
314    /// The training script is issuing commands
315    Running,
316    /// The training script is waiting for the backend to return a future result
317    Waiting,
318}
319
320/// A distribution of latencies that can be sampled from
321pub enum LatencyDistribution {
322    /// A beta distribution scaled to a given range of values
323    Beta(BetaDistribution),
324}
325
326impl LatencyDistribution {
327    fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
328        match &self {
329            LatencyDistribution::Beta(sampler) => sampler.sample(rng),
330        }
331    }
332}
333
334/// A beta distribution scaled to a given range of values.
335pub struct BetaDistribution {
336    min_duration: tokio::time::Duration,
337    max_duration: tokio::time::Duration,
338    dist: rand_distr::Beta<f64>,
339}
340
341impl BetaDistribution {
342    /// Sample a sclaed value from the distribution.
343    pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
344        let sample = self.dist.sample(rng);
345
346        self.min_duration
347            + tokio::time::Duration::from_micros(
348                (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
349            )
350    }
351
352    /// Create a new beta distribution.
353    pub fn new(
354        min_duration: tokio::time::Duration,
355        max_duration: tokio::time::Duration,
356        alpha: f64,
357        beta: f64,
358    ) -> anyhow::Result<Self> {
359        if min_duration > max_duration {
360            return Err(anyhow::anyhow!(
361                "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
362                min_duration,
363                max_duration
364            ));
365        }
366        Ok(Self {
367            min_duration,
368            max_duration,
369            dist: rand_distr::Beta::new(alpha, beta)?,
370        })
371    }
372}
373/// Configuration for latencies between distances for the simulator
374pub struct LatencyConfig {
375    /// inter-region latency distribution
376    pub inter_region_distribution: LatencyDistribution,
377    /// inter-data center latency distribution
378    pub inter_dc_distribution: LatencyDistribution,
379    /// inter-zone latency distribution
380    pub inter_zone_distribution: LatencyDistribution,
381    /// Single random number generator for all distributions to ensure deterministic sampling
382    pub rng: StdRng,
383}
384
385impl LatencyConfig {
386    fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
387        match distance {
388            Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
389            Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
390            Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
391            Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
392        }
393    }
394}
395
396impl Default for LatencyConfig {
397    fn default() -> Self {
398        let seed: u64 = 0000;
399        let mut seed_bytes = [0u8; 32];
400        seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
401
402        Self {
403            inter_region_distribution: LatencyDistribution::Beta(
404                BetaDistribution::new(
405                    tokio::time::Duration::from_millis(500),
406                    tokio::time::Duration::from_millis(1000),
407                    2.0,
408                    1.0,
409                )
410                .unwrap(),
411            ),
412            inter_dc_distribution: LatencyDistribution::Beta(
413                BetaDistribution::new(
414                    tokio::time::Duration::from_millis(50),
415                    tokio::time::Duration::from_millis(100),
416                    2.0,
417                    1.0,
418                )
419                .unwrap(),
420            ),
421            inter_zone_distribution: LatencyDistribution::Beta(
422                BetaDistribution::new(
423                    tokio::time::Duration::from_millis(5),
424                    tokio::time::Duration::from_millis(10),
425                    2.0,
426                    1.0,
427                )
428                .unwrap(),
429            ),
430            rng: StdRng::from_seed(seed_bytes),
431        }
432    }
433}
434
435/// A handle to a running [`SimNet`] instance.
436pub struct SimNetHandle {
437    join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
438    event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
439    pending_event_count: Arc<AtomicUsize>,
440    /// A receiver to receive simulator operational messages.
441    /// The receiver can be moved out of the simnet handle.
442    training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
443    /// Signal to stop the simnet loop
444    stop_signal: Arc<AtomicBool>,
445    resources: DashMap<ProcId, Point>,
446    latencies: std::sync::Mutex<LatencyConfig>,
447}
448
449impl SimNetHandle {
450    /// Sends an event to be scheduled onto the simnet's event loop
451    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
452    pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
453        self.send_event_impl(event, true)
454    }
455
456    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
457    fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
458        self.pending_event_count
459            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
460        self.event_tx
461            .send((event, advanceable, None))
462            .map_err(|err| SimNetError::Closed(err.to_string()))
463    }
464
465    /// Sends an non-advanceable event to be scheduled onto the simnet's event loop
466    /// A non-advanceable event is an event that cannot advance the simnet's time unless
467    /// the earliest event in the simnet's advancing event queue occurs after the earliest
468    /// event in the simnet's non-advancing event queue, or some debounce period has passed
469    /// where there are only events in the simnet's non-advancing event queue.
470    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
471    pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
472        self.send_event_impl(event, false)
473    }
474
475    /// Sends an event that already has a scheduled time onto the simnet's event loop
476    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
477    pub(crate) fn send_scheduled_event(
478        &self,
479        ScheduledEvent { event, time }: ScheduledEvent,
480    ) -> Result<(), SimNetError> {
481        self.pending_event_count
482            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
483        self.event_tx
484            .send((event, true, Some(time)))
485            .map_err(|err| SimNetError::Closed(err.to_string()))
486    }
487
488    /// Let the simnet know if the training script is running or waiting for the backend
489    /// to return a future result.
490    pub fn set_training_script_state(&self, state: TrainingScriptState) {
491        self.training_script_state_tx.send(state).unwrap();
492    }
493
494    /// Bind the given address to this simulator instance.
495    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
496    pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
497        self.pending_event_count
498            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
499        self.event_tx
500            .send((
501                Box::new(NodeJoinEvent {
502                    channel_addr: address,
503                }),
504                true,
505                None,
506            ))
507            .map_err(|err| SimNetError::Closed(err.to_string()))
508    }
509
510    /// Close the simulator, processing pending messages before
511    /// completing the returned future.
512    pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
513        // Signal the simnet loop to stop
514        self.stop_signal.store(true, Ordering::SeqCst);
515
516        let mut guard = self.join_handle.lock().await;
517        if let Some(handle) = guard.take() {
518            handle.await
519        } else {
520            Ok(vec![])
521        }
522    }
523
524    /// Wait for all of the received events to be scheduled for flight.
525    /// It ticks the simnet time till all of the scheduled events are processed.
526    pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
527        let pending_event_count = self.pending_event_count.clone();
528        // poll for the pending event count to be zero.
529        let mut interval = interval(Duration::from_millis(10));
530        let deadline = RealClock.now() + timeout;
531        while RealClock.now() < deadline {
532            interval.tick().await;
533            if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
534                return Ok(());
535            }
536        }
537        Err(SimNetError::Timeout(
538            timeout,
539            "timeout waiting for received events to be scheduled".to_string(),
540        ))
541    }
542
543    /// Register the location in resource space for a Proc
544    pub fn register_proc(&self, proc_id: ProcId, point: Point) {
545        self.resources.insert(proc_id, point);
546    }
547
548    /// Sample a latency between two procs
549    pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
550        let distances = [
551            Distance::Region,
552            Distance::DataCenter,
553            Distance::Zone,
554            Distance::Rack,
555            Distance::Host,
556            Distance::Same,
557        ];
558
559        let src_coords = self
560            .resources
561            .get(src)
562            .map(|point| point.coords().clone())
563            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
564
565        let dest_coords = self
566            .resources
567            .get(dest)
568            .map(|point| point.coords().clone())
569            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
570
571        for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
572            if src != dest {
573                let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
574                return guard.from_distance(&distance);
575            }
576        }
577
578        let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
579        guard.from_distance(&Distance::Same)
580    }
581}
582
583#[derive(Debug)]
584enum Distance {
585    Region,
586    DataCenter,
587    Zone,
588    Rack,
589    Host,
590    Same,
591}
592
593/// SimNet defines a network of nodes.
594/// Each node is identified by a unique id.
595/// The network is represented as a graph of nodes.
596/// The graph is represented as a map of edges.
597/// The network also has a cloud of inflight messages
598pub struct SimNet {
599    address_book: DashSet<ChannelAddr>,
600    state: State,
601    max_latency: Duration,
602    records: Vec<SimulatorEventRecord>,
603    // number of events that has been received but not yet processed.
604    pending_event_count: Arc<AtomicUsize>,
605}
606
607/// Starts a sim net.
608pub fn start() {
609    start_with_config(LatencyConfig::default())
610}
611
612/// Starts a sim net with configured latencies between distances
613pub fn start_with_config(config: LatencyConfig) {
614    let max_duration_ms = 1000 * 10;
615    // Construct a topology with one node: the default A.
616    let address_book: DashSet<ChannelAddr> = DashSet::new();
617
618    let (training_script_state_tx, training_script_state_rx) =
619        tokio::sync::watch::channel(TrainingScriptState::Waiting);
620    let (event_tx, event_rx) =
621        mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
622    let pending_event_count = Arc::new(AtomicUsize::new(0));
623    let stop_signal = Arc::new(AtomicBool::new(false));
624
625    let join_handle = Mutex::new(Some({
626        let pending_event_count = pending_event_count.clone();
627        let stop_signal = stop_signal.clone();
628
629        tokio::spawn(async move {
630            SimNet {
631                address_book,
632                state: State {
633                    scheduled_events: BTreeMap::new(),
634                    unadvanceable_scheduled_events: BTreeMap::new(),
635                },
636                max_latency: Duration::from_millis(max_duration_ms),
637                records: Vec::new(),
638                pending_event_count,
639            }
640            .run(event_rx, training_script_state_rx, stop_signal)
641            .await
642        })
643    }));
644
645    HANDLE.get_or_init(|| SimNetHandle {
646        join_handle,
647        event_tx,
648        pending_event_count,
649        training_script_state_tx,
650        stop_signal,
651        resources: DashMap::new(),
652        latencies: std::sync::Mutex::new(config),
653    });
654}
655
656impl SimNet {
657    fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
658        // Get latency
659        ScheduledEvent {
660            time: SimClock.now() + event.duration(),
661            event,
662        }
663    }
664
665    /// Schedule the event into the network.
666    fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
667        let start_at = SimClock.now();
668        let end_at = scheduled_event.time;
669
670        self.records.push(SimulatorEventRecord {
671            summary: scheduled_event.event.summary(),
672            start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
673            end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
674        });
675
676        if advanceable {
677            self.state
678                .scheduled_events
679                .entry(scheduled_event.time)
680                .or_default()
681                .push(scheduled_event);
682        } else {
683            self.state
684                .unadvanceable_scheduled_events
685                .entry(scheduled_event.time)
686                .or_default()
687                .push(scheduled_event);
688        }
689    }
690
691    /// Run the simulation. This will dispatch all the messages in the network.
692    /// And wait for new ones.
693    async fn run(
694        &mut self,
695        mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
696        training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
697        stop_signal: Arc<AtomicBool>,
698    ) -> Vec<SimulatorEventRecord> {
699        // The simulated number of milliseconds the training script
700        // has spent waiting for the backend to resolve a future
701        let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
702        // Duration elapsed while only non_advanceable_events has events
703        let mut debounce_timer: Option<tokio::time::Instant> = None;
704
705        let debounce_duration = std::env::var("SIM_DEBOUNCE")
706            .ok()
707            .and_then(|val| val.parse::<u64>().ok())
708            .unwrap_or(1);
709
710        'outer: loop {
711            // Check if we should stop
712            if stop_signal.load(Ordering::SeqCst) {
713                break 'outer self.records.clone();
714            }
715
716            while let Ok(Some((event, advanceable, time))) = RealClock
717                .timeout(
718                    tokio::time::Duration::from_millis(debounce_duration),
719                    event_rx.recv(),
720                )
721                .await
722            {
723                let scheduled_event = match time {
724                    Some(time) => ScheduledEvent {
725                        time: time + training_script_waiting_time,
726                        event,
727                    },
728                    None => self.create_scheduled_event(event),
729                };
730                self.schedule_event(scheduled_event, advanceable);
731            }
732
733            {
734                // If the training script is running and issuing commands
735                // it is not safe to advance past the training script time
736                // otherwise a command issued by the training script may
737                // be scheduled for a time in the past
738                if training_script_state_rx.borrow().is_running()
739                    && self
740                        .state
741                        .scheduled_events
742                        .first_key_value()
743                        .is_some_and(|(time, _)| {
744                            *time > RealClock.now() + training_script_waiting_time
745                        })
746                {
747                    tokio::task::yield_now().await;
748                    continue;
749                }
750                match (
751                    self.state.scheduled_events.first_key_value(),
752                    self.state.unadvanceable_scheduled_events.first_key_value(),
753                ) {
754                    (None, Some(_)) if debounce_timer.is_none() => {
755                        // Start debounce timer when only the non-advancedable
756                        // queue has events and the timer has not already started
757                        debounce_timer = Some(RealClock.now());
758                    }
759                    // Timer already active
760                    (None, Some(_)) => {}
761                    // Reset timer when non-advanceable queue is not the only queue with events
762                    _ => {
763                        debounce_timer = None;
764                    }
765                }
766                // process for next delivery time.
767                let Some((scheduled_time, scheduled_events)) = (match (
768                    self.state.scheduled_events.first_key_value(),
769                    self.state.unadvanceable_scheduled_events.first_key_value(),
770                ) {
771                    (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
772                        if unadvanceable_time < advanceable_time {
773                            self.state.unadvanceable_scheduled_events.pop_first()
774                        } else {
775                            self.state.scheduled_events.pop_first()
776                        }
777                    }
778                    (Some(_), None) => self.state.scheduled_events.pop_first(),
779                    (None, Some(_)) => match debounce_timer {
780                        Some(time) => {
781                            if time.elapsed() > tokio::time::Duration::from_millis(1000) {
782                                // debounce interval has elapsed, reset timer
783                                debounce_timer = None;
784                                self.state.unadvanceable_scheduled_events.pop_first()
785                            } else {
786                                None
787                            }
788                        }
789                        None => None,
790                    },
791                    (None, None) => None,
792                }) else {
793                    tokio::select! {
794                        Some((event, advanceable, time)) = event_rx.recv() => {
795                            let scheduled_event = match time {
796                                Some(time) => ScheduledEvent {
797                                    time: time + training_script_waiting_time,
798                                    event,
799                                },
800                                None => self.create_scheduled_event(event),
801                            };
802                            self.schedule_event(scheduled_event, advanceable);
803                        },
804                        _ = RealClock.sleep(Duration::from_millis(10)) => {}
805                    }
806                    continue;
807                };
808                if training_script_state_rx.borrow().is_waiting() {
809                    let advanced_time = scheduled_time - SimClock.now();
810                    training_script_waiting_time += advanced_time;
811                }
812                SimClock.advance_to(scheduled_time);
813                for scheduled_event in scheduled_events {
814                    self.pending_event_count
815                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
816                    if scheduled_event.event.handle_network(self).await.is_err() {
817                        break 'outer self.records.clone(); //TODO
818                    }
819                }
820            }
821        }
822    }
823}
824
825fn serialize_optional_channel_addr<S>(
826    addr: &Option<ChannelAddr>,
827    serializer: S,
828) -> Result<S::Ok, S::Error>
829where
830    S: Serializer,
831{
832    match addr {
833        Some(addr) => serializer.serialize_str(&addr.to_string()),
834        None => serializer.serialize_none(),
835    }
836}
837
838fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
839where
840    D: Deserializer<'de>,
841{
842    let s = String::deserialize(deserializer)?;
843    s.parse().map_err(serde::de::Error::custom)
844}
845
846/// DeliveryRecord is a structure to bookkeep the message events.
847#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
848pub struct SimulatorEventRecord {
849    /// Event dependent summary for user
850    pub summary: String,
851    /// The time at which the message delivery was started.
852    pub start_at: u64,
853    /// The time at which the message was delivered to the receiver.
854    pub end_at: u64,
855}
856
857#[cfg(test)]
858mod tests {
859    use std::collections::HashMap;
860    use std::sync::Arc;
861
862    use async_trait::async_trait;
863    use ndslice::extent;
864    use tokio::sync::Mutex;
865
866    use super::*;
867    use crate::channel::sim::SimAddr;
868    use crate::clock::Clock;
869    use crate::clock::RealClock;
870    use crate::clock::SimClock;
871    use crate::id;
872    use crate::simnet;
873    use crate::simnet::Dispatcher;
874    use crate::simnet::Event;
875    use crate::simnet::SimNetError;
876
877    #[derive(Debug)]
878    struct MessageDeliveryEvent {
879        src_addr: SimAddr,
880        dest_addr: SimAddr,
881        data: wirevalue::Any,
882        duration: tokio::time::Duration,
883        dispatcher: Option<TestDispatcher>,
884    }
885
886    #[async_trait]
887    impl Event for MessageDeliveryEvent {
888        async fn handle(&self) -> Result<(), simnet::SimNetError> {
889            if let Some(dispatcher) = &self.dispatcher {
890                dispatcher
891                    .send(self.dest_addr.clone(), self.data.clone())
892                    .await?;
893            }
894            Ok(())
895        }
896        fn duration(&self) -> tokio::time::Duration {
897            self.duration
898        }
899
900        fn summary(&self) -> String {
901            format!(
902                "Sending message from {} to {}",
903                self.src_addr.addr().clone(),
904                self.dest_addr.addr().clone()
905            )
906        }
907    }
908
909    impl MessageDeliveryEvent {
910        fn new(
911            src_addr: SimAddr,
912            dest_addr: SimAddr,
913            data: wirevalue::Any,
914            dispatcher: Option<TestDispatcher>,
915            duration: tokio::time::Duration,
916        ) -> Self {
917            Self {
918                src_addr,
919                dest_addr,
920                data,
921                duration,
922                dispatcher,
923            }
924        }
925    }
926
927    #[derive(Debug, Clone)]
928    struct TestDispatcher {
929        pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<wirevalue::Any>>>>,
930    }
931
932    impl Default for TestDispatcher {
933        fn default() -> Self {
934            Self {
935                mbuffers: Arc::new(Mutex::new(HashMap::new())),
936            }
937        }
938    }
939
940    #[async_trait]
941    impl Dispatcher<SimAddr> for TestDispatcher {
942        async fn send(&self, target: SimAddr, data: wirevalue::Any) -> Result<(), SimNetError> {
943            let mut buf = self.mbuffers.lock().await;
944            buf.entry(target).or_default().push(data);
945            Ok(())
946        }
947    }
948
949    #[cfg(target_os = "linux")]
950    fn random_abstract_addr() -> ChannelAddr {
951        use rand::Rng;
952        use rand::distributions::Alphanumeric;
953
954        let random_string = rand::thread_rng()
955            .sample_iter(&Alphanumeric)
956            .take(24)
957            .map(char::from)
958            .collect::<String>();
959        format!("unix!@{random_string}").parse().unwrap()
960    }
961
962    #[tokio::test]
963    async fn test_handle_instantiation() {
964        start();
965        simnet_handle().unwrap().close().await.unwrap();
966    }
967
968    #[tokio::test]
969    async fn test_simnet_config() {
970        // Tests that we can create a simnet, config latency between distances and sample latencies between procs.
971        let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
972
973        let alice = id!(world[0]);
974        let bob = id!(world[1]);
975        let charlie = id!(world[2]);
976
977        let config = LatencyConfig {
978            inter_zone_distribution: LatencyDistribution::Beta(
979                BetaDistribution::new(
980                    tokio::time::Duration::from_millis(1000),
981                    tokio::time::Duration::from_millis(1000),
982                    1.0,
983                    1.0,
984                )
985                .unwrap(),
986            ),
987            inter_dc_distribution: LatencyDistribution::Beta(
988                BetaDistribution::new(
989                    tokio::time::Duration::from_millis(2000),
990                    tokio::time::Duration::from_millis(2000),
991                    1.0,
992                    1.0,
993                )
994                .unwrap(),
995            ),
996            ..Default::default()
997        };
998        start_with_config(config);
999
1000        let handle = simnet_handle().unwrap();
1001        handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1002        handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1003        handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1004        assert_eq!(
1005            handle.sample_latency(&alice, &bob),
1006            tokio::time::Duration::from_millis(1000)
1007        );
1008        assert_eq!(
1009            handle.sample_latency(&alice, &charlie),
1010            tokio::time::Duration::from_millis(2000)
1011        );
1012    }
1013
1014    #[tokio::test]
1015    async fn test_simnet_debounce() {
1016        let config = LatencyConfig {
1017            inter_zone_distribution: LatencyDistribution::Beta(
1018                BetaDistribution::new(
1019                    tokio::time::Duration::from_millis(1000),
1020                    tokio::time::Duration::from_millis(1000),
1021                    1.0,
1022                    1.0,
1023                )
1024                .unwrap(),
1025            ),
1026            ..Default::default()
1027        };
1028        start_with_config(config);
1029        let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1030        let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1031
1032        let latency = Duration::from_millis(10000);
1033
1034        let alice = SimAddr::new(alice).unwrap();
1035        let bob = SimAddr::new(bob).unwrap();
1036
1037        // Rapidly send 10 messages expecting that each one debounces the processing
1038        for _ in 0..10 {
1039            simnet_handle()
1040                .unwrap()
1041                .send_event(Box::new(MessageDeliveryEvent::new(
1042                    alice.clone(),
1043                    bob.clone(),
1044                    wirevalue::Any::serialize(&"123".to_string()).unwrap(),
1045                    None,
1046                    latency,
1047                )))
1048                .unwrap();
1049            RealClock
1050                .sleep(tokio::time::Duration::from_micros(500))
1051                .await;
1052        }
1053
1054        simnet_handle()
1055            .unwrap()
1056            .flush(Duration::from_secs(20))
1057            .await
1058            .unwrap();
1059
1060        let records = simnet_handle().unwrap().close().await;
1061        assert_eq!(records.as_ref().unwrap().len(), 10);
1062
1063        // If debounce is successful, the simnet will not advance to the delivery of any of
1064        // the messages before all are received
1065        assert_eq!(
1066            records.unwrap().last().unwrap().end_at,
1067            latency.as_millis() as u64
1068        );
1069    }
1070
1071    #[tokio::test]
1072    async fn test_sim_dispatch() {
1073        start();
1074        let sender = Some(TestDispatcher::default());
1075        let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1076        // // Create a simple network of 4 nodes.
1077        for i in 0..4 {
1078            addresses.push(
1079                format!("local:{}", i)
1080                    .parse::<simnet::ChannelAddr>()
1081                    .unwrap(),
1082            );
1083        }
1084
1085        let messages: Vec<wirevalue::Any> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1086            .into_iter()
1087            .map(|s| wirevalue::Any::serialize(&s.to_string()).unwrap())
1088            .collect();
1089
1090        let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1091        let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1092        let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1093        let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1094        let one = Box::new(MessageDeliveryEvent::new(
1095            addr_0.clone(),
1096            addr_1.clone(),
1097            messages[0].clone(),
1098            sender.clone(),
1099            tokio::time::Duration::ZERO,
1100        ));
1101        let two = Box::new(MessageDeliveryEvent::new(
1102            addr_2.clone(),
1103            addr_3.clone(),
1104            messages[1].clone(),
1105            sender.clone(),
1106            tokio::time::Duration::ZERO,
1107        ));
1108        let three = Box::new(MessageDeliveryEvent::new(
1109            addr_0.clone(),
1110            addr_1.clone(),
1111            messages[2].clone(),
1112            sender.clone(),
1113            tokio::time::Duration::ZERO,
1114        ));
1115
1116        simnet_handle().unwrap().send_event(one).unwrap();
1117        simnet_handle().unwrap().send_event(two).unwrap();
1118        simnet_handle().unwrap().send_event(three).unwrap();
1119
1120        simnet_handle()
1121            .unwrap()
1122            .flush(Duration::from_millis(1000))
1123            .await
1124            .unwrap();
1125        let records = simnet_handle().unwrap().close().await.unwrap();
1126        eprintln!("Records: {:?}", records);
1127        // Close the channel
1128        simnet_handle().unwrap().close().await.unwrap();
1129
1130        // Check results
1131        let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1132        assert_eq!(buf.len(), 2);
1133        assert_eq!(buf[&addr_1].len(), 2);
1134        assert_eq!(buf[&addr_3].len(), 1);
1135
1136        assert_eq!(buf[&addr_1][0], messages[0]);
1137        assert_eq!(buf[&addr_1][1], messages[2]);
1138        assert_eq!(buf[&addr_3][0], messages[1]);
1139    }
1140
1141    #[tokio::test]
1142    async fn test_sim_sleep() {
1143        start();
1144
1145        let start = SimClock.now();
1146        assert_eq!(
1147            SimClock.duration_since_start(start),
1148            tokio::time::Duration::ZERO
1149        );
1150
1151        SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1152
1153        let end = SimClock.now();
1154        assert_eq!(
1155            SimClock.duration_since_start(end),
1156            tokio::time::Duration::from_secs(10)
1157        );
1158    }
1159
1160    #[tokio::test]
1161    async fn test_torch_op() {
1162        start();
1163        let args_string = "1, 2".to_string();
1164        let kwargs_string = "a=2".to_string();
1165
1166        let (tx, rx) = oneshot::channel();
1167
1168        simnet_handle()
1169            .unwrap()
1170            .send_event(TorchOpEvent::new(
1171                "torch.ops.aten.ones.default".to_string(),
1172                tx,
1173                args_string,
1174                kwargs_string,
1175                id!(mesh_0_worker[0].worker_0),
1176            ))
1177            .unwrap();
1178
1179        rx.await.unwrap();
1180
1181        simnet_handle()
1182            .unwrap()
1183            .flush(Duration::from_millis(1000))
1184            .await
1185            .unwrap();
1186        let records = simnet_handle().unwrap().close().await;
1187        let expected_record = SimulatorEventRecord {
1188            summary:
1189                "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1190                    .to_string(),
1191            start_at: 0,
1192            end_at: 100,
1193        };
1194        assert!(records.as_ref().unwrap().len() == 1);
1195        assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1196    }
1197}