hyperactor/
simnet.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)]
10
11//! A simulator capable of simulating Hyperactor's network channels (see: [`channel`]).
12//! The simulator can simulate message delivery delays and failures, and is used for
13//! testing and development of message distribution techniques.
14
15use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::task::JoinError;
42use tokio::task::JoinHandle;
43use tokio::time::interval;
44
45// for macros
46use crate::ActorId;
47use crate::Mailbox;
48use crate::OncePortRef;
49use crate::ProcId;
50use crate::attrs::Attrs;
51use crate::channel::ChannelAddr;
52use crate::clock::Clock;
53use crate::clock::RealClock;
54use crate::clock::SimClock;
55use crate::context::MailboxExt;
56use crate::data::Serialized;
57
58static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
59
60/// A handle for SimNet through which you can send and schedule events in the
61/// network.
62///
63/// Return the \[`NotStarted`\] error when called before `simnet::start()` has been called
64#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
65pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
66    match HANDLE.get() {
67        Some(handle) => Ok(handle),
68        None => Err(SimNetError::Closed("SimNet not started".to_string())),
69    }
70}
71
72const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
73
74/// This is used to define an Address-type for the network.
75/// Addresses are bound to nodes in the network.
76pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
77impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
78
79/// Minimum time unit for the simulator.
80pub type SimulatorTimeInstant = tokio::time::Instant;
81
82/// The unit of execution for the simulator.
83/// Using handle(), simnet can schedule executions in the network.
84/// If you want to send a message for example, you would want to implement
85/// a MessageDeliveryEvent much on the lines expressed in simnet tests.
86/// You can also do other more advanced concepts such as node churn,
87/// or even simulate process spawns in a distributed system. For example,
88/// one can implement a SystemActorSimEvent in order to spawn a system
89/// actor.
90#[async_trait]
91pub trait Event: Send + Sync + Debug {
92    /// This is the method that will be called when the simulator fires the event
93    /// at a particular time instant. Examples:
94    /// For messages, it will be delivering the message to the dst's receiver queue.
95    /// For a proc spawn, it will be creating the proc object and instantiating it.
96    /// For any event that manipulates the network (like adding/removing nodes etc.)
97    /// implement handle_network().
98    async fn handle(&self) -> Result<(), SimNetError>;
99
100    /// This is the method that will be called when the simulator fires the event
101    /// Unless you need to make changes to the network, you do not have to implement this.
102    /// Only implement handle() method for all non-simnet requirements.
103    async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
104        self.handle().await
105    }
106
107    /// The latency of the event. This could be network latency, induced latency (sleep), or
108    /// GPU work latency.
109    fn duration(&self) -> tokio::time::Duration;
110
111    /// A user-friendly summary of the event
112    fn summary(&self) -> String;
113}
114
115/// This is a simple event that is used to join a node to the network.
116/// It is used to bind a node to a channel address.
117#[derive(Debug)]
118struct NodeJoinEvent {
119    channel_addr: ChannelAddr,
120}
121
122#[async_trait]
123impl Event for NodeJoinEvent {
124    async fn handle(&self) -> Result<(), SimNetError> {
125        Ok(())
126    }
127
128    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
129        self.handle().await
130    }
131
132    fn duration(&self) -> tokio::time::Duration {
133        tokio::time::Duration::ZERO
134    }
135
136    fn summary(&self) -> String {
137        "Node join".into()
138    }
139}
140
141#[derive(Debug)]
142/// A pytorch operation
143pub struct TorchOpEvent {
144    op: String,
145    done_tx: OncePortRef<()>,
146    mailbox: Mailbox,
147    args_string: String,
148    kwargs_string: String,
149    worker_actor_id: ActorId,
150}
151
152#[async_trait]
153impl Event for TorchOpEvent {
154    async fn handle(&self) -> Result<(), SimNetError> {
155        Ok(())
156    }
157
158    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
159        let mut headers = Attrs::new();
160        crate::mailbox::headers::set_send_timestamp(&mut headers);
161        let serialized =
162            Serialized::serialize(&()).map_err(|err| SimNetError::Closed(err.to_string()))?;
163        self.mailbox
164            .post(self.done_tx.port_id().clone(), headers, serialized);
165        Ok(())
166    }
167
168    fn duration(&self) -> tokio::time::Duration {
169        tokio::time::Duration::from_millis(100)
170    }
171
172    fn summary(&self) -> String {
173        let kwargs_string = if self.kwargs_string.is_empty() {
174            "".to_string()
175        } else {
176            format!(", {}", self.kwargs_string)
177        };
178        format!(
179            "[{}] Torch Op: {}({}{})",
180            self.worker_actor_id, self.op, self.args_string, kwargs_string
181        )
182    }
183}
184
185impl TorchOpEvent {
186    /// Creates a new TorchOpEvent.
187    pub fn new(
188        op: String,
189        done_tx: OncePortRef<()>,
190        mailbox: Mailbox,
191        args_string: String,
192        kwargs_string: String,
193        worker_actor_id: ActorId,
194    ) -> Box<Self> {
195        Box::new(Self {
196            op,
197            done_tx,
198            mailbox,
199            args_string,
200            kwargs_string,
201            worker_actor_id,
202        })
203    }
204}
205
206#[derive(Debug)]
207pub(crate) struct SleepEvent {
208    done_tx: OncePortRef<()>,
209    mailbox: Mailbox,
210    duration: tokio::time::Duration,
211}
212
213impl SleepEvent {
214    pub(crate) fn new(
215        done_tx: OncePortRef<()>,
216        mailbox: Mailbox,
217        duration: tokio::time::Duration,
218    ) -> Box<Self> {
219        Box::new(Self {
220            done_tx,
221            mailbox,
222            duration,
223        })
224    }
225}
226
227#[async_trait]
228impl Event for SleepEvent {
229    async fn handle(&self) -> Result<(), SimNetError> {
230        Ok(())
231    }
232
233    async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
234        let mut headers = Attrs::new();
235        crate::mailbox::headers::set_send_timestamp(&mut headers);
236        let serialized =
237            Serialized::serialize(&()).map_err(|err| SimNetError::Closed(err.to_string()))?;
238        self.mailbox
239            .post(self.done_tx.port_id().clone(), headers, serialized);
240        Ok(())
241    }
242
243    fn duration(&self) -> tokio::time::Duration {
244        self.duration
245    }
246
247    fn summary(&self) -> String {
248        format!("Sleeping for {} ms", self.duration.as_millis())
249    }
250}
251
252/// Each message is timestamped with the delivery time
253/// of the message to the sender.
254/// The timestamp is used to determine the order in which
255/// messages are delivered to senders.
256#[derive(Debug)]
257pub(crate) struct ScheduledEvent {
258    pub(crate) time: SimulatorTimeInstant,
259    pub(crate) event: Box<dyn Event>,
260}
261
262/// Dispatcher is a trait that defines the send operation.
263/// The send operation takes a target address and a data buffer.
264/// This method is called when the simulator is ready for the message to be received
265/// by the target address.
266#[async_trait]
267pub trait Dispatcher<A> {
268    /// Send a raw data blob to the given target.
269    async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
270}
271
272/// SimNetError is used to indicate errors that occur during
273/// network simulation.
274#[derive(thiserror::Error, Debug)]
275#[non_exhaustive]
276pub enum SimNetError {
277    /// An invalid address was encountered.
278    #[error("invalid address: {0}")]
279    InvalidAddress(String),
280
281    /// An invalid node was encountered.
282    #[error("invalid node: {0}")]
283    InvalidNode(String, #[source] anyhow::Error),
284
285    /// An invalid parameter was encountered.
286    #[error("invalid arg: {0}")]
287    InvalidArg(String),
288
289    /// The simulator has been closed.
290    #[error("closed: {0}")]
291    Closed(String),
292
293    /// Timeout when waiting for something.
294    #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
295    Timeout(Duration, String),
296
297    /// Cannot deliver the message because destination address is missing.
298    #[error("missing destination address")]
299    MissingDestinationAddress,
300
301    /// SimnetHandle being accessed without starting simnet
302    #[error("simnet not started")]
303    NotStarted,
304}
305
306struct State {
307    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
308    scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
309    // The simnet is allowed to advance to the time of the earliest event in this queue at any time
310    // only if the earliest event in `scheduled_events` occurs after the earliest event in this queue
311    // or some debounce period has passed where there are only events in this queue.
312    unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
313}
314
315/// The state of the python training script.
316#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
317pub enum TrainingScriptState {
318    /// The training script is issuing commands
319    Running,
320    /// The training script is waiting for the backend to return a future result
321    Waiting,
322}
323
324/// A distribution of latencies that can be sampled from
325pub enum LatencyDistribution {
326    /// A beta distribution scaled to a given range of values
327    Beta(BetaDistribution),
328}
329
330impl LatencyDistribution {
331    fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
332        match &self {
333            LatencyDistribution::Beta(sampler) => sampler.sample(rng),
334        }
335    }
336}
337
338/// A beta distribution scaled to a given range of values.
339pub struct BetaDistribution {
340    min_duration: tokio::time::Duration,
341    max_duration: tokio::time::Duration,
342    dist: rand_distr::Beta<f64>,
343}
344
345impl BetaDistribution {
346    /// Sample a sclaed value from the distribution.
347    pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
348        let sample = self.dist.sample(rng);
349
350        self.min_duration
351            + tokio::time::Duration::from_micros(
352                (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
353            )
354    }
355
356    /// Create a new beta distribution.
357    pub fn new(
358        min_duration: tokio::time::Duration,
359        max_duration: tokio::time::Duration,
360        alpha: f64,
361        beta: f64,
362    ) -> anyhow::Result<Self> {
363        if min_duration > max_duration {
364            return Err(anyhow::anyhow!(
365                "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
366                min_duration,
367                max_duration
368            ));
369        }
370        Ok(Self {
371            min_duration,
372            max_duration,
373            dist: rand_distr::Beta::new(alpha, beta)?,
374        })
375    }
376}
377/// Configuration for latencies between distances for the simulator
378pub struct LatencyConfig {
379    /// inter-region latency distribution
380    pub inter_region_distribution: LatencyDistribution,
381    /// inter-data center latency distribution
382    pub inter_dc_distribution: LatencyDistribution,
383    /// inter-zone latency distribution
384    pub inter_zone_distribution: LatencyDistribution,
385    /// Single random number generator for all distributions to ensure deterministic sampling
386    pub rng: StdRng,
387}
388
389impl LatencyConfig {
390    fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
391        match distance {
392            Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
393            Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
394            Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
395            Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
396        }
397    }
398}
399
400impl Default for LatencyConfig {
401    fn default() -> Self {
402        let seed: u64 = 0000;
403        let mut seed_bytes = [0u8; 32];
404        seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
405
406        Self {
407            inter_region_distribution: LatencyDistribution::Beta(
408                BetaDistribution::new(
409                    tokio::time::Duration::from_millis(500),
410                    tokio::time::Duration::from_millis(1000),
411                    2.0,
412                    1.0,
413                )
414                .unwrap(),
415            ),
416            inter_dc_distribution: LatencyDistribution::Beta(
417                BetaDistribution::new(
418                    tokio::time::Duration::from_millis(50),
419                    tokio::time::Duration::from_millis(100),
420                    2.0,
421                    1.0,
422                )
423                .unwrap(),
424            ),
425            inter_zone_distribution: LatencyDistribution::Beta(
426                BetaDistribution::new(
427                    tokio::time::Duration::from_millis(5),
428                    tokio::time::Duration::from_millis(10),
429                    2.0,
430                    1.0,
431                )
432                .unwrap(),
433            ),
434            rng: StdRng::from_seed(seed_bytes),
435        }
436    }
437}
438
439/// A handle to a running [`SimNet`] instance.
440pub struct SimNetHandle {
441    join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
442    event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
443    pending_event_count: Arc<AtomicUsize>,
444    /// A receiver to receive simulator operational messages.
445    /// The receiver can be moved out of the simnet handle.
446    training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
447    /// Signal to stop the simnet loop
448    stop_signal: Arc<AtomicBool>,
449    resources: DashMap<ProcId, Point>,
450    latencies: std::sync::Mutex<LatencyConfig>,
451}
452
453impl SimNetHandle {
454    /// Sends an event to be scheduled onto the simnet's event loop
455    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
456    pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
457        self.send_event_impl(event, true)
458    }
459
460    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
461    fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
462        self.pending_event_count
463            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
464        self.event_tx
465            .send((event, advanceable, None))
466            .map_err(|err| SimNetError::Closed(err.to_string()))
467    }
468
469    /// Sends an non-advanceable event to be scheduled onto the simnet's event loop
470    /// A non-advanceable event is an event that cannot advance the simnet's time unless
471    /// the earliest event in the simnet's advancing event queue occurs after the earliest
472    /// event in the simnet's non-advancing event queue, or some debounce period has passed
473    /// where there are only events in the simnet's non-advancing event queue.
474    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
475    pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
476        self.send_event_impl(event, false)
477    }
478
479    /// Sends an event that already has a scheduled time onto the simnet's event loop
480    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
481    pub(crate) fn send_scheduled_event(
482        &self,
483        ScheduledEvent { event, time }: ScheduledEvent,
484    ) -> Result<(), SimNetError> {
485        self.pending_event_count
486            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
487        self.event_tx
488            .send((event, true, Some(time)))
489            .map_err(|err| SimNetError::Closed(err.to_string()))
490    }
491
492    /// Let the simnet know if the training script is running or waiting for the backend
493    /// to return a future result.
494    pub fn set_training_script_state(&self, state: TrainingScriptState) {
495        self.training_script_state_tx.send(state).unwrap();
496    }
497
498    /// Bind the given address to this simulator instance.
499    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
500    pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
501        self.pending_event_count
502            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
503        self.event_tx
504            .send((
505                Box::new(NodeJoinEvent {
506                    channel_addr: address,
507                }),
508                true,
509                None,
510            ))
511            .map_err(|err| SimNetError::Closed(err.to_string()))
512    }
513
514    /// Close the simulator, processing pending messages before
515    /// completing the returned future.
516    pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
517        // Signal the simnet loop to stop
518        self.stop_signal.store(true, Ordering::SeqCst);
519
520        let mut guard = self.join_handle.lock().await;
521        if let Some(handle) = guard.take() {
522            handle.await
523        } else {
524            Ok(vec![])
525        }
526    }
527
528    /// Wait for all of the received events to be scheduled for flight.
529    /// It ticks the simnet time till all of the scheduled events are processed.
530    pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
531        let pending_event_count = self.pending_event_count.clone();
532        // poll for the pending event count to be zero.
533        let mut interval = interval(Duration::from_millis(10));
534        let deadline = RealClock.now() + timeout;
535        while RealClock.now() < deadline {
536            interval.tick().await;
537            if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
538                return Ok(());
539            }
540        }
541        Err(SimNetError::Timeout(
542            timeout,
543            "timeout waiting for received events to be scheduled".to_string(),
544        ))
545    }
546
547    /// Register the location in resource space for a Proc
548    pub fn register_proc(&self, proc_id: ProcId, point: Point) {
549        self.resources.insert(proc_id, point);
550    }
551
552    /// Sample a latency between two procs
553    pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
554        let distances = [
555            Distance::Region,
556            Distance::DataCenter,
557            Distance::Zone,
558            Distance::Rack,
559            Distance::Host,
560            Distance::Same,
561        ];
562
563        let src_coords = self
564            .resources
565            .get(src)
566            .map(|point| point.coords().clone())
567            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
568
569        let dest_coords = self
570            .resources
571            .get(dest)
572            .map(|point| point.coords().clone())
573            .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
574
575        for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
576            if src != dest {
577                let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
578                return guard.from_distance(&distance);
579            }
580        }
581
582        let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
583        guard.from_distance(&Distance::Same)
584    }
585}
586
587#[derive(Debug)]
588enum Distance {
589    Region,
590    DataCenter,
591    Zone,
592    Rack,
593    Host,
594    Same,
595}
596
597/// SimNet defines a network of nodes.
598/// Each node is identified by a unique id.
599/// The network is represented as a graph of nodes.
600/// The graph is represented as a map of edges.
601/// The network also has a cloud of inflight messages
602pub struct SimNet {
603    address_book: DashSet<ChannelAddr>,
604    state: State,
605    max_latency: Duration,
606    records: Vec<SimulatorEventRecord>,
607    // number of events that has been received but not yet processed.
608    pending_event_count: Arc<AtomicUsize>,
609}
610
611/// Starts a sim net.
612pub fn start() {
613    start_with_config(LatencyConfig::default())
614}
615
616/// Starts a sim net with configured latencies between distances
617pub fn start_with_config(config: LatencyConfig) {
618    let max_duration_ms = 1000 * 10;
619    // Construct a topology with one node: the default A.
620    let address_book: DashSet<ChannelAddr> = DashSet::new();
621
622    let (training_script_state_tx, training_script_state_rx) =
623        tokio::sync::watch::channel(TrainingScriptState::Waiting);
624    let (event_tx, event_rx) =
625        mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
626    let pending_event_count = Arc::new(AtomicUsize::new(0));
627    let stop_signal = Arc::new(AtomicBool::new(false));
628
629    let join_handle = Mutex::new(Some({
630        let pending_event_count = pending_event_count.clone();
631        let stop_signal = stop_signal.clone();
632
633        tokio::spawn(async move {
634            SimNet {
635                address_book,
636                state: State {
637                    scheduled_events: BTreeMap::new(),
638                    unadvanceable_scheduled_events: BTreeMap::new(),
639                },
640                max_latency: Duration::from_millis(max_duration_ms),
641                records: Vec::new(),
642                pending_event_count,
643            }
644            .run(event_rx, training_script_state_rx, stop_signal)
645            .await
646        })
647    }));
648
649    HANDLE.get_or_init(|| SimNetHandle {
650        join_handle,
651        event_tx,
652        pending_event_count,
653        training_script_state_tx,
654        stop_signal,
655        resources: DashMap::new(),
656        latencies: std::sync::Mutex::new(config),
657    });
658}
659
660impl SimNet {
661    fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
662        // Get latency
663        ScheduledEvent {
664            time: SimClock.now() + event.duration(),
665            event,
666        }
667    }
668
669    /// Schedule the event into the network.
670    fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
671        let start_at = SimClock.now();
672        let end_at = scheduled_event.time;
673
674        self.records.push(SimulatorEventRecord {
675            summary: scheduled_event.event.summary(),
676            start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
677            end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
678        });
679
680        if advanceable {
681            self.state
682                .scheduled_events
683                .entry(scheduled_event.time)
684                .or_default()
685                .push(scheduled_event);
686        } else {
687            self.state
688                .unadvanceable_scheduled_events
689                .entry(scheduled_event.time)
690                .or_default()
691                .push(scheduled_event);
692        }
693    }
694
695    /// Run the simulation. This will dispatch all the messages in the network.
696    /// And wait for new ones.
697    async fn run(
698        &mut self,
699        mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
700        training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
701        stop_signal: Arc<AtomicBool>,
702    ) -> Vec<SimulatorEventRecord> {
703        // The simulated number of milliseconds the training script
704        // has spent waiting for the backend to resolve a future
705        let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
706        // Duration elapsed while only non_advanceable_events has events
707        let mut debounce_timer: Option<tokio::time::Instant> = None;
708
709        let debounce_duration = std::env::var("SIM_DEBOUNCE")
710            .ok()
711            .and_then(|val| val.parse::<u64>().ok())
712            .unwrap_or(1);
713
714        'outer: loop {
715            // Check if we should stop
716            if stop_signal.load(Ordering::SeqCst) {
717                break 'outer self.records.clone();
718            }
719
720            while let Ok(Some((event, advanceable, time))) = RealClock
721                .timeout(
722                    tokio::time::Duration::from_millis(debounce_duration),
723                    event_rx.recv(),
724                )
725                .await
726            {
727                let scheduled_event = match time {
728                    Some(time) => ScheduledEvent {
729                        time: time + training_script_waiting_time,
730                        event,
731                    },
732                    None => self.create_scheduled_event(event),
733                };
734                self.schedule_event(scheduled_event, advanceable);
735            }
736
737            {
738                // If the training script is running and issuing commands
739                // it is not safe to advance past the training script time
740                // otherwise a command issued by the training script may
741                // be scheduled for a time in the past
742                if training_script_state_rx.borrow().is_running()
743                    && self
744                        .state
745                        .scheduled_events
746                        .first_key_value()
747                        .is_some_and(|(time, _)| {
748                            *time > RealClock.now() + training_script_waiting_time
749                        })
750                {
751                    tokio::task::yield_now().await;
752                    continue;
753                }
754                match (
755                    self.state.scheduled_events.first_key_value(),
756                    self.state.unadvanceable_scheduled_events.first_key_value(),
757                ) {
758                    (None, Some(_)) if debounce_timer.is_none() => {
759                        // Start debounce timer when only the non-advancedable
760                        // queue has events and the timer has not already started
761                        debounce_timer = Some(RealClock.now());
762                    }
763                    // Timer already active
764                    (None, Some(_)) => {}
765                    // Reset timer when non-advanceable queue is not the only queue with events
766                    _ => {
767                        debounce_timer = None;
768                    }
769                }
770                // process for next delivery time.
771                let Some((scheduled_time, scheduled_events)) = (match (
772                    self.state.scheduled_events.first_key_value(),
773                    self.state.unadvanceable_scheduled_events.first_key_value(),
774                ) {
775                    (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
776                        if unadvanceable_time < advanceable_time {
777                            self.state.unadvanceable_scheduled_events.pop_first()
778                        } else {
779                            self.state.scheduled_events.pop_first()
780                        }
781                    }
782                    (Some(_), None) => self.state.scheduled_events.pop_first(),
783                    (None, Some(_)) => match debounce_timer {
784                        Some(time) => {
785                            if time.elapsed() > tokio::time::Duration::from_millis(1000) {
786                                // debounce interval has elapsed, reset timer
787                                debounce_timer = None;
788                                self.state.unadvanceable_scheduled_events.pop_first()
789                            } else {
790                                None
791                            }
792                        }
793                        None => None,
794                    },
795                    (None, None) => None,
796                }) else {
797                    tokio::select! {
798                        Some((event, advanceable, time)) = event_rx.recv() => {
799                            let scheduled_event = match time {
800                                Some(time) => ScheduledEvent {
801                                    time: time + training_script_waiting_time,
802                                    event,
803                                },
804                                None => self.create_scheduled_event(event),
805                            };
806                            self.schedule_event(scheduled_event, advanceable);
807                        },
808                        _ = RealClock.sleep(Duration::from_millis(10)) => {}
809                    }
810                    continue;
811                };
812                if training_script_state_rx.borrow().is_waiting() {
813                    let advanced_time = scheduled_time - SimClock.now();
814                    training_script_waiting_time += advanced_time;
815                }
816                SimClock.advance_to(scheduled_time);
817                for scheduled_event in scheduled_events {
818                    self.pending_event_count
819                        .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
820                    if scheduled_event.event.handle_network(self).await.is_err() {
821                        break 'outer self.records.clone(); //TODO
822                    }
823                }
824            }
825        }
826    }
827}
828
829fn serialize_optional_channel_addr<S>(
830    addr: &Option<ChannelAddr>,
831    serializer: S,
832) -> Result<S::Ok, S::Error>
833where
834    S: Serializer,
835{
836    match addr {
837        Some(addr) => serializer.serialize_str(&addr.to_string()),
838        None => serializer.serialize_none(),
839    }
840}
841
842fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
843where
844    D: Deserializer<'de>,
845{
846    let s = String::deserialize(deserializer)?;
847    s.parse().map_err(serde::de::Error::custom)
848}
849
850/// DeliveryRecord is a structure to bookkeep the message events.
851#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
852pub struct SimulatorEventRecord {
853    /// Event dependent summary for user
854    pub summary: String,
855    /// The time at which the message delivery was started.
856    pub start_at: u64,
857    /// The time at which the message was delivered to the receiver.
858    pub end_at: u64,
859}
860
861#[cfg(test)]
862mod tests {
863    use std::collections::HashMap;
864    use std::sync::Arc;
865
866    use async_trait::async_trait;
867    use ndslice::extent;
868    use tokio::sync::Mutex;
869
870    use super::*;
871    use crate::channel::sim::SimAddr;
872    use crate::clock::Clock;
873    use crate::clock::RealClock;
874    use crate::clock::SimClock;
875    use crate::data::Serialized;
876    use crate::id;
877    use crate::simnet;
878    use crate::simnet::Dispatcher;
879    use crate::simnet::Event;
880    use crate::simnet::SimNetError;
881
882    #[derive(Debug)]
883    struct MessageDeliveryEvent {
884        src_addr: SimAddr,
885        dest_addr: SimAddr,
886        data: Serialized,
887        duration: tokio::time::Duration,
888        dispatcher: Option<TestDispatcher>,
889    }
890
891    #[async_trait]
892    impl Event for MessageDeliveryEvent {
893        async fn handle(&self) -> Result<(), simnet::SimNetError> {
894            if let Some(dispatcher) = &self.dispatcher {
895                dispatcher
896                    .send(self.dest_addr.clone(), self.data.clone())
897                    .await?;
898            }
899            Ok(())
900        }
901        fn duration(&self) -> tokio::time::Duration {
902            self.duration
903        }
904
905        fn summary(&self) -> String {
906            format!(
907                "Sending message from {} to {}",
908                self.src_addr.addr().clone(),
909                self.dest_addr.addr().clone()
910            )
911        }
912    }
913
914    impl MessageDeliveryEvent {
915        fn new(
916            src_addr: SimAddr,
917            dest_addr: SimAddr,
918            data: Serialized,
919            dispatcher: Option<TestDispatcher>,
920            duration: tokio::time::Duration,
921        ) -> Self {
922            Self {
923                src_addr,
924                dest_addr,
925                data,
926                duration,
927                dispatcher,
928            }
929        }
930    }
931
932    #[derive(Debug, Clone)]
933    struct TestDispatcher {
934        pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
935    }
936
937    impl Default for TestDispatcher {
938        fn default() -> Self {
939            Self {
940                mbuffers: Arc::new(Mutex::new(HashMap::new())),
941            }
942        }
943    }
944
945    #[async_trait]
946    impl Dispatcher<SimAddr> for TestDispatcher {
947        async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
948            let mut buf = self.mbuffers.lock().await;
949            buf.entry(target).or_default().push(data);
950            Ok(())
951        }
952    }
953
954    #[cfg(target_os = "linux")]
955    fn random_abstract_addr() -> ChannelAddr {
956        use rand::Rng;
957        use rand::distributions::Alphanumeric;
958
959        let random_string = rand::thread_rng()
960            .sample_iter(&Alphanumeric)
961            .take(24)
962            .map(char::from)
963            .collect::<String>();
964        format!("unix!@{random_string}").parse().unwrap()
965    }
966
967    #[tokio::test]
968    async fn test_handle_instantiation() {
969        start();
970        simnet_handle().unwrap().close().await.unwrap();
971    }
972
973    #[tokio::test]
974    async fn test_simnet_config() {
975        // Tests that we can create a simnet, config latency between distances and sample latencies between procs.
976        let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
977
978        let alice = id!(world[0]);
979        let bob = id!(world[1]);
980        let charlie = id!(world[2]);
981
982        let config = LatencyConfig {
983            inter_zone_distribution: LatencyDistribution::Beta(
984                BetaDistribution::new(
985                    tokio::time::Duration::from_millis(1000),
986                    tokio::time::Duration::from_millis(1000),
987                    1.0,
988                    1.0,
989                )
990                .unwrap(),
991            ),
992            inter_dc_distribution: LatencyDistribution::Beta(
993                BetaDistribution::new(
994                    tokio::time::Duration::from_millis(2000),
995                    tokio::time::Duration::from_millis(2000),
996                    1.0,
997                    1.0,
998                )
999                .unwrap(),
1000            ),
1001            ..Default::default()
1002        };
1003        start_with_config(config);
1004
1005        let handle = simnet_handle().unwrap();
1006        handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1007        handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1008        handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1009        assert_eq!(
1010            handle.sample_latency(&alice, &bob),
1011            tokio::time::Duration::from_millis(1000)
1012        );
1013        assert_eq!(
1014            handle.sample_latency(&alice, &charlie),
1015            tokio::time::Duration::from_millis(2000)
1016        );
1017    }
1018
1019    #[tokio::test]
1020    async fn test_simnet_debounce() {
1021        let config = LatencyConfig {
1022            inter_zone_distribution: LatencyDistribution::Beta(
1023                BetaDistribution::new(
1024                    tokio::time::Duration::from_millis(1000),
1025                    tokio::time::Duration::from_millis(1000),
1026                    1.0,
1027                    1.0,
1028                )
1029                .unwrap(),
1030            ),
1031            ..Default::default()
1032        };
1033        start_with_config(config);
1034        let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1035        let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1036
1037        let latency = Duration::from_millis(10000);
1038
1039        let alice = SimAddr::new(alice).unwrap();
1040        let bob = SimAddr::new(bob).unwrap();
1041
1042        // Rapidly send 10 messages expecting that each one debounces the processing
1043        for _ in 0..10 {
1044            simnet_handle()
1045                .unwrap()
1046                .send_event(Box::new(MessageDeliveryEvent::new(
1047                    alice.clone(),
1048                    bob.clone(),
1049                    Serialized::serialize(&"123".to_string()).unwrap(),
1050                    None,
1051                    latency,
1052                )))
1053                .unwrap();
1054            RealClock
1055                .sleep(tokio::time::Duration::from_micros(500))
1056                .await;
1057        }
1058
1059        simnet_handle()
1060            .unwrap()
1061            .flush(Duration::from_secs(20))
1062            .await
1063            .unwrap();
1064
1065        let records = simnet_handle().unwrap().close().await;
1066        assert_eq!(records.as_ref().unwrap().len(), 10);
1067
1068        // If debounce is successful, the simnet will not advance to the delivery of any of
1069        // the messages before all are received
1070        assert_eq!(
1071            records.unwrap().last().unwrap().end_at,
1072            latency.as_millis() as u64
1073        );
1074    }
1075
1076    #[tokio::test]
1077    async fn test_sim_dispatch() {
1078        start();
1079        let sender = Some(TestDispatcher::default());
1080        let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1081        // // Create a simple network of 4 nodes.
1082        for i in 0..4 {
1083            addresses.push(
1084                format!("local:{}", i)
1085                    .parse::<simnet::ChannelAddr>()
1086                    .unwrap(),
1087            );
1088        }
1089
1090        let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1091            .into_iter()
1092            .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1093            .collect();
1094
1095        let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1096        let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1097        let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1098        let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1099        let one = Box::new(MessageDeliveryEvent::new(
1100            addr_0.clone(),
1101            addr_1.clone(),
1102            messages[0].clone(),
1103            sender.clone(),
1104            tokio::time::Duration::ZERO,
1105        ));
1106        let two = Box::new(MessageDeliveryEvent::new(
1107            addr_2.clone(),
1108            addr_3.clone(),
1109            messages[1].clone(),
1110            sender.clone(),
1111            tokio::time::Duration::ZERO,
1112        ));
1113        let three = Box::new(MessageDeliveryEvent::new(
1114            addr_0.clone(),
1115            addr_1.clone(),
1116            messages[2].clone(),
1117            sender.clone(),
1118            tokio::time::Duration::ZERO,
1119        ));
1120
1121        simnet_handle().unwrap().send_event(one).unwrap();
1122        simnet_handle().unwrap().send_event(two).unwrap();
1123        simnet_handle().unwrap().send_event(three).unwrap();
1124
1125        simnet_handle()
1126            .unwrap()
1127            .flush(Duration::from_millis(1000))
1128            .await
1129            .unwrap();
1130        let records = simnet_handle().unwrap().close().await.unwrap();
1131        eprintln!("Records: {:?}", records);
1132        // Close the channel
1133        simnet_handle().unwrap().close().await.unwrap();
1134
1135        // Check results
1136        let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1137        assert_eq!(buf.len(), 2);
1138        assert_eq!(buf[&addr_1].len(), 2);
1139        assert_eq!(buf[&addr_3].len(), 1);
1140
1141        assert_eq!(buf[&addr_1][0], messages[0]);
1142        assert_eq!(buf[&addr_1][1], messages[2]);
1143        assert_eq!(buf[&addr_3][0], messages[1]);
1144    }
1145
1146    #[tokio::test]
1147    async fn test_sim_sleep() {
1148        start();
1149
1150        let start = SimClock.now();
1151        assert_eq!(
1152            SimClock.duration_since_start(start),
1153            tokio::time::Duration::ZERO
1154        );
1155
1156        SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1157
1158        let end = SimClock.now();
1159        assert_eq!(
1160            SimClock.duration_since_start(end),
1161            tokio::time::Duration::from_secs(10)
1162        );
1163    }
1164
1165    #[tokio::test]
1166    async fn test_torch_op() {
1167        start();
1168        let args_string = "1, 2".to_string();
1169        let kwargs_string = "a=2".to_string();
1170
1171        let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
1172        let (tx, rx) = mailbox.open_once_port::<()>();
1173
1174        simnet_handle()
1175            .unwrap()
1176            .send_event(TorchOpEvent::new(
1177                "torch.ops.aten.ones.default".to_string(),
1178                tx.bind(),
1179                mailbox,
1180                args_string,
1181                kwargs_string,
1182                id!(mesh_0_worker[0].worker_0),
1183            ))
1184            .unwrap();
1185
1186        rx.recv().await.unwrap();
1187
1188        simnet_handle()
1189            .unwrap()
1190            .flush(Duration::from_millis(1000))
1191            .await
1192            .unwrap();
1193        let records = simnet_handle().unwrap().close().await;
1194        let expected_record = SimulatorEventRecord {
1195            summary:
1196                "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1197                    .to_string(),
1198            start_at: 0,
1199            end_at: 100,
1200        };
1201        assert!(records.as_ref().unwrap().len() == 1);
1202        assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1203    }
1204}