1#![allow(dead_code)]
10
11use std::collections::BTreeMap;
16use std::fmt::Debug;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::OnceLock;
20use std::sync::atomic::AtomicBool;
21use std::sync::atomic::AtomicUsize;
22use std::sync::atomic::Ordering;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dashmap::DashMap;
27use dashmap::DashSet;
28use enum_as_inner::EnumAsInner;
29use ndslice::view::Point;
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand_distr::Distribution;
33use serde::Deserialize;
34use serde::Deserializer;
35use serde::Serialize;
36use serde::Serializer;
37use tokio::sync::Mutex;
38use tokio::sync::mpsc;
39use tokio::sync::mpsc::UnboundedReceiver;
40use tokio::sync::mpsc::UnboundedSender;
41use tokio::task::JoinError;
42use tokio::task::JoinHandle;
43use tokio::time::interval;
44
45use crate::ActorId;
47use crate::Mailbox;
48use crate::OncePortRef;
49use crate::ProcId;
50use crate::channel::ChannelAddr;
51use crate::clock::Clock;
52use crate::clock::RealClock;
53use crate::clock::SimClock;
54use crate::data::Serialized;
55
56static HANDLE: OnceLock<SimNetHandle> = OnceLock::new();
57
58#[allow(clippy::result_large_err)] pub fn simnet_handle() -> Result<&'static SimNetHandle, SimNetError> {
64 match HANDLE.get() {
65 Some(handle) => Ok(handle),
66 None => Err(SimNetError::Closed("SimNet not started".to_string())),
67 }
68}
69
70const OPERATIONAL_MESSAGE_BUFFER_SIZE: usize = 8;
71
72pub trait Address: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone {}
75impl<A: Hash + Debug + Eq + PartialEq + Ord + PartialOrd + Clone> Address for A {}
76
77pub type SimulatorTimeInstant = tokio::time::Instant;
79
80#[async_trait]
89pub trait Event: Send + Sync + Debug {
90 async fn handle(&mut self) -> Result<(), SimNetError>;
97
98 async fn handle_network(&mut self, _phantom: &SimNet) -> Result<(), SimNetError> {
102 self.handle().await
103 }
104
105 fn duration(&self) -> tokio::time::Duration;
108
109 fn summary(&self) -> String;
111}
112
113#[derive(Debug)]
116struct NodeJoinEvent {
117 channel_addr: ChannelAddr,
118}
119
120#[async_trait]
121impl Event for NodeJoinEvent {
122 async fn handle(&mut self) -> Result<(), SimNetError> {
123 Ok(())
124 }
125
126 async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
127 self.handle().await
128 }
129
130 fn duration(&self) -> tokio::time::Duration {
131 tokio::time::Duration::ZERO
132 }
133
134 fn summary(&self) -> String {
135 "Node join".into()
136 }
137}
138
139#[derive(Debug)]
140pub struct TorchOpEvent {
142 op: String,
143 done_tx: OncePortRef<()>,
144 mailbox: Mailbox,
145 args_string: String,
146 kwargs_string: String,
147 worker_actor_id: ActorId,
148}
149
150#[async_trait]
151impl Event for TorchOpEvent {
152 async fn handle(&mut self) -> Result<(), SimNetError> {
153 Ok(())
154 }
155
156 async fn handle_network(&mut self, _simnet: &SimNet) -> Result<(), SimNetError> {
157 self.done_tx
158 .clone()
159 .send(&self.mailbox, ())
160 .map_err(|err| SimNetError::Closed(err.to_string()))?;
161 Ok(())
162 }
163
164 fn duration(&self) -> tokio::time::Duration {
165 tokio::time::Duration::from_millis(100)
166 }
167
168 fn summary(&self) -> String {
169 let kwargs_string = if self.kwargs_string.is_empty() {
170 "".to_string()
171 } else {
172 format!(", {}", self.kwargs_string)
173 };
174 format!(
175 "[{}] Torch Op: {}({}{})",
176 self.worker_actor_id, self.op, self.args_string, kwargs_string
177 )
178 }
179}
180
181impl TorchOpEvent {
182 pub fn new(
184 op: String,
185 done_tx: OncePortRef<()>,
186 mailbox: Mailbox,
187 args_string: String,
188 kwargs_string: String,
189 worker_actor_id: ActorId,
190 ) -> Box<Self> {
191 Box::new(Self {
192 op,
193 done_tx,
194 mailbox,
195 args_string,
196 kwargs_string,
197 worker_actor_id,
198 })
199 }
200}
201
202#[derive(Debug)]
207pub(crate) struct ScheduledEvent {
208 pub(crate) time: SimulatorTimeInstant,
209 pub(crate) event: Box<dyn Event>,
210}
211
212#[async_trait]
217pub trait Dispatcher<A> {
218 async fn send(&self, target: A, data: Serialized) -> Result<(), SimNetError>;
220}
221
222#[derive(thiserror::Error, Debug)]
225#[non_exhaustive]
226pub enum SimNetError {
227 #[error("invalid address: {0}")]
229 InvalidAddress(String),
230
231 #[error("invalid node: {0}")]
233 InvalidNode(String, #[source] anyhow::Error),
234
235 #[error("invalid arg: {0}")]
237 InvalidArg(String),
238
239 #[error("closed: {0}")]
241 Closed(String),
242
243 #[error("timeout after {} ms: {}", .0.as_millis(), .1)]
245 Timeout(Duration, String),
246
247 #[error("missing destination address")]
249 MissingDestinationAddress,
250
251 #[error("simnet not started")]
253 NotStarted,
254
255 #[error("panicked task")]
257 PanickedTask,
258}
259
260struct State {
261 scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
263 unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
267}
268
269#[derive(EnumAsInner, Debug, Serialize, Deserialize, PartialEq, Clone)]
271pub enum TrainingScriptState {
272 Running,
274 Waiting,
276}
277
278pub enum LatencyDistribution {
280 Beta(BetaDistribution),
282}
283
284impl LatencyDistribution {
285 fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
286 match &self {
287 LatencyDistribution::Beta(sampler) => sampler.sample(rng),
288 }
289 }
290}
291
292pub struct BetaDistribution {
294 min_duration: tokio::time::Duration,
295 max_duration: tokio::time::Duration,
296 dist: rand_distr::Beta<f64>,
297}
298
299impl BetaDistribution {
300 pub fn sample(&self, rng: &mut StdRng) -> tokio::time::Duration {
302 let sample = self.dist.sample(rng);
303
304 self.min_duration
305 + tokio::time::Duration::from_micros(
306 (sample * (self.max_duration - self.min_duration).as_micros() as f64) as u64,
307 )
308 }
309
310 pub fn new(
312 min_duration: tokio::time::Duration,
313 max_duration: tokio::time::Duration,
314 alpha: f64,
315 beta: f64,
316 ) -> anyhow::Result<Self> {
317 if min_duration > max_duration {
318 return Err(anyhow::anyhow!(
319 "min_duration must not be greater than max_duration, got min_duration: {:?}, max_duration: {:?}",
320 min_duration,
321 max_duration
322 ));
323 }
324 Ok(Self {
325 min_duration,
326 max_duration,
327 dist: rand_distr::Beta::new(alpha, beta)?,
328 })
329 }
330}
331pub struct LatencyConfig {
333 pub inter_region_distribution: LatencyDistribution,
335 pub inter_dc_distribution: LatencyDistribution,
337 pub inter_zone_distribution: LatencyDistribution,
339 pub rng: StdRng,
341}
342
343impl LatencyConfig {
344 fn from_distance(&mut self, distance: &Distance) -> tokio::time::Duration {
345 match distance {
346 Distance::Region => self.inter_region_distribution.sample(&mut self.rng),
347 Distance::DataCenter => self.inter_dc_distribution.sample(&mut self.rng),
348 Distance::Zone => self.inter_zone_distribution.sample(&mut self.rng),
349 Distance::Rack | Distance::Host | Distance::Same => tokio::time::Duration::ZERO,
350 }
351 }
352}
353
354impl Default for LatencyConfig {
355 fn default() -> Self {
356 let seed: u64 = 0000;
357 let mut seed_bytes = [0u8; 32];
358 seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
359
360 Self {
361 inter_region_distribution: LatencyDistribution::Beta(
362 BetaDistribution::new(
363 tokio::time::Duration::from_millis(500),
364 tokio::time::Duration::from_millis(1000),
365 2.0,
366 1.0,
367 )
368 .unwrap(),
369 ),
370 inter_dc_distribution: LatencyDistribution::Beta(
371 BetaDistribution::new(
372 tokio::time::Duration::from_millis(50),
373 tokio::time::Duration::from_millis(100),
374 2.0,
375 1.0,
376 )
377 .unwrap(),
378 ),
379 inter_zone_distribution: LatencyDistribution::Beta(
380 BetaDistribution::new(
381 tokio::time::Duration::from_millis(5),
382 tokio::time::Duration::from_millis(10),
383 2.0,
384 1.0,
385 )
386 .unwrap(),
387 ),
388 rng: StdRng::from_seed(seed_bytes),
389 }
390 }
391}
392
393pub struct SimNetHandle {
395 join_handle: Mutex<Option<JoinHandle<Vec<SimulatorEventRecord>>>>,
396 event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
397 pending_event_count: Arc<AtomicUsize>,
398 training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
401 stop_signal: Arc<AtomicBool>,
403 resources: DashMap<ProcId, Point>,
404 latencies: std::sync::Mutex<LatencyConfig>,
405}
406
407impl SimNetHandle {
408 #[allow(clippy::result_large_err)] pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
411 self.send_event_impl(event, true)
412 }
413
414 #[allow(clippy::result_large_err)] fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
416 self.pending_event_count
417 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
418 self.event_tx
419 .send((event, advanceable, None))
420 .map_err(|err| SimNetError::Closed(err.to_string()))
421 }
422
423 #[allow(clippy::result_large_err)] pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
430 self.send_event_impl(event, false)
431 }
432
433 #[allow(clippy::result_large_err)] pub(crate) fn send_scheduled_event(
436 &self,
437 ScheduledEvent { event, time }: ScheduledEvent,
438 ) -> Result<(), SimNetError> {
439 self.pending_event_count
440 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
441 self.event_tx
442 .send((event, true, Some(time)))
443 .map_err(|err| SimNetError::Closed(err.to_string()))
444 }
445
446 pub fn set_training_script_state(&self, state: TrainingScriptState) {
449 self.training_script_state_tx.send(state).unwrap();
450 }
451
452 #[allow(clippy::result_large_err)] pub fn bind(&self, address: ChannelAddr) -> Result<(), SimNetError> {
455 self.pending_event_count
456 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
457 self.event_tx
458 .send((
459 Box::new(NodeJoinEvent {
460 channel_addr: address,
461 }),
462 true,
463 None,
464 ))
465 .map_err(|err| SimNetError::Closed(err.to_string()))
466 }
467
468 pub async fn close(&self) -> Result<Vec<SimulatorEventRecord>, JoinError> {
471 self.stop_signal.store(true, Ordering::SeqCst);
473
474 let mut guard = self.join_handle.lock().await;
475 if let Some(handle) = guard.take() {
476 handle.await
477 } else {
478 Ok(vec![])
479 }
480 }
481
482 pub async fn flush(&self, timeout: Duration) -> Result<(), SimNetError> {
485 let pending_event_count = self.pending_event_count.clone();
486 let mut interval = interval(Duration::from_millis(10));
488 let deadline = RealClock.now() + timeout;
489 while RealClock.now() < deadline {
490 interval.tick().await;
491 if pending_event_count.load(std::sync::atomic::Ordering::SeqCst) == 0 {
492 return Ok(());
493 }
494 }
495 Err(SimNetError::Timeout(
496 timeout,
497 "timeout waiting for received events to be scheduled".to_string(),
498 ))
499 }
500
501 pub fn register_proc(&self, proc_id: ProcId, point: Point) {
503 self.resources.insert(proc_id, point);
504 }
505
506 pub fn sample_latency(&self, src: &ProcId, dest: &ProcId) -> tokio::time::Duration {
508 let distances = [
509 Distance::Region,
510 Distance::DataCenter,
511 Distance::Zone,
512 Distance::Rack,
513 Distance::Host,
514 Distance::Same,
515 ];
516
517 let src_coords = self
518 .resources
519 .get(src)
520 .map(|point| point.coords().clone())
521 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
522
523 let dest_coords = self
524 .resources
525 .get(dest)
526 .map(|point| point.coords().clone())
527 .unwrap_or(distances.iter().map(|_| 0).collect::<Vec<usize>>());
528
529 for ((src, dest), distance) in src_coords.into_iter().zip(dest_coords).zip(distances) {
530 if src != dest {
531 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
532 return guard.from_distance(&distance);
533 }
534 }
535
536 let mut guard = self.latencies.lock().unwrap_or_else(|e| e.into_inner());
537 guard.from_distance(&Distance::Same)
538 }
539}
540
541#[derive(Debug)]
542enum Distance {
543 Region,
544 DataCenter,
545 Zone,
546 Rack,
547 Host,
548 Same,
549}
550
551pub struct SimNet {
557 address_book: DashSet<ChannelAddr>,
558 state: State,
559 max_latency: Duration,
560 records: Vec<SimulatorEventRecord>,
561 pending_event_count: Arc<AtomicUsize>,
563}
564
565pub fn start() {
567 start_with_config(LatencyConfig::default())
568}
569
570pub fn start_with_config(config: LatencyConfig) {
572 let max_duration_ms = 1000 * 10;
573 let address_book: DashSet<ChannelAddr> = DashSet::new();
575
576 let (training_script_state_tx, training_script_state_rx) =
577 tokio::sync::watch::channel(TrainingScriptState::Waiting);
578 let (event_tx, event_rx) =
579 mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
580 let pending_event_count = Arc::new(AtomicUsize::new(0));
581 let stop_signal = Arc::new(AtomicBool::new(false));
582
583 let join_handle = Mutex::new(Some({
584 let pending_event_count = pending_event_count.clone();
585 let stop_signal = stop_signal.clone();
586
587 tokio::spawn(async move {
588 SimNet {
589 address_book,
590 state: State {
591 scheduled_events: BTreeMap::new(),
592 unadvanceable_scheduled_events: BTreeMap::new(),
593 },
594 max_latency: Duration::from_millis(max_duration_ms),
595 records: Vec::new(),
596 pending_event_count,
597 }
598 .run(event_rx, training_script_state_rx, stop_signal)
599 .await
600 })
601 }));
602
603 HANDLE.get_or_init(|| SimNetHandle {
604 join_handle,
605 event_tx,
606 pending_event_count,
607 training_script_state_tx,
608 stop_signal,
609 resources: DashMap::new(),
610 latencies: std::sync::Mutex::new(config),
611 });
612}
613
614impl SimNet {
615 fn create_scheduled_event(&mut self, event: Box<dyn Event>) -> ScheduledEvent {
616 ScheduledEvent {
618 time: SimClock.now() + event.duration(),
619 event,
620 }
621 }
622
623 fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
625 let start_at = SimClock.now();
626 let end_at = scheduled_event.time;
627
628 self.records.push(SimulatorEventRecord {
629 summary: scheduled_event.event.summary(),
630 start_at: SimClock.duration_since_start(start_at).as_millis() as u64,
631 end_at: SimClock.duration_since_start(end_at).as_millis() as u64,
632 });
633
634 if advanceable {
635 self.state
636 .scheduled_events
637 .entry(scheduled_event.time)
638 .or_insert_with(Vec::new)
639 .push(scheduled_event);
640 } else {
641 self.state
642 .unadvanceable_scheduled_events
643 .entry(scheduled_event.time)
644 .or_insert_with(Vec::new)
645 .push(scheduled_event);
646 }
647 }
648
649 async fn run(
652 &mut self,
653 mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
654 training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
655 stop_signal: Arc<AtomicBool>,
656 ) -> Vec<SimulatorEventRecord> {
657 let mut training_script_waiting_time = tokio::time::Duration::from_millis(0);
660 let mut debounce_timer: Option<tokio::time::Instant> = None;
662
663 let debounce_duration = std::env::var("SIM_DEBOUNCE")
664 .ok()
665 .and_then(|val| val.parse::<u64>().ok())
666 .unwrap_or(1);
667
668 'outer: loop {
669 if stop_signal.load(Ordering::SeqCst) {
671 break 'outer self.records.clone();
672 }
673
674 while let Ok(Some((event, advanceable, time))) = RealClock
675 .timeout(
676 tokio::time::Duration::from_millis(debounce_duration),
677 event_rx.recv(),
678 )
679 .await
680 {
681 let scheduled_event = match time {
682 Some(time) => ScheduledEvent {
683 time: time + training_script_waiting_time,
684 event,
685 },
686 None => self.create_scheduled_event(event),
687 };
688 self.schedule_event(scheduled_event, advanceable);
689 }
690
691 {
692 if training_script_state_rx.borrow().is_running()
697 && self
698 .state
699 .scheduled_events
700 .first_key_value()
701 .is_some_and(|(time, _)| {
702 *time > RealClock.now() + training_script_waiting_time
703 })
704 {
705 tokio::task::yield_now().await;
706 continue;
707 }
708 match (
709 self.state.scheduled_events.first_key_value(),
710 self.state.unadvanceable_scheduled_events.first_key_value(),
711 ) {
712 (None, Some(_)) if debounce_timer.is_none() => {
713 debounce_timer = Some(RealClock.now());
716 }
717 (None, Some(_)) => {}
719 _ => {
721 debounce_timer = None;
722 }
723 }
724 let Some((scheduled_time, scheduled_events)) = (match (
726 self.state.scheduled_events.first_key_value(),
727 self.state.unadvanceable_scheduled_events.first_key_value(),
728 ) {
729 (Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
730 if unadvanceable_time < advanceable_time {
731 self.state.unadvanceable_scheduled_events.pop_first()
732 } else {
733 self.state.scheduled_events.pop_first()
734 }
735 }
736 (Some(_), None) => self.state.scheduled_events.pop_first(),
737 (None, Some(_)) => match debounce_timer {
738 Some(time) => {
739 if time.elapsed() > tokio::time::Duration::from_millis(1000) {
740 debounce_timer = None;
742 self.state.unadvanceable_scheduled_events.pop_first()
743 } else {
744 None
745 }
746 }
747 None => None,
748 },
749 (None, None) => None,
750 }) else {
751 tokio::select! {
752 Some((event, advanceable, time)) = event_rx.recv() => {
753 let scheduled_event = match time {
754 Some(time) => ScheduledEvent {
755 time: time + training_script_waiting_time,
756 event,
757 },
758 None => self.create_scheduled_event(event),
759 };
760 self.schedule_event(scheduled_event, advanceable);
761 },
762 _ = RealClock.sleep(Duration::from_millis(10)) => {}
763 }
764 continue;
765 };
766 if training_script_state_rx.borrow().is_waiting() {
767 let advanced_time = scheduled_time - SimClock.now();
768 training_script_waiting_time += advanced_time;
769 }
770 SimClock.advance_to(scheduled_time);
771 for mut scheduled_event in scheduled_events {
772 self.pending_event_count
773 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
774 if scheduled_event.event.handle_network(self).await.is_err() {
775 break 'outer self.records.clone(); }
777 }
778 }
779 }
780 }
781}
782
783fn serialize_optional_channel_addr<S>(
784 addr: &Option<ChannelAddr>,
785 serializer: S,
786) -> Result<S::Ok, S::Error>
787where
788 S: Serializer,
789{
790 match addr {
791 Some(addr) => serializer.serialize_str(&addr.to_string()),
792 None => serializer.serialize_none(),
793 }
794}
795
796fn deserialize_channel_addr<'de, D>(deserializer: D) -> Result<ChannelAddr, D::Error>
797where
798 D: Deserializer<'de>,
799{
800 let s = String::deserialize(deserializer)?;
801 s.parse().map_err(serde::de::Error::custom)
802}
803
804#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
806pub struct SimulatorEventRecord {
807 pub summary: String,
809 pub start_at: u64,
811 pub end_at: u64,
813}
814
815#[cfg(test)]
816mod tests {
817 use std::collections::HashMap;
818 use std::sync::Arc;
819
820 use async_trait::async_trait;
821 use ndslice::extent;
822 use tokio::sync::Mutex;
823
824 use super::*;
825 use crate::channel::sim::SimAddr;
826 use crate::clock::Clock;
827 use crate::clock::RealClock;
828 use crate::clock::SimClock;
829 use crate::data::Serialized;
830 use crate::id;
831 use crate::simnet;
832 use crate::simnet::Dispatcher;
833 use crate::simnet::Event;
834 use crate::simnet::SimNetError;
835
836 #[derive(Debug)]
837 struct MessageDeliveryEvent {
838 src_addr: SimAddr,
839 dest_addr: SimAddr,
840 data: Serialized,
841 duration: tokio::time::Duration,
842 dispatcher: Option<TestDispatcher>,
843 }
844
845 #[async_trait]
846 impl Event for MessageDeliveryEvent {
847 async fn handle(&mut self) -> Result<(), simnet::SimNetError> {
848 if let Some(dispatcher) = &self.dispatcher {
849 dispatcher
850 .send(self.dest_addr.clone(), self.data.clone())
851 .await?;
852 }
853 Ok(())
854 }
855 fn duration(&self) -> tokio::time::Duration {
856 self.duration
857 }
858
859 fn summary(&self) -> String {
860 format!(
861 "Sending message from {} to {}",
862 self.src_addr.addr().clone(),
863 self.dest_addr.addr().clone()
864 )
865 }
866 }
867
868 impl MessageDeliveryEvent {
869 fn new(
870 src_addr: SimAddr,
871 dest_addr: SimAddr,
872 data: Serialized,
873 dispatcher: Option<TestDispatcher>,
874 duration: tokio::time::Duration,
875 ) -> Self {
876 Self {
877 src_addr,
878 dest_addr,
879 data,
880 duration,
881 dispatcher,
882 }
883 }
884 }
885
886 #[derive(Debug, Clone)]
887 struct TestDispatcher {
888 pub mbuffers: Arc<Mutex<HashMap<SimAddr, Vec<Serialized>>>>,
889 }
890
891 impl Default for TestDispatcher {
892 fn default() -> Self {
893 Self {
894 mbuffers: Arc::new(Mutex::new(HashMap::new())),
895 }
896 }
897 }
898
899 #[async_trait]
900 impl Dispatcher<SimAddr> for TestDispatcher {
901 async fn send(&self, target: SimAddr, data: Serialized) -> Result<(), SimNetError> {
902 let mut buf = self.mbuffers.lock().await;
903 buf.entry(target).or_default().push(data);
904 Ok(())
905 }
906 }
907
908 #[cfg(target_os = "linux")]
909 fn random_abstract_addr() -> ChannelAddr {
910 use rand::Rng;
911 use rand::distributions::Alphanumeric;
912
913 let random_string = rand::thread_rng()
914 .sample_iter(&Alphanumeric)
915 .take(24)
916 .map(char::from)
917 .collect::<String>();
918 format!("unix!@{random_string}").parse().unwrap()
919 }
920
921 #[tokio::test]
922 async fn test_handle_instantiation() {
923 start();
924 simnet_handle().unwrap().close().await.unwrap();
925 }
926
927 #[tokio::test]
928 async fn test_simnet_config() {
929 let ext = extent!(region = 1, dc = 1, zone = 1, rack = 4, host = 4, gpu = 8);
931
932 let alice = id!(world[0]);
933 let bob = id!(world[1]);
934 let charlie = id!(world[2]);
935
936 let config = LatencyConfig {
937 inter_zone_distribution: LatencyDistribution::Beta(
938 BetaDistribution::new(
939 tokio::time::Duration::from_millis(1000),
940 tokio::time::Duration::from_millis(1000),
941 1.0,
942 1.0,
943 )
944 .unwrap(),
945 ),
946 inter_dc_distribution: LatencyDistribution::Beta(
947 BetaDistribution::new(
948 tokio::time::Duration::from_millis(2000),
949 tokio::time::Duration::from_millis(2000),
950 1.0,
951 1.0,
952 )
953 .unwrap(),
954 ),
955 ..Default::default()
956 };
957 start_with_config(config);
958
959 let handle = simnet_handle().unwrap();
960 handle.register_proc(alice.clone(), ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap());
961 handle.register_proc(bob.clone(), ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap());
962 handle.register_proc(charlie.clone(), ext.point(vec![0, 1, 0, 0, 0, 0]).unwrap());
963 assert_eq!(
964 handle.sample_latency(&alice, &bob),
965 tokio::time::Duration::from_millis(1000)
966 );
967 assert_eq!(
968 handle.sample_latency(&alice, &charlie),
969 tokio::time::Duration::from_millis(2000)
970 );
971 }
972
973 #[tokio::test]
974 async fn test_simnet_debounce() {
975 let config = LatencyConfig {
976 inter_zone_distribution: LatencyDistribution::Beta(
977 BetaDistribution::new(
978 tokio::time::Duration::from_millis(1000),
979 tokio::time::Duration::from_millis(1000),
980 1.0,
981 1.0,
982 )
983 .unwrap(),
984 ),
985 ..Default::default()
986 };
987 start_with_config(config);
988 let alice = "local:1".parse::<simnet::ChannelAddr>().unwrap();
989 let bob = "local:2".parse::<simnet::ChannelAddr>().unwrap();
990
991 let latency = Duration::from_millis(10000);
992
993 let alice = SimAddr::new(alice).unwrap();
994 let bob = SimAddr::new(bob).unwrap();
995
996 for _ in 0..10 {
998 simnet_handle()
999 .unwrap()
1000 .send_event(Box::new(MessageDeliveryEvent::new(
1001 alice.clone(),
1002 bob.clone(),
1003 Serialized::serialize(&"123".to_string()).unwrap(),
1004 None,
1005 latency,
1006 )))
1007 .unwrap();
1008 RealClock
1009 .sleep(tokio::time::Duration::from_micros(500))
1010 .await;
1011 }
1012
1013 simnet_handle()
1014 .unwrap()
1015 .flush(Duration::from_secs(20))
1016 .await
1017 .unwrap();
1018
1019 let records = simnet_handle().unwrap().close().await;
1020 assert_eq!(records.as_ref().unwrap().len(), 10);
1021
1022 assert_eq!(
1025 records.unwrap().last().unwrap().end_at,
1026 latency.as_millis() as u64
1027 );
1028 }
1029
1030 #[tokio::test]
1031 async fn test_sim_dispatch() {
1032 start();
1033 let sender = Some(TestDispatcher::default());
1034 let mut addresses: Vec<simnet::ChannelAddr> = Vec::new();
1035 for i in 0..4 {
1037 addresses.push(
1038 format!("local:{}", i)
1039 .parse::<simnet::ChannelAddr>()
1040 .unwrap(),
1041 );
1042 }
1043
1044 let messages: Vec<Serialized> = vec!["First 0 1", "First 2 3", "Second 0 1"]
1045 .into_iter()
1046 .map(|s| Serialized::serialize(&s.to_string()).unwrap())
1047 .collect();
1048
1049 let addr_0 = SimAddr::new(addresses[0].clone()).unwrap();
1050 let addr_1 = SimAddr::new(addresses[1].clone()).unwrap();
1051 let addr_2 = SimAddr::new(addresses[2].clone()).unwrap();
1052 let addr_3 = SimAddr::new(addresses[3].clone()).unwrap();
1053 let one = Box::new(MessageDeliveryEvent::new(
1054 addr_0.clone(),
1055 addr_1.clone(),
1056 messages[0].clone(),
1057 sender.clone(),
1058 tokio::time::Duration::ZERO,
1059 ));
1060 let two = Box::new(MessageDeliveryEvent::new(
1061 addr_2.clone(),
1062 addr_3.clone(),
1063 messages[1].clone(),
1064 sender.clone(),
1065 tokio::time::Duration::ZERO,
1066 ));
1067 let three = Box::new(MessageDeliveryEvent::new(
1068 addr_0.clone(),
1069 addr_1.clone(),
1070 messages[2].clone(),
1071 sender.clone(),
1072 tokio::time::Duration::ZERO,
1073 ));
1074
1075 simnet_handle().unwrap().send_event(one).unwrap();
1076 simnet_handle().unwrap().send_event(two).unwrap();
1077 simnet_handle().unwrap().send_event(three).unwrap();
1078
1079 simnet_handle()
1080 .unwrap()
1081 .flush(Duration::from_millis(1000))
1082 .await
1083 .unwrap();
1084 let records = simnet_handle().unwrap().close().await.unwrap();
1085 eprintln!("Records: {:?}", records);
1086 simnet_handle().unwrap().close().await.unwrap();
1088
1089 let buf = sender.as_ref().unwrap().mbuffers.lock().await;
1091 assert_eq!(buf.len(), 2);
1092 assert_eq!(buf[&addr_1].len(), 2);
1093 assert_eq!(buf[&addr_3].len(), 1);
1094
1095 assert_eq!(buf[&addr_1][0], messages[0]);
1096 assert_eq!(buf[&addr_1][1], messages[2]);
1097 assert_eq!(buf[&addr_3][0], messages[1]);
1098 }
1099
1100 #[tokio::test]
1101 async fn test_sim_sleep() {
1102 start();
1103
1104 let start = SimClock.now();
1105 assert_eq!(
1106 SimClock.duration_since_start(start),
1107 tokio::time::Duration::ZERO
1108 );
1109
1110 SimClock.sleep(tokio::time::Duration::from_secs(10)).await;
1111
1112 let end = SimClock.now();
1113 assert_eq!(
1114 SimClock.duration_since_start(end),
1115 tokio::time::Duration::from_secs(10)
1116 );
1117 }
1118
1119 #[tokio::test]
1120 async fn test_torch_op() {
1121 start();
1122 let args_string = "1, 2".to_string();
1123 let kwargs_string = "a=2".to_string();
1124
1125 let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
1126 let (tx, rx) = mailbox.open_once_port::<()>();
1127
1128 simnet_handle()
1129 .unwrap()
1130 .send_event(TorchOpEvent::new(
1131 "torch.ops.aten.ones.default".to_string(),
1132 tx.bind(),
1133 mailbox,
1134 args_string,
1135 kwargs_string,
1136 id!(mesh_0_worker[0].worker_0),
1137 ))
1138 .unwrap();
1139
1140 rx.recv().await.unwrap();
1141
1142 simnet_handle()
1143 .unwrap()
1144 .flush(Duration::from_millis(1000))
1145 .await
1146 .unwrap();
1147 let records = simnet_handle().unwrap().close().await;
1148 let expected_record = SimulatorEventRecord {
1149 summary:
1150 "[mesh_0_worker[0].worker_0[0]] Torch Op: torch.ops.aten.ones.default(1, 2, a=2)"
1151 .to_string(),
1152 start_at: 0,
1153 end_at: 100,
1154 };
1155 assert!(records.as_ref().unwrap().len() == 1);
1156 assert_eq!(records.unwrap().first().unwrap(), &expected_record);
1157 }
1158}