hyperactor/
simnet.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)]
10
11//! A simulator capable of simulating Hyperactor's network channels (see: [`channel`]).
12//! The simulator can simulate message delivery delays and failures, and is used for
13//! testing and development of message distribution techniques.
14
15use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::sync::oneshot;
42use tokio::task::JoinError;
43use tokio::task::JoinHandle;
44use tokio::time::interval;
45
46use crate::ActorId;
47use crate::ProcId;
48use crate::channel::ChannelAddr;
49use crate::clock::Clock;
50use crate::clock::RealClock;
51use crate::clock::SimClock;
52use crate::data::Serialized;
53
54static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
55
56/// A handle for SimNet through which you can send and schedule events in the
57/// network.
58///
59/// Return the \[`NotStarted`\] error when called before `simnet::start()` has been called
60#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
61pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
62    match HANDLE.get() {
63        Some(handle) => Ok(handle),
64        None => Err(SimNetError::Closed("SimNet not started".to_string())),
65    }
66}
67
68const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
69
70/// This is used to define an Address-type for the network.
71/// Addresses are bound to nodes in the network.
72pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
73impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
74
75/// Minimum time unit for the simulator.
76pub type SimulatorTimeInstant = tokio::time::Instant;
77
78/// The unit of execution for the simulator.
79/// Using handle(), simnet can schedule executions in the network.
80/// If you want to send a message for example, you would want to implement
81/// a MessageDeliveryEvent much on the lines expressed in simnet tests.
82/// You can also do other more advanced concepts such as node churn,
83/// or even simulate process spawns in a distributed system. For example,
84/// one can implement a SystemActorSimEvent in order to spawn a system
85/// actor.
86#[async_trait]
87pub trait Event: Send + Sync + Debug {
88    /// This is the method that will be called when the simulator fires the event
89    /// at a particular time instant. Examples:
90    /// For messages, it will be delivering the message to the dst's receiver queue.
91    /// For a proc spawn, it will be creating the proc object and instantiating it.
92    /// For any event that manipulates the network (like adding/removing nodes etc.)
93    /// implement handle_network().
94    async fn handle(&self) -> Result<(), SimNetError>;
95
96    /// This is the method that will be called when the simulator fires the event
97    /// Unless you need to make changes to the network, you do not have to implement this.
98    /// Only implement handle() method for all non-simnet requirements.
99    async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
100        self.handle().await
101    }
102
103    /// The latency of the event. This could be network latency, induced latency (sleep), or
104    /// GPU work latency.
105    fn duration(&self) -> tokio::time::Duration;
106
107    /// A user-friendly summary of the event
108    fn summary(&self) -> String;
109}
110
111/// This is a simple event that is used to join a node to the network.
112/// It is used to bind a node to a channel address.
113#[derive(Debug)]
114struct NodeJoinEvent {
115    channel_addr: ChannelAddr,
116}
117
118#[async_trait]
119impl Event for NodeJoinEvent {
120    async fn handle(&self) -> Result<(), SimNetError> {
121        Ok(())
122    }
123
124    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
125        self.handle().await
126    }
127
128    fn duration(&self) -> tokio::time::Duration {
129        tokio::time::Duration::ZERO
130    }
131
132    fn summary(&self) -> String {
133        "Node join".into()
134    }
135}
136
137#[derive(Debug)]
138/// A pytorch operation
139pub struct TorchOpEvent {
140    op: String,
141    done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
142    args_string: String,
143    kwargs_string: String,
144    worker_actor_id: ActorId,
145}
146
147#[async_trait]
148impl Event for TorchOpEvent {
149    async fn handle(&self) -> Result<(), SimNetError> {
150        Ok(())
151    }
152
153    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
154        let mut guard = self.done_tx.lock().await;
155        match guard.take() {
156            Some(done_tx) => {
157                let _ = done_tx
158                    .send(())
159                    .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
160            }
161            None => {
162                return Err(SimNetError::Closed("already sent once".to_string()));
163            }
164        }
165        Ok(())
166    }
167
168    fn duration(&self) -> tokio::time::Duration {
169        tokio::time::Duration::from_millis(100)
170    }
171
172    fn summary(&self) -> String {
173        let kwargs_string = if self.kwargs_string.is_empty() {
174            "".to_string()
175        } else {
176            format!(", {}", self.kwargs_string)
177        };
178        format!(
179            "[{}] Torch Op: {}({}{})",
180            self.worker_actor_id, self.op, self.args_string, kwargs_string
181        )
182    }
183}
184
185impl TorchOpEvent {
186    /// Creates a new TorchOpEvent.
187    pub fn new(
188        op: String,
189        done_tx: oneshot::Sender<()>,
190        args_string: String,
191        kwargs_string: String,
192        worker_actor_id: ActorId,
193    ) -> Box<Self> {
194        Box::new(Self {
195            op,
196            done_tx: Arc::new(Mutex::new(Some(done_tx))),
197            args_string,
198            kwargs_string,
199            worker_actor_id,
200        })
201    }
202}
203
204#[derive(Debug)]
205pub(crate) struct SleepEvent {
206    done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
207    duration: tokio::time::Duration,
208}
209
210impl SleepEvent {
211    pub(crate) fn new(done_tx: oneshot::Sender<()>, duration: tokio::time::Duration) -> Box<Self> {
212        Box::new(Self {
213            done_tx: Arc::new(Mutex::new(Some(done_tx))),
214            duration,
215        })
216    }
217}
218
219#[async_trait]
220impl Event for SleepEvent {
221    async fn handle(&self) -> Result<(), SimNetError> {
222        Ok(())
223    }
224
225    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
226        let mut guard = self.done_tx.lock().await;
227        match guard.take() {
228            Some(done_tx) => {
229                let _ = done_tx
230                    .send(())
231                    .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
232            }
233            None => {
234                return Err(SimNetError::Closed("already sent once".to_string()));
235            }
236        }
237        Ok(())
238    }
239
240    fn duration(&self) -> tokio::time::Duration {
241        self.duration
242    }
243
244    fn summary(&self) -> String {
245        format!("Sleeping for {} ms", self.duration.as_millis())
246    }
247}
248
249/// Each message is timestamped with the delivery time
250/// of the message to the sender.
251/// The timestamp is used to determine the order in which
252/// messages are delivered to senders.
253#[derive(Debug)]
254pub(crate) struct ScheduledEvent {
255    pub(crate) time: SimulatorTimeInstant,
256    pub(crate) event: Box<dyn Event>,
257}
258
259/// Dispatcher is a trait that defines the send operation.
260/// The send operation takes a target address and a data buffer.
261/// This method is called when the simulator is ready for the message to be received
262/// by the target address.
263#[async_trait]
264pub trait Dispatcher<A> {
265    /// Send a raw data blob to the given target.
266    async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
267}
268
269/// SimNetError is used to indicate errors that occur during
270/// network simulation.
271#[derive(thiserror::Error, Debug)]
272#[non_exhaustive]
273pub enum SimNetError {
274    /// An invalid address was encountered.
275    #[error("invalid address: {0}")]
276    InvalidAddress(String),
277
278    /// An invalid node was encountered.
279    #[error("invalid node: {0}")]
280    InvalidNode(String, #[source] anyhow::Error),
281
282    /// An invalid parameter was encountered.
283    #[error("invalid arg: {0}")]
284    InvalidArg(String),
285
286    /// The simulator has been closed.
287    #[error("closed: {0}")]
288    Closed(String),
289
290    /// Timeout when waiting for something.
291    #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
292    Timeout(Duration, String),
293
294    /// Cannot deliver the message because destination address is missing.
295    #[error("missing destination address")]
296    MissingDestinationAddress,
297
298    /// SimnetHandle being accessed without starting simnet
299    #[error("simnet not started")]
300    NotStarted,
301}
302
303struct State {
304    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
305    scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
306    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
307    // only if the earliest event in `scheduled_events` occurs after the earliest event in this queue
308    // or some debounce period has passed where there are only events in this queue.
309    unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
310}
311
312/// The state of the python training script.
313#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
314pub enum TrainingScriptState {
315    /// The training script is issuing commands
316    Running,
317    /// The training script is waiting for the backend to return a future result
318    Waiting,
319}
320
321/// A distribution of latencies that can be sampled from
322pub enum LatencyDistribution {
323    /// A beta distribution scaled to a given range of values
324    Beta(BetaDistribution),
325}
326
327impl LatencyDistribution {
328    fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
329        match &self {
330            LatencyDistribution::Beta(sampler) => sampler.sample(rng),
331        }
332    }
333}
334
335/// A beta distribution scaled to a given range of values.
336pub struct BetaDistribution {
337    min_duration: tokio::time::Duration,
338    max_duration: tokio::time::Duration,
339    dist: rand_distr::Beta<f64>,
340}
341
342impl BetaDistribution {
343    /// Sample a sclaed value from the distribution.
344    pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
345        let sample = self.dist.sample(rng);
346
347        self.min_duration
348            + tokio::time::Duration::from_micros(
349                (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
350            )
351    }
352
353    /// Create a new beta distribution.
354    pub fn new(
355        min_duration: tokio::time::Duration,
356        max_duration: tokio::time::Duration,
357        alpha: f64,
358        beta: f64,
359    ) -> anyhow::Result<Self> {
360        if min_duration > max_duration {
361            return Err(anyhow::anyhow!(
362                "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
363                min_duration,
364                max_duration
365            ));
366        }
367        Ok(Self {
368            min_duration,
369            max_duration,
370            dist: rand_distr::Beta::new(alpha, beta)?,
371        })
372    }
373}
374/// Configuration for latencies between distances for the simulator
375pub struct LatencyConfig {
376    /// inter-region latency distribution
377    pub inter_region_distribution: LatencyDistribution,
378    /// inter-data center latency distribution
379    pub inter_dc_distribution: LatencyDistribution,
380    /// inter-zone latency distribution
381    pub inter_zone_distribution: LatencyDistribution,
382    /// Single random number generator for all distributions to ensure deterministic sampling
383    pub rng: StdRng,
384}
385
386impl LatencyConfig {
387    fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
388        match distance {
389            Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
390            Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
391            Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
392            Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
393        }
394    }
395}
396
397impl Default for LatencyConfig {
398    fn default() -> Self {
399        let seed: u64 = 0000;
400        let mut seed_bytes = [0u8; 32];
401        seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
402
403        Self {
404            inter_region_distribution: LatencyDistribution::Beta(
405                BetaDistribution::new(
406                    tokio::time::Duration::from_millis(500),
407                    tokio::time::Duration::from_millis(1000),
408                    2.0,
409                    1.0,
410                )
411                .unwrap(),
412            ),
413            inter_dc_distribution: LatencyDistribution::Beta(
414                BetaDistribution::new(
415                    tokio::time::Duration::from_millis(50),
416                    tokio::time::Duration::from_millis(100),
417                    2.0,
418                    1.0,
419                )
420                .unwrap(),
421            ),
422            inter_zone_distribution: LatencyDistribution::Beta(
423                BetaDistribution::new(
424                    tokio::time::Duration::from_millis(5),
425                    tokio::time::Duration::from_millis(10),
426                    2.0,
427                    1.0,
428                )
429                .unwrap(),
430            ),
431            rng: StdRng::from_seed(seed_bytes),
432        }
433    }
434}
435
436/// A handle to a running [`SimNet`] instance.
437pub struct SimNetHandle {
438    join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
439    event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
440    pending_event_count: Arc<AtomicUsize>,
441    /// A receiver to receive simulator operational messages.
442    /// The receiver can be moved out of the simnet handle.
443    training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
444    /// Signal to stop the simnet loop
445    stop_signal: Arc<AtomicBool>,
446    resources: DashMap<ProcId, Point>,
447    latencies: std::sync::Mutex<LatencyConfig>,
448}
449
450impl SimNetHandle {
451    /// Sends an event to be scheduled onto the simnet's event loop
452    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
453    pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
454        self.send_event_impl(event, true)
455    }
456
457    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
458    fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
459        self.pending_event_count
460            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
461        self.event_tx
462            .send((event, advanceable, None))
463            .map_err(|err| SimNetError::Closed(err.to_string()))
464    }
465
466    /// Sends an non-advanceable event to be scheduled onto the simnet's event loop
467    /// A non-advanceable event is an event that cannot advance the simnet's time unless
468    /// the earliest event in the simnet's advancing event queue occurs after the earliest
469    /// event in the simnet's non-advancing event queue, or some debounce period has passed
470    /// where there are only events in the simnet's non-advancing event queue.
471    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
472    pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
473        self.send_event_impl(event, false)
474    }
475
476    /// Sends an event that already has a scheduled time onto the simnet's event loop
477    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
478    pub(crate) fn send_scheduled_event(
479        &self,
480        ScheduledEvent { event, time }: ScheduledEvent,
481    ) -> Result<(), SimNetError> {
482        self.pending_event_count
483            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
484        self.event_tx
485            .send((event, true, Some(time)))
486            .map_err(|err| SimNetError::Closed(err.to_string()))
487    }
488
489    /// Let the simnet know if the training script is running or waiting for the backend
490    /// to return a future result.
491    pub fn set_training_script_state(&self, state: TrainingScriptState) {
492        self.training_script_state_tx.send(state).unwrap();
493    }
494
495    /// Bind the given address to this simulator instance.
496    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
497    pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
498        self.pending_event_count
499            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
500        self.event_tx
501            .send((
502                Box::new(NodeJoinEvent {
503                    channel_addr: address,
504                }),
505                true,
506                None,
507            ))
508            .map_err(|err| SimNetError::Closed(err.to_string()))
509    }
510
511    /// Close the simulator, processing pending messages before
512    /// completing the returned future.
513    pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
514        // Signal the simnet loop to stop
515        self.stop_signal.store(true, Ordering::SeqCst);
516
517        let mut guard = self.join_handle.lock().await;
518        if let Some(handle) = guard.take() {
519            handle.await
520        } else {
521            Ok(vec![])
522        }
523    }
524
525    /// Wait for all of the received events to be scheduled for flight.
526    /// It ticks the simnet time till all of the scheduled events are processed.
527    pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
528        let pending_event_count = self.pending_event_count.clone();
529        // poll for the pending event count to be zero.
530        let mut interval = interval(Duration::from_millis(10));
531        let deadline = RealClock.now() + timeout;
532        while RealClock.now() < deadline {
533            interval.tick().await;
534            if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
535                return Ok(());
536            }
537        }
538        Err(SimNetError::Timeout(
539            timeout,
540            "timeout waiting for received events to be scheduled".to_string(),
541        ))
542    }
543
544    /// Register the location in resource space for a Proc
545    pub fn register_proc(&self, proc_id: ProcId, point: Point) {
546        self.resources.insert(proc_id, point);
547    }
548
549    /// Sample a latency between two procs
550    pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
551        let distances = [
552            Distance::Region,
553            Distance::DataCenter,
554            Distance::Zone,
555            Distance::Rack,
556            Distance::Host,
557            Distance::Same,
558        ];
559
560        let src_coords = self
561            .resources
562            .get(src)
563            .map(|point| point.coords().clone())
564            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
565
566        let dest_coords = self
567            .resources
568            .get(dest)
569            .map(|point| point.coords().clone())
570            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
571
572        for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
573            if src != dest {
574                let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
575                return guard.from_distance(&distance);
576            }
577        }
578
579        let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
580        guard.from_distance(&Distance::Same)
581    }
582}
583
584#[derive(Debug)]
585enum Distance {
586    Region,
587    DataCenter,
588    Zone,
589    Rack,
590    Host,
591    Same,
592}
593
594/// SimNet defines a network of nodes.
595/// Each node is identified by a unique id.
596/// The network is represented as a graph of nodes.
597/// The graph is represented as a map of edges.
598/// The network also has a cloud of inflight messages
599pub struct SimNet {
600    address_book: DashSet<ChannelAddr>,
601    state: State,
602    max_latency: Duration,
603    records: Vec<SimulatorEventRecord>,
604    // number of events that has been received but not yet processed.
605    pending_event_count: Arc<AtomicUsize>,
606}
607
608/// Starts a sim net.
609pub fn start() {
610    start_with_config(LatencyConfig::default())
611}
612
613/// Starts a sim net with configured latencies between distances
614pub fn start_with_config(config: LatencyConfig) {
615    let max_duration_ms = 1000 * 10;
616    // Construct a topology with one node: the default A.
617    let address_book: DashSet<ChannelAddr> = DashSet::new();
618
619    let (training_script_state_tx, training_script_state_rx) =
620        tokio::sync::watch::channel(TrainingScriptState::Waiting);
621    let (event_tx, event_rx) =
622        mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
623    let pending_event_count = Arc::new(AtomicUsize::new(0));
624    let stop_signal = Arc::new(AtomicBool::new(false));
625
626    let join_handle = Mutex::new(Some({
627        let pending_event_count = pending_event_count.clone();
628        let stop_signal = stop_signal.clone();
629
630        tokio::spawn(async move {
631            SimNet {
632                address_book,
633                state: State {
634                    scheduled_events: BTreeMap::new(),
635                    unadvanceable_scheduled_events: BTreeMap::new(),
636                },
637                max_latency: Duration::from_millis(max_duration_ms),
638                records: Vec::new(),
639                pending_event_count,
640            }
641            .run(event_rx, training_script_state_rx, stop_signal)
642            .await
643        })
644    }));
645
646    HANDLE.get_or_init(|| SimNetHandle {
647        join_handle,
648        event_tx,
649        pending_event_count,
650        training_script_state_tx,
651        stop_signal,
652        resources: DashMap::new(),
653        latencies: std::sync::Mutex::new(config),
654    });
655}
656
657impl SimNet {
658    fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
659        // Get latency
660        ScheduledEvent {
661            time: SimClock.now() + event.duration(),
662            event,
663        }
664    }
665
666    /// Schedule the event into the network.
667    fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
668        let start_at = SimClock.now();
669        let end_at = scheduled_event.time;
670
671        self.records.push(SimulatorEventRecord {
672            summary: scheduled_event.event.summary(),
673            start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
674            end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
675        });
676
677        if advanceable {
678            self.state
679                .scheduled_events
680                .entry(scheduled_event.time)
681                .or_default()
682                .push(scheduled_event);
683        } else {
684            self.state
685                .unadvanceable_scheduled_events
686                .entry(scheduled_event.time)
687                .or_default()
688                .push(scheduled_event);
689        }
690    }
691
692    /// Run the simulation. This will dispatch all the messages in the network.
693    /// And wait for new ones.
694    async fn run(
695        &mut self,
696        mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
697        training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
698        stop_signal: Arc<AtomicBool>,
699    ) -> Vec<SimulatorEventRecord> {
700        // The simulated number of milliseconds the training script
701        // has spent waiting for the backend to resolve a future
702        let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
703        // Duration elapsed while only non_advanceable_events has events
704        let mut debounce_timer: Option<tokio::time::Instant> = None;
705
706        let debounce_duration = std::env::var("SIM_DEBOUNCE")
707            .ok()
708            .and_then(|val| val.parse::<u64>().ok())
709            .unwrap_or(1);
710
711        'outer: loop {
712            // Check if we should stop
713            if stop_signal.load(Ordering::SeqCst) {
714                break 'outer self.records.clone();
715            }
716
717            while let Ok(Some((event, advanceable, time))) = RealClock
718                .timeout(
719                    tokio::time::Duration::from_millis(debounce_duration),
720                    event_rx.recv(),
721                )
722                .await
723            {
724                let scheduled_event = match time {
725                    Some(time) => ScheduledEvent {
726                        time: time + training_script_waiting_time,
727                        event,
728                    },
729                    None => self.create_scheduled_event(event),
730                };
731                self.schedule_event(scheduled_event, advanceable);
732            }
733
734            {
735                // If the training script is running and issuing commands
736                // it is not safe to advance past the training script time
737                // otherwise a command issued by the training script may
738                // be scheduled for a time in the past
739                if training_script_state_rx.borrow().is_running()
740                    && self
741                        .state
742                        .scheduled_events
743                        .first_key_value()
744                        .is_some_and(|(time, _)| {
745                            *time > RealClock.now() + training_script_waiting_time
746                        })
747                {
748                    tokio::task::yield_now().await;
749                    continue;
750                }
751                match (
752                    self.state.scheduled_events.first_key_value(),
753                    self.state.unadvanceable_scheduled_events.first_key_value(),
754                ) {
755                    (None, Some(_)) if debounce_timer.is_none() => {
756                        // Start debounce timer when only the non-advancedable
757                        // queue has events and the timer has not already started
758                        debounce_timer = Some(RealClock.now());
759                    }
760                    // Timer already active
761                    (None, Some(_)) => {}
762                    // Reset timer when non-advanceable queue is not the only queue with events
763                    _ => {
764                        debounce_timer = None;
765                    }
766                }
767                // process for next delivery time.
768                let Some((scheduled_time, scheduled_events)) = (match (
769                    self.state.scheduled_events.first_key_value(),
770                    self.state.unadvanceable_scheduled_events.first_key_value(),
771                ) {
772                    (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
773                        if unadvanceable_time < advanceable_time {
774                            self.state.unadvanceable_scheduled_events.pop_first()
775                        } else {
776                            self.state.scheduled_events.pop_first()
777                        }
778                    }
779                    (Some(_), None) => self.state.scheduled_events.pop_first(),
780                    (None, Some(_)) => match debounce_timer {
781                        Some(time) => {
782                            if time.elapsed() > tokio::time::Duration::from_millis(1000) {
783                                // debounce interval has elapsed, reset timer
784                                debounce_timer = None;
785                                self.state.unadvanceable_scheduled_events.pop_first()
786                            } else {
787                                None
788                            }
789                        }
790                        None => None,
791                    },
792                    (None, None) => None,
793                }) else {
794                    tokio::select! {
795                        Some((event, advanceable, time)) = event_rx.recv() => {
796                            let scheduled_event = match time {
797                                Some(time) => ScheduledEvent {
798                                    time: time + training_script_waiting_time,
799                                    event,
800                                },
801                                None => self.create_scheduled_event(event),
802                            };
803                            self.schedule_event(scheduled_event, advanceable);
804                        },
805                        _ = RealClock.sleep(Duration::from_millis(10)) => {}
806                    }
807                    continue;
808                };
809                if training_script_state_rx.borrow().is_waiting() {
810                    let advanced_time = scheduled_time - SimClock.now();
811                    training_script_waiting_time += advanced_time;
812                }
813                SimClock.advance_to(scheduled_time);
814                for scheduled_event in scheduled_events {
815                    self.pending_event_count
816                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
817                    if scheduled_event.event.handle_network(self).await.is_err() {
818                        break 'outer self.records.clone(); //TODO
819                    }
820                }
821            }
822        }
823    }
824}
825
826fn serialize_optional_channel_addr<S>(
827    addr: &Option<ChannelAddr>,
828    serializer: S,
829) -> Result<S::Ok, S::Error>
830where
831    S: Serializer,
832{
833    match addr {
834        Some(addr) => serializer.serialize_str(&addr.to_string()),
835        None => serializer.serialize_none(),
836    }
837}
838
839fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
840where
841    D: Deserializer<'de>,
842{
843    let s = String::deserialize(deserializer)?;
844    s.parse().map_err(serde::de::Error::custom)
845}
846
847/// DeliveryRecord is a structure to bookkeep the message events.
848#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
849pub struct SimulatorEventRecord {
850    /// Event dependent summary for user
851    pub summary: String,
852    /// The time at which the message delivery was started.
853    pub start_at: u64,
854    /// The time at which the message was delivered to the receiver.
855    pub end_at: u64,
856}
857
858#[cfg(test)]
859mod tests {
860    use std::collections::HashMap;
861    use std::sync::Arc;
862
863    use async_trait::async_trait;
864    use ndslice::extent;
865    use tokio::sync::Mutex;
866
867    use super::*;
868    use crate::channel::sim::SimAddr;
869    use crate::clock::Clock;
870    use crate::clock::RealClock;
871    use crate::clock::SimClock;
872    use crate::data::Serialized;
873    use crate::id;
874    use crate::simnet;
875    use crate::simnet::Dispatcher;
876    use crate::simnet::Event;
877    use crate::simnet::SimNetError;
878
879    #[derive(Debug)]
880    struct MessageDeliveryEvent {
881        src_addr: SimAddr,
882        dest_addr: SimAddr,
883        data: Serialized,
884        duration: tokio::time::Duration,
885        dispatcher: Option<TestDispatcher>,
886    }
887
888    #[async_trait]
889    impl Event for MessageDeliveryEvent {
890        async fn handle(&self) -> Result<(), simnet::SimNetError> {
891            if let Some(dispatcher) = &self.dispatcher {
892                dispatcher
893                    .send(self.dest_addr.clone(), self.data.clone())
894                    .await?;
895            }
896            Ok(())
897        }
898        fn duration(&self) -> tokio::time::Duration {
899            self.duration
900        }
901
902        fn summary(&self) -> String {
903            format!(
904                "Sending message from {} to {}",
905                self.src_addr.addr().clone(),
906                self.dest_addr.addr().clone()
907            )
908        }
909    }
910
911    impl MessageDeliveryEvent {
912        fn new(
913            src_addr: SimAddr,
914            dest_addr: SimAddr,
915            data: Serialized,
916            dispatcher: Option<TestDispatcher>,
917            duration: tokio::time::Duration,
918        ) -> Self {
919            Self {
920                src_addr,
921                dest_addr,
922                data,
923                duration,
924                dispatcher,
925            }
926        }
927    }
928
929    #[derive(Debug, Clone)]
930    struct TestDispatcher {
931        pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
932    }
933
934    impl Default for TestDispatcher {
935        fn default() -> Self {
936            Self {
937                mbuffers: Arc::new(Mutex::new(HashMap::new())),
938            }
939        }
940    }
941
942    #[async_trait]
943    impl Dispatcher<SimAddr> for TestDispatcher {
944        async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
945            let mut buf = self.mbuffers.lock().await;
946            buf.entry(target).or_default().push(data);
947            Ok(())
948        }
949    }
950
951    #[cfg(target_os = "linux")]
952    fn random_abstract_addr() -> ChannelAddr {
953        use rand::Rng;
954        use rand::distributions::Alphanumeric;
955
956        let random_string = rand::thread_rng()
957            .sample_iter(&Alphanumeric)
958            .take(24)
959            .map(char::from)
960            .collect::<String>();
961        format!("unix!@{random_string}").parse().unwrap()
962    }
963
964    #[tokio::test]
965    async fn test_handle_instantiation() {
966        start();
967        simnet_handle().unwrap().close().await.unwrap();
968    }
969
970    #[tokio::test]
971    async fn test_simnet_config() {
972        // Tests that we can create a simnet, config latency between distances and sample latencies between procs.
973        let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
974
975        let alice = id!(world[0]);
976        let bob = id!(world[1]);
977        let charlie = id!(world[2]);
978
979        let config = LatencyConfig {
980            inter_zone_distribution: LatencyDistribution::Beta(
981                BetaDistribution::new(
982                    tokio::time::Duration::from_millis(1000),
983                    tokio::time::Duration::from_millis(1000),
984                    1.0,
985                    1.0,
986                )
987                .unwrap(),
988            ),
989            inter_dc_distribution: LatencyDistribution::Beta(
990                BetaDistribution::new(
991                    tokio::time::Duration::from_millis(2000),
992                    tokio::time::Duration::from_millis(2000),
993                    1.0,
994                    1.0,
995                )
996                .unwrap(),
997            ),
998            ..Default::default()
999        };
1000        start_with_config(config);
1001
1002        let handle = simnet_handle().unwrap();
1003        handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1004        handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1005        handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1006        assert_eq!(
1007            handle.sample_latency(&alice, &bob),
1008            tokio::time::Duration::from_millis(1000)
1009        );
1010        assert_eq!(
1011            handle.sample_latency(&alice, &charlie),
1012            tokio::time::Duration::from_millis(2000)
1013        );
1014    }
1015
1016    #[tokio::test]
1017    async fn test_simnet_debounce() {
1018        let config = LatencyConfig {
1019            inter_zone_distribution: LatencyDistribution::Beta(
1020                BetaDistribution::new(
1021                    tokio::time::Duration::from_millis(1000),
1022                    tokio::time::Duration::from_millis(1000),
1023                    1.0,
1024                    1.0,
1025                )
1026                .unwrap(),
1027            ),
1028            ..Default::default()
1029        };
1030        start_with_config(config);
1031        let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1032        let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1033
1034        let latency = Duration::from_millis(10000);
1035
1036        let alice = SimAddr::new(alice).unwrap();
1037        let bob = SimAddr::new(bob).unwrap();
1038
1039        // Rapidly send 10 messages expecting that each one debounces the processing
1040        for _ in 0..10 {
1041            simnet_handle()
1042                .unwrap()
1043                .send_event(Box::new(MessageDeliveryEvent::new(
1044                    alice.clone(),
1045                    bob.clone(),
1046                    Serialized::serialize(&"123".to_string()).unwrap(),
1047                    None,
1048                    latency,
1049                )))
1050                .unwrap();
1051            RealClock
1052                .sleep(tokio::time::Duration::from_micros(500))
1053                .await;
1054        }
1055
1056        simnet_handle()
1057            .unwrap()
1058            .flush(Duration::from_secs(20))
1059            .await
1060            .unwrap();
1061
1062        let records = simnet_handle().unwrap().close().await;
1063        assert_eq!(records.as_ref().unwrap().len(), 10);
1064
1065        // If debounce is successful, the simnet will not advance to the delivery of any of
1066        // the messages before all are received
1067        assert_eq!(
1068            records.unwrap().last().unwrap().end_at,
1069            latency.as_millis() as u64
1070        );
1071    }
1072
1073    #[tokio::test]
1074    async fn test_sim_dispatch() {
1075        start();
1076        let sender = Some(TestDispatcher::default());
1077        let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1078        // // Create a simple network of 4 nodes.
1079        for i in 0..4 {
1080            addresses.push(
1081                format!("local:{}", i)
1082                    .parse::<simnet::ChannelAddr>()
1083                    .unwrap(),
1084            );
1085        }
1086
1087        let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1088            .into_iter()
1089            .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1090            .collect();
1091
1092        let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1093        let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1094        let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1095        let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1096        let one = Box::new(MessageDeliveryEvent::new(
1097            addr_0.clone(),
1098            addr_1.clone(),
1099            messages[0].clone(),
1100            sender.clone(),
1101            tokio::time::Duration::ZERO,
1102        ));
1103        let two = Box::new(MessageDeliveryEvent::new(
1104            addr_2.clone(),
1105            addr_3.clone(),
1106            messages[1].clone(),
1107            sender.clone(),
1108            tokio::time::Duration::ZERO,
1109        ));
1110        let three = Box::new(MessageDeliveryEvent::new(
1111            addr_0.clone(),
1112            addr_1.clone(),
1113            messages[2].clone(),
1114            sender.clone(),
1115            tokio::time::Duration::ZERO,
1116        ));
1117
1118        simnet_handle().unwrap().send_event(one).unwrap();
1119        simnet_handle().unwrap().send_event(two).unwrap();
1120        simnet_handle().unwrap().send_event(three).unwrap();
1121
1122        simnet_handle()
1123            .unwrap()
1124            .flush(Duration::from_millis(1000))
1125            .await
1126            .unwrap();
1127        let records = simnet_handle().unwrap().close().await.unwrap();
1128        eprintln!("Records: {:?}", records);
1129        // Close the channel
1130        simnet_handle().unwrap().close().await.unwrap();
1131
1132        // Check results
1133        let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1134        assert_eq!(buf.len(), 2);
1135        assert_eq!(buf[&addr_1].len(), 2);
1136        assert_eq!(buf[&addr_3].len(), 1);
1137
1138        assert_eq!(buf[&addr_1][0], messages[0]);
1139        assert_eq!(buf[&addr_1][1], messages[2]);
1140        assert_eq!(buf[&addr_3][0], messages[1]);
1141    }
1142
1143    #[tokio::test]
1144    async fn test_sim_sleep() {
1145        start();
1146
1147        let start = SimClock.now();
1148        assert_eq!(
1149            SimClock.duration_since_start(start),
1150            tokio::time::Duration::ZERO
1151        );
1152
1153        SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1154
1155        let end = SimClock.now();
1156        assert_eq!(
1157            SimClock.duration_since_start(end),
1158            tokio::time::Duration::from_secs(10)
1159        );
1160    }
1161
1162    #[tokio::test]
1163    async fn test_torch_op() {
1164        start();
1165        let args_string = "1, 2".to_string();
1166        let kwargs_string = "a=2".to_string();
1167
1168        let (tx, rx) = oneshot::channel();
1169
1170        simnet_handle()
1171            .unwrap()
1172            .send_event(TorchOpEvent::new(
1173                "torch.ops.aten.ones.default".to_string(),
1174                tx,
1175                args_string,
1176                kwargs_string,
1177                id!(mesh_0_worker[0].worker_0),
1178            ))
1179            .unwrap();
1180
1181        rx.await.unwrap();
1182
1183        simnet_handle()
1184            .unwrap()
1185            .flush(Duration::from_millis(1000))
1186            .await
1187            .unwrap();
1188        let records = simnet_handle().unwrap().close().await;
1189        let expected_record = SimulatorEventRecord {
1190            summary:
1191                "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1192                    .to_string(),
1193            start_at: 0,
1194            end_at: 100,
1195        };
1196        assert!(records.as_ref().unwrap().len() == 1);
1197        assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1198    }
1199}