hyperactor/channel/
sim.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// SimTx contains a way to send through the network.
10// SimRx contains a way to receive messages.
11
12//! Local simulated channel implementation.
13use std::any::Any;
14// send leads to add to network.
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use dashmap::DashMap;
19use regex::Regex;
20
21use super::*;
22use crate::channel;
23use crate::clock::Clock;
24use crate::clock::RealClock;
25use crate::data::Serialized;
26use crate::mailbox::MessageEnvelope;
27use crate::simnet::Dispatcher;
28use crate::simnet::Event;
29use crate::simnet::ScheduledEvent;
30use crate::simnet::SimNetError;
31use crate::simnet::simnet_handle;
32
33lazy_static! {
34    static ref SENDER: SimDispatcher = SimDispatcher::default();
35}
36static SIM_LINK_BUF_SIZE: usize = 256;
37
38/// An address for a simulated channel.
39#[derive(
40    Clone,
41    Debug,
42    PartialEq,
43    Eq,
44    Serialize,
45    Deserialize,
46    Ord,
47    PartialOrd,
48    Hash
49)]
50pub struct SimAddr {
51    src: Option<Box<ChannelAddr>>,
52    /// The address.
53    addr: Box<ChannelAddr>,
54    /// If source is the client
55    client: bool,
56}
57
58impl SimAddr {
59    /// Creates a new SimAddr.
60    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
61    /// Creates a new SimAddr without a source to be served
62    pub fn new(addr: ChannelAddr) -> Result<Self, SimNetError> {
63        Self::new_impl(None, addr, false)
64    }
65
66    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
67    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
68    pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
69        Self::new_impl(Some(Box::new(src)), addr, false)
70    }
71
72    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
73    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
74    fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
75        Self::new_impl(Some(Box::new(src)), addr, true)
76    }
77
78    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
79    fn new_impl(
80        src: Option<Box<ChannelAddr>>,
81        addr: ChannelAddr,
82        client: bool,
83    ) -> Result<Self, SimNetError> {
84        if let ChannelAddr::Sim(_) = &addr {
85            return Err(SimNetError::InvalidArg(format!(
86                "addr cannot be a sim address, found {}",
87                addr
88            )));
89        }
90        Ok(Self {
91            src,
92            addr: Box::new(addr),
93            client,
94        })
95    }
96
97    /// Returns the address.
98    pub fn addr(&self) -> &ChannelAddr {
99        &self.addr
100    }
101
102    /// Returns the source address
103    pub fn src(&self) -> &Option<Box<ChannelAddr>> {
104        &self.src
105    }
106
107    /// The underlying transport we are simulating
108    pub fn transport(&self) -> ChannelTransport {
109        self.addr.transport()
110    }
111}
112
113impl fmt::Display for SimAddr {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match &self.src {
116            None => write!(f, "{}", self.addr),
117            Some(src) => write!(f, "{},{}", src, self.addr),
118        }
119    }
120}
121
122/// Message Event that can be passed around in the simnet.
123#[derive(Debug)]
124pub(crate) struct MessageDeliveryEvent {
125    dest_addr: ChannelAddr,
126    data: Serialized,
127    latency: tokio::time::Duration,
128}
129
130impl MessageDeliveryEvent {
131    /// Creates a new MessageDeliveryEvent.
132    pub fn new(dest_addr: ChannelAddr, data: Serialized, latency: tokio::time::Duration) -> Self {
133        Self {
134            dest_addr,
135            data,
136            latency,
137        }
138    }
139}
140
141#[async_trait]
142impl Event for MessageDeliveryEvent {
143    async fn handle(&mut self) -> Result<(), SimNetError> {
144        // Send the message to the correct receiver.
145        SENDER
146            .send(self.dest_addr.clone(), self.data.clone())
147            .await?;
148        Ok(())
149    }
150
151    fn duration(&self) -> tokio::time::Duration {
152        self.latency
153    }
154
155    fn summary(&self) -> String {
156        format!("Sending message to {}", self.dest_addr.clone())
157    }
158}
159
160/// Bind a channel address to the simnet. It will register the address as a node in simnet,
161/// and configure default latencies between this node and all other existing nodes.
162pub async fn bind(addr: ChannelAddr) -> anyhow::Result<(), SimNetError> {
163    simnet_handle()?.bind(addr)
164}
165
166/// Returns a simulated channel address that is bound to "any" channel address.
167pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr {
168    ChannelAddr::Sim(SimAddr {
169        src: None,
170        addr: Box::new(ChannelAddr::any(transport)),
171        client: false,
172    })
173}
174
175/// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma.
176#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
177pub fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
178    let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| {
179        ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err))
180    })?;
181
182    let result = re.captures(addr_string);
183    if let Some(caps) = result {
184        let parts = caps
185            .iter()
186            .skip(1)
187            .map(|cap| cap.map_or("", |m| m.as_str()))
188            .filter(|m| !m.is_empty())
189            .collect::<Vec<_>>();
190
191        if parts.iter().any(|part| part.starts_with("sim!")) {
192            return Err(ChannelError::InvalidAddress(addr_string.to_string()));
193        }
194
195        match parts.len() {
196            1 => {
197                let addr = parts[0].parse::<ChannelAddr>()?;
198
199                Ok(ChannelAddr::Sim(SimAddr::new(addr)?))
200            }
201            3 => {
202                let src_addr = parts[0].parse::<ChannelAddr>()?;
203                let addr = parts[2].parse::<ChannelAddr>()?;
204                Ok(ChannelAddr::Sim(if parts[0] == "client" {
205                    SimAddr::new_with_client_src(src_addr, addr)
206                } else {
207                    SimAddr::new_with_src(src_addr, addr)
208                }?))
209            }
210            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
211        }
212    } else {
213        Err(ChannelError::InvalidAddress(addr_string.to_string()))
214    }
215}
216
217impl<M: RemoteMessage> Drop for SimRx<M> {
218    fn drop(&mut self) {
219        // Remove the sender from the dispatchers.
220        SENDER.dispatchers.remove(&self.addr);
221    }
222}
223
224/// Primarily used for dispatching messages to the correct sender.
225#[derive(Debug)]
226pub struct SimDispatcher {
227    dispatchers: DashMap<ChannelAddr, mpsc::Sender<Serialized>>,
228    sender_cache: DashMap<ChannelAddr, Arc<dyn Tx<MessageEnvelope> + Send + Sync>>,
229}
230
231fn create_egress_sender(
232    addr: ChannelAddr,
233) -> anyhow::Result<Arc<dyn Tx<MessageEnvelope> + Send + Sync>> {
234    let tx = channel::dial(addr)?;
235    Ok(Arc::new(tx))
236}
237
238#[async_trait]
239impl Dispatcher<ChannelAddr> for SimDispatcher {
240    async fn send(&self, addr: ChannelAddr, data: Serialized) -> Result<(), SimNetError> {
241        self.dispatchers
242            .get(&addr)
243            .ok_or_else(|| {
244                SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found"))
245            })?
246            .send(data)
247            .await
248            .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into()))
249    }
250}
251
252impl Default for SimDispatcher {
253    fn default() -> Self {
254        Self {
255            dispatchers: DashMap::new(),
256            sender_cache: DashMap::new(),
257        }
258    }
259}
260
261#[derive(Debug)]
262pub(crate) struct SimTx<M: RemoteMessage> {
263    src_addr: Option<ChannelAddr>,
264    dst_addr: ChannelAddr,
265    status: watch::Receiver<TxStatus>, // Default impl. Always reports `Active`.
266    client: bool,
267    _phantom: PhantomData<M>,
268}
269
270#[derive(Debug)]
271pub(crate) struct SimRx<M: RemoteMessage> {
272    /// The destination address, not the full SimAddr.
273    addr: ChannelAddr,
274    rx: mpsc::Receiver<Serialized>,
275    _phantom: PhantomData<M>,
276}
277
278#[async_trait]
279impl<M: RemoteMessage + Any> Tx<M> for SimTx<M> {
280    fn try_post(&self, message: M, _return_handle: oneshot::Sender<M>) -> Result<(), SendError<M>> {
281        let data = match Serialized::serialize(&message) {
282            Ok(data) => data,
283            Err(err) => return Err(SendError(err.into(), message)),
284        };
285
286        let envelope = (&message as &dyn Any)
287            .downcast_ref::<MessageEnvelope>()
288            .expect("RemoteMessage should always be a MessageEnvelope");
289
290        let (sender, dest) = (envelope.sender().clone(), envelope.dest().0.clone());
291
292        match simnet_handle() {
293            Ok(handle) => {
294                let event = Box::new(MessageDeliveryEvent::new(
295                    self.dst_addr.clone(),
296                    data,
297                    handle.sample_latency(sender.proc_id(), dest.proc_id()),
298                ));
299
300                match &self.src_addr {
301                    Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent {
302                        event,
303                        time: RealClock.now(),
304                    }),
305                    _ => handle.send_event(event),
306                }
307            }
308            .map_err(|err: SimNetError| SendError(ChannelError::from(err), message)),
309            Err(err) => Err(SendError(ChannelError::from(err), message)),
310        }
311    }
312
313    fn addr(&self) -> ChannelAddr {
314        self.dst_addr.clone()
315    }
316
317    fn status(&self) -> &watch::Receiver<TxStatus> {
318        &self.status
319    }
320}
321
322/// Dial a peer and return a transmitter. The transmitter can retrieve from the
323/// network the link latency.
324#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
325pub(crate) fn dial<M: RemoteMessage>(addr: SimAddr) -> Result<SimTx<M>, ChannelError> {
326    // This watch channel always reports active. The sender is
327    // dropped.
328    let (_, status) = watch::channel(TxStatus::Active);
329    let dialer = addr.src().clone().map(|src| *src);
330
331    Ok(SimTx {
332        src_addr: dialer,
333        dst_addr: addr.addr().clone(),
334        status,
335        client: addr.client,
336        _phantom: PhantomData,
337    })
338}
339
340/// Serve a sim channel. Set up the right simulated sender and receivers
341/// The mpsc tx will be used to dispatch messages when it's time while
342/// the mpsc rx will be used by the above applications to handle received messages
343/// like any other channel.
344/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address.
345pub(crate) fn serve<M: RemoteMessage>(
346    sim_addr: SimAddr,
347) -> anyhow::Result<(ChannelAddr, SimRx<M>)> {
348    let (tx, rx) = mpsc::channel::<Serialized>(SIM_LINK_BUF_SIZE);
349    // Add tx to sender dispatch.
350    SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx);
351    // Return the sender.
352    Ok((
353        ChannelAddr::Sim(sim_addr.clone()),
354        SimRx {
355            addr: *sim_addr.addr.clone(),
356            rx,
357            _phantom: PhantomData,
358        },
359    ))
360}
361
362#[async_trait]
363impl<M: RemoteMessage> Rx<M> for SimRx<M> {
364    async fn recv(&mut self) -> Result<M, ChannelError> {
365        let data = self.rx.recv().await.ok_or(ChannelError::Closed)?;
366        data.deserialized().map_err(ChannelError::from)
367    }
368
369    fn addr(&self) -> ChannelAddr {
370        self.addr.clone()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use std::iter::zip;
377
378    use ndslice::extent;
379
380    use super::*;
381    use crate::PortId;
382    use crate::attrs::Attrs;
383    use crate::clock::Clock;
384    use crate::clock::RealClock;
385    use crate::clock::SimClock;
386    use crate::id;
387    use crate::simnet;
388    use crate::simnet::BetaDistribution;
389    use crate::simnet::LatencyConfig;
390    use crate::simnet::LatencyDistribution;
391    use crate::simnet::start;
392    use crate::simnet::start_with_config;
393
394    #[tokio::test]
395    async fn test_sim_basic() {
396        let dst_ok = vec!["tcp:[::1]:1234", "tcp:127.0.0.1:8080", "local:123"];
397        let srcs_ok = vec!["tcp:[::2]:1234", "tcp:127.0.0.2:8080", "local:124"];
398
399        start();
400        let handle = simnet_handle().unwrap();
401
402        // TODO: New NodeAdd event should do this for you..
403        for addr in dst_ok.iter().chain(srcs_ok.iter()) {
404            // Add to network along with its edges.
405            simnet_handle()
406                .unwrap()
407                .bind(addr.parse::<ChannelAddr>().unwrap())
408                .unwrap();
409        }
410        for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) {
411            let dst_addr = SimAddr::new_with_src(
412                src_addr.parse::<ChannelAddr>().unwrap(),
413                dst_addr.parse::<ChannelAddr>().unwrap(),
414            )
415            .unwrap();
416
417            let (_, mut rx) = sim::serve::<MessageEnvelope>(dst_addr.clone()).unwrap();
418            let tx = sim::dial::<MessageEnvelope>(dst_addr).unwrap();
419            let data = Serialized::serialize(&456).unwrap();
420            let sender = id!(world[0].hello);
421            let dest = id!(world[1].hello);
422            let ext = extent!(region = 1, dc = 1, rack = 4, host = 4, gpu = 8);
423            handle.register_proc(
424                sender.proc_id().clone(),
425                ext.point(vec![0, 0, 0, 0, 0]).unwrap(),
426            );
427            handle.register_proc(
428                dest.proc_id().clone(),
429                ext.point(vec![0, 0, 0, 1, 0]).unwrap(),
430            );
431
432            let msg = MessageEnvelope::new(sender, PortId(dest, 0), data.clone(), Attrs::new());
433            tx.try_post(msg, oneshot::channel().0).unwrap();
434            assert_eq!(*rx.recv().await.unwrap().data(), data);
435        }
436
437        let records = sim::simnet_handle().unwrap().close().await.unwrap();
438        eprintln!("records: {:#?}", records);
439    }
440
441    #[tokio::test]
442    async fn test_invalid_sim_addr() {
443        let src = "sim!src";
444        let dst = "sim!dst";
445        let sim_addr = format!("{},{}", src, dst);
446        let result = parse(&sim_addr);
447        assert!(matches!(result, Err(ChannelError::InvalidAddress(_))));
448    }
449
450    #[tokio::test]
451    async fn test_parse_sim_addr() {
452        let sim_addr = "sim!unix:@dst";
453        let result = sim_addr.parse();
454        assert!(result.is_ok());
455        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
456            panic!("Expected a sim address");
457        };
458        assert!(sim_addr.src().is_none());
459        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
460
461        let sim_addr = "sim!unix:@src,unix:@dst";
462        let result = sim_addr.parse();
463        assert!(result.is_ok());
464        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
465            panic!("Expected a sim address");
466        };
467        assert!(sim_addr.src().is_some());
468        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
469    }
470
471    #[tokio::test]
472    async fn test_realtime_frontier() {
473        tokio::time::pause();
474        // 1 second of latency
475        start_with_config(LatencyConfig {
476            inter_zone_distribution: LatencyDistribution::Beta(
477                BetaDistribution::new(
478                    tokio::time::Duration::from_millis(100),
479                    tokio::time::Duration::from_millis(100),
480                    1.0,
481                    1.0,
482                )
483                .unwrap(),
484            ),
485            ..Default::default()
486        });
487
488        let sim_addr = SimAddr::new("unix:@dst".parse::<ChannelAddr>().unwrap()).unwrap();
489        let sim_addr_with_src = SimAddr::new_with_src(
490            "unix:@src".parse::<ChannelAddr>().unwrap(),
491            "unix:@dst".parse::<ChannelAddr>().unwrap(),
492        )
493        .unwrap();
494        let (_, mut rx) = sim::serve::<MessageEnvelope>(sim_addr.clone()).unwrap();
495        let tx = sim::dial::<MessageEnvelope>(sim_addr_with_src).unwrap();
496
497        let controller = id!(world[0].controller);
498        let dest = id!(world[1].dest);
499        let handle = simnet::simnet_handle().unwrap();
500
501        let ext = extent!(region = 1, dc = 1, zone = 1, rack = 4, host = 4, gpu = 8);
502        handle.register_proc(
503            controller.proc_id().clone(),
504            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
505        );
506        handle.register_proc(
507            dest.proc_id().clone(),
508            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
509        );
510
511        // This message will be delievered at simulator time = 100 seconds
512        tx.try_post(
513            MessageEnvelope::new(
514                controller,
515                PortId(dest, 0),
516                Serialized::serialize(&456).unwrap(),
517                Attrs::new(),
518            ),
519            oneshot::channel().0,
520        )
521        .unwrap();
522        {
523            // Allow simnet to run
524            tokio::task::yield_now().await;
525            // Messages have not been receive since 10 seconds have not elapsed
526            assert!(rx.rx.try_recv().is_err());
527        }
528        // Advance "real" time by 100 seconds
529        tokio::time::advance(tokio::time::Duration::from_secs(100)).await;
530        {
531            // Allow some time for simnet to run
532            tokio::task::yield_now().await;
533            // Messages are received
534            assert!(rx.rx.try_recv().is_ok());
535        }
536    }
537
538    #[tokio::test]
539    async fn test_client_message_scheduled_realtime() {
540        tokio::time::pause();
541        // 1 second of latency
542        start_with_config(LatencyConfig {
543            inter_zone_distribution: LatencyDistribution::Beta(
544                BetaDistribution::new(
545                    tokio::time::Duration::from_millis(1000),
546                    tokio::time::Duration::from_millis(1000),
547                    1.0,
548                    1.0,
549                )
550                .unwrap(),
551            ),
552            ..Default::default()
553        });
554
555        let controller_to_dst = SimAddr::new_with_src(
556            "unix:@controller".parse::<ChannelAddr>().unwrap(),
557            "unix:@dst".parse::<ChannelAddr>().unwrap(),
558        )
559        .unwrap();
560
561        let controller_tx = sim::dial::<MessageEnvelope>(controller_to_dst.clone()).unwrap();
562
563        let client_to_dst = SimAddr::new_with_client_src(
564            "unix:@client".parse::<ChannelAddr>().unwrap(),
565            "unix:@dst".parse::<ChannelAddr>().unwrap(),
566        )
567        .unwrap();
568        let client_tx = sim::dial::<MessageEnvelope>(client_to_dst).unwrap();
569
570        let controller = id!(world[0].controller);
571        let dest = id!(world[1].dest);
572        let client = id!(world[2].client);
573
574        let handle = simnet::simnet_handle().unwrap();
575        let ext = extent!(region = 1, dc = 1, zone = 1, rack = 4, host = 4, gpu = 8);
576        handle.register_proc(
577            controller.proc_id().clone(),
578            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
579        );
580        handle.register_proc(
581            client.proc_id().clone(),
582            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
583        );
584        handle.register_proc(
585            dest.proc_id().clone(),
586            ext.point(vec![0, 0, 0, 0, 1, 0]).unwrap(),
587        );
588
589        assert_eq!(
590            SimClock.duration_since_start(RealClock.now()),
591            tokio::time::Duration::ZERO
592        );
593        // Fast forward real time to 5 seconds
594        tokio::time::advance(tokio::time::Duration::from_secs(5)).await;
595        {
596            // Send client message
597            client_tx
598                .try_post(
599                    MessageEnvelope::new(
600                        client.clone(),
601                        PortId(dest.clone(), 0),
602                        Serialized::serialize(&456).unwrap(),
603                        Attrs::new(),
604                    ),
605                    oneshot::channel().0,
606                )
607                .unwrap();
608            // Send system message
609            controller_tx
610                .try_post(
611                    MessageEnvelope::new(
612                        controller.clone(),
613                        PortId(dest.clone(), 0),
614                        Serialized::serialize(&456).unwrap(),
615                        Attrs::new(),
616                    ),
617                    oneshot::channel().0,
618                )
619                .unwrap();
620            // Allow some time for simnet to run
621            RealClock.sleep(tokio::time::Duration::from_secs(1)).await;
622        }
623        let recs = simnet::simnet_handle().unwrap().close().await.unwrap();
624        assert_eq!(recs.len(), 2);
625        let end_times = recs.iter().map(|rec| rec.end_at).collect::<Vec<_>>();
626        // client message was delivered at "real" time = 5 seconds
627        assert!(end_times.contains(&5000));
628        // system message was delivered at simulated time = 1 second
629        assert!(end_times.contains(&1000));
630    }
631}