1#![allow(dead_code)]
10
11use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::sync::oneshot;
42use tokio::task::JoinError;
43use tokio::task::JoinHandle;
44use tokio::time::interval;
45
46use crate::ActorId;
47use crate::ProcId;
48use crate::channel::ChannelAddr;
49use crate::clock::Clock;
50use crate::clock::RealClock;
51use crate::clock::SimClock;
52use crate::data::Serialized;
53
54static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
55
56#[allow(clippy::result_large_err)] pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
62 match HANDLE.get() {
63 Some(handle) => Ok(handle),
64 None => Err(SimNetError::Closed("SimNet not started".to_string())),
65 }
66}
67
68const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
69
70pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
73impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
74
75pub type SimulatorTimeInstant = tokio::time::Instant;
77
78#[async_trait]
87pub trait Event: Send + Sync + Debug {
88 async fn handle(&self) -> Result<(), SimNetError>;
95
96 async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
100 self.handle().await
101 }
102
103 fn duration(&self) -> tokio::time::Duration;
106
107 fn summary(&self) -> String;
109}
110
111#[derive(Debug)]
114struct NodeJoinEvent {
115 channel_addr: ChannelAddr,
116}
117
118#[async_trait]
119impl Event for NodeJoinEvent {
120 async fn handle(&self) -> Result<(), SimNetError> {
121 Ok(())
122 }
123
124 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
125 self.handle().await
126 }
127
128 fn duration(&self) -> tokio::time::Duration {
129 tokio::time::Duration::ZERO
130 }
131
132 fn summary(&self) -> String {
133 "Node join".into()
134 }
135}
136
137#[derive(Debug)]
138pub struct TorchOpEvent {
140 op: String,
141 done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
142 args_string: String,
143 kwargs_string: String,
144 worker_actor_id: ActorId,
145}
146
147#[async_trait]
148impl Event for TorchOpEvent {
149 async fn handle(&self) -> Result<(), SimNetError> {
150 Ok(())
151 }
152
153 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
154 let mut guard = self.done_tx.lock().await;
155 match guard.take() {
156 Some(done_tx) => {
157 let _ = done_tx
158 .send(())
159 .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
160 }
161 None => {
162 return Err(SimNetError::Closed("already sent once".to_string()));
163 }
164 }
165 Ok(())
166 }
167
168 fn duration(&self) -> tokio::time::Duration {
169 tokio::time::Duration::from_millis(100)
170 }
171
172 fn summary(&self) -> String {
173 let kwargs_string = if self.kwargs_string.is_empty() {
174 "".to_string()
175 } else {
176 format!(", {}", self.kwargs_string)
177 };
178 format!(
179 "[{}] Torch Op: {}({}{})",
180 self.worker_actor_id, self.op, self.args_string, kwargs_string
181 )
182 }
183}
184
185impl TorchOpEvent {
186 pub fn new(
188 op: String,
189 done_tx: oneshot::Sender<()>,
190 args_string: String,
191 kwargs_string: String,
192 worker_actor_id: ActorId,
193 ) -> Box<Self> {
194 Box::new(Self {
195 op,
196 done_tx: Arc::new(Mutex::new(Some(done_tx))),
197 args_string,
198 kwargs_string,
199 worker_actor_id,
200 })
201 }
202}
203
204#[derive(Debug)]
205pub(crate) struct SleepEvent {
206 done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
207 duration: tokio::time::Duration,
208}
209
210impl SleepEvent {
211 pub(crate) fn new(done_tx: oneshot::Sender<()>, duration: tokio::time::Duration) -> Box<Self> {
212 Box::new(Self {
213 done_tx: Arc::new(Mutex::new(Some(done_tx))),
214 duration,
215 })
216 }
217}
218
219#[async_trait]
220impl Event for SleepEvent {
221 async fn handle(&self) -> Result<(), SimNetError> {
222 Ok(())
223 }
224
225 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
226 let mut guard = self.done_tx.lock().await;
227 match guard.take() {
228 Some(done_tx) => {
229 let _ = done_tx
230 .send(())
231 .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
232 }
233 None => {
234 return Err(SimNetError::Closed("already sent once".to_string()));
235 }
236 }
237 Ok(())
238 }
239
240 fn duration(&self) -> tokio::time::Duration {
241 self.duration
242 }
243
244 fn summary(&self) -> String {
245 format!("Sleeping for {} ms", self.duration.as_millis())
246 }
247}
248
249#[derive(Debug)]
254pub(crate) struct ScheduledEvent {
255 pub(crate) time: SimulatorTimeInstant,
256 pub(crate) event: Box<dyn Event>,
257}
258
259#[async_trait]
264pub trait Dispatcher<A> {
265 async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
267}
268
269#[derive(thiserror::Error, Debug)]
272#[non_exhaustive]
273pub enum SimNetError {
274 #[error("invalid address: {0}")]
276 InvalidAddress(String),
277
278 #[error("invalid node: {0}")]
280 InvalidNode(String, #[source] anyhow::Error),
281
282 #[error("invalid arg: {0}")]
284 InvalidArg(String),
285
286 #[error("closed: {0}")]
288 Closed(String),
289
290 #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
292 Timeout(Duration, String),
293
294 #[error("missing destination address")]
296 MissingDestinationAddress,
297
298 #[error("simnet not started")]
300 NotStarted,
301}
302
303struct State {
304 scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
306 unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
310}
311
312#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
314pub enum TrainingScriptState {
315 Running,
317 Waiting,
319}
320
321pub enum LatencyDistribution {
323 Beta(BetaDistribution),
325}
326
327impl LatencyDistribution {
328 fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
329 match &self {
330 LatencyDistribution::Beta(sampler) => sampler.sample(rng),
331 }
332 }
333}
334
335pub struct BetaDistribution {
337 min_duration: tokio::time::Duration,
338 max_duration: tokio::time::Duration,
339 dist: rand_distr::Beta<f64>,
340}
341
342impl BetaDistribution {
343 pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
345 let sample = self.dist.sample(rng);
346
347 self.min_duration
348 + tokio::time::Duration::from_micros(
349 (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
350 )
351 }
352
353 pub fn new(
355 min_duration: tokio::time::Duration,
356 max_duration: tokio::time::Duration,
357 alpha: f64,
358 beta: f64,
359 ) -> anyhow::Result<Self> {
360 if min_duration > max_duration {
361 return Err(anyhow::anyhow!(
362 "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
363 min_duration,
364 max_duration
365 ));
366 }
367 Ok(Self {
368 min_duration,
369 max_duration,
370 dist: rand_distr::Beta::new(alpha, beta)?,
371 })
372 }
373}
374pub struct LatencyConfig {
376 pub inter_region_distribution: LatencyDistribution,
378 pub inter_dc_distribution: LatencyDistribution,
380 pub inter_zone_distribution: LatencyDistribution,
382 pub rng: StdRng,
384}
385
386impl LatencyConfig {
387 fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
388 match distance {
389 Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
390 Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
391 Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
392 Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
393 }
394 }
395}
396
397impl Default for LatencyConfig {
398 fn default() -> Self {
399 let seed: u64 = 0000;
400 let mut seed_bytes = [0u8; 32];
401 seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
402
403 Self {
404 inter_region_distribution: LatencyDistribution::Beta(
405 BetaDistribution::new(
406 tokio::time::Duration::from_millis(500),
407 tokio::time::Duration::from_millis(1000),
408 2.0,
409 1.0,
410 )
411 .unwrap(),
412 ),
413 inter_dc_distribution: LatencyDistribution::Beta(
414 BetaDistribution::new(
415 tokio::time::Duration::from_millis(50),
416 tokio::time::Duration::from_millis(100),
417 2.0,
418 1.0,
419 )
420 .unwrap(),
421 ),
422 inter_zone_distribution: LatencyDistribution::Beta(
423 BetaDistribution::new(
424 tokio::time::Duration::from_millis(5),
425 tokio::time::Duration::from_millis(10),
426 2.0,
427 1.0,
428 )
429 .unwrap(),
430 ),
431 rng: StdRng::from_seed(seed_bytes),
432 }
433 }
434}
435
436pub struct SimNetHandle {
438 join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
439 event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
440 pending_event_count: Arc<AtomicUsize>,
441 training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
444 stop_signal: Arc<AtomicBool>,
446 resources: DashMap<ProcId, Point>,
447 latencies: std::sync::Mutex<LatencyConfig>,
448}
449
450impl SimNetHandle {
451 #[allow(clippy::result_large_err)] pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
454 self.send_event_impl(event, true)
455 }
456
457 #[allow(clippy::result_large_err)] fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
459 self.pending_event_count
460 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
461 self.event_tx
462 .send((event, advanceable, None))
463 .map_err(|err| SimNetError::Closed(err.to_string()))
464 }
465
466 #[allow(clippy::result_large_err)] pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
473 self.send_event_impl(event, false)
474 }
475
476 #[allow(clippy::result_large_err)] pub(crate) fn send_scheduled_event(
479 &self,
480 ScheduledEvent { event, time }: ScheduledEvent,
481 ) -> Result<(), SimNetError> {
482 self.pending_event_count
483 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
484 self.event_tx
485 .send((event, true, Some(time)))
486 .map_err(|err| SimNetError::Closed(err.to_string()))
487 }
488
489 pub fn set_training_script_state(&self, state: TrainingScriptState) {
492 self.training_script_state_tx.send(state).unwrap();
493 }
494
495 #[allow(clippy::result_large_err)] pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
498 self.pending_event_count
499 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
500 self.event_tx
501 .send((
502 Box::new(NodeJoinEvent {
503 channel_addr: address,
504 }),
505 true,
506 None,
507 ))
508 .map_err(|err| SimNetError::Closed(err.to_string()))
509 }
510
511 pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
514 self.stop_signal.store(true, Ordering::SeqCst);
516
517 let mut guard = self.join_handle.lock().await;
518 if let Some(handle) = guard.take() {
519 handle.await
520 } else {
521 Ok(vec![])
522 }
523 }
524
525 pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
528 let pending_event_count = self.pending_event_count.clone();
529 let mut interval = interval(Duration::from_millis(10));
531 let deadline = RealClock.now() + timeout;
532 while RealClock.now() < deadline {
533 interval.tick().await;
534 if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
535 return Ok(());
536 }
537 }
538 Err(SimNetError::Timeout(
539 timeout,
540 "timeout waiting for received events to be scheduled".to_string(),
541 ))
542 }
543
544 pub fn register_proc(&self, proc_id: ProcId, point: Point) {
546 self.resources.insert(proc_id, point);
547 }
548
549 pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
551 let distances = [
552 Distance::Region,
553 Distance::DataCenter,
554 Distance::Zone,
555 Distance::Rack,
556 Distance::Host,
557 Distance::Same,
558 ];
559
560 let src_coords = self
561 .resources
562 .get(src)
563 .map(|point| point.coords().clone())
564 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
565
566 let dest_coords = self
567 .resources
568 .get(dest)
569 .map(|point| point.coords().clone())
570 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
571
572 for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
573 if src != dest {
574 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
575 return guard.from_distance(&distance);
576 }
577 }
578
579 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
580 guard.from_distance(&Distance::Same)
581 }
582}
583
584#[derive(Debug)]
585enum Distance {
586 Region,
587 DataCenter,
588 Zone,
589 Rack,
590 Host,
591 Same,
592}
593
594pub struct SimNet {
600 address_book: DashSet<ChannelAddr>,
601 state: State,
602 max_latency: Duration,
603 records: Vec<SimulatorEventRecord>,
604 pending_event_count: Arc<AtomicUsize>,
606}
607
608pub fn start() {
610 start_with_config(LatencyConfig::default())
611}
612
613pub fn start_with_config(config: LatencyConfig) {
615 let max_duration_ms = 1000 * 10;
616 let address_book: DashSet<ChannelAddr> = DashSet::new();
618
619 let (training_script_state_tx, training_script_state_rx) =
620 tokio::sync::watch::channel(TrainingScriptState::Waiting);
621 let (event_tx, event_rx) =
622 mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
623 let pending_event_count = Arc::new(AtomicUsize::new(0));
624 let stop_signal = Arc::new(AtomicBool::new(false));
625
626 let join_handle = Mutex::new(Some({
627 let pending_event_count = pending_event_count.clone();
628 let stop_signal = stop_signal.clone();
629
630 tokio::spawn(async move {
631 SimNet {
632 address_book,
633 state: State {
634 scheduled_events: BTreeMap::new(),
635 unadvanceable_scheduled_events: BTreeMap::new(),
636 },
637 max_latency: Duration::from_millis(max_duration_ms),
638 records: Vec::new(),
639 pending_event_count,
640 }
641 .run(event_rx, training_script_state_rx, stop_signal)
642 .await
643 })
644 }));
645
646 HANDLE.get_or_init(|| SimNetHandle {
647 join_handle,
648 event_tx,
649 pending_event_count,
650 training_script_state_tx,
651 stop_signal,
652 resources: DashMap::new(),
653 latencies: std::sync::Mutex::new(config),
654 });
655}
656
657impl SimNet {
658 fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
659 ScheduledEvent {
661 time: SimClock.now() + event.duration(),
662 event,
663 }
664 }
665
666 fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
668 let start_at = SimClock.now();
669 let end_at = scheduled_event.time;
670
671 self.records.push(SimulatorEventRecord {
672 summary: scheduled_event.event.summary(),
673 start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
674 end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
675 });
676
677 if advanceable {
678 self.state
679 .scheduled_events
680 .entry(scheduled_event.time)
681 .or_default()
682 .push(scheduled_event);
683 } else {
684 self.state
685 .unadvanceable_scheduled_events
686 .entry(scheduled_event.time)
687 .or_default()
688 .push(scheduled_event);
689 }
690 }
691
692 async fn run(
695 &mut self,
696 mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
697 training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
698 stop_signal: Arc<AtomicBool>,
699 ) -> Vec<SimulatorEventRecord> {
700 let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
703 let mut debounce_timer: Option<tokio::time::Instant> = None;
705
706 let debounce_duration = std::env::var("SIM_DEBOUNCE")
707 .ok()
708 .and_then(|val| val.parse::<u64>().ok())
709 .unwrap_or(1);
710
711 'outer: loop {
712 if stop_signal.load(Ordering::SeqCst) {
714 break 'outer self.records.clone();
715 }
716
717 while let Ok(Some((event, advanceable, time))) = RealClock
718 .timeout(
719 tokio::time::Duration::from_millis(debounce_duration),
720 event_rx.recv(),
721 )
722 .await
723 {
724 let scheduled_event = match time {
725 Some(time) => ScheduledEvent {
726 time: time + training_script_waiting_time,
727 event,
728 },
729 None => self.create_scheduled_event(event),
730 };
731 self.schedule_event(scheduled_event, advanceable);
732 }
733
734 {
735 if training_script_state_rx.borrow().is_running()
740 && self
741 .state
742 .scheduled_events
743 .first_key_value()
744 .is_some_and(|(time, _)| {
745 *time > RealClock.now() + training_script_waiting_time
746 })
747 {
748 tokio::task::yield_now().await;
749 continue;
750 }
751 match (
752 self.state.scheduled_events.first_key_value(),
753 self.state.unadvanceable_scheduled_events.first_key_value(),
754 ) {
755 (None, Some(_)) if debounce_timer.is_none() => {
756 debounce_timer = Some(RealClock.now());
759 }
760 (None, Some(_)) => {}
762 _ => {
764 debounce_timer = None;
765 }
766 }
767 let Some((scheduled_time, scheduled_events)) = (match (
769 self.state.scheduled_events.first_key_value(),
770 self.state.unadvanceable_scheduled_events.first_key_value(),
771 ) {
772 (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
773 if unadvanceable_time < advanceable_time {
774 self.state.unadvanceable_scheduled_events.pop_first()
775 } else {
776 self.state.scheduled_events.pop_first()
777 }
778 }
779 (Some(_), None) => self.state.scheduled_events.pop_first(),
780 (None, Some(_)) => match debounce_timer {
781 Some(time) => {
782 if time.elapsed() > tokio::time::Duration::from_millis(1000) {
783 debounce_timer = None;
785 self.state.unadvanceable_scheduled_events.pop_first()
786 } else {
787 None
788 }
789 }
790 None => None,
791 },
792 (None, None) => None,
793 }) else {
794 tokio::select! {
795 Some((event, advanceable, time)) = event_rx.recv() => {
796 let scheduled_event = match time {
797 Some(time) => ScheduledEvent {
798 time: time + training_script_waiting_time,
799 event,
800 },
801 None => self.create_scheduled_event(event),
802 };
803 self.schedule_event(scheduled_event, advanceable);
804 },
805 _ = RealClock.sleep(Duration::from_millis(10)) => {}
806 }
807 continue;
808 };
809 if training_script_state_rx.borrow().is_waiting() {
810 let advanced_time = scheduled_time - SimClock.now();
811 training_script_waiting_time += advanced_time;
812 }
813 SimClock.advance_to(scheduled_time);
814 for scheduled_event in scheduled_events {
815 self.pending_event_count
816 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
817 if scheduled_event.event.handle_network(self).await.is_err() {
818 break 'outer self.records.clone(); }
820 }
821 }
822 }
823 }
824}
825
826fn serialize_optional_channel_addr<S>(
827 addr: &Option<ChannelAddr>,
828 serializer: S,
829) -> Result<S::Ok, S::Error>
830where
831 S: Serializer,
832{
833 match addr {
834 Some(addr) => serializer.serialize_str(&addr.to_string()),
835 None => serializer.serialize_none(),
836 }
837}
838
839fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
840where
841 D: Deserializer<'de>,
842{
843 let s = String::deserialize(deserializer)?;
844 s.parse().map_err(serde::de::Error::custom)
845}
846
847#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
849pub struct SimulatorEventRecord {
850 pub summary: String,
852 pub start_at: u64,
854 pub end_at: u64,
856}
857
858#[cfg(test)]
859mod tests {
860 use std::collections::HashMap;
861 use std::sync::Arc;
862
863 use async_trait::async_trait;
864 use ndslice::extent;
865 use tokio::sync::Mutex;
866
867 use super::*;
868 use crate::channel::sim::SimAddr;
869 use crate::clock::Clock;
870 use crate::clock::RealClock;
871 use crate::clock::SimClock;
872 use crate::data::Serialized;
873 use crate::id;
874 use crate::simnet;
875 use crate::simnet::Dispatcher;
876 use crate::simnet::Event;
877 use crate::simnet::SimNetError;
878
879 #[derive(Debug)]
880 struct MessageDeliveryEvent {
881 src_addr: SimAddr,
882 dest_addr: SimAddr,
883 data: Serialized,
884 duration: tokio::time::Duration,
885 dispatcher: Option<TestDispatcher>,
886 }
887
888 #[async_trait]
889 impl Event for MessageDeliveryEvent {
890 async fn handle(&self) -> Result<(), simnet::SimNetError> {
891 if let Some(dispatcher) = &self.dispatcher {
892 dispatcher
893 .send(self.dest_addr.clone(), self.data.clone())
894 .await?;
895 }
896 Ok(())
897 }
898 fn duration(&self) -> tokio::time::Duration {
899 self.duration
900 }
901
902 fn summary(&self) -> String {
903 format!(
904 "Sending message from {} to {}",
905 self.src_addr.addr().clone(),
906 self.dest_addr.addr().clone()
907 )
908 }
909 }
910
911 impl MessageDeliveryEvent {
912 fn new(
913 src_addr: SimAddr,
914 dest_addr: SimAddr,
915 data: Serialized,
916 dispatcher: Option<TestDispatcher>,
917 duration: tokio::time::Duration,
918 ) -> Self {
919 Self {
920 src_addr,
921 dest_addr,
922 data,
923 duration,
924 dispatcher,
925 }
926 }
927 }
928
929 #[derive(Debug, Clone)]
930 struct TestDispatcher {
931 pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
932 }
933
934 impl Default for TestDispatcher {
935 fn default() -> Self {
936 Self {
937 mbuffers: Arc::new(Mutex::new(HashMap::new())),
938 }
939 }
940 }
941
942 #[async_trait]
943 impl Dispatcher<SimAddr> for TestDispatcher {
944 async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
945 let mut buf = self.mbuffers.lock().await;
946 buf.entry(target).or_default().push(data);
947 Ok(())
948 }
949 }
950
951 #[cfg(target_os = "linux")]
952 fn random_abstract_addr() -> ChannelAddr {
953 use rand::Rng;
954 use rand::distributions::Alphanumeric;
955
956 let random_string = rand::thread_rng()
957 .sample_iter(&Alphanumeric)
958 .take(24)
959 .map(char::from)
960 .collect::<String>();
961 format!("unix!@{random_string}").parse().unwrap()
962 }
963
964 #[tokio::test]
965 async fn test_handle_instantiation() {
966 start();
967 simnet_handle().unwrap().close().await.unwrap();
968 }
969
970 #[tokio::test]
971 async fn test_simnet_config() {
972 let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
974
975 let alice = id!(world[0]);
976 let bob = id!(world[1]);
977 let charlie = id!(world[2]);
978
979 let config = LatencyConfig {
980 inter_zone_distribution: LatencyDistribution::Beta(
981 BetaDistribution::new(
982 tokio::time::Duration::from_millis(1000),
983 tokio::time::Duration::from_millis(1000),
984 1.0,
985 1.0,
986 )
987 .unwrap(),
988 ),
989 inter_dc_distribution: LatencyDistribution::Beta(
990 BetaDistribution::new(
991 tokio::time::Duration::from_millis(2000),
992 tokio::time::Duration::from_millis(2000),
993 1.0,
994 1.0,
995 )
996 .unwrap(),
997 ),
998 ..Default::default()
999 };
1000 start_with_config(config);
1001
1002 let handle = simnet_handle().unwrap();
1003 handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1004 handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1005 handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1006 assert_eq!(
1007 handle.sample_latency(&alice, &bob),
1008 tokio::time::Duration::from_millis(1000)
1009 );
1010 assert_eq!(
1011 handle.sample_latency(&alice, &charlie),
1012 tokio::time::Duration::from_millis(2000)
1013 );
1014 }
1015
1016 #[tokio::test]
1017 async fn test_simnet_debounce() {
1018 let config = LatencyConfig {
1019 inter_zone_distribution: LatencyDistribution::Beta(
1020 BetaDistribution::new(
1021 tokio::time::Duration::from_millis(1000),
1022 tokio::time::Duration::from_millis(1000),
1023 1.0,
1024 1.0,
1025 )
1026 .unwrap(),
1027 ),
1028 ..Default::default()
1029 };
1030 start_with_config(config);
1031 let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1032 let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1033
1034 let latency = Duration::from_millis(10000);
1035
1036 let alice = SimAddr::new(alice).unwrap();
1037 let bob = SimAddr::new(bob).unwrap();
1038
1039 for _ in 0..10 {
1041 simnet_handle()
1042 .unwrap()
1043 .send_event(Box::new(MessageDeliveryEvent::new(
1044 alice.clone(),
1045 bob.clone(),
1046 Serialized::serialize(&"123".to_string()).unwrap(),
1047 None,
1048 latency,
1049 )))
1050 .unwrap();
1051 RealClock
1052 .sleep(tokio::time::Duration::from_micros(500))
1053 .await;
1054 }
1055
1056 simnet_handle()
1057 .unwrap()
1058 .flush(Duration::from_secs(20))
1059 .await
1060 .unwrap();
1061
1062 let records = simnet_handle().unwrap().close().await;
1063 assert_eq!(records.as_ref().unwrap().len(), 10);
1064
1065 assert_eq!(
1068 records.unwrap().last().unwrap().end_at,
1069 latency.as_millis() as u64
1070 );
1071 }
1072
1073 #[tokio::test]
1074 async fn test_sim_dispatch() {
1075 start();
1076 let sender = Some(TestDispatcher::default());
1077 let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1078 for i in 0..4 {
1080 addresses.push(
1081 format!("local:{}", i)
1082 .parse::<simnet::ChannelAddr>()
1083 .unwrap(),
1084 );
1085 }
1086
1087 let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1088 .into_iter()
1089 .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1090 .collect();
1091
1092 let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1093 let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1094 let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1095 let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1096 let one = Box::new(MessageDeliveryEvent::new(
1097 addr_0.clone(),
1098 addr_1.clone(),
1099 messages[0].clone(),
1100 sender.clone(),
1101 tokio::time::Duration::ZERO,
1102 ));
1103 let two = Box::new(MessageDeliveryEvent::new(
1104 addr_2.clone(),
1105 addr_3.clone(),
1106 messages[1].clone(),
1107 sender.clone(),
1108 tokio::time::Duration::ZERO,
1109 ));
1110 let three = Box::new(MessageDeliveryEvent::new(
1111 addr_0.clone(),
1112 addr_1.clone(),
1113 messages[2].clone(),
1114 sender.clone(),
1115 tokio::time::Duration::ZERO,
1116 ));
1117
1118 simnet_handle().unwrap().send_event(one).unwrap();
1119 simnet_handle().unwrap().send_event(two).unwrap();
1120 simnet_handle().unwrap().send_event(three).unwrap();
1121
1122 simnet_handle()
1123 .unwrap()
1124 .flush(Duration::from_millis(1000))
1125 .await
1126 .unwrap();
1127 let records = simnet_handle().unwrap().close().await.unwrap();
1128 eprintln!("Records: {:?}", records);
1129 simnet_handle().unwrap().close().await.unwrap();
1131
1132 let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1134 assert_eq!(buf.len(), 2);
1135 assert_eq!(buf[&addr_1].len(), 2);
1136 assert_eq!(buf[&addr_3].len(), 1);
1137
1138 assert_eq!(buf[&addr_1][0], messages[0]);
1139 assert_eq!(buf[&addr_1][1], messages[2]);
1140 assert_eq!(buf[&addr_3][0], messages[1]);
1141 }
1142
1143 #[tokio::test]
1144 async fn test_sim_sleep() {
1145 start();
1146
1147 let start = SimClock.now();
1148 assert_eq!(
1149 SimClock.duration_since_start(start),
1150 tokio::time::Duration::ZERO
1151 );
1152
1153 SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1154
1155 let end = SimClock.now();
1156 assert_eq!(
1157 SimClock.duration_since_start(end),
1158 tokio::time::Duration::from_secs(10)
1159 );
1160 }
1161
1162 #[tokio::test]
1163 async fn test_torch_op() {
1164 start();
1165 let args_string = "1, 2".to_string();
1166 let kwargs_string = "a=2".to_string();
1167
1168 let (tx, rx) = oneshot::channel();
1169
1170 simnet_handle()
1171 .unwrap()
1172 .send_event(TorchOpEvent::new(
1173 "torch.ops.aten.ones.default".to_string(),
1174 tx,
1175 args_string,
1176 kwargs_string,
1177 id!(mesh_0_worker[0].worker_0),
1178 ))
1179 .unwrap();
1180
1181 rx.await.unwrap();
1182
1183 simnet_handle()
1184 .unwrap()
1185 .flush(Duration::from_millis(1000))
1186 .await
1187 .unwrap();
1188 let records = simnet_handle().unwrap().close().await;
1189 let expected_record = SimulatorEventRecord {
1190 summary:
1191 "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1192 .to_string(),
1193 start_at: 0,
1194 end_at: 100,
1195 };
1196 assert!(records.as_ref().unwrap().len() == 1);
1197 assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1198 }
1199}