hyperactor/
simnet.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)]
10
11//! A simulator capable of simulating Hyperactor's network channels (see: [`channel`]).
12//! The simulator can simulate message delivery delays and failures, and is used for
13//! testing and development of message distribution techniques.
14
15use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::task::JoinError;
42use tokio::task::JoinHandle;
43use tokio::time::interval;
44
45// for macros
46use crate::ActorId;
47use crate::Mailbox;
48use crate::OncePortRef;
49use crate::ProcId;
50use crate::channel::ChannelAddr;
51use crate::clock::Clock;
52use crate::clock::RealClock;
53use crate::clock::SimClock;
54use crate::data::Serialized;
55
56static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
57
58/// A handle for SimNet through which you can send and schedule events in the
59/// network.
60///
61/// Return the \[`NotStarted`\] error when called before `simnet::start()` has been called
62#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
63pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
64    match HANDLE.get() {
65        Some(handle) => Ok(handle),
66        None => Err(SimNetError::Closed("SimNet not started".to_string())),
67    }
68}
69
70const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
71
72/// This is used to define an Address-type for the network.
73/// Addresses are bound to nodes in the network.
74pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
75impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
76
77/// Minimum time unit for the simulator.
78pub type SimulatorTimeInstant = tokio::time::Instant;
79
80/// The unit of execution for the simulator.
81/// Using handle(), simnet can schedule executions in the network.
82/// If you want to send a message for example, you would want to implement
83/// a MessageDeliveryEvent much on the lines expressed in simnet tests.
84/// You can also do other more advanced concepts such as node churn,
85/// or even simulate process spawns in a distributed system. For example,
86/// one can implement a SystemActorSimEvent in order to spawn a system
87/// actor.
88#[async_trait]
89pub trait Event: Send + Sync + Debug {
90    /// This is the method that will be called when the simulator fires the event
91    /// at a particular time instant. Examples:
92    /// For messages, it will be delivering the message to the dst's receiver queue.
93    /// For a proc spawn, it will be creating the proc object and instantiating it.
94    /// For any event that manipulates the network (like adding/removing nodes etc.)
95    /// implement handle_network().
96    async fn handle(&mut self) -> Result<(), SimNetError>;
97
98    /// This is the method that will be called when the simulator fires the event
99    /// Unless you need to make changes to the network, you do not have to implement this.
100    /// Only implement handle() method for all non-simnet requirements.
101    async fn handle_network(&mut self, _phantom: &SimNet) -> Result<(), SimNetError> {
102        self.handle().await
103    }
104
105    /// The latency of the event. This could be network latency, induced latency (sleep), or
106    /// GPU work latency.
107    fn duration(&self) -> tokio::time::Duration;
108
109    /// A user-friendly summary of the event
110    fn summary(&self) -> String;
111}
112
113/// This is a simple event that is used to join a node to the network.
114/// It is used to bind a node to a channel address.
115#[derive(Debug)]
116struct NodeJoinEvent {
117    channel_addr: ChannelAddr,
118}
119
120#[async_trait]
121impl Event for NodeJoinEvent {
122    async fn handle(&mut self) -> Result<(), SimNetError> {
123        Ok(())
124    }
125
126    async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
127        self.handle().await
128    }
129
130    fn duration(&self) -> tokio::time::Duration {
131        tokio::time::Duration::ZERO
132    }
133
134    fn summary(&self) -> String {
135        "Node join".into()
136    }
137}
138
139#[derive(Debug)]
140/// A pytorch operation
141pub struct TorchOpEvent {
142    op: String,
143    done_tx: OncePortRef<()>,
144    mailbox: Mailbox,
145    args_string: String,
146    kwargs_string: String,
147    worker_actor_id: ActorId,
148}
149
150#[async_trait]
151impl Event for TorchOpEvent {
152    async fn handle(&mut self) -> Result<(), SimNetError> {
153        Ok(())
154    }
155
156    async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
157        self.done_tx
158            .clone()
159            .send(&self.mailbox, ())
160            .map_err(|err| SimNetError::Closed(err.to_string()))?;
161        Ok(())
162    }
163
164    fn duration(&self) -> tokio::time::Duration {
165        tokio::time::Duration::from_millis(100)
166    }
167
168    fn summary(&self) -> String {
169        let kwargs_string = if self.kwargs_string.is_empty() {
170            "".to_string()
171        } else {
172            format!(", {}", self.kwargs_string)
173        };
174        format!(
175            "[{}] Torch Op: {}({}{})",
176            self.worker_actor_id, self.op, self.args_string, kwargs_string
177        )
178    }
179}
180
181impl TorchOpEvent {
182    /// Creates a new TorchOpEvent.
183    pub fn new(
184        op: String,
185        done_tx: OncePortRef<()>,
186        mailbox: Mailbox,
187        args_string: String,
188        kwargs_string: String,
189        worker_actor_id: ActorId,
190    ) -> Box<Self> {
191        Box::new(Self {
192            op,
193            done_tx,
194            mailbox,
195            args_string,
196            kwargs_string,
197            worker_actor_id,
198        })
199    }
200}
201
202/// Each message is timestamped with the delivery time
203/// of the message to the sender.
204/// The timestamp is used to determine the order in which
205/// messages are delivered to senders.
206#[derive(Debug)]
207pub(crate) struct ScheduledEvent {
208    pub(crate) time: SimulatorTimeInstant,
209    pub(crate) event: Box<dyn Event>,
210}
211
212/// Dispatcher is a trait that defines the send operation.
213/// The send operation takes a target address and a data buffer.
214/// This method is called when the simulator is ready for the message to be received
215/// by the target address.
216#[async_trait]
217pub trait Dispatcher<A> {
218    /// Send a raw data blob to the given target.
219    async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
220}
221
222/// SimNetError is used to indicate errors that occur during
223/// network simulation.
224#[derive(thiserror::Error, Debug)]
225#[non_exhaustive]
226pub enum SimNetError {
227    /// An invalid address was encountered.
228    #[error("invalid address: {0}")]
229    InvalidAddress(String),
230
231    /// An invalid node was encountered.
232    #[error("invalid node: {0}")]
233    InvalidNode(String, #[source] anyhow::Error),
234
235    /// An invalid parameter was encountered.
236    #[error("invalid arg: {0}")]
237    InvalidArg(String),
238
239    /// The simulator has been closed.
240    #[error("closed: {0}")]
241    Closed(String),
242
243    /// Timeout when waiting for something.
244    #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
245    Timeout(Duration, String),
246
247    /// Cannot deliver the message because destination address is missing.
248    #[error("missing destination address")]
249    MissingDestinationAddress,
250
251    /// SimnetHandle being accessed without starting simnet
252    #[error("simnet not started")]
253    NotStarted,
254
255    /// A task has panicked.
256    #[error("panicked task")]
257    PanickedTask,
258}
259
260struct State {
261    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
262    scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
263    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
264    // only if the earliest event in `scheduled_events` occurs after the earliest event in this queue
265    // or some debounce period has passed where there are only events in this queue.
266    unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
267}
268
269/// The state of the python training script.
270#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
271pub enum TrainingScriptState {
272    /// The training script is issuing commands
273    Running,
274    /// The training script is waiting for the backend to return a future result
275    Waiting,
276}
277
278/// A distribution of latencies that can be sampled from
279pub enum LatencyDistribution {
280    /// A beta distribution scaled to a given range of values
281    Beta(BetaDistribution),
282}
283
284impl LatencyDistribution {
285    fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
286        match &self {
287            LatencyDistribution::Beta(sampler) => sampler.sample(rng),
288        }
289    }
290}
291
292/// A beta distribution scaled to a given range of values.
293pub struct BetaDistribution {
294    min_duration: tokio::time::Duration,
295    max_duration: tokio::time::Duration,
296    dist: rand_distr::Beta<f64>,
297}
298
299impl BetaDistribution {
300    /// Sample a sclaed value from the distribution.
301    pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
302        let sample = self.dist.sample(rng);
303
304        self.min_duration
305            + tokio::time::Duration::from_micros(
306                (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
307            )
308    }
309
310    /// Create a new beta distribution.
311    pub fn new(
312        min_duration: tokio::time::Duration,
313        max_duration: tokio::time::Duration,
314        alpha: f64,
315        beta: f64,
316    ) -> anyhow::Result<Self> {
317        if min_duration > max_duration {
318            return Err(anyhow::anyhow!(
319                "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
320                min_duration,
321                max_duration
322            ));
323        }
324        Ok(Self {
325            min_duration,
326            max_duration,
327            dist: rand_distr::Beta::new(alpha, beta)?,
328        })
329    }
330}
331/// Configuration for latencies between distances for the simulator
332pub struct LatencyConfig {
333    /// inter-region latency distribution
334    pub inter_region_distribution: LatencyDistribution,
335    /// inter-data center latency distribution
336    pub inter_dc_distribution: LatencyDistribution,
337    /// inter-zone latency distribution
338    pub inter_zone_distribution: LatencyDistribution,
339    /// Single random number generator for all distributions to ensure deterministic sampling
340    pub rng: StdRng,
341}
342
343impl LatencyConfig {
344    fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
345        match distance {
346            Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
347            Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
348            Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
349            Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
350        }
351    }
352}
353
354impl Default for LatencyConfig {
355    fn default() -> Self {
356        let seed: u64 = 0000;
357        let mut seed_bytes = [0u8; 32];
358        seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
359
360        Self {
361            inter_region_distribution: LatencyDistribution::Beta(
362                BetaDistribution::new(
363                    tokio::time::Duration::from_millis(500),
364                    tokio::time::Duration::from_millis(1000),
365                    2.0,
366                    1.0,
367                )
368                .unwrap(),
369            ),
370            inter_dc_distribution: LatencyDistribution::Beta(
371                BetaDistribution::new(
372                    tokio::time::Duration::from_millis(50),
373                    tokio::time::Duration::from_millis(100),
374                    2.0,
375                    1.0,
376                )
377                .unwrap(),
378            ),
379            inter_zone_distribution: LatencyDistribution::Beta(
380                BetaDistribution::new(
381                    tokio::time::Duration::from_millis(5),
382                    tokio::time::Duration::from_millis(10),
383                    2.0,
384                    1.0,
385                )
386                .unwrap(),
387            ),
388            rng: StdRng::from_seed(seed_bytes),
389        }
390    }
391}
392
393/// A handle to a running [`SimNet`] instance.
394pub struct SimNetHandle {
395    join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
396    event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
397    pending_event_count: Arc<AtomicUsize>,
398    /// A receiver to receive simulator operational messages.
399    /// The receiver can be moved out of the simnet handle.
400    training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
401    /// Signal to stop the simnet loop
402    stop_signal: Arc<AtomicBool>,
403    resources: DashMap<ProcId, Point>,
404    latencies: std::sync::Mutex<LatencyConfig>,
405}
406
407impl SimNetHandle {
408    /// Sends an event to be scheduled onto the simnet's event loop
409    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
410    pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
411        self.send_event_impl(event, true)
412    }
413
414    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
415    fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
416        self.pending_event_count
417            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
418        self.event_tx
419            .send((event, advanceable, None))
420            .map_err(|err| SimNetError::Closed(err.to_string()))
421    }
422
423    /// Sends an non-advanceable event to be scheduled onto the simnet's event loop
424    /// A non-advanceable event is an event that cannot advance the simnet's time unless
425    /// the earliest event in the simnet's advancing event queue occurs after the earliest
426    /// event in the simnet's non-advancing event queue, or some debounce period has passed
427    /// where there are only events in the simnet's non-advancing event queue.
428    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
429    pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
430        self.send_event_impl(event, false)
431    }
432
433    /// Sends an event that already has a scheduled time onto the simnet's event loop
434    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
435    pub(crate) fn send_scheduled_event(
436        &self,
437        ScheduledEvent { event, time }: ScheduledEvent,
438    ) -> Result<(), SimNetError> {
439        self.pending_event_count
440            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
441        self.event_tx
442            .send((event, true, Some(time)))
443            .map_err(|err| SimNetError::Closed(err.to_string()))
444    }
445
446    /// Let the simnet know if the training script is running or waiting for the backend
447    /// to return a future result.
448    pub fn set_training_script_state(&self, state: TrainingScriptState) {
449        self.training_script_state_tx.send(state).unwrap();
450    }
451
452    /// Bind the given address to this simulator instance.
453    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
454    pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
455        self.pending_event_count
456            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
457        self.event_tx
458            .send((
459                Box::new(NodeJoinEvent {
460                    channel_addr: address,
461                }),
462                true,
463                None,
464            ))
465            .map_err(|err| SimNetError::Closed(err.to_string()))
466    }
467
468    /// Close the simulator, processing pending messages before
469    /// completing the returned future.
470    pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
471        // Signal the simnet loop to stop
472        self.stop_signal.store(true, Ordering::SeqCst);
473
474        let mut guard = self.join_handle.lock().await;
475        if let Some(handle) = guard.take() {
476            handle.await
477        } else {
478            Ok(vec![])
479        }
480    }
481
482    /// Wait for all of the received events to be scheduled for flight.
483    /// It ticks the simnet time till all of the scheduled events are processed.
484    pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
485        let pending_event_count = self.pending_event_count.clone();
486        // poll for the pending event count to be zero.
487        let mut interval = interval(Duration::from_millis(10));
488        let deadline = RealClock.now() + timeout;
489        while RealClock.now() < deadline {
490            interval.tick().await;
491            if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
492                return Ok(());
493            }
494        }
495        Err(SimNetError::Timeout(
496            timeout,
497            "timeout waiting for received events to be scheduled".to_string(),
498        ))
499    }
500
501    /// Register the location in resource space for a Proc
502    pub fn register_proc(&self, proc_id: ProcId, point: Point) {
503        self.resources.insert(proc_id, point);
504    }
505
506    /// Sample a latency between two procs
507    pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
508        let distances = [
509            Distance::Region,
510            Distance::DataCenter,
511            Distance::Zone,
512            Distance::Rack,
513            Distance::Host,
514            Distance::Same,
515        ];
516
517        let src_coords = self
518            .resources
519            .get(src)
520            .map(|point| point.coords().clone())
521            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
522
523        let dest_coords = self
524            .resources
525            .get(dest)
526            .map(|point| point.coords().clone())
527            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
528
529        for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
530            if src != dest {
531                let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
532                return guard.from_distance(&distance);
533            }
534        }
535
536        let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
537        guard.from_distance(&Distance::Same)
538    }
539}
540
541#[derive(Debug)]
542enum Distance {
543    Region,
544    DataCenter,
545    Zone,
546    Rack,
547    Host,
548    Same,
549}
550
551/// SimNet defines a network of nodes.
552/// Each node is identified by a unique id.
553/// The network is represented as a graph of nodes.
554/// The graph is represented as a map of edges.
555/// The network also has a cloud of inflight messages
556pub struct SimNet {
557    address_book: DashSet<ChannelAddr>,
558    state: State,
559    max_latency: Duration,
560    records: Vec<SimulatorEventRecord>,
561    // number of events that has been received but not yet processed.
562    pending_event_count: Arc<AtomicUsize>,
563}
564
565/// Starts a sim net.
566pub fn start() {
567    start_with_config(LatencyConfig::default())
568}
569
570/// Starts a sim net with configured latencies between distances
571pub fn start_with_config(config: LatencyConfig) {
572    let max_duration_ms = 1000 * 10;
573    // Construct a topology with one node: the default A.
574    let address_book: DashSet<ChannelAddr> = DashSet::new();
575
576    let (training_script_state_tx, training_script_state_rx) =
577        tokio::sync::watch::channel(TrainingScriptState::Waiting);
578    let (event_tx, event_rx) =
579        mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
580    let pending_event_count = Arc::new(AtomicUsize::new(0));
581    let stop_signal = Arc::new(AtomicBool::new(false));
582
583    let join_handle = Mutex::new(Some({
584        let pending_event_count = pending_event_count.clone();
585        let stop_signal = stop_signal.clone();
586
587        tokio::spawn(async move {
588            SimNet {
589                address_book,
590                state: State {
591                    scheduled_events: BTreeMap::new(),
592                    unadvanceable_scheduled_events: BTreeMap::new(),
593                },
594                max_latency: Duration::from_millis(max_duration_ms),
595                records: Vec::new(),
596                pending_event_count,
597            }
598            .run(event_rx, training_script_state_rx, stop_signal)
599            .await
600        })
601    }));
602
603    HANDLE.get_or_init(|| SimNetHandle {
604        join_handle,
605        event_tx,
606        pending_event_count,
607        training_script_state_tx,
608        stop_signal,
609        resources: DashMap::new(),
610        latencies: std::sync::Mutex::new(config),
611    });
612}
613
614impl SimNet {
615    fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
616        // Get latency
617        ScheduledEvent {
618            time: SimClock.now() + event.duration(),
619            event,
620        }
621    }
622
623    /// Schedule the event into the network.
624    fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
625        let start_at = SimClock.now();
626        let end_at = scheduled_event.time;
627
628        self.records.push(SimulatorEventRecord {
629            summary: scheduled_event.event.summary(),
630            start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
631            end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
632        });
633
634        if advanceable {
635            self.state
636                .scheduled_events
637                .entry(scheduled_event.time)
638                .or_insert_with(Vec::new)
639                .push(scheduled_event);
640        } else {
641            self.state
642                .unadvanceable_scheduled_events
643                .entry(scheduled_event.time)
644                .or_insert_with(Vec::new)
645                .push(scheduled_event);
646        }
647    }
648
649    /// Run the simulation. This will dispatch all the messages in the network.
650    /// And wait for new ones.
651    async fn run(
652        &mut self,
653        mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
654        training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
655        stop_signal: Arc<AtomicBool>,
656    ) -> Vec<SimulatorEventRecord> {
657        // The simulated number of milliseconds the training script
658        // has spent waiting for the backend to resolve a future
659        let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
660        // Duration elapsed while only non_advanceable_events has events
661        let mut debounce_timer: Option<tokio::time::Instant> = None;
662
663        let debounce_duration = std::env::var("SIM_DEBOUNCE")
664            .ok()
665            .and_then(|val| val.parse::<u64>().ok())
666            .unwrap_or(1);
667
668        'outer: loop {
669            // Check if we should stop
670            if stop_signal.load(Ordering::SeqCst) {
671                break 'outer self.records.clone();
672            }
673
674            while let Ok(Some((event, advanceable, time))) = RealClock
675                .timeout(
676                    tokio::time::Duration::from_millis(debounce_duration),
677                    event_rx.recv(),
678                )
679                .await
680            {
681                let scheduled_event = match time {
682                    Some(time) => ScheduledEvent {
683                        time: time + training_script_waiting_time,
684                        event,
685                    },
686                    None => self.create_scheduled_event(event),
687                };
688                self.schedule_event(scheduled_event, advanceable);
689            }
690
691            {
692                // If the training script is runnning and issuing commands
693                // it is not safe to advance past the training script time
694                // otherwise a command issued by the training script may
695                // be scheduled for a time in the past
696                if training_script_state_rx.borrow().is_running()
697                    && self
698                        .state
699                        .scheduled_events
700                        .first_key_value()
701                        .is_some_and(|(time, _)| {
702                            *time > RealClock.now() + training_script_waiting_time
703                        })
704                {
705                    tokio::task::yield_now().await;
706                    continue;
707                }
708                match (
709                    self.state.scheduled_events.first_key_value(),
710                    self.state.unadvanceable_scheduled_events.first_key_value(),
711                ) {
712                    (None, Some(_)) if debounce_timer.is_none() => {
713                        // Start debounce timer when only the non-advancedable
714                        // queue has events and the timer has not already started
715                        debounce_timer = Some(RealClock.now());
716                    }
717                    // Timer already active
718                    (None, Some(_)) => {}
719                    // Reset timer when non-advanceable queue is not the only queue with events
720                    _ => {
721                        debounce_timer = None;
722                    }
723                }
724                // process for next delivery time.
725                let Some((scheduled_time, scheduled_events)) = (match (
726                    self.state.scheduled_events.first_key_value(),
727                    self.state.unadvanceable_scheduled_events.first_key_value(),
728                ) {
729                    (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
730                        if unadvanceable_time < advanceable_time {
731                            self.state.unadvanceable_scheduled_events.pop_first()
732                        } else {
733                            self.state.scheduled_events.pop_first()
734                        }
735                    }
736                    (Some(_), None) => self.state.scheduled_events.pop_first(),
737                    (None, Some(_)) => match debounce_timer {
738                        Some(time) => {
739                            if time.elapsed() > tokio::time::Duration::from_millis(1000) {
740                                // debounce interval has elapsed, reset timer
741                                debounce_timer = None;
742                                self.state.unadvanceable_scheduled_events.pop_first()
743                            } else {
744                                None
745                            }
746                        }
747                        None => None,
748                    },
749                    (None, None) => None,
750                }) else {
751                    tokio::select! {
752                        Some((event, advanceable, time)) = event_rx.recv() => {
753                            let scheduled_event = match time {
754                                Some(time) => ScheduledEvent {
755                                    time: time + training_script_waiting_time,
756                                    event,
757                                },
758                                None => self.create_scheduled_event(event),
759                            };
760                            self.schedule_event(scheduled_event, advanceable);
761                        },
762                        _ = RealClock.sleep(Duration::from_millis(10)) => {}
763                    }
764                    continue;
765                };
766                if training_script_state_rx.borrow().is_waiting() {
767                    let advanced_time = scheduled_time - SimClock.now();
768                    training_script_waiting_time += advanced_time;
769                }
770                SimClock.advance_to(scheduled_time);
771                for mut scheduled_event in scheduled_events {
772                    self.pending_event_count
773                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
774                    if scheduled_event.event.handle_network(self).await.is_err() {
775                        break 'outer self.records.clone(); //TODO
776                    }
777                }
778            }
779        }
780    }
781}
782
783fn serialize_optional_channel_addr<S>(
784    addr: &Option<ChannelAddr>,
785    serializer: S,
786) -> Result<S::Ok, S::Error>
787where
788    S: Serializer,
789{
790    match addr {
791        Some(addr) => serializer.serialize_str(&addr.to_string()),
792        None => serializer.serialize_none(),
793    }
794}
795
796fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
797where
798    D: Deserializer<'de>,
799{
800    let s = String::deserialize(deserializer)?;
801    s.parse().map_err(serde::de::Error::custom)
802}
803
804/// DeliveryRecord is a structure to bookkeep the message events.
805#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
806pub struct SimulatorEventRecord {
807    /// Event dependent summary for user
808    pub summary: String,
809    /// The time at which the message delivery was started.
810    pub start_at: u64,
811    /// The time at which the message was delivered to the receiver.
812    pub end_at: u64,
813}
814
815#[cfg(test)]
816mod tests {
817    use std::collections::HashMap;
818    use std::sync::Arc;
819
820    use async_trait::async_trait;
821    use ndslice::extent;
822    use tokio::sync::Mutex;
823
824    use super::*;
825    use crate::channel::sim::SimAddr;
826    use crate::clock::Clock;
827    use crate::clock::RealClock;
828    use crate::clock::SimClock;
829    use crate::data::Serialized;
830    use crate::id;
831    use crate::simnet;
832    use crate::simnet::Dispatcher;
833    use crate::simnet::Event;
834    use crate::simnet::SimNetError;
835
836    #[derive(Debug)]
837    struct MessageDeliveryEvent {
838        src_addr: SimAddr,
839        dest_addr: SimAddr,
840        data: Serialized,
841        duration: tokio::time::Duration,
842        dispatcher: Option<TestDispatcher>,
843    }
844
845    #[async_trait]
846    impl Event for MessageDeliveryEvent {
847        async fn handle(&mut self) -> Result<(), simnet::SimNetError> {
848            if let Some(dispatcher) = &self.dispatcher {
849                dispatcher
850                    .send(self.dest_addr.clone(), self.data.clone())
851                    .await?;
852            }
853            Ok(())
854        }
855        fn duration(&self) -> tokio::time::Duration {
856            self.duration
857        }
858
859        fn summary(&self) -> String {
860            format!(
861                "Sending message from {} to {}",
862                self.src_addr.addr().clone(),
863                self.dest_addr.addr().clone()
864            )
865        }
866    }
867
868    impl MessageDeliveryEvent {
869        fn new(
870            src_addr: SimAddr,
871            dest_addr: SimAddr,
872            data: Serialized,
873            dispatcher: Option<TestDispatcher>,
874            duration: tokio::time::Duration,
875        ) -> Self {
876            Self {
877                src_addr,
878                dest_addr,
879                data,
880                duration,
881                dispatcher,
882            }
883        }
884    }
885
886    #[derive(Debug, Clone)]
887    struct TestDispatcher {
888        pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
889    }
890
891    impl Default for TestDispatcher {
892        fn default() -> Self {
893            Self {
894                mbuffers: Arc::new(Mutex::new(HashMap::new())),
895            }
896        }
897    }
898
899    #[async_trait]
900    impl Dispatcher<SimAddr> for TestDispatcher {
901        async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
902            let mut buf = self.mbuffers.lock().await;
903            buf.entry(target).or_default().push(data);
904            Ok(())
905        }
906    }
907
908    #[cfg(target_os = "linux")]
909    fn random_abstract_addr() -> ChannelAddr {
910        use rand::Rng;
911        use rand::distributions::Alphanumeric;
912
913        let random_string = rand::thread_rng()
914            .sample_iter(&Alphanumeric)
915            .take(24)
916            .map(char::from)
917            .collect::<String>();
918        format!("unix!@{random_string}").parse().unwrap()
919    }
920
921    #[tokio::test]
922    async fn test_handle_instantiation() {
923        start();
924        simnet_handle().unwrap().close().await.unwrap();
925    }
926
927    #[tokio::test]
928    async fn test_simnet_config() {
929        // Tests that we can create a simnet, config latency between distances and sample latencies between procs.
930        let ext = extent!(region = 1, dc = 1, zone = 1, rack = 4, host = 4, gpu = 8);
931
932        let alice = id!(world[0]);
933        let bob = id!(world[1]);
934        let charlie = id!(world[2]);
935
936        let config = LatencyConfig {
937            inter_zone_distribution: LatencyDistribution::Beta(
938                BetaDistribution::new(
939                    tokio::time::Duration::from_millis(1000),
940                    tokio::time::Duration::from_millis(1000),
941                    1.0,
942                    1.0,
943                )
944                .unwrap(),
945            ),
946            inter_dc_distribution: LatencyDistribution::Beta(
947                BetaDistribution::new(
948                    tokio::time::Duration::from_millis(2000),
949                    tokio::time::Duration::from_millis(2000),
950                    1.0,
951                    1.0,
952                )
953                .unwrap(),
954            ),
955            ..Default::default()
956        };
957        start_with_config(config);
958
959        let handle = simnet_handle().unwrap();
960        handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
961        handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
962        handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
963        assert_eq!(
964            handle.sample_latency(&alice, &bob),
965            tokio::time::Duration::from_millis(1000)
966        );
967        assert_eq!(
968            handle.sample_latency(&alice, &charlie),
969            tokio::time::Duration::from_millis(2000)
970        );
971    }
972
973    #[tokio::test]
974    async fn test_simnet_debounce() {
975        let config = LatencyConfig {
976            inter_zone_distribution: LatencyDistribution::Beta(
977                BetaDistribution::new(
978                    tokio::time::Duration::from_millis(1000),
979                    tokio::time::Duration::from_millis(1000),
980                    1.0,
981                    1.0,
982                )
983                .unwrap(),
984            ),
985            ..Default::default()
986        };
987        start_with_config(config);
988        let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
989        let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
990
991        let latency = Duration::from_millis(10000);
992
993        let alice = SimAddr::new(alice).unwrap();
994        let bob = SimAddr::new(bob).unwrap();
995
996        // Rapidly send 10 messages expecting that each one debounces the processing
997        for _ in 0..10 {
998            simnet_handle()
999                .unwrap()
1000                .send_event(Box::new(MessageDeliveryEvent::new(
1001                    alice.clone(),
1002                    bob.clone(),
1003                    Serialized::serialize(&"123".to_string()).unwrap(),
1004                    None,
1005                    latency,
1006                )))
1007                .unwrap();
1008            RealClock
1009                .sleep(tokio::time::Duration::from_micros(500))
1010                .await;
1011        }
1012
1013        simnet_handle()
1014            .unwrap()
1015            .flush(Duration::from_secs(20))
1016            .await
1017            .unwrap();
1018
1019        let records = simnet_handle().unwrap().close().await;
1020        assert_eq!(records.as_ref().unwrap().len(), 10);
1021
1022        // If debounce is successful, the simnet will not advance to the delivery of any of
1023        // the messages before all are received
1024        assert_eq!(
1025            records.unwrap().last().unwrap().end_at,
1026            latency.as_millis() as u64
1027        );
1028    }
1029
1030    #[tokio::test]
1031    async fn test_sim_dispatch() {
1032        start();
1033        let sender = Some(TestDispatcher::default());
1034        let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1035        // // Create a simple network of 4 nodes.
1036        for i in 0..4 {
1037            addresses.push(
1038                format!("local:{}", i)
1039                    .parse::<simnet::ChannelAddr>()
1040                    .unwrap(),
1041            );
1042        }
1043
1044        let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1045            .into_iter()
1046            .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1047            .collect();
1048
1049        let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1050        let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1051        let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1052        let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1053        let one = Box::new(MessageDeliveryEvent::new(
1054            addr_0.clone(),
1055            addr_1.clone(),
1056            messages[0].clone(),
1057            sender.clone(),
1058            tokio::time::Duration::ZERO,
1059        ));
1060        let two = Box::new(MessageDeliveryEvent::new(
1061            addr_2.clone(),
1062            addr_3.clone(),
1063            messages[1].clone(),
1064            sender.clone(),
1065            tokio::time::Duration::ZERO,
1066        ));
1067        let three = Box::new(MessageDeliveryEvent::new(
1068            addr_0.clone(),
1069            addr_1.clone(),
1070            messages[2].clone(),
1071            sender.clone(),
1072            tokio::time::Duration::ZERO,
1073        ));
1074
1075        simnet_handle().unwrap().send_event(one).unwrap();
1076        simnet_handle().unwrap().send_event(two).unwrap();
1077        simnet_handle().unwrap().send_event(three).unwrap();
1078
1079        simnet_handle()
1080            .unwrap()
1081            .flush(Duration::from_millis(1000))
1082            .await
1083            .unwrap();
1084        let records = simnet_handle().unwrap().close().await.unwrap();
1085        eprintln!("Records: {:?}", records);
1086        // Close the channel
1087        simnet_handle().unwrap().close().await.unwrap();
1088
1089        // Check results
1090        let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1091        assert_eq!(buf.len(), 2);
1092        assert_eq!(buf[&addr_1].len(), 2);
1093        assert_eq!(buf[&addr_3].len(), 1);
1094
1095        assert_eq!(buf[&addr_1][0], messages[0]);
1096        assert_eq!(buf[&addr_1][1], messages[2]);
1097        assert_eq!(buf[&addr_3][0], messages[1]);
1098    }
1099
1100    #[tokio::test]
1101    async fn test_sim_sleep() {
1102        start();
1103
1104        let start = SimClock.now();
1105        assert_eq!(
1106            SimClock.duration_since_start(start),
1107            tokio::time::Duration::ZERO
1108        );
1109
1110        SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1111
1112        let end = SimClock.now();
1113        assert_eq!(
1114            SimClock.duration_since_start(end),
1115            tokio::time::Duration::from_secs(10)
1116        );
1117    }
1118
1119    #[tokio::test]
1120    async fn test_torch_op() {
1121        start();
1122        let args_string = "1, 2".to_string();
1123        let kwargs_string = "a=2".to_string();
1124
1125        let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
1126        let (tx, rx) = mailbox.open_once_port::<()>();
1127
1128        simnet_handle()
1129            .unwrap()
1130            .send_event(TorchOpEvent::new(
1131                "torch.ops.aten.ones.default".to_string(),
1132                tx.bind(),
1133                mailbox,
1134                args_string,
1135                kwargs_string,
1136                id!(mesh_0_worker[0].worker_0),
1137            ))
1138            .unwrap();
1139
1140        rx.recv().await.unwrap();
1141
1142        simnet_handle()
1143            .unwrap()
1144            .flush(Duration::from_millis(1000))
1145            .await
1146            .unwrap();
1147        let records = simnet_handle().unwrap().close().await;
1148        let expected_record = SimulatorEventRecord {
1149            summary:
1150                "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1151                    .to_string(),
1152            start_at: 0,
1153            end_at: 100,
1154        };
1155        assert!(records.as_ref().unwrap().len() == 1);
1156        assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1157    }
1158}