1#![allow(dead_code)]
10
11use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::sync::oneshot;
42use tokio::task::JoinError;
43use tokio::task::JoinHandle;
44use tokio::time::interval;
45
46use crate::ActorId;
47use crate::ProcId;
48use crate::channel::ChannelAddr;
49use crate::clock::Clock;
50use crate::clock::RealClock;
51use crate::clock::SimClock;
52
53static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
54
55#[allow(clippy::result_large_err)] pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
61 match HANDLE.get() {
62 Some(handle) => Ok(handle),
63 None => Err(SimNetError::Closed("SimNet not started".to_string())),
64 }
65}
66
67const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
68
69pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
72impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
73
74pub type SimulatorTimeInstant = tokio::time::Instant;
76
77#[async_trait]
86pub trait Event: Send + Sync + Debug {
87 async fn handle(&self) -> Result<(), SimNetError>;
94
95 async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
99 self.handle().await
100 }
101
102 fn duration(&self) -> tokio::time::Duration;
105
106 fn summary(&self) -> String;
108}
109
110#[derive(Debug)]
113struct NodeJoinEvent {
114 channel_addr: ChannelAddr,
115}
116
117#[async_trait]
118impl Event for NodeJoinEvent {
119 async fn handle(&self) -> Result<(), SimNetError> {
120 Ok(())
121 }
122
123 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
124 self.handle().await
125 }
126
127 fn duration(&self) -> tokio::time::Duration {
128 tokio::time::Duration::ZERO
129 }
130
131 fn summary(&self) -> String {
132 "Node join".into()
133 }
134}
135
136#[derive(Debug)]
137pub struct TorchOpEvent {
139 op: String,
140 done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
141 args_string: String,
142 kwargs_string: String,
143 worker_actor_id: ActorId,
144}
145
146#[async_trait]
147impl Event for TorchOpEvent {
148 async fn handle(&self) -> Result<(), SimNetError> {
149 Ok(())
150 }
151
152 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
153 let mut guard = self.done_tx.lock().await;
154 match guard.take() {
155 Some(done_tx) => {
156 let _ = done_tx
157 .send(())
158 .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
159 }
160 None => {
161 return Err(SimNetError::Closed("already sent once".to_string()));
162 }
163 }
164 Ok(())
165 }
166
167 fn duration(&self) -> tokio::time::Duration {
168 tokio::time::Duration::from_millis(100)
169 }
170
171 fn summary(&self) -> String {
172 let kwargs_string = if self.kwargs_string.is_empty() {
173 "".to_string()
174 } else {
175 format!(", {}", self.kwargs_string)
176 };
177 format!(
178 "[{}] Torch Op: {}({}{})",
179 self.worker_actor_id, self.op, self.args_string, kwargs_string
180 )
181 }
182}
183
184impl TorchOpEvent {
185 pub fn new(
187 op: String,
188 done_tx: oneshot::Sender<()>,
189 args_string: String,
190 kwargs_string: String,
191 worker_actor_id: ActorId,
192 ) -> Box<Self> {
193 Box::new(Self {
194 op,
195 done_tx: Arc::new(Mutex::new(Some(done_tx))),
196 args_string,
197 kwargs_string,
198 worker_actor_id,
199 })
200 }
201}
202
203#[derive(Debug)]
204pub(crate) struct SleepEvent {
205 done_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
206 duration: tokio::time::Duration,
207}
208
209impl SleepEvent {
210 pub(crate) fn new(done_tx: oneshot::Sender<()>, duration: tokio::time::Duration) -> Box<Self> {
211 Box::new(Self {
212 done_tx: Arc::new(Mutex::new(Some(done_tx))),
213 duration,
214 })
215 }
216}
217
218#[async_trait]
219impl Event for SleepEvent {
220 async fn handle(&self) -> Result<(), SimNetError> {
221 Ok(())
222 }
223
224 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
225 let mut guard = self.done_tx.lock().await;
226 match guard.take() {
227 Some(done_tx) => {
228 let _ = done_tx
229 .send(())
230 .map_err(|_| SimNetError::Closed("done channel is closed".to_string()));
231 }
232 None => {
233 return Err(SimNetError::Closed("already sent once".to_string()));
234 }
235 }
236 Ok(())
237 }
238
239 fn duration(&self) -> tokio::time::Duration {
240 self.duration
241 }
242
243 fn summary(&self) -> String {
244 format!("Sleeping for {} ms", self.duration.as_millis())
245 }
246}
247
248#[derive(Debug)]
253pub(crate) struct ScheduledEvent {
254 pub(crate) time: SimulatorTimeInstant,
255 pub(crate) event: Box<dyn Event>,
256}
257
258#[async_trait]
263pub trait Dispatcher<A> {
264 async fn send(&self, target: A, data: wirevalue::Any) -> Result<(), SimNetError>;
266}
267
268#[derive(thiserror::Error, Debug)]
271#[non_exhaustive]
272pub enum SimNetError {
273 #[error("invalid address: {0}")]
275 InvalidAddress(String),
276
277 #[error("invalid node: {0}")]
279 InvalidNode(String, #[source] anyhow::Error),
280
281 #[error("invalid arg: {0}")]
283 InvalidArg(String),
284
285 #[error("closed: {0}")]
287 Closed(String),
288
289 #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
291 Timeout(Duration, String),
292
293 #[error("missing destination address")]
295 MissingDestinationAddress,
296
297 #[error("simnet not started")]
299 NotStarted,
300}
301
302struct State {
303 scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
305 unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
309}
310
311#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
313pub enum TrainingScriptState {
314 Running,
316 Waiting,
318}
319
320pub enum LatencyDistribution {
322 Beta(BetaDistribution),
324}
325
326impl LatencyDistribution {
327 fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
328 match &self {
329 LatencyDistribution::Beta(sampler) => sampler.sample(rng),
330 }
331 }
332}
333
334pub struct BetaDistribution {
336 min_duration: tokio::time::Duration,
337 max_duration: tokio::time::Duration,
338 dist: rand_distr::Beta<f64>,
339}
340
341impl BetaDistribution {
342 pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
344 let sample = self.dist.sample(rng);
345
346 self.min_duration
347 + tokio::time::Duration::from_micros(
348 (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
349 )
350 }
351
352 pub fn new(
354 min_duration: tokio::time::Duration,
355 max_duration: tokio::time::Duration,
356 alpha: f64,
357 beta: f64,
358 ) -> anyhow::Result<Self> {
359 if min_duration > max_duration {
360 return Err(anyhow::anyhow!(
361 "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
362 min_duration,
363 max_duration
364 ));
365 }
366 Ok(Self {
367 min_duration,
368 max_duration,
369 dist: rand_distr::Beta::new(alpha, beta)?,
370 })
371 }
372}
373pub struct LatencyConfig {
375 pub inter_region_distribution: LatencyDistribution,
377 pub inter_dc_distribution: LatencyDistribution,
379 pub inter_zone_distribution: LatencyDistribution,
381 pub rng: StdRng,
383}
384
385impl LatencyConfig {
386 fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
387 match distance {
388 Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
389 Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
390 Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
391 Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
392 }
393 }
394}
395
396impl Default for LatencyConfig {
397 fn default() -> Self {
398 let seed: u64 = 0000;
399 let mut seed_bytes = [0u8; 32];
400 seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
401
402 Self {
403 inter_region_distribution: LatencyDistribution::Beta(
404 BetaDistribution::new(
405 tokio::time::Duration::from_millis(500),
406 tokio::time::Duration::from_millis(1000),
407 2.0,
408 1.0,
409 )
410 .unwrap(),
411 ),
412 inter_dc_distribution: LatencyDistribution::Beta(
413 BetaDistribution::new(
414 tokio::time::Duration::from_millis(50),
415 tokio::time::Duration::from_millis(100),
416 2.0,
417 1.0,
418 )
419 .unwrap(),
420 ),
421 inter_zone_distribution: LatencyDistribution::Beta(
422 BetaDistribution::new(
423 tokio::time::Duration::from_millis(5),
424 tokio::time::Duration::from_millis(10),
425 2.0,
426 1.0,
427 )
428 .unwrap(),
429 ),
430 rng: StdRng::from_seed(seed_bytes),
431 }
432 }
433}
434
435pub struct SimNetHandle {
437 join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
438 event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
439 pending_event_count: Arc<AtomicUsize>,
440 training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
443 stop_signal: Arc<AtomicBool>,
445 resources: DashMap<ProcId, Point>,
446 latencies: std::sync::Mutex<LatencyConfig>,
447}
448
449impl SimNetHandle {
450 #[allow(clippy::result_large_err)] pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
453 self.send_event_impl(event, true)
454 }
455
456 #[allow(clippy::result_large_err)] fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
458 self.pending_event_count
459 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
460 self.event_tx
461 .send((event, advanceable, None))
462 .map_err(|err| SimNetError::Closed(err.to_string()))
463 }
464
465 #[allow(clippy::result_large_err)] pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
472 self.send_event_impl(event, false)
473 }
474
475 #[allow(clippy::result_large_err)] pub(crate) fn send_scheduled_event(
478 &self,
479 ScheduledEvent { event, time }: ScheduledEvent,
480 ) -> Result<(), SimNetError> {
481 self.pending_event_count
482 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
483 self.event_tx
484 .send((event, true, Some(time)))
485 .map_err(|err| SimNetError::Closed(err.to_string()))
486 }
487
488 pub fn set_training_script_state(&self, state: TrainingScriptState) {
491 self.training_script_state_tx.send(state).unwrap();
492 }
493
494 #[allow(clippy::result_large_err)] pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
497 self.pending_event_count
498 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
499 self.event_tx
500 .send((
501 Box::new(NodeJoinEvent {
502 channel_addr: address,
503 }),
504 true,
505 None,
506 ))
507 .map_err(|err| SimNetError::Closed(err.to_string()))
508 }
509
510 pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
513 self.stop_signal.store(true, Ordering::SeqCst);
515
516 let mut guard = self.join_handle.lock().await;
517 if let Some(handle) = guard.take() {
518 handle.await
519 } else {
520 Ok(vec![])
521 }
522 }
523
524 pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
527 let pending_event_count = self.pending_event_count.clone();
528 let mut interval = interval(Duration::from_millis(10));
530 let deadline = RealClock.now() + timeout;
531 while RealClock.now() < deadline {
532 interval.tick().await;
533 if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
534 return Ok(());
535 }
536 }
537 Err(SimNetError::Timeout(
538 timeout,
539 "timeout waiting for received events to be scheduled".to_string(),
540 ))
541 }
542
543 pub fn register_proc(&self, proc_id: ProcId, point: Point) {
545 self.resources.insert(proc_id, point);
546 }
547
548 pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
550 let distances = [
551 Distance::Region,
552 Distance::DataCenter,
553 Distance::Zone,
554 Distance::Rack,
555 Distance::Host,
556 Distance::Same,
557 ];
558
559 let src_coords = self
560 .resources
561 .get(src)
562 .map(|point| point.coords().clone())
563 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
564
565 let dest_coords = self
566 .resources
567 .get(dest)
568 .map(|point| point.coords().clone())
569 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
570
571 for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
572 if src != dest {
573 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
574 return guard.from_distance(&distance);
575 }
576 }
577
578 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
579 guard.from_distance(&Distance::Same)
580 }
581}
582
583#[derive(Debug)]
584enum Distance {
585 Region,
586 DataCenter,
587 Zone,
588 Rack,
589 Host,
590 Same,
591}
592
593pub struct SimNet {
599 address_book: DashSet<ChannelAddr>,
600 state: State,
601 max_latency: Duration,
602 records: Vec<SimulatorEventRecord>,
603 pending_event_count: Arc<AtomicUsize>,
605}
606
607pub fn start() {
609 start_with_config(LatencyConfig::default())
610}
611
612pub fn start_with_config(config: LatencyConfig) {
614 let max_duration_ms = 1000 * 10;
615 let address_book: DashSet<ChannelAddr> = DashSet::new();
617
618 let (training_script_state_tx, training_script_state_rx) =
619 tokio::sync::watch::channel(TrainingScriptState::Waiting);
620 let (event_tx, event_rx) =
621 mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
622 let pending_event_count = Arc::new(AtomicUsize::new(0));
623 let stop_signal = Arc::new(AtomicBool::new(false));
624
625 let join_handle = Mutex::new(Some({
626 let pending_event_count = pending_event_count.clone();
627 let stop_signal = stop_signal.clone();
628
629 tokio::spawn(async move {
630 SimNet {
631 address_book,
632 state: State {
633 scheduled_events: BTreeMap::new(),
634 unadvanceable_scheduled_events: BTreeMap::new(),
635 },
636 max_latency: Duration::from_millis(max_duration_ms),
637 records: Vec::new(),
638 pending_event_count,
639 }
640 .run(event_rx, training_script_state_rx, stop_signal)
641 .await
642 })
643 }));
644
645 HANDLE.get_or_init(|| SimNetHandle {
646 join_handle,
647 event_tx,
648 pending_event_count,
649 training_script_state_tx,
650 stop_signal,
651 resources: DashMap::new(),
652 latencies: std::sync::Mutex::new(config),
653 });
654}
655
656impl SimNet {
657 fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
658 ScheduledEvent {
660 time: SimClock.now() + event.duration(),
661 event,
662 }
663 }
664
665 fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
667 let start_at = SimClock.now();
668 let end_at = scheduled_event.time;
669
670 self.records.push(SimulatorEventRecord {
671 summary: scheduled_event.event.summary(),
672 start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
673 end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
674 });
675
676 if advanceable {
677 self.state
678 .scheduled_events
679 .entry(scheduled_event.time)
680 .or_default()
681 .push(scheduled_event);
682 } else {
683 self.state
684 .unadvanceable_scheduled_events
685 .entry(scheduled_event.time)
686 .or_default()
687 .push(scheduled_event);
688 }
689 }
690
691 async fn run(
694 &mut self,
695 mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
696 training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
697 stop_signal: Arc<AtomicBool>,
698 ) -> Vec<SimulatorEventRecord> {
699 let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
702 let mut debounce_timer: Option<tokio::time::Instant> = None;
704
705 let debounce_duration = std::env::var("SIM_DEBOUNCE")
706 .ok()
707 .and_then(|val| val.parse::<u64>().ok())
708 .unwrap_or(1);
709
710 'outer: loop {
711 if stop_signal.load(Ordering::SeqCst) {
713 break 'outer self.records.clone();
714 }
715
716 while let Ok(Some((event, advanceable, time))) = RealClock
717 .timeout(
718 tokio::time::Duration::from_millis(debounce_duration),
719 event_rx.recv(),
720 )
721 .await
722 {
723 let scheduled_event = match time {
724 Some(time) => ScheduledEvent {
725 time: time + training_script_waiting_time,
726 event,
727 },
728 None => self.create_scheduled_event(event),
729 };
730 self.schedule_event(scheduled_event, advanceable);
731 }
732
733 {
734 if training_script_state_rx.borrow().is_running()
739 && self
740 .state
741 .scheduled_events
742 .first_key_value()
743 .is_some_and(|(time, _)| {
744 *time > RealClock.now() + training_script_waiting_time
745 })
746 {
747 tokio::task::yield_now().await;
748 continue;
749 }
750 match (
751 self.state.scheduled_events.first_key_value(),
752 self.state.unadvanceable_scheduled_events.first_key_value(),
753 ) {
754 (None, Some(_)) if debounce_timer.is_none() => {
755 debounce_timer = Some(RealClock.now());
758 }
759 (None, Some(_)) => {}
761 _ => {
763 debounce_timer = None;
764 }
765 }
766 let Some((scheduled_time, scheduled_events)) = (match (
768 self.state.scheduled_events.first_key_value(),
769 self.state.unadvanceable_scheduled_events.first_key_value(),
770 ) {
771 (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
772 if unadvanceable_time < advanceable_time {
773 self.state.unadvanceable_scheduled_events.pop_first()
774 } else {
775 self.state.scheduled_events.pop_first()
776 }
777 }
778 (Some(_), None) => self.state.scheduled_events.pop_first(),
779 (None, Some(_)) => match debounce_timer {
780 Some(time) => {
781 if time.elapsed() > tokio::time::Duration::from_millis(1000) {
782 debounce_timer = None;
784 self.state.unadvanceable_scheduled_events.pop_first()
785 } else {
786 None
787 }
788 }
789 None => None,
790 },
791 (None, None) => None,
792 }) else {
793 tokio::select! {
794 Some((event, advanceable, time)) = event_rx.recv() => {
795 let scheduled_event = match time {
796 Some(time) => ScheduledEvent {
797 time: time + training_script_waiting_time,
798 event,
799 },
800 None => self.create_scheduled_event(event),
801 };
802 self.schedule_event(scheduled_event, advanceable);
803 },
804 _ = RealClock.sleep(Duration::from_millis(10)) => {}
805 }
806 continue;
807 };
808 if training_script_state_rx.borrow().is_waiting() {
809 let advanced_time = scheduled_time - SimClock.now();
810 training_script_waiting_time += advanced_time;
811 }
812 SimClock.advance_to(scheduled_time);
813 for scheduled_event in scheduled_events {
814 self.pending_event_count
815 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
816 if scheduled_event.event.handle_network(self).await.is_err() {
817 break 'outer self.records.clone(); }
819 }
820 }
821 }
822 }
823}
824
825fn serialize_optional_channel_addr<S>(
826 addr: &Option<ChannelAddr>,
827 serializer: S,
828) -> Result<S::Ok, S::Error>
829where
830 S: Serializer,
831{
832 match addr {
833 Some(addr) => serializer.serialize_str(&addr.to_string()),
834 None => serializer.serialize_none(),
835 }
836}
837
838fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
839where
840 D: Deserializer<'de>,
841{
842 let s = String::deserialize(deserializer)?;
843 s.parse().map_err(serde::de::Error::custom)
844}
845
846#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
848pub struct SimulatorEventRecord {
849 pub summary: String,
851 pub start_at: u64,
853 pub end_at: u64,
855}
856
857#[cfg(test)]
858mod tests {
859 use std::collections::HashMap;
860 use std::sync::Arc;
861
862 use async_trait::async_trait;
863 use ndslice::extent;
864 use tokio::sync::Mutex;
865
866 use super::*;
867 use crate::channel::sim::SimAddr;
868 use crate::clock::Clock;
869 use crate::clock::RealClock;
870 use crate::clock::SimClock;
871 use crate::id;
872 use crate::simnet;
873 use crate::simnet::Dispatcher;
874 use crate::simnet::Event;
875 use crate::simnet::SimNetError;
876
877 #[derive(Debug)]
878 struct MessageDeliveryEvent {
879 src_addr: SimAddr,
880 dest_addr: SimAddr,
881 data: wirevalue::Any,
882 duration: tokio::time::Duration,
883 dispatcher: Option<TestDispatcher>,
884 }
885
886 #[async_trait]
887 impl Event for MessageDeliveryEvent {
888 async fn handle(&self) -> Result<(), simnet::SimNetError> {
889 if let Some(dispatcher) = &self.dispatcher {
890 dispatcher
891 .send(self.dest_addr.clone(), self.data.clone())
892 .await?;
893 }
894 Ok(())
895 }
896 fn duration(&self) -> tokio::time::Duration {
897 self.duration
898 }
899
900 fn summary(&self) -> String {
901 format!(
902 "Sending message from {} to {}",
903 self.src_addr.addr().clone(),
904 self.dest_addr.addr().clone()
905 )
906 }
907 }
908
909 impl MessageDeliveryEvent {
910 fn new(
911 src_addr: SimAddr,
912 dest_addr: SimAddr,
913 data: wirevalue::Any,
914 dispatcher: Option<TestDispatcher>,
915 duration: tokio::time::Duration,
916 ) -> Self {
917 Self {
918 src_addr,
919 dest_addr,
920 data,
921 duration,
922 dispatcher,
923 }
924 }
925 }
926
927 #[derive(Debug, Clone)]
928 struct TestDispatcher {
929 pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<wirevalue::Any>>>>,
930 }
931
932 impl Default for TestDispatcher {
933 fn default() -> Self {
934 Self {
935 mbuffers: Arc::new(Mutex::new(HashMap::new())),
936 }
937 }
938 }
939
940 #[async_trait]
941 impl Dispatcher<SimAddr> for TestDispatcher {
942 async fn send(&self, target: SimAddr, data: wirevalue::Any) -> Result<(), SimNetError> {
943 let mut buf = self.mbuffers.lock().await;
944 buf.entry(target).or_default().push(data);
945 Ok(())
946 }
947 }
948
949 #[cfg(target_os = "linux")]
950 fn random_abstract_addr() -> ChannelAddr {
951 use rand::Rng;
952 use rand::distributions::Alphanumeric;
953
954 let random_string = rand::thread_rng()
955 .sample_iter(&Alphanumeric)
956 .take(24)
957 .map(char::from)
958 .collect::<String>();
959 format!("unix!@{random_string}").parse().unwrap()
960 }
961
962 #[tokio::test]
963 async fn test_handle_instantiation() {
964 start();
965 simnet_handle().unwrap().close().await.unwrap();
966 }
967
968 #[tokio::test]
969 async fn test_simnet_config() {
970 let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
972
973 let alice = id!(world[0]);
974 let bob = id!(world[1]);
975 let charlie = id!(world[2]);
976
977 let config = LatencyConfig {
978 inter_zone_distribution: LatencyDistribution::Beta(
979 BetaDistribution::new(
980 tokio::time::Duration::from_millis(1000),
981 tokio::time::Duration::from_millis(1000),
982 1.0,
983 1.0,
984 )
985 .unwrap(),
986 ),
987 inter_dc_distribution: LatencyDistribution::Beta(
988 BetaDistribution::new(
989 tokio::time::Duration::from_millis(2000),
990 tokio::time::Duration::from_millis(2000),
991 1.0,
992 1.0,
993 )
994 .unwrap(),
995 ),
996 ..Default::default()
997 };
998 start_with_config(config);
999
1000 let handle = simnet_handle().unwrap();
1001 handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1002 handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1003 handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1004 assert_eq!(
1005 handle.sample_latency(&alice, &bob),
1006 tokio::time::Duration::from_millis(1000)
1007 );
1008 assert_eq!(
1009 handle.sample_latency(&alice, &charlie),
1010 tokio::time::Duration::from_millis(2000)
1011 );
1012 }
1013
1014 #[tokio::test]
1015 async fn test_simnet_debounce() {
1016 let config = LatencyConfig {
1017 inter_zone_distribution: LatencyDistribution::Beta(
1018 BetaDistribution::new(
1019 tokio::time::Duration::from_millis(1000),
1020 tokio::time::Duration::from_millis(1000),
1021 1.0,
1022 1.0,
1023 )
1024 .unwrap(),
1025 ),
1026 ..Default::default()
1027 };
1028 start_with_config(config);
1029 let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1030 let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1031
1032 let latency = Duration::from_millis(10000);
1033
1034 let alice = SimAddr::new(alice).unwrap();
1035 let bob = SimAddr::new(bob).unwrap();
1036
1037 for _ in 0..10 {
1039 simnet_handle()
1040 .unwrap()
1041 .send_event(Box::new(MessageDeliveryEvent::new(
1042 alice.clone(),
1043 bob.clone(),
1044 wirevalue::Any::serialize(&"123".to_string()).unwrap(),
1045 None,
1046 latency,
1047 )))
1048 .unwrap();
1049 RealClock
1050 .sleep(tokio::time::Duration::from_micros(500))
1051 .await;
1052 }
1053
1054 simnet_handle()
1055 .unwrap()
1056 .flush(Duration::from_secs(20))
1057 .await
1058 .unwrap();
1059
1060 let records = simnet_handle().unwrap().close().await;
1061 assert_eq!(records.as_ref().unwrap().len(), 10);
1062
1063 assert_eq!(
1066 records.unwrap().last().unwrap().end_at,
1067 latency.as_millis() as u64
1068 );
1069 }
1070
1071 #[tokio::test]
1072 async fn test_sim_dispatch() {
1073 start();
1074 let sender = Some(TestDispatcher::default());
1075 let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1076 for i in 0..4 {
1078 addresses.push(
1079 format!("local:{}", i)
1080 .parse::<simnet::ChannelAddr>()
1081 .unwrap(),
1082 );
1083 }
1084
1085 let messages: Vec<wirevalue::Any> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1086 .into_iter()
1087 .map(|s| wirevalue::Any::serialize(&s.to_string()).unwrap())
1088 .collect();
1089
1090 let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1091 let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1092 let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1093 let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1094 let one = Box::new(MessageDeliveryEvent::new(
1095 addr_0.clone(),
1096 addr_1.clone(),
1097 messages[0].clone(),
1098 sender.clone(),
1099 tokio::time::Duration::ZERO,
1100 ));
1101 let two = Box::new(MessageDeliveryEvent::new(
1102 addr_2.clone(),
1103 addr_3.clone(),
1104 messages[1].clone(),
1105 sender.clone(),
1106 tokio::time::Duration::ZERO,
1107 ));
1108 let three = Box::new(MessageDeliveryEvent::new(
1109 addr_0.clone(),
1110 addr_1.clone(),
1111 messages[2].clone(),
1112 sender.clone(),
1113 tokio::time::Duration::ZERO,
1114 ));
1115
1116 simnet_handle().unwrap().send_event(one).unwrap();
1117 simnet_handle().unwrap().send_event(two).unwrap();
1118 simnet_handle().unwrap().send_event(three).unwrap();
1119
1120 simnet_handle()
1121 .unwrap()
1122 .flush(Duration::from_millis(1000))
1123 .await
1124 .unwrap();
1125 let records = simnet_handle().unwrap().close().await.unwrap();
1126 eprintln!("Records: {:?}", records);
1127 simnet_handle().unwrap().close().await.unwrap();
1129
1130 let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1132 assert_eq!(buf.len(), 2);
1133 assert_eq!(buf[&addr_1].len(), 2);
1134 assert_eq!(buf[&addr_3].len(), 1);
1135
1136 assert_eq!(buf[&addr_1][0], messages[0]);
1137 assert_eq!(buf[&addr_1][1], messages[2]);
1138 assert_eq!(buf[&addr_3][0], messages[1]);
1139 }
1140
1141 #[tokio::test]
1142 async fn test_sim_sleep() {
1143 start();
1144
1145 let start = SimClock.now();
1146 assert_eq!(
1147 SimClock.duration_since_start(start),
1148 tokio::time::Duration::ZERO
1149 );
1150
1151 SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1152
1153 let end = SimClock.now();
1154 assert_eq!(
1155 SimClock.duration_since_start(end),
1156 tokio::time::Duration::from_secs(10)
1157 );
1158 }
1159
1160 #[tokio::test]
1161 async fn test_torch_op() {
1162 start();
1163 let args_string = "1, 2".to_string();
1164 let kwargs_string = "a=2".to_string();
1165
1166 let (tx, rx) = oneshot::channel();
1167
1168 simnet_handle()
1169 .unwrap()
1170 .send_event(TorchOpEvent::new(
1171 "torch.ops.aten.ones.default".to_string(),
1172 tx,
1173 args_string,
1174 kwargs_string,
1175 id!(mesh_0_worker[0].worker_0),
1176 ))
1177 .unwrap();
1178
1179 rx.await.unwrap();
1180
1181 simnet_handle()
1182 .unwrap()
1183 .flush(Duration::from_millis(1000))
1184 .await
1185 .unwrap();
1186 let records = simnet_handle().unwrap().close().await;
1187 let expected_record = SimulatorEventRecord {
1188 summary:
1189 "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1190 .to_string(),
1191 start_at: 0,
1192 end_at: 100,
1193 };
1194 assert!(records.as_ref().unwrap().len() == 1);
1195 assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1196 }
1197}