1#![allow(dead_code)]
10
11use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::task::JoinError;
42use tokio::task::JoinHandle;
43use tokio::time::interval;
44
45use crate::ActorId;
47use crate::Mailbox;
48use crate::OncePortRef;
49use crate::ProcId;
50use crate::attrs::Attrs;
51use crate::channel::ChannelAddr;
52use crate::clock::Clock;
53use crate::clock::RealClock;
54use crate::clock::SimClock;
55use crate::context::MailboxExt;
56use crate::data::Serialized;
57
58static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
59
60#[allow(clippy::result_large_err)] pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
66 match HANDLE.get() {
67 Some(handle) => Ok(handle),
68 None => Err(SimNetError::Closed("SimNet not started".to_string())),
69 }
70}
71
72const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
73
74pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
77impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
78
79pub type SimulatorTimeInstant = tokio::time::Instant;
81
82#[async_trait]
91pub trait Event: Send + Sync + Debug {
92 async fn handle(&self) -> Result<(), SimNetError>;
99
100 async fn handle_network(&self, _phantom: &SimNet) -> Result<(), SimNetError> {
104 self.handle().await
105 }
106
107 fn duration(&self) -> tokio::time::Duration;
110
111 fn summary(&self) -> String;
113}
114
115#[derive(Debug)]
118struct NodeJoinEvent {
119 channel_addr: ChannelAddr,
120}
121
122#[async_trait]
123impl Event for NodeJoinEvent {
124 async fn handle(&self) -> Result<(), SimNetError> {
125 Ok(())
126 }
127
128 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
129 self.handle().await
130 }
131
132 fn duration(&self) -> tokio::time::Duration {
133 tokio::time::Duration::ZERO
134 }
135
136 fn summary(&self) -> String {
137 "Node join".into()
138 }
139}
140
141#[derive(Debug)]
142pub struct TorchOpEvent {
144 op: String,
145 done_tx: OncePortRef<()>,
146 mailbox: Mailbox,
147 args_string: String,
148 kwargs_string: String,
149 worker_actor_id: ActorId,
150}
151
152#[async_trait]
153impl Event for TorchOpEvent {
154 async fn handle(&self) -> Result<(), SimNetError> {
155 Ok(())
156 }
157
158 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
159 let mut headers = Attrs::new();
160 crate::mailbox::headers::set_send_timestamp(&mut headers);
161 let serialized =
162 Serialized::serialize(&()).map_err(|err| SimNetError::Closed(err.to_string()))?;
163 self.mailbox
164 .post(self.done_tx.port_id().clone(), headers, serialized);
165 Ok(())
166 }
167
168 fn duration(&self) -> tokio::time::Duration {
169 tokio::time::Duration::from_millis(100)
170 }
171
172 fn summary(&self) -> String {
173 let kwargs_string = if self.kwargs_string.is_empty() {
174 "".to_string()
175 } else {
176 format!(", {}", self.kwargs_string)
177 };
178 format!(
179 "[{}] Torch Op: {}({}{})",
180 self.worker_actor_id, self.op, self.args_string, kwargs_string
181 )
182 }
183}
184
185impl TorchOpEvent {
186 pub fn new(
188 op: String,
189 done_tx: OncePortRef<()>,
190 mailbox: Mailbox,
191 args_string: String,
192 kwargs_string: String,
193 worker_actor_id: ActorId,
194 ) -> Box<Self> {
195 Box::new(Self {
196 op,
197 done_tx,
198 mailbox,
199 args_string,
200 kwargs_string,
201 worker_actor_id,
202 })
203 }
204}
205
206#[derive(Debug)]
207pub(crate) struct SleepEvent {
208 done_tx: OncePortRef<()>,
209 mailbox: Mailbox,
210 duration: tokio::time::Duration,
211}
212
213impl SleepEvent {
214 pub(crate) fn new(
215 done_tx: OncePortRef<()>,
216 mailbox: Mailbox,
217 duration: tokio::time::Duration,
218 ) -> Box<Self> {
219 Box::new(Self {
220 done_tx,
221 mailbox,
222 duration,
223 })
224 }
225}
226
227#[async_trait]
228impl Event for SleepEvent {
229 async fn handle(&self) -> Result<(), SimNetError> {
230 Ok(())
231 }
232
233 async fn handle_network(&self, _simnet: &SimNet) -> Result<(), SimNetError> {
234 let mut headers = Attrs::new();
235 crate::mailbox::headers::set_send_timestamp(&mut headers);
236 let serialized =
237 Serialized::serialize(&()).map_err(|err| SimNetError::Closed(err.to_string()))?;
238 self.mailbox
239 .post(self.done_tx.port_id().clone(), headers, serialized);
240 Ok(())
241 }
242
243 fn duration(&self) -> tokio::time::Duration {
244 self.duration
245 }
246
247 fn summary(&self) -> String {
248 format!("Sleeping for {} ms", self.duration.as_millis())
249 }
250}
251
252#[derive(Debug)]
257pub(crate) struct ScheduledEvent {
258 pub(crate) time: SimulatorTimeInstant,
259 pub(crate) event: Box<dyn Event>,
260}
261
262#[async_trait]
267pub trait Dispatcher<A> {
268 async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
270}
271
272#[derive(thiserror::Error, Debug)]
275#[non_exhaustive]
276pub enum SimNetError {
277 #[error("invalid address: {0}")]
279 InvalidAddress(String),
280
281 #[error("invalid node: {0}")]
283 InvalidNode(String, #[source] anyhow::Error),
284
285 #[error("invalid arg: {0}")]
287 InvalidArg(String),
288
289 #[error("closed: {0}")]
291 Closed(String),
292
293 #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
295 Timeout(Duration, String),
296
297 #[error("missing destination address")]
299 MissingDestinationAddress,
300
301 #[error("simnet not started")]
303 NotStarted,
304}
305
306struct State {
307 scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
309 unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
313}
314
315#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
317pub enum TrainingScriptState {
318 Running,
320 Waiting,
322}
323
324pub enum LatencyDistribution {
326 Beta(BetaDistribution),
328}
329
330impl LatencyDistribution {
331 fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
332 match &self {
333 LatencyDistribution::Beta(sampler) => sampler.sample(rng),
334 }
335 }
336}
337
338pub struct BetaDistribution {
340 min_duration: tokio::time::Duration,
341 max_duration: tokio::time::Duration,
342 dist: rand_distr::Beta<f64>,
343}
344
345impl BetaDistribution {
346 pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
348 let sample = self.dist.sample(rng);
349
350 self.min_duration
351 + tokio::time::Duration::from_micros(
352 (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
353 )
354 }
355
356 pub fn new(
358 min_duration: tokio::time::Duration,
359 max_duration: tokio::time::Duration,
360 alpha: f64,
361 beta: f64,
362 ) -> anyhow::Result<Self> {
363 if min_duration > max_duration {
364 return Err(anyhow::anyhow!(
365 "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
366 min_duration,
367 max_duration
368 ));
369 }
370 Ok(Self {
371 min_duration,
372 max_duration,
373 dist: rand_distr::Beta::new(alpha, beta)?,
374 })
375 }
376}
377pub struct LatencyConfig {
379 pub inter_region_distribution: LatencyDistribution,
381 pub inter_dc_distribution: LatencyDistribution,
383 pub inter_zone_distribution: LatencyDistribution,
385 pub rng: StdRng,
387}
388
389impl LatencyConfig {
390 fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
391 match distance {
392 Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
393 Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
394 Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
395 Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
396 }
397 }
398}
399
400impl Default for LatencyConfig {
401 fn default() -> Self {
402 let seed: u64 = 0000;
403 let mut seed_bytes = [0u8; 32];
404 seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
405
406 Self {
407 inter_region_distribution: LatencyDistribution::Beta(
408 BetaDistribution::new(
409 tokio::time::Duration::from_millis(500),
410 tokio::time::Duration::from_millis(1000),
411 2.0,
412 1.0,
413 )
414 .unwrap(),
415 ),
416 inter_dc_distribution: LatencyDistribution::Beta(
417 BetaDistribution::new(
418 tokio::time::Duration::from_millis(50),
419 tokio::time::Duration::from_millis(100),
420 2.0,
421 1.0,
422 )
423 .unwrap(),
424 ),
425 inter_zone_distribution: LatencyDistribution::Beta(
426 BetaDistribution::new(
427 tokio::time::Duration::from_millis(5),
428 tokio::time::Duration::from_millis(10),
429 2.0,
430 1.0,
431 )
432 .unwrap(),
433 ),
434 rng: StdRng::from_seed(seed_bytes),
435 }
436 }
437}
438
439pub struct SimNetHandle {
441 join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
442 event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
443 pending_event_count: Arc<AtomicUsize>,
444 training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
447 stop_signal: Arc<AtomicBool>,
449 resources: DashMap<ProcId, Point>,
450 latencies: std::sync::Mutex<LatencyConfig>,
451}
452
453impl SimNetHandle {
454 #[allow(clippy::result_large_err)] pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
457 self.send_event_impl(event, true)
458 }
459
460 #[allow(clippy::result_large_err)] fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
462 self.pending_event_count
463 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
464 self.event_tx
465 .send((event, advanceable, None))
466 .map_err(|err| SimNetError::Closed(err.to_string()))
467 }
468
469 #[allow(clippy::result_large_err)] pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
476 self.send_event_impl(event, false)
477 }
478
479 #[allow(clippy::result_large_err)] pub(crate) fn send_scheduled_event(
482 &self,
483 ScheduledEvent { event, time }: ScheduledEvent,
484 ) -> Result<(), SimNetError> {
485 self.pending_event_count
486 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
487 self.event_tx
488 .send((event, true, Some(time)))
489 .map_err(|err| SimNetError::Closed(err.to_string()))
490 }
491
492 pub fn set_training_script_state(&self, state: TrainingScriptState) {
495 self.training_script_state_tx.send(state).unwrap();
496 }
497
498 #[allow(clippy::result_large_err)] pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
501 self.pending_event_count
502 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
503 self.event_tx
504 .send((
505 Box::new(NodeJoinEvent {
506 channel_addr: address,
507 }),
508 true,
509 None,
510 ))
511 .map_err(|err| SimNetError::Closed(err.to_string()))
512 }
513
514 pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
517 self.stop_signal.store(true, Ordering::SeqCst);
519
520 let mut guard = self.join_handle.lock().await;
521 if let Some(handle) = guard.take() {
522 handle.await
523 } else {
524 Ok(vec![])
525 }
526 }
527
528 pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
531 let pending_event_count = self.pending_event_count.clone();
532 let mut interval = interval(Duration::from_millis(10));
534 let deadline = RealClock.now() + timeout;
535 while RealClock.now() < deadline {
536 interval.tick().await;
537 if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
538 return Ok(());
539 }
540 }
541 Err(SimNetError::Timeout(
542 timeout,
543 "timeout waiting for received events to be scheduled".to_string(),
544 ))
545 }
546
547 pub fn register_proc(&self, proc_id: ProcId, point: Point) {
549 self.resources.insert(proc_id, point);
550 }
551
552 pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
554 let distances = [
555 Distance::Region,
556 Distance::DataCenter,
557 Distance::Zone,
558 Distance::Rack,
559 Distance::Host,
560 Distance::Same,
561 ];
562
563 let src_coords = self
564 .resources
565 .get(src)
566 .map(|point| point.coords().clone())
567 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
568
569 let dest_coords = self
570 .resources
571 .get(dest)
572 .map(|point| point.coords().clone())
573 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
574
575 for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
576 if src != dest {
577 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
578 return guard.from_distance(&distance);
579 }
580 }
581
582 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
583 guard.from_distance(&Distance::Same)
584 }
585}
586
587#[derive(Debug)]
588enum Distance {
589 Region,
590 DataCenter,
591 Zone,
592 Rack,
593 Host,
594 Same,
595}
596
597pub struct SimNet {
603 address_book: DashSet<ChannelAddr>,
604 state: State,
605 max_latency: Duration,
606 records: Vec<SimulatorEventRecord>,
607 pending_event_count: Arc<AtomicUsize>,
609}
610
611pub fn start() {
613 start_with_config(LatencyConfig::default())
614}
615
616pub fn start_with_config(config: LatencyConfig) {
618 let max_duration_ms = 1000 * 10;
619 let address_book: DashSet<ChannelAddr> = DashSet::new();
621
622 let (training_script_state_tx, training_script_state_rx) =
623 tokio::sync::watch::channel(TrainingScriptState::Waiting);
624 let (event_tx, event_rx) =
625 mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
626 let pending_event_count = Arc::new(AtomicUsize::new(0));
627 let stop_signal = Arc::new(AtomicBool::new(false));
628
629 let join_handle = Mutex::new(Some({
630 let pending_event_count = pending_event_count.clone();
631 let stop_signal = stop_signal.clone();
632
633 tokio::spawn(async move {
634 SimNet {
635 address_book,
636 state: State {
637 scheduled_events: BTreeMap::new(),
638 unadvanceable_scheduled_events: BTreeMap::new(),
639 },
640 max_latency: Duration::from_millis(max_duration_ms),
641 records: Vec::new(),
642 pending_event_count,
643 }
644 .run(event_rx, training_script_state_rx, stop_signal)
645 .await
646 })
647 }));
648
649 HANDLE.get_or_init(|| SimNetHandle {
650 join_handle,
651 event_tx,
652 pending_event_count,
653 training_script_state_tx,
654 stop_signal,
655 resources: DashMap::new(),
656 latencies: std::sync::Mutex::new(config),
657 });
658}
659
660impl SimNet {
661 fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
662 ScheduledEvent {
664 time: SimClock.now() + event.duration(),
665 event,
666 }
667 }
668
669 fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
671 let start_at = SimClock.now();
672 let end_at = scheduled_event.time;
673
674 self.records.push(SimulatorEventRecord {
675 summary: scheduled_event.event.summary(),
676 start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
677 end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
678 });
679
680 if advanceable {
681 self.state
682 .scheduled_events
683 .entry(scheduled_event.time)
684 .or_default()
685 .push(scheduled_event);
686 } else {
687 self.state
688 .unadvanceable_scheduled_events
689 .entry(scheduled_event.time)
690 .or_default()
691 .push(scheduled_event);
692 }
693 }
694
695 async fn run(
698 &mut self,
699 mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
700 training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
701 stop_signal: Arc<AtomicBool>,
702 ) -> Vec<SimulatorEventRecord> {
703 let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
706 let mut debounce_timer: Option<tokio::time::Instant> = None;
708
709 let debounce_duration = std::env::var("SIM_DEBOUNCE")
710 .ok()
711 .and_then(|val| val.parse::<u64>().ok())
712 .unwrap_or(1);
713
714 'outer: loop {
715 if stop_signal.load(Ordering::SeqCst) {
717 break 'outer self.records.clone();
718 }
719
720 while let Ok(Some((event, advanceable, time))) = RealClock
721 .timeout(
722 tokio::time::Duration::from_millis(debounce_duration),
723 event_rx.recv(),
724 )
725 .await
726 {
727 let scheduled_event = match time {
728 Some(time) => ScheduledEvent {
729 time: time + training_script_waiting_time,
730 event,
731 },
732 None => self.create_scheduled_event(event),
733 };
734 self.schedule_event(scheduled_event, advanceable);
735 }
736
737 {
738 if training_script_state_rx.borrow().is_running()
743 && self
744 .state
745 .scheduled_events
746 .first_key_value()
747 .is_some_and(|(time, _)| {
748 *time > RealClock.now() + training_script_waiting_time
749 })
750 {
751 tokio::task::yield_now().await;
752 continue;
753 }
754 match (
755 self.state.scheduled_events.first_key_value(),
756 self.state.unadvanceable_scheduled_events.first_key_value(),
757 ) {
758 (None, Some(_)) if debounce_timer.is_none() => {
759 debounce_timer = Some(RealClock.now());
762 }
763 (None, Some(_)) => {}
765 _ => {
767 debounce_timer = None;
768 }
769 }
770 let Some((scheduled_time, scheduled_events)) = (match (
772 self.state.scheduled_events.first_key_value(),
773 self.state.unadvanceable_scheduled_events.first_key_value(),
774 ) {
775 (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
776 if unadvanceable_time < advanceable_time {
777 self.state.unadvanceable_scheduled_events.pop_first()
778 } else {
779 self.state.scheduled_events.pop_first()
780 }
781 }
782 (Some(_), None) => self.state.scheduled_events.pop_first(),
783 (None, Some(_)) => match debounce_timer {
784 Some(time) => {
785 if time.elapsed() > tokio::time::Duration::from_millis(1000) {
786 debounce_timer = None;
788 self.state.unadvanceable_scheduled_events.pop_first()
789 } else {
790 None
791 }
792 }
793 None => None,
794 },
795 (None, None) => None,
796 }) else {
797 tokio::select! {
798 Some((event, advanceable, time)) = event_rx.recv() => {
799 let scheduled_event = match time {
800 Some(time) => ScheduledEvent {
801 time: time + training_script_waiting_time,
802 event,
803 },
804 None => self.create_scheduled_event(event),
805 };
806 self.schedule_event(scheduled_event, advanceable);
807 },
808 _ = RealClock.sleep(Duration::from_millis(10)) => {}
809 }
810 continue;
811 };
812 if training_script_state_rx.borrow().is_waiting() {
813 let advanced_time = scheduled_time - SimClock.now();
814 training_script_waiting_time += advanced_time;
815 }
816 SimClock.advance_to(scheduled_time);
817 for scheduled_event in scheduled_events {
818 self.pending_event_count
819 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
820 if scheduled_event.event.handle_network(self).await.is_err() {
821 break 'outer self.records.clone(); }
823 }
824 }
825 }
826 }
827}
828
829fn serialize_optional_channel_addr<S>(
830 addr: &Option<ChannelAddr>,
831 serializer: S,
832) -> Result<S::Ok, S::Error>
833where
834 S: Serializer,
835{
836 match addr {
837 Some(addr) => serializer.serialize_str(&addr.to_string()),
838 None => serializer.serialize_none(),
839 }
840}
841
842fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
843where
844 D: Deserializer<'de>,
845{
846 let s = String::deserialize(deserializer)?;
847 s.parse().map_err(serde::de::Error::custom)
848}
849
850#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
852pub struct SimulatorEventRecord {
853 pub summary: String,
855 pub start_at: u64,
857 pub end_at: u64,
859}
860
861#[cfg(test)]
862mod tests {
863 use std::collections::HashMap;
864 use std::sync::Arc;
865
866 use async_trait::async_trait;
867 use ndslice::extent;
868 use tokio::sync::Mutex;
869
870 use super::*;
871 use crate::channel::sim::SimAddr;
872 use crate::clock::Clock;
873 use crate::clock::RealClock;
874 use crate::clock::SimClock;
875 use crate::data::Serialized;
876 use crate::id;
877 use crate::simnet;
878 use crate::simnet::Dispatcher;
879 use crate::simnet::Event;
880 use crate::simnet::SimNetError;
881
882 #[derive(Debug)]
883 struct MessageDeliveryEvent {
884 src_addr: SimAddr,
885 dest_addr: SimAddr,
886 data: Serialized,
887 duration: tokio::time::Duration,
888 dispatcher: Option<TestDispatcher>,
889 }
890
891 #[async_trait]
892 impl Event for MessageDeliveryEvent {
893 async fn handle(&self) -> Result<(), simnet::SimNetError> {
894 if let Some(dispatcher) = &self.dispatcher {
895 dispatcher
896 .send(self.dest_addr.clone(), self.data.clone())
897 .await?;
898 }
899 Ok(())
900 }
901 fn duration(&self) -> tokio::time::Duration {
902 self.duration
903 }
904
905 fn summary(&self) -> String {
906 format!(
907 "Sending message from {} to {}",
908 self.src_addr.addr().clone(),
909 self.dest_addr.addr().clone()
910 )
911 }
912 }
913
914 impl MessageDeliveryEvent {
915 fn new(
916 src_addr: SimAddr,
917 dest_addr: SimAddr,
918 data: Serialized,
919 dispatcher: Option<TestDispatcher>,
920 duration: tokio::time::Duration,
921 ) -> Self {
922 Self {
923 src_addr,
924 dest_addr,
925 data,
926 duration,
927 dispatcher,
928 }
929 }
930 }
931
932 #[derive(Debug, Clone)]
933 struct TestDispatcher {
934 pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
935 }
936
937 impl Default for TestDispatcher {
938 fn default() -> Self {
939 Self {
940 mbuffers: Arc::new(Mutex::new(HashMap::new())),
941 }
942 }
943 }
944
945 #[async_trait]
946 impl Dispatcher<SimAddr> for TestDispatcher {
947 async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
948 let mut buf = self.mbuffers.lock().await;
949 buf.entry(target).or_default().push(data);
950 Ok(())
951 }
952 }
953
954 #[cfg(target_os = "linux")]
955 fn random_abstract_addr() -> ChannelAddr {
956 use rand::Rng;
957 use rand::distributions::Alphanumeric;
958
959 let random_string = rand::thread_rng()
960 .sample_iter(&Alphanumeric)
961 .take(24)
962 .map(char::from)
963 .collect::<String>();
964 format!("unix!@{random_string}").parse().unwrap()
965 }
966
967 #[tokio::test]
968 async fn test_handle_instantiation() {
969 start();
970 simnet_handle().unwrap().close().await.unwrap();
971 }
972
973 #[tokio::test]
974 async fn test_simnet_config() {
975 let ext = extent!(region = 1, dc = 2, zone = 2, rack = 4, host = 4, gpu = 8);
977
978 let alice = id!(world[0]);
979 let bob = id!(world[1]);
980 let charlie = id!(world[2]);
981
982 let config = LatencyConfig {
983 inter_zone_distribution: LatencyDistribution::Beta(
984 BetaDistribution::new(
985 tokio::time::Duration::from_millis(1000),
986 tokio::time::Duration::from_millis(1000),
987 1.0,
988 1.0,
989 )
990 .unwrap(),
991 ),
992 inter_dc_distribution: LatencyDistribution::Beta(
993 BetaDistribution::new(
994 tokio::time::Duration::from_millis(2000),
995 tokio::time::Duration::from_millis(2000),
996 1.0,
997 1.0,
998 )
999 .unwrap(),
1000 ),
1001 ..Default::default()
1002 };
1003 start_with_config(config);
1004
1005 let handle = simnet_handle().unwrap();
1006 handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
1007 handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
1008 handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
1009 assert_eq!(
1010 handle.sample_latency(&alice, &bob),
1011 tokio::time::Duration::from_millis(1000)
1012 );
1013 assert_eq!(
1014 handle.sample_latency(&alice, &charlie),
1015 tokio::time::Duration::from_millis(2000)
1016 );
1017 }
1018
1019 #[tokio::test]
1020 async fn test_simnet_debounce() {
1021 let config = LatencyConfig {
1022 inter_zone_distribution: LatencyDistribution::Beta(
1023 BetaDistribution::new(
1024 tokio::time::Duration::from_millis(1000),
1025 tokio::time::Duration::from_millis(1000),
1026 1.0,
1027 1.0,
1028 )
1029 .unwrap(),
1030 ),
1031 ..Default::default()
1032 };
1033 start_with_config(config);
1034 let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
1035 let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
1036
1037 let latency = Duration::from_millis(10000);
1038
1039 let alice = SimAddr::new(alice).unwrap();
1040 let bob = SimAddr::new(bob).unwrap();
1041
1042 for _ in 0..10 {
1044 simnet_handle()
1045 .unwrap()
1046 .send_event(Box::new(MessageDeliveryEvent::new(
1047 alice.clone(),
1048 bob.clone(),
1049 Serialized::serialize(&"123".to_string()).unwrap(),
1050 None,
1051 latency,
1052 )))
1053 .unwrap();
1054 RealClock
1055 .sleep(tokio::time::Duration::from_micros(500))
1056 .await;
1057 }
1058
1059 simnet_handle()
1060 .unwrap()
1061 .flush(Duration::from_secs(20))
1062 .await
1063 .unwrap();
1064
1065 let records = simnet_handle().unwrap().close().await;
1066 assert_eq!(records.as_ref().unwrap().len(), 10);
1067
1068 assert_eq!(
1071 records.unwrap().last().unwrap().end_at,
1072 latency.as_millis() as u64
1073 );
1074 }
1075
1076 #[tokio::test]
1077 async fn test_sim_dispatch() {
1078 start();
1079 let sender = Some(TestDispatcher::default());
1080 let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1081 for i in 0..4 {
1083 addresses.push(
1084 format!("local:{}", i)
1085 .parse::<simnet::ChannelAddr>()
1086 .unwrap(),
1087 );
1088 }
1089
1090 let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1091 .into_iter()
1092 .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1093 .collect();
1094
1095 let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1096 let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1097 let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1098 let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1099 let one = Box::new(MessageDeliveryEvent::new(
1100 addr_0.clone(),
1101 addr_1.clone(),
1102 messages[0].clone(),
1103 sender.clone(),
1104 tokio::time::Duration::ZERO,
1105 ));
1106 let two = Box::new(MessageDeliveryEvent::new(
1107 addr_2.clone(),
1108 addr_3.clone(),
1109 messages[1].clone(),
1110 sender.clone(),
1111 tokio::time::Duration::ZERO,
1112 ));
1113 let three = Box::new(MessageDeliveryEvent::new(
1114 addr_0.clone(),
1115 addr_1.clone(),
1116 messages[2].clone(),
1117 sender.clone(),
1118 tokio::time::Duration::ZERO,
1119 ));
1120
1121 simnet_handle().unwrap().send_event(one).unwrap();
1122 simnet_handle().unwrap().send_event(two).unwrap();
1123 simnet_handle().unwrap().send_event(three).unwrap();
1124
1125 simnet_handle()
1126 .unwrap()
1127 .flush(Duration::from_millis(1000))
1128 .await
1129 .unwrap();
1130 let records = simnet_handle().unwrap().close().await.unwrap();
1131 eprintln!("Records: {:?}", records);
1132 simnet_handle().unwrap().close().await.unwrap();
1134
1135 let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1137 assert_eq!(buf.len(), 2);
1138 assert_eq!(buf[&addr_1].len(), 2);
1139 assert_eq!(buf[&addr_3].len(), 1);
1140
1141 assert_eq!(buf[&addr_1][0], messages[0]);
1142 assert_eq!(buf[&addr_1][1], messages[2]);
1143 assert_eq!(buf[&addr_3][0], messages[1]);
1144 }
1145
1146 #[tokio::test]
1147 async fn test_sim_sleep() {
1148 start();
1149
1150 let start = SimClock.now();
1151 assert_eq!(
1152 SimClock.duration_since_start(start),
1153 tokio::time::Duration::ZERO
1154 );
1155
1156 SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1157
1158 let end = SimClock.now();
1159 assert_eq!(
1160 SimClock.duration_since_start(end),
1161 tokio::time::Duration::from_secs(10)
1162 );
1163 }
1164
1165 #[tokio::test]
1166 async fn test_torch_op() {
1167 start();
1168 let args_string = "1, 2".to_string();
1169 let kwargs_string = "a=2".to_string();
1170
1171 let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
1172 let (tx, rx) = mailbox.open_once_port::<()>();
1173
1174 simnet_handle()
1175 .unwrap()
1176 .send_event(TorchOpEvent::new(
1177 "torch.ops.aten.ones.default".to_string(),
1178 tx.bind(),
1179 mailbox,
1180 args_string,
1181 kwargs_string,
1182 id!(mesh_0_worker[0].worker_0),
1183 ))
1184 .unwrap();
1185
1186 rx.recv().await.unwrap();
1187
1188 simnet_handle()
1189 .unwrap()
1190 .flush(Duration::from_millis(1000))
1191 .await
1192 .unwrap();
1193 let records = simnet_handle().unwrap().close().await;
1194 let expected_record = SimulatorEventRecord {
1195 summary:
1196 "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1197 .to_string(),
1198 start_at: 0,
1199 end_at: 100,
1200 };
1201 assert!(records.as_ref().unwrap().len() == 1);
1202 assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1203 }
1204}