hyperactor/channel/
sim.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// SimTx contains a way to send through the network.
10// SimRx contains a way to receive messages.
11
12//! Local simulated channel implementation.
13use std::any::Any;
14// send leads to add to network.
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use dashmap::DashMap;
19use regex::Regex;
20
21use super::*;
22use crate::channel;
23use crate::clock::Clock;
24use crate::clock::RealClock;
25use crate::mailbox::MessageEnvelope;
26use crate::simnet::Dispatcher;
27use crate::simnet::Event;
28use crate::simnet::ScheduledEvent;
29use crate::simnet::SimNetError;
30use crate::simnet::simnet_handle;
31
32lazy_static! {
33    static ref SENDER: SimDispatcher = SimDispatcher::default();
34}
35static SIM_LINK_BUF_SIZE: usize = 256;
36
37/// An address for a simulated channel.
38#[derive(
39    Clone,
40    Debug,
41    PartialEq,
42    Eq,
43    Serialize,
44    Deserialize,
45    Ord,
46    PartialOrd,
47    Hash
48)]
49pub struct SimAddr {
50    src: Option<Box<ChannelAddr>>,
51    /// The address.
52    addr: Box<ChannelAddr>,
53    /// If source is the client
54    client: bool,
55}
56
57impl SimAddr {
58    /// Creates a new SimAddr.
59    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
60    /// Creates a new SimAddr without a source to be served
61    pub fn new(addr: ChannelAddr) -> Result<Self, SimNetError> {
62        Self::new_impl(None, addr, false)
63    }
64
65    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
66    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
67    pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
68        Self::new_impl(Some(Box::new(src)), addr, false)
69    }
70
71    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
72    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
73    fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
74        Self::new_impl(Some(Box::new(src)), addr, true)
75    }
76
77    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
78    fn new_impl(
79        src: Option<Box<ChannelAddr>>,
80        addr: ChannelAddr,
81        client: bool,
82    ) -> Result<Self, SimNetError> {
83        if let ChannelAddr::Sim(_) = &addr {
84            return Err(SimNetError::InvalidArg(format!(
85                "addr cannot be a sim address, found {}",
86                addr
87            )));
88        }
89        Ok(Self {
90            src,
91            addr: Box::new(addr),
92            client,
93        })
94    }
95
96    /// Returns the address.
97    pub fn addr(&self) -> &ChannelAddr {
98        &self.addr
99    }
100
101    /// Returns the source address
102    pub fn src(&self) -> &Option<Box<ChannelAddr>> {
103        &self.src
104    }
105
106    /// The underlying transport we are simulating
107    pub fn transport(&self) -> ChannelTransport {
108        self.addr.transport()
109    }
110}
111
112impl fmt::Display for SimAddr {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match &self.src {
115            None => write!(f, "{}", self.addr),
116            Some(src) => write!(f, "{},{}", src, self.addr),
117        }
118    }
119}
120
121/// Message Event that can be passed around in the simnet.
122#[derive(Debug)]
123pub(crate) struct MessageDeliveryEvent {
124    dest_addr: ChannelAddr,
125    data: wirevalue::Any,
126    latency: tokio::time::Duration,
127}
128
129impl MessageDeliveryEvent {
130    /// Creates a new MessageDeliveryEvent.
131    pub fn new(
132        dest_addr: ChannelAddr,
133        data: wirevalue::Any,
134        latency: tokio::time::Duration,
135    ) -> Self {
136        Self {
137            dest_addr,
138            data,
139            latency,
140        }
141    }
142}
143
144#[async_trait]
145impl Event for MessageDeliveryEvent {
146    async fn handle(&self) -> Result<(), SimNetError> {
147        // Send the message to the correct receiver.
148        SENDER
149            .send(self.dest_addr.clone(), self.data.clone())
150            .await?;
151        Ok(())
152    }
153
154    fn duration(&self) -> tokio::time::Duration {
155        self.latency
156    }
157
158    fn summary(&self) -> String {
159        format!("Sending message to {}", self.dest_addr.clone())
160    }
161}
162
163/// Bind a channel address to the simnet. It will register the address as a node in simnet,
164/// and configure default latencies between this node and all other existing nodes.
165pub async fn bind(addr: ChannelAddr) -> anyhow::Result<(), SimNetError> {
166    simnet_handle()?.bind(addr)
167}
168
169/// Returns a simulated channel address that is bound to "any" channel address.
170pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr {
171    ChannelAddr::Sim(SimAddr {
172        src: None,
173        addr: Box::new(ChannelAddr::any(transport)),
174        client: false,
175    })
176}
177
178/// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma.
179#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
180pub fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
181    let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| {
182        ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err))
183    })?;
184
185    let result = re.captures(addr_string);
186    if let Some(caps) = result {
187        let parts = caps
188            .iter()
189            .skip(1)
190            .map(|cap| cap.map_or("", |m| m.as_str()))
191            .filter(|m| !m.is_empty())
192            .collect::<Vec<_>>();
193
194        if parts.iter().any(|part| part.starts_with("sim!")) {
195            return Err(ChannelError::InvalidAddress(addr_string.to_string()));
196        }
197
198        match parts.len() {
199            1 => {
200                let addr = parts[0].parse::<ChannelAddr>()?;
201
202                Ok(ChannelAddr::Sim(SimAddr::new(addr)?))
203            }
204            3 => {
205                let src_addr = parts[0].parse::<ChannelAddr>()?;
206                let addr = parts[2].parse::<ChannelAddr>()?;
207                Ok(ChannelAddr::Sim(if parts[0] == "client" {
208                    SimAddr::new_with_client_src(src_addr, addr)
209                } else {
210                    SimAddr::new_with_src(src_addr, addr)
211                }?))
212            }
213            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
214        }
215    } else {
216        Err(ChannelError::InvalidAddress(addr_string.to_string()))
217    }
218}
219
220impl<M: RemoteMessage> Drop for SimRx<M> {
221    fn drop(&mut self) {
222        // Remove the sender from the dispatchers.
223        SENDER.dispatchers.remove(&self.addr);
224    }
225}
226
227/// Primarily used for dispatching messages to the correct sender.
228pub struct SimDispatcher {
229    dispatchers: DashMap<ChannelAddr, mpsc::Sender<wirevalue::Any>>,
230    sender_cache: DashMap<ChannelAddr, Arc<dyn Tx<MessageEnvelope> + Send + Sync>>,
231}
232
233fn create_egress_sender(
234    addr: ChannelAddr,
235) -> anyhow::Result<Arc<dyn Tx<MessageEnvelope> + Send + Sync>> {
236    let tx = channel::dial(addr)?;
237    Ok(Arc::new(tx))
238}
239
240#[async_trait]
241impl Dispatcher<ChannelAddr> for SimDispatcher {
242    async fn send(&self, addr: ChannelAddr, data: wirevalue::Any) -> Result<(), SimNetError> {
243        self.dispatchers
244            .get(&addr)
245            .ok_or_else(|| {
246                SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found"))
247            })?
248            .send(data)
249            .await
250            .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into()))
251    }
252}
253
254impl Default for SimDispatcher {
255    fn default() -> Self {
256        Self {
257            dispatchers: DashMap::new(),
258            sender_cache: DashMap::new(),
259        }
260    }
261}
262
263pub(crate) struct SimTx<M: RemoteMessage> {
264    src_addr: Option<ChannelAddr>,
265    dst_addr: ChannelAddr,
266    status: watch::Receiver<TxStatus>, // Default impl. Always reports `Active`.
267    client: bool,
268    _phantom: PhantomData<M>,
269}
270
271pub(crate) struct SimRx<M: RemoteMessage> {
272    /// The destination address, not the full SimAddr.
273    addr: ChannelAddr,
274    rx: mpsc::Receiver<wirevalue::Any>,
275    _phantom: PhantomData<M>,
276}
277
278#[async_trait]
279impl<M: RemoteMessage + Any> Tx<M> for SimTx<M> {
280    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
281        let data = match wirevalue::Any::serialize(&message) {
282            Ok(data) => data,
283            Err(err) => {
284                if let Some(return_channel) = return_channel {
285                    return_channel
286                        .send(SendError {
287                            error: err.into(),
288                            message,
289                            reason: None,
290                        })
291                        .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
292                }
293
294                return;
295            }
296        };
297
298        let envelope = (&message as &dyn Any)
299            .downcast_ref::<MessageEnvelope>()
300            .expect("RemoteMessage should always be a MessageEnvelope");
301
302        let (sender, dest) = (envelope.sender().clone(), envelope.dest().0.clone());
303
304        match simnet_handle() {
305            Ok(handle) => {
306                let event = Box::new(MessageDeliveryEvent::new(
307                    self.dst_addr.clone(),
308                    data,
309                    handle.sample_latency(sender.proc_id(), dest.proc_id()),
310                ));
311
312                let result = match &self.src_addr {
313                    Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent {
314                        event,
315                        time: RealClock.now(),
316                    }),
317                    _ => handle.send_event(event),
318                };
319                if let Err(err) = result {
320                    if let Some(return_channel) = return_channel {
321                        return_channel
322                            .send(SendError {
323                                error: err.into(),
324                                message,
325                                reason: None,
326                            })
327                            .unwrap_or_else(|m| {
328                                tracing::warn!("failed to deliver SendError: {}", m)
329                            });
330                    }
331                }
332            }
333            Err(err) => {
334                if let Some(return_channel) = return_channel {
335                    return_channel
336                        .send(SendError {
337                            error: err.into(),
338                            message,
339                            reason: None,
340                        })
341                        .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
342                }
343            }
344        }
345    }
346
347    fn addr(&self) -> ChannelAddr {
348        self.dst_addr.clone()
349    }
350
351    fn status(&self) -> &watch::Receiver<TxStatus> {
352        &self.status
353    }
354}
355
356/// Dial a peer and return a transmitter. The transmitter can retrieve from the
357/// network the link latency.
358#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
359pub(crate) fn dial<M: RemoteMessage>(addr: SimAddr) -> Result<SimTx<M>, ChannelError> {
360    // This watch channel always reports active. The sender is
361    // dropped.
362    let (_, status) = watch::channel(TxStatus::Active);
363    let dialer = addr.src().clone().map(|src| *src);
364
365    Ok(SimTx {
366        src_addr: dialer,
367        dst_addr: addr.addr().clone(),
368        status,
369        client: addr.client,
370        _phantom: PhantomData,
371    })
372}
373
374/// Serve a sim channel. Set up the right simulated sender and receivers
375/// The mpsc tx will be used to dispatch messages when it's time while
376/// the mpsc rx will be used by the above applications to handle received messages
377/// like any other channel.
378/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address.
379pub(crate) fn serve<M: RemoteMessage>(
380    sim_addr: SimAddr,
381) -> anyhow::Result<(ChannelAddr, SimRx<M>)> {
382    let (tx, rx) = mpsc::channel::<wirevalue::Any>(SIM_LINK_BUF_SIZE);
383    // Add tx to sender dispatch.
384    SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx);
385    // Return the sender.
386    Ok((
387        ChannelAddr::Sim(sim_addr.clone()),
388        SimRx {
389            addr: *sim_addr.addr.clone(),
390            rx,
391            _phantom: PhantomData,
392        },
393    ))
394}
395
396#[async_trait]
397impl<M: RemoteMessage> Rx<M> for SimRx<M> {
398    async fn recv(&mut self) -> Result<M, ChannelError> {
399        let data = self.rx.recv().await.ok_or(ChannelError::Closed)?;
400        data.deserialized().map_err(ChannelError::from)
401    }
402
403    fn addr(&self) -> ChannelAddr {
404        self.addr.clone()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use std::iter::zip;
411
412    use hyperactor_config::attrs::Attrs;
413    use ndslice::extent;
414
415    use super::*;
416    use crate::PortId;
417    use crate::clock::Clock;
418    use crate::clock::RealClock;
419    use crate::clock::SimClock;
420    use crate::id;
421    use crate::simnet;
422    use crate::simnet::BetaDistribution;
423    use crate::simnet::LatencyConfig;
424    use crate::simnet::LatencyDistribution;
425    use crate::simnet::start;
426    use crate::simnet::start_with_config;
427
428    #[tokio::test]
429    async fn test_sim_basic() {
430        let dst_ok = vec!["tcp:[::1]:1234", "tcp:127.0.0.1:8080", "local:123"];
431        let srcs_ok = vec!["tcp:[::2]:1234", "tcp:127.0.0.2:8080", "local:124"];
432
433        start();
434        let handle = simnet_handle().unwrap();
435
436        // TODO: New NodeAdd event should do this for you..
437        for addr in dst_ok.iter().chain(srcs_ok.iter()) {
438            // Add to network along with its edges.
439            simnet_handle()
440                .unwrap()
441                .bind(addr.parse::<ChannelAddr>().unwrap())
442                .unwrap();
443        }
444        for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) {
445            let dst_addr = SimAddr::new_with_src(
446                src_addr.parse::<ChannelAddr>().unwrap(),
447                dst_addr.parse::<ChannelAddr>().unwrap(),
448            )
449            .unwrap();
450
451            let (_, mut rx) = sim::serve::<MessageEnvelope>(dst_addr.clone()).unwrap();
452            let tx = sim::dial::<MessageEnvelope>(dst_addr).unwrap();
453            let data = wirevalue::Any::serialize(&456).unwrap();
454            let sender = id!(world[0].hello);
455            let dest = id!(world[1].hello);
456            let ext = extent!(region = 1, dc = 1, rack = 4, host = 4, gpu = 8);
457            handle.register_proc(
458                sender.proc_id().clone(),
459                ext.point(vec![0, 0, 0, 0, 0]).unwrap(),
460            );
461            handle.register_proc(
462                dest.proc_id().clone(),
463                ext.point(vec![0, 0, 0, 1, 0]).unwrap(),
464            );
465
466            let msg = MessageEnvelope::new(sender, PortId(dest, 0), data.clone(), Attrs::new());
467            tx.post(msg);
468            assert_eq!(*rx.recv().await.unwrap().data(), data);
469        }
470
471        let records = sim::simnet_handle().unwrap().close().await.unwrap();
472        eprintln!("records: {:#?}", records);
473    }
474
475    #[tokio::test]
476    async fn test_invalid_sim_addr() {
477        let src = "sim!src";
478        let dst = "sim!dst";
479        let sim_addr = format!("{},{}", src, dst);
480        let result = parse(&sim_addr);
481        assert!(matches!(result, Err(ChannelError::InvalidAddress(_))));
482    }
483
484    #[tokio::test]
485    async fn test_parse_sim_addr() {
486        let sim_addr = "sim!unix:@dst";
487        let result = sim_addr.parse();
488        assert!(result.is_ok());
489        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
490            panic!("Expected a sim address");
491        };
492        assert!(sim_addr.src().is_none());
493        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
494
495        let sim_addr = "sim!unix:@src,unix:@dst";
496        let result = sim_addr.parse();
497        assert!(result.is_ok());
498        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
499            panic!("Expected a sim address");
500        };
501        assert!(sim_addr.src().is_some());
502        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
503    }
504
505    #[tokio::test]
506    async fn test_realtime_frontier() {
507        tokio::time::pause();
508        // 1 second of latency
509        start_with_config(LatencyConfig {
510            inter_zone_distribution: LatencyDistribution::Beta(
511                BetaDistribution::new(
512                    tokio::time::Duration::from_millis(100),
513                    tokio::time::Duration::from_millis(100),
514                    1.0,
515                    1.0,
516                )
517                .unwrap(),
518            ),
519            ..Default::default()
520        });
521
522        let sim_addr = SimAddr::new("unix:@dst".parse::<ChannelAddr>().unwrap()).unwrap();
523        let sim_addr_with_src = SimAddr::new_with_src(
524            "unix:@src".parse::<ChannelAddr>().unwrap(),
525            "unix:@dst".parse::<ChannelAddr>().unwrap(),
526        )
527        .unwrap();
528        let (_, mut rx) = sim::serve::<MessageEnvelope>(sim_addr.clone()).unwrap();
529        let tx = sim::dial::<MessageEnvelope>(sim_addr_with_src).unwrap();
530
531        let controller = id!(world[0].controller);
532        let dest = id!(world[1].dest);
533        let handle = simnet::simnet_handle().unwrap();
534
535        let ext = extent!(region = 1, dc = 1, zone = 2, rack = 4, host = 4, gpu = 8);
536        handle.register_proc(
537            controller.proc_id().clone(),
538            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
539        );
540        handle.register_proc(
541            dest.proc_id().clone(),
542            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
543        );
544
545        // This message will be delievered at simulator time = 100 seconds
546        tx.post(MessageEnvelope::new(
547            controller,
548            PortId(dest, 0),
549            wirevalue::Any::serialize(&456).unwrap(),
550            Attrs::new(),
551        ));
552        {
553            // Allow simnet to run
554            tokio::task::yield_now().await;
555            // Messages have not been receive since 10 seconds have not elapsed
556            assert!(rx.rx.try_recv().is_err());
557        }
558        // Advance "real" time by 100 seconds
559        tokio::time::advance(tokio::time::Duration::from_secs(100)).await;
560        {
561            // Allow some time for simnet to run
562            tokio::task::yield_now().await;
563            // Messages are received
564            assert!(rx.rx.try_recv().is_ok());
565        }
566    }
567
568    #[tokio::test]
569    async fn test_client_message_scheduled_realtime() {
570        tokio::time::pause();
571        // 1 second of latency
572        start_with_config(LatencyConfig {
573            inter_zone_distribution: LatencyDistribution::Beta(
574                BetaDistribution::new(
575                    tokio::time::Duration::from_millis(1000),
576                    tokio::time::Duration::from_millis(1000),
577                    1.0,
578                    1.0,
579                )
580                .unwrap(),
581            ),
582            ..Default::default()
583        });
584
585        let controller_to_dst = SimAddr::new_with_src(
586            "unix:@controller".parse::<ChannelAddr>().unwrap(),
587            "unix:@dst".parse::<ChannelAddr>().unwrap(),
588        )
589        .unwrap();
590
591        let controller_tx = sim::dial::<MessageEnvelope>(controller_to_dst.clone()).unwrap();
592
593        let client_to_dst = SimAddr::new_with_client_src(
594            "unix:@client".parse::<ChannelAddr>().unwrap(),
595            "unix:@dst".parse::<ChannelAddr>().unwrap(),
596        )
597        .unwrap();
598        let client_tx = sim::dial::<MessageEnvelope>(client_to_dst).unwrap();
599
600        let controller = id!(world[0].controller);
601        let dest = id!(world[1].dest);
602        let client = id!(world[2].client);
603
604        let handle = simnet::simnet_handle().unwrap();
605        let ext = extent!(region = 1, dc = 1, zone = 2, rack = 4, host = 4, gpu = 8);
606        handle.register_proc(
607            controller.proc_id().clone(),
608            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
609        );
610        handle.register_proc(
611            client.proc_id().clone(),
612            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
613        );
614        handle.register_proc(
615            dest.proc_id().clone(),
616            ext.point(vec![0, 0, 0, 0, 1, 0]).unwrap(),
617        );
618
619        assert_eq!(
620            SimClock.duration_since_start(RealClock.now()),
621            tokio::time::Duration::ZERO
622        );
623        // Fast forward real time to 5 seconds
624        tokio::time::advance(tokio::time::Duration::from_secs(5)).await;
625        {
626            // Send client message
627            client_tx.post(MessageEnvelope::new(
628                client.clone(),
629                PortId(dest.clone(), 0),
630                wirevalue::Any::serialize(&456).unwrap(),
631                Attrs::new(),
632            ));
633            // Send system message
634            controller_tx.post(MessageEnvelope::new(
635                controller.clone(),
636                PortId(dest.clone(), 0),
637                wirevalue::Any::serialize(&456).unwrap(),
638                Attrs::new(),
639            ));
640            // Allow some time for simnet to run
641            RealClock.sleep(tokio::time::Duration::from_secs(1)).await;
642        }
643        let recs = simnet::simnet_handle().unwrap().close().await.unwrap();
644        assert_eq!(recs.len(), 2);
645        let end_times = recs.iter().map(|rec| rec.end_at).collect::<Vec<_>>();
646        // client message was delivered at "real" time = 5 seconds
647        assert!(end_times.contains(&5000));
648        // system message was delivered at simulated time = 1 second
649        assert!(end_times.contains(&1000));
650    }
651}