hyperactor/channel/
sim.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// SimTx contains a way to send through the network.
10// SimRx contains a way to receive messages.
11
12//! Local simulated channel implementation.
13use std::any::Any;
14// send leads to add to network.
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use dashmap::DashMap;
19use regex::Regex;
20
21use super::*;
22use crate::channel;
23use crate::clock::Clock;
24use crate::clock::RealClock;
25use crate::data::Serialized;
26use crate::mailbox::MessageEnvelope;
27use crate::simnet::Dispatcher;
28use crate::simnet::Event;
29use crate::simnet::ScheduledEvent;
30use crate::simnet::SimNetError;
31use crate::simnet::simnet_handle;
32
33lazy_static! {
34    static ref SENDER: SimDispatcher = SimDispatcher::default();
35}
36static SIM_LINK_BUF_SIZE: usize = 256;
37
38/// An address for a simulated channel.
39#[derive(
40    Clone,
41    Debug,
42    PartialEq,
43    Eq,
44    Serialize,
45    Deserialize,
46    Ord,
47    PartialOrd,
48    Hash
49)]
50pub struct SimAddr {
51    src: Option<Box<ChannelAddr>>,
52    /// The address.
53    addr: Box<ChannelAddr>,
54    /// If source is the client
55    client: bool,
56}
57
58impl SimAddr {
59    /// Creates a new SimAddr.
60    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
61    /// Creates a new SimAddr without a source to be served
62    pub fn new(addr: ChannelAddr) -> Result<Self, SimNetError> {
63        Self::new_impl(None, addr, false)
64    }
65
66    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
67    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
68    pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
69        Self::new_impl(Some(Box::new(src)), addr, false)
70    }
71
72    /// Creates a new directional SimAddr meant to convey a channel between two addresses.
73    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
74    fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result<Self, SimNetError> {
75        Self::new_impl(Some(Box::new(src)), addr, true)
76    }
77
78    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
79    fn new_impl(
80        src: Option<Box<ChannelAddr>>,
81        addr: ChannelAddr,
82        client: bool,
83    ) -> Result<Self, SimNetError> {
84        if let ChannelAddr::Sim(_) = &addr {
85            return Err(SimNetError::InvalidArg(format!(
86                "addr cannot be a sim address, found {}",
87                addr
88            )));
89        }
90        Ok(Self {
91            src,
92            addr: Box::new(addr),
93            client,
94        })
95    }
96
97    /// Returns the address.
98    pub fn addr(&self) -> &ChannelAddr {
99        &self.addr
100    }
101
102    /// Returns the source address
103    pub fn src(&self) -> &Option<Box<ChannelAddr>> {
104        &self.src
105    }
106
107    /// The underlying transport we are simulating
108    pub fn transport(&self) -> ChannelTransport {
109        self.addr.transport()
110    }
111}
112
113impl fmt::Display for SimAddr {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match &self.src {
116            None => write!(f, "{}", self.addr),
117            Some(src) => write!(f, "{},{}", src, self.addr),
118        }
119    }
120}
121
122/// Message Event that can be passed around in the simnet.
123#[derive(Debug)]
124pub(crate) struct MessageDeliveryEvent {
125    dest_addr: ChannelAddr,
126    data: Serialized,
127    latency: tokio::time::Duration,
128}
129
130impl MessageDeliveryEvent {
131    /// Creates a new MessageDeliveryEvent.
132    pub fn new(dest_addr: ChannelAddr, data: Serialized, latency: tokio::time::Duration) -> Self {
133        Self {
134            dest_addr,
135            data,
136            latency,
137        }
138    }
139}
140
141#[async_trait]
142impl Event for MessageDeliveryEvent {
143    async fn handle(&self) -> Result<(), SimNetError> {
144        // Send the message to the correct receiver.
145        SENDER
146            .send(self.dest_addr.clone(), self.data.clone())
147            .await?;
148        Ok(())
149    }
150
151    fn duration(&self) -> tokio::time::Duration {
152        self.latency
153    }
154
155    fn summary(&self) -> String {
156        format!("Sending message to {}", self.dest_addr.clone())
157    }
158}
159
160/// Bind a channel address to the simnet. It will register the address as a node in simnet,
161/// and configure default latencies between this node and all other existing nodes.
162pub async fn bind(addr: ChannelAddr) -> anyhow::Result<(), SimNetError> {
163    simnet_handle()?.bind(addr)
164}
165
166/// Returns a simulated channel address that is bound to "any" channel address.
167pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr {
168    ChannelAddr::Sim(SimAddr {
169        src: None,
170        addr: Box::new(ChannelAddr::any(transport)),
171        client: false,
172    })
173}
174
175/// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma.
176#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
177pub fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
178    let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| {
179        ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err))
180    })?;
181
182    let result = re.captures(addr_string);
183    if let Some(caps) = result {
184        let parts = caps
185            .iter()
186            .skip(1)
187            .map(|cap| cap.map_or("", |m| m.as_str()))
188            .filter(|m| !m.is_empty())
189            .collect::<Vec<_>>();
190
191        if parts.iter().any(|part| part.starts_with("sim!")) {
192            return Err(ChannelError::InvalidAddress(addr_string.to_string()));
193        }
194
195        match parts.len() {
196            1 => {
197                let addr = parts[0].parse::<ChannelAddr>()?;
198
199                Ok(ChannelAddr::Sim(SimAddr::new(addr)?))
200            }
201            3 => {
202                let src_addr = parts[0].parse::<ChannelAddr>()?;
203                let addr = parts[2].parse::<ChannelAddr>()?;
204                Ok(ChannelAddr::Sim(if parts[0] == "client" {
205                    SimAddr::new_with_client_src(src_addr, addr)
206                } else {
207                    SimAddr::new_with_src(src_addr, addr)
208                }?))
209            }
210            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
211        }
212    } else {
213        Err(ChannelError::InvalidAddress(addr_string.to_string()))
214    }
215}
216
217impl<M: RemoteMessage> Drop for SimRx<M> {
218    fn drop(&mut self) {
219        // Remove the sender from the dispatchers.
220        SENDER.dispatchers.remove(&self.addr);
221    }
222}
223
224/// Primarily used for dispatching messages to the correct sender.
225#[derive(Debug)]
226pub struct SimDispatcher {
227    dispatchers: DashMap<ChannelAddr, mpsc::Sender<Serialized>>,
228    sender_cache: DashMap<ChannelAddr, Arc<dyn Tx<MessageEnvelope> + Send + Sync>>,
229}
230
231fn create_egress_sender(
232    addr: ChannelAddr,
233) -> anyhow::Result<Arc<dyn Tx<MessageEnvelope> + Send + Sync>> {
234    let tx = channel::dial(addr)?;
235    Ok(Arc::new(tx))
236}
237
238#[async_trait]
239impl Dispatcher<ChannelAddr> for SimDispatcher {
240    async fn send(&self, addr: ChannelAddr, data: Serialized) -> Result<(), SimNetError> {
241        self.dispatchers
242            .get(&addr)
243            .ok_or_else(|| {
244                SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found"))
245            })?
246            .send(data)
247            .await
248            .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into()))
249    }
250}
251
252impl Default for SimDispatcher {
253    fn default() -> Self {
254        Self {
255            dispatchers: DashMap::new(),
256            sender_cache: DashMap::new(),
257        }
258    }
259}
260
261#[derive(Debug)]
262pub(crate) struct SimTx<M: RemoteMessage> {
263    src_addr: Option<ChannelAddr>,
264    dst_addr: ChannelAddr,
265    status: watch::Receiver<TxStatus>, // Default impl. Always reports `Active`.
266    client: bool,
267    _phantom: PhantomData<M>,
268}
269
270#[derive(Debug)]
271pub(crate) struct SimRx<M: RemoteMessage> {
272    /// The destination address, not the full SimAddr.
273    addr: ChannelAddr,
274    rx: mpsc::Receiver<Serialized>,
275    _phantom: PhantomData<M>,
276}
277
278#[async_trait]
279impl<M: RemoteMessage + Any> Tx<M> for SimTx<M> {
280    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
281        let data = match Serialized::serialize(&message) {
282            Ok(data) => data,
283            Err(err) => {
284                if let Some(return_channel) = return_channel {
285                    return_channel
286                        .send(SendError(err.into(), message))
287                        .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
288                }
289
290                return;
291            }
292        };
293
294        let envelope = (&message as &dyn Any)
295            .downcast_ref::<MessageEnvelope>()
296            .expect("RemoteMessage should always be a MessageEnvelope");
297
298        let (sender, dest) = (envelope.sender().clone(), envelope.dest().0.clone());
299
300        match simnet_handle() {
301            Ok(handle) => {
302                let event = Box::new(MessageDeliveryEvent::new(
303                    self.dst_addr.clone(),
304                    data,
305                    handle.sample_latency(sender.proc_id(), dest.proc_id()),
306                ));
307
308                let result = match &self.src_addr {
309                    Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent {
310                        event,
311                        time: RealClock.now(),
312                    }),
313                    _ => handle.send_event(event),
314                };
315                if let Err(err) = result {
316                    if let Some(return_channel) = return_channel {
317                        return_channel
318                            .send(SendError(err.into(), message))
319                            .unwrap_or_else(|m| {
320                                tracing::warn!("failed to deliver SendError: {}", m)
321                            });
322                    }
323                }
324            }
325            Err(err) => {
326                if let Some(return_channel) = return_channel {
327                    return_channel
328                        .send(SendError(err.into(), message))
329                        .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
330                }
331            }
332        }
333    }
334
335    fn addr(&self) -> ChannelAddr {
336        self.dst_addr.clone()
337    }
338
339    fn status(&self) -> &watch::Receiver<TxStatus> {
340        &self.status
341    }
342}
343
344/// Dial a peer and return a transmitter. The transmitter can retrieve from the
345/// network the link latency.
346#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
347pub(crate) fn dial<M: RemoteMessage>(addr: SimAddr) -> Result<SimTx<M>, ChannelError> {
348    // This watch channel always reports active. The sender is
349    // dropped.
350    let (_, status) = watch::channel(TxStatus::Active);
351    let dialer = addr.src().clone().map(|src| *src);
352
353    Ok(SimTx {
354        src_addr: dialer,
355        dst_addr: addr.addr().clone(),
356        status,
357        client: addr.client,
358        _phantom: PhantomData,
359    })
360}
361
362/// Serve a sim channel. Set up the right simulated sender and receivers
363/// The mpsc tx will be used to dispatch messages when it's time while
364/// the mpsc rx will be used by the above applications to handle received messages
365/// like any other channel.
366/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address.
367pub(crate) fn serve<M: RemoteMessage>(
368    sim_addr: SimAddr,
369) -> anyhow::Result<(ChannelAddr, SimRx<M>)> {
370    let (tx, rx) = mpsc::channel::<Serialized>(SIM_LINK_BUF_SIZE);
371    // Add tx to sender dispatch.
372    SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx);
373    // Return the sender.
374    Ok((
375        ChannelAddr::Sim(sim_addr.clone()),
376        SimRx {
377            addr: *sim_addr.addr.clone(),
378            rx,
379            _phantom: PhantomData,
380        },
381    ))
382}
383
384#[async_trait]
385impl<M: RemoteMessage> Rx<M> for SimRx<M> {
386    async fn recv(&mut self) -> Result<M, ChannelError> {
387        let data = self.rx.recv().await.ok_or(ChannelError::Closed)?;
388        data.deserialized().map_err(ChannelError::from)
389    }
390
391    fn addr(&self) -> ChannelAddr {
392        self.addr.clone()
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use std::iter::zip;
399
400    use ndslice::extent;
401
402    use super::*;
403    use crate::PortId;
404    use crate::attrs::Attrs;
405    use crate::clock::Clock;
406    use crate::clock::RealClock;
407    use crate::clock::SimClock;
408    use crate::id;
409    use crate::simnet;
410    use crate::simnet::BetaDistribution;
411    use crate::simnet::LatencyConfig;
412    use crate::simnet::LatencyDistribution;
413    use crate::simnet::start;
414    use crate::simnet::start_with_config;
415
416    #[tokio::test]
417    async fn test_sim_basic() {
418        let dst_ok = vec!["tcp:[::1]:1234", "tcp:127.0.0.1:8080", "local:123"];
419        let srcs_ok = vec!["tcp:[::2]:1234", "tcp:127.0.0.2:8080", "local:124"];
420
421        start();
422        let handle = simnet_handle().unwrap();
423
424        // TODO: New NodeAdd event should do this for you..
425        for addr in dst_ok.iter().chain(srcs_ok.iter()) {
426            // Add to network along with its edges.
427            simnet_handle()
428                .unwrap()
429                .bind(addr.parse::<ChannelAddr>().unwrap())
430                .unwrap();
431        }
432        for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) {
433            let dst_addr = SimAddr::new_with_src(
434                src_addr.parse::<ChannelAddr>().unwrap(),
435                dst_addr.parse::<ChannelAddr>().unwrap(),
436            )
437            .unwrap();
438
439            let (_, mut rx) = sim::serve::<MessageEnvelope>(dst_addr.clone()).unwrap();
440            let tx = sim::dial::<MessageEnvelope>(dst_addr).unwrap();
441            let data = Serialized::serialize(&456).unwrap();
442            let sender = id!(world[0].hello);
443            let dest = id!(world[1].hello);
444            let ext = extent!(region = 1, dc = 1, rack = 4, host = 4, gpu = 8);
445            handle.register_proc(
446                sender.proc_id().clone(),
447                ext.point(vec![0, 0, 0, 0, 0]).unwrap(),
448            );
449            handle.register_proc(
450                dest.proc_id().clone(),
451                ext.point(vec![0, 0, 0, 1, 0]).unwrap(),
452            );
453
454            let msg = MessageEnvelope::new(sender, PortId(dest, 0), data.clone(), Attrs::new());
455            tx.post(msg);
456            assert_eq!(*rx.recv().await.unwrap().data(), data);
457        }
458
459        let records = sim::simnet_handle().unwrap().close().await.unwrap();
460        eprintln!("records: {:#?}", records);
461    }
462
463    #[tokio::test]
464    async fn test_invalid_sim_addr() {
465        let src = "sim!src";
466        let dst = "sim!dst";
467        let sim_addr = format!("{},{}", src, dst);
468        let result = parse(&sim_addr);
469        assert!(matches!(result, Err(ChannelError::InvalidAddress(_))));
470    }
471
472    #[tokio::test]
473    async fn test_parse_sim_addr() {
474        let sim_addr = "sim!unix:@dst";
475        let result = sim_addr.parse();
476        assert!(result.is_ok());
477        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
478            panic!("Expected a sim address");
479        };
480        assert!(sim_addr.src().is_none());
481        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
482
483        let sim_addr = "sim!unix:@src,unix:@dst";
484        let result = sim_addr.parse();
485        assert!(result.is_ok());
486        let ChannelAddr::Sim(sim_addr) = result.unwrap() else {
487            panic!("Expected a sim address");
488        };
489        assert!(sim_addr.src().is_some());
490        assert_eq!(sim_addr.addr().to_string(), "unix:@dst");
491    }
492
493    #[tokio::test]
494    async fn test_realtime_frontier() {
495        tokio::time::pause();
496        // 1 second of latency
497        start_with_config(LatencyConfig {
498            inter_zone_distribution: LatencyDistribution::Beta(
499                BetaDistribution::new(
500                    tokio::time::Duration::from_millis(100),
501                    tokio::time::Duration::from_millis(100),
502                    1.0,
503                    1.0,
504                )
505                .unwrap(),
506            ),
507            ..Default::default()
508        });
509
510        let sim_addr = SimAddr::new("unix:@dst".parse::<ChannelAddr>().unwrap()).unwrap();
511        let sim_addr_with_src = SimAddr::new_with_src(
512            "unix:@src".parse::<ChannelAddr>().unwrap(),
513            "unix:@dst".parse::<ChannelAddr>().unwrap(),
514        )
515        .unwrap();
516        let (_, mut rx) = sim::serve::<MessageEnvelope>(sim_addr.clone()).unwrap();
517        let tx = sim::dial::<MessageEnvelope>(sim_addr_with_src).unwrap();
518
519        let controller = id!(world[0].controller);
520        let dest = id!(world[1].dest);
521        let handle = simnet::simnet_handle().unwrap();
522
523        let ext = extent!(region = 1, dc = 1, zone = 2, rack = 4, host = 4, gpu = 8);
524        handle.register_proc(
525            controller.proc_id().clone(),
526            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
527        );
528        handle.register_proc(
529            dest.proc_id().clone(),
530            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
531        );
532
533        // This message will be delievered at simulator time = 100 seconds
534        tx.post(MessageEnvelope::new(
535            controller,
536            PortId(dest, 0),
537            Serialized::serialize(&456).unwrap(),
538            Attrs::new(),
539        ));
540        {
541            // Allow simnet to run
542            tokio::task::yield_now().await;
543            // Messages have not been receive since 10 seconds have not elapsed
544            assert!(rx.rx.try_recv().is_err());
545        }
546        // Advance "real" time by 100 seconds
547        tokio::time::advance(tokio::time::Duration::from_secs(100)).await;
548        {
549            // Allow some time for simnet to run
550            tokio::task::yield_now().await;
551            // Messages are received
552            assert!(rx.rx.try_recv().is_ok());
553        }
554    }
555
556    #[tokio::test]
557    async fn test_client_message_scheduled_realtime() {
558        tokio::time::pause();
559        // 1 second of latency
560        start_with_config(LatencyConfig {
561            inter_zone_distribution: LatencyDistribution::Beta(
562                BetaDistribution::new(
563                    tokio::time::Duration::from_millis(1000),
564                    tokio::time::Duration::from_millis(1000),
565                    1.0,
566                    1.0,
567                )
568                .unwrap(),
569            ),
570            ..Default::default()
571        });
572
573        let controller_to_dst = SimAddr::new_with_src(
574            "unix:@controller".parse::<ChannelAddr>().unwrap(),
575            "unix:@dst".parse::<ChannelAddr>().unwrap(),
576        )
577        .unwrap();
578
579        let controller_tx = sim::dial::<MessageEnvelope>(controller_to_dst.clone()).unwrap();
580
581        let client_to_dst = SimAddr::new_with_client_src(
582            "unix:@client".parse::<ChannelAddr>().unwrap(),
583            "unix:@dst".parse::<ChannelAddr>().unwrap(),
584        )
585        .unwrap();
586        let client_tx = sim::dial::<MessageEnvelope>(client_to_dst).unwrap();
587
588        let controller = id!(world[0].controller);
589        let dest = id!(world[1].dest);
590        let client = id!(world[2].client);
591
592        let handle = simnet::simnet_handle().unwrap();
593        let ext = extent!(region = 1, dc = 1, zone = 2, rack = 4, host = 4, gpu = 8);
594        handle.register_proc(
595            controller.proc_id().clone(),
596            ext.point(vec![0, 0, 1, 0, 0, 0]).unwrap(),
597        );
598        handle.register_proc(
599            client.proc_id().clone(),
600            ext.point(vec![0, 0, 0, 0, 0, 0]).unwrap(),
601        );
602        handle.register_proc(
603            dest.proc_id().clone(),
604            ext.point(vec![0, 0, 0, 0, 1, 0]).unwrap(),
605        );
606
607        assert_eq!(
608            SimClock.duration_since_start(RealClock.now()),
609            tokio::time::Duration::ZERO
610        );
611        // Fast forward real time to 5 seconds
612        tokio::time::advance(tokio::time::Duration::from_secs(5)).await;
613        {
614            // Send client message
615            client_tx.post(MessageEnvelope::new(
616                client.clone(),
617                PortId(dest.clone(), 0),
618                Serialized::serialize(&456).unwrap(),
619                Attrs::new(),
620            ));
621            // Send system message
622            controller_tx.post(MessageEnvelope::new(
623                controller.clone(),
624                PortId(dest.clone(), 0),
625                Serialized::serialize(&456).unwrap(),
626                Attrs::new(),
627            ));
628            // Allow some time for simnet to run
629            RealClock.sleep(tokio::time::Duration::from_secs(1)).await;
630        }
631        let recs = simnet::simnet_handle().unwrap().close().await.unwrap();
632        assert_eq!(recs.len(), 2);
633        let end_times = recs.iter().map(|rec| rec.end_at).collect::<Vec<_>>();
634        // client message was delivered at "real" time = 5 seconds
635        assert!(end_times.contains(&5000));
636        // system message was delivered at simulated time = 1 second
637        assert!(end_times.contains(&1000));
638    }
639}