hyperactor_example_pingpong/
pingpong.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Ping-pong latency benchmark using hyperactor channels.
10//!
11//! Equivalent to the ZMQ-based Python ping-pong benchmark, but uses hyperactor
12//! channels directly. Supports TCP, Unix, and Local transports, as well as
13//! duplex mode (single connection, both directions).
14//!
15//! Usage:
16//!   (no args)                        run locally (subprocesses for TCP, in-process otherwise)
17//!   --duplex                         run using duplex channels (in-process)
18//!   --server [--transport tcp]       run as echo server
19//!   --client <ADDR>                  run as client connecting to ADDR (e.g. tcp:[::1]:5555)
20
21use std::io::Write;
22use std::net::IpAddr;
23use std::net::Ipv6Addr;
24use std::net::SocketAddr;
25use std::path::PathBuf;
26use std::process::Command;
27use std::time::Duration;
28use std::time::Instant;
29
30use clap::Parser;
31use enum_as_inner::EnumAsInner;
32use hyperactor::channel;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::channel::ChannelRx;
35use hyperactor::channel::ChannelTransport;
36use hyperactor::channel::Rx;
37use hyperactor::channel::TcpMode;
38use hyperactor::channel::Tx;
39use hyperactor::channel::duplex;
40use serde::Deserialize;
41use serde::Serialize;
42use typeuri::Named;
43
44#[derive(Clone, Debug, Named, Serialize, Deserialize, EnumAsInner)]
45enum Message {
46    /// Initial handshake carrying the client's reply address.
47    Hello(ChannelAddr),
48    /// Payload echoed back and forth.
49    Echo(serde_multipart::Part),
50}
51
52impl Message {
53    fn payload_len(&self) -> usize {
54        match self {
55            Message::Hello(_) => 0,
56            Message::Echo(part) => part.len(),
57        }
58    }
59}
60
61#[derive(Parser)]
62#[command(about = "Hyperactor channel ping-pong benchmark")]
63struct Cli {
64    /// Run as server.
65    #[arg(long)]
66    server: bool,
67
68    /// Run as client connecting to the given channel address (e.g. "tcp:[::1]:5555").
69    #[arg(long)]
70    client: Option<ChannelAddr>,
71
72    /// Channel transport.
73    #[arg(long, default_value = "tcp")]
74    transport: ChannelTransport,
75
76    /// Number of ping-pong iterations.
77    #[arg(long, default_value_t = 1000)]
78    iterations: usize,
79
80    /// Port for TCP transport (server and local mode).
81    #[arg(long, default_value_t = 5555)]
82    port: u16,
83
84    /// Payload size in bytes.
85    #[arg(long, default_value_t = 100)]
86    message_size: usize,
87
88    /// Use duplex channels (single connection, both directions).
89    #[arg(long)]
90    duplex: bool,
91
92    /// Run a full benchmark suite; write CSV to the given path.
93    #[arg(long)]
94    suite: Option<PathBuf>,
95
96    /// Compare two suite CSV files (baseline, then current).
97    #[arg(long, num_args = 2, value_names = ["BASELINE", "CURRENT"])]
98    diff: Option<Vec<PathBuf>>,
99}
100
101async fn run_server(addr: ChannelAddr) -> anyhow::Result<()> {
102    let (listen_addr, mut rx) = channel::serve::<Message>(addr)?;
103    println!("Server listening on {listen_addr}");
104
105    // First message is a Hello carrying the client's reply address.
106    let client_addr = rx
107        .recv()
108        .await?
109        .into_hello()
110        .map_err(|_| anyhow::anyhow!("expected Hello"))?;
111    let client_tx = channel::dial(client_addr)?;
112
113    // Echo loop.
114    loop {
115        let msg = rx.recv().await?;
116        client_tx.post(msg);
117    }
118}
119
120async fn run_client(
121    server_addr: ChannelAddr,
122    num_iterations: usize,
123    message_size: usize,
124) -> anyhow::Result<()> {
125    let server_tx = channel::dial::<Message>(server_addr.clone())?;
126
127    // Open our own channel for receiving replies.
128    let (client_addr, mut client_rx) =
129        channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone()))?;
130
131    println!("Client connected to {server_addr}");
132
133    // Tell the server where to send replies.
134    server_tx.post(Message::Hello(client_addr));
135
136    let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
137    let message_bytes = message.payload_len();
138
139    // Warmup.
140    for _ in 0..10 {
141        server_tx.post(message.clone());
142        client_rx.recv().await?;
143    }
144
145    println!("Payload size: {message_size} bytes");
146    println!("Starting {num_iterations} ping-pong iterations...");
147
148    let mut latencies = Vec::with_capacity(num_iterations);
149    let mut total_bytes_sent = 0usize;
150    let mut total_bytes_received = 0usize;
151
152    let total_start = Instant::now();
153
154    for i in 0..num_iterations {
155        let start = Instant::now();
156        server_tx.post(message.clone()); // cheap: Part is Arc-backed
157        total_bytes_sent += message_bytes;
158
159        let response = client_rx.recv().await?;
160        total_bytes_received += response.payload_len();
161
162        latencies.push(start.elapsed());
163
164        if (i + 1) % 100 == 0 {
165            println!("Completed {}/{num_iterations} iterations", i + 1);
166        }
167    }
168
169    let total_elapsed = total_start.elapsed();
170
171    let avg_ms = latencies.iter().sum::<Duration>().as_secs_f64() * 1000.0 / latencies.len() as f64;
172    let min_ms = latencies.iter().min().unwrap().as_secs_f64() * 1000.0;
173    let max_ms = latencies.iter().max().unwrap().as_secs_f64() * 1000.0;
174
175    let total_bytes = total_bytes_sent + total_bytes_received;
176    let total_secs = total_elapsed.as_secs_f64();
177    let bw_bps = total_bytes as f64 / total_secs;
178    let bw_mbps = bw_bps * 8.0 / (1024.0 * 1024.0);
179
180    println!();
181    println!("Results:");
182    println!("Average latency: {avg_ms:.3} ms");
183    println!("Min latency: {min_ms:.3} ms");
184    println!("Max latency: {max_ms:.3} ms");
185    println!("Total iterations: {}", latencies.len());
186    println!("Total time: {total_secs:.3} seconds");
187    println!("Bytes sent: {total_bytes_sent} bytes");
188    println!("Bytes received: {total_bytes_received} bytes");
189    println!("Total bytes transferred: {total_bytes} bytes");
190    println!("Bandwidth: {bw_bps:.0} bytes/sec ({bw_mbps:.2} Mbps)");
191
192    Ok(())
193}
194
195/// TCP local mode: spawn server and client as separate OS processes.
196fn run_local_subprocess(
197    transport: &str,
198    port: u16,
199    iterations: usize,
200    message_size: usize,
201) -> anyhow::Result<()> {
202    println!("Running local benchmark (subprocesses)...");
203    println!("Using port {port}");
204
205    let exe = std::env::current_exe()?;
206
207    // Start server subprocess.
208    let mut server = Command::new(&exe)
209        .args([
210            "--transport",
211            transport,
212            "--server",
213            "--port",
214            &port.to_string(),
215        ])
216        .spawn()?;
217
218    // Give the server time to bind.
219    std::thread::sleep(Duration::from_millis(200));
220
221    // Start client subprocess.
222    let addr = format!("tcp:[::1]:{port}");
223    let mut client = Command::new(&exe)
224        .args([
225            "--transport",
226            transport,
227            "--client",
228            &addr,
229            "--iterations",
230            &iterations.to_string(),
231            "--message-size",
232            &message_size.to_string(),
233        ])
234        .spawn()?;
235
236    // Wait for client to finish.
237    client.wait()?;
238
239    // Terminate server.
240    let _ = server.kill();
241    let _ = server.wait();
242
243    Ok(())
244}
245
246/// Non-TCP local mode: run server and client as tokio tasks in-process.
247async fn run_local_inprocess(
248    transport: ChannelTransport,
249    iterations: usize,
250    message_size: usize,
251) -> anyhow::Result<()> {
252    println!("Running local benchmark (in-process, {transport})...");
253
254    let (server_addr, server_rx) = channel::serve::<Message>(ChannelAddr::any(transport))?;
255    let _server = tokio::spawn(async move {
256        let _ = run_server_loop(server_rx).await;
257    });
258
259    run_client(server_addr, iterations, message_size).await
260}
261
262/// Echo loop used by the in-process server task.
263async fn run_server_loop(mut rx: ChannelRx<Message>) -> anyhow::Result<()> {
264    let client_addr = rx
265        .recv()
266        .await?
267        .into_hello()
268        .map_err(|_| anyhow::anyhow!("expected Hello"))?;
269    let client_tx = channel::dial(client_addr)?;
270    loop {
271        let msg = rx.recv().await?;
272        client_tx.post(msg);
273    }
274}
275
276/// Run a duplex benchmark in-process.
277async fn run_local_duplex(
278    transport: ChannelTransport,
279    iterations: usize,
280    message_size: usize,
281) -> anyhow::Result<()> {
282    println!("Running duplex benchmark (in-process, {transport})...");
283
284    let mut server = duplex::serve::<Message, Message>(ChannelAddr::any(transport))?;
285    let server_addr = server.addr().clone();
286
287    // Server task: accept one link, echo back.
288    let server_handle = tokio::spawn(async move {
289        let (mut rx, tx) = server.accept().await.unwrap();
290        while let Ok(msg) = rx.recv().await {
291            tx.post(msg);
292        }
293    });
294
295    let (client_tx, mut client_rx) = duplex::dial::<Message, Message>(server_addr.clone()).unwrap();
296
297    println!("Client connected to {server_addr} (duplex)");
298
299    let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
300    let message_bytes = message.payload_len();
301
302    // Warmup.
303    for _ in 0..10 {
304        client_tx.post(message.clone());
305        client_rx.recv().await?;
306    }
307
308    println!("Payload size: {message_size} bytes");
309    println!("Starting {iterations} ping-pong iterations...");
310
311    let mut latencies = Vec::with_capacity(iterations);
312    let mut total_bytes_sent = 0usize;
313    let mut total_bytes_received = 0usize;
314
315    let total_start = Instant::now();
316
317    for i in 0..iterations {
318        let start = Instant::now();
319        client_tx.post(message.clone());
320        total_bytes_sent += message_bytes;
321
322        let response = client_rx.recv().await?;
323        total_bytes_received += response.payload_len();
324
325        latencies.push(start.elapsed());
326
327        if (i + 1) % 100 == 0 {
328            println!("Completed {}/{iterations} iterations", i + 1);
329        }
330    }
331
332    let total_elapsed = total_start.elapsed();
333
334    let avg_ms = latencies.iter().sum::<Duration>().as_secs_f64() * 1000.0 / latencies.len() as f64;
335    let min_ms = latencies.iter().min().unwrap().as_secs_f64() * 1000.0;
336    let max_ms = latencies.iter().max().unwrap().as_secs_f64() * 1000.0;
337
338    let total_bytes = total_bytes_sent + total_bytes_received;
339    let total_secs = total_elapsed.as_secs_f64();
340    let bw_bps = total_bytes as f64 / total_secs;
341    let bw_mbps = bw_bps * 8.0 / (1024.0 * 1024.0);
342
343    println!();
344    println!("Results:");
345    println!("Average latency: {avg_ms:.3} ms");
346    println!("Min latency: {min_ms:.3} ms");
347    println!("Max latency: {max_ms:.3} ms");
348    println!("Total iterations: {}", latencies.len());
349    println!("Total time: {total_secs:.3} seconds");
350    println!("Bytes sent: {total_bytes_sent} bytes");
351    println!("Bytes received: {total_bytes_received} bytes");
352    println!("Total bytes transferred: {total_bytes} bytes");
353    println!("Bandwidth: {bw_bps:.0} bytes/sec ({bw_mbps:.2} Mbps)");
354
355    server_handle.abort();
356    Ok(())
357}
358
359/// Quiet duplex ping-pong benchmark for suite mode, returning total elapsed time.
360async fn bench_ping_pong_duplex(
361    transport: ChannelTransport,
362    num_iterations: usize,
363    message_size: usize,
364) -> anyhow::Result<Duration> {
365    let mut server = duplex::serve::<Message, Message>(ChannelAddr::any(transport))?;
366    let server_addr = server.addr().clone();
367
368    let server_handle = tokio::spawn(async move {
369        let (mut rx, tx) = server.accept().await.unwrap();
370        while let Ok(msg) = rx.recv().await {
371            tx.post(msg);
372        }
373    });
374
375    let (client_tx, mut client_rx) = duplex::dial::<Message, Message>(server_addr).unwrap();
376
377    let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
378
379    // Warmup.
380    for _ in 0..10 {
381        client_tx.post(message.clone());
382        client_rx.recv().await?;
383    }
384
385    let start = Instant::now();
386    for _ in 0..num_iterations {
387        client_tx.post(message.clone());
388        client_rx.recv().await?;
389    }
390    let elapsed = start.elapsed();
391
392    server_handle.abort();
393    Ok(elapsed)
394}
395
396const SUITE_SIZES: &[usize] = &[100, 1_000, 10_000, 100_000, 1_000_000];
397
398/// Run a single in-process ping-pong benchmark, returning total elapsed time.
399async fn bench_ping_pong(
400    transport: ChannelTransport,
401    num_iterations: usize,
402    message_size: usize,
403) -> anyhow::Result<Duration> {
404    let (server_addr, server_rx) = channel::serve::<Message>(ChannelAddr::any(transport))?;
405    let server_handle = tokio::spawn(async move {
406        let _ = run_server_loop(server_rx).await;
407    });
408
409    let server_tx = channel::dial::<Message>(server_addr)?;
410    let (client_addr, mut client_rx) =
411        channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone()))?;
412    server_tx.post(Message::Hello(client_addr));
413
414    let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
415
416    // Warmup.
417    for _ in 0..10 {
418        server_tx.post(message.clone());
419        client_rx.recv().await?;
420    }
421
422    let start = Instant::now();
423    for _ in 0..num_iterations {
424        server_tx.post(message.clone());
425        client_rx.recv().await?;
426    }
427    let elapsed = start.elapsed();
428
429    server_handle.abort();
430    Ok(elapsed)
431}
432
433/// Benchmark entry: name, transport, and whether to use duplex.
434struct BenchEntry {
435    name: &'static str,
436    transport: ChannelTransport,
437    use_duplex: bool,
438}
439
440/// Run a benchmark suite across transports and message sizes, writing CSV output.
441async fn run_suite(output: &std::path::Path, iterations: usize) -> anyhow::Result<()> {
442    let entries = vec![
443        BenchEntry {
444            name: "local",
445            transport: ChannelTransport::Local,
446            use_duplex: false,
447        },
448        BenchEntry {
449            name: "unix",
450            transport: ChannelTransport::Unix,
451            use_duplex: false,
452        },
453        BenchEntry {
454            name: "tcp",
455            transport: ChannelTransport::Tcp(TcpMode::Hostname),
456            use_duplex: false,
457        },
458        BenchEntry {
459            name: "duplex-unix",
460            transport: ChannelTransport::Unix,
461            use_duplex: true,
462        },
463        BenchEntry {
464            name: "duplex-tcp",
465            transport: ChannelTransport::Tcp(TcpMode::Hostname),
466            use_duplex: true,
467        },
468    ];
469
470    let mut file = std::fs::File::create(output)?;
471
472    // CSV header.
473    let headers: Vec<String> = SUITE_SIZES.iter().map(|s| s.to_string()).collect();
474    writeln!(file, "transport,{}", headers.join(","))?;
475
476    // Table header to stdout.
477    print!("{:<14}", "transport");
478    for size in SUITE_SIZES {
479        print!("{:>14}", size);
480    }
481    println!();
482
483    for entry in &entries {
484        print!("{:<14}", entry.name);
485        let mut times_ms = Vec::new();
486        for &size in SUITE_SIZES {
487            let dur = if entry.use_duplex {
488                bench_ping_pong_duplex(entry.transport.clone(), iterations, size).await?
489            } else {
490                bench_ping_pong(entry.transport.clone(), iterations, size).await?
491            };
492            let ms = dur.as_secs_f64() * 1000.0;
493            times_ms.push(ms);
494            print!("{:>12.3}ms", ms);
495        }
496        println!();
497
498        let values: Vec<String> = times_ms.iter().map(|t| format!("{t:.3}")).collect();
499        writeln!(file, "{},{}", entry.name, values.join(","))?;
500    }
501
502    eprintln!("\nResults written to {}", output.display());
503    Ok(())
504}
505
506/// Compare two suite CSV files and print a delta table.
507fn run_diff(baseline: &std::path::Path, current: &std::path::Path) -> anyhow::Result<()> {
508    let parse_csv =
509        |path: &std::path::Path| -> anyhow::Result<(Vec<String>, Vec<(String, Vec<f64>)>)> {
510            let content = std::fs::read_to_string(path)?;
511            let mut lines = content.lines();
512            let header = lines.next().ok_or_else(|| anyhow::anyhow!("empty CSV"))?;
513            let sizes: Vec<String> = header
514                .split(',')
515                .skip(1)
516                .map(|s| s.trim().to_string())
517                .collect();
518            let mut rows = Vec::new();
519            for line in lines {
520                if line.is_empty() {
521                    continue;
522                }
523                let parts: Vec<&str> = line.split(',').collect();
524                let name = parts[0].trim().to_string();
525                let values: Vec<f64> = parts[1..]
526                    .iter()
527                    .map(|s| s.trim().parse().unwrap_or(f64::NAN))
528                    .collect();
529                rows.push((name, values));
530            }
531            Ok((sizes, rows))
532        };
533
534    let (sizes_a, rows_a) = parse_csv(baseline)?;
535    let (sizes_b, rows_b) = parse_csv(current)?;
536
537    if sizes_a != sizes_b {
538        anyhow::bail!("CSV files have different message size columns");
539    }
540
541    println!("baseline: {}", baseline.display());
542    println!("current:  {}", current.display());
543    println!();
544
545    // Header.
546    print!("{:<16}", "");
547    for size in &sizes_a {
548        print!("{:>14}", size);
549    }
550    println!();
551
552    for (name, vals_base) in &rows_a {
553        let Some((_, vals_curr)) = rows_b.iter().find(|(n, _)| n == name) else {
554            continue;
555        };
556
557        print!("{:<16}", format!("{name} (base)"));
558        for v in vals_base {
559            print!("{:>12.3}ms", v);
560        }
561        println!();
562
563        print!("{:<16}", format!("{name} (curr)"));
564        for v in vals_curr {
565            print!("{:>12.3}ms", v);
566        }
567        println!();
568
569        print!("{:<16}", format!("{name} (diff)"));
570        for (a, b) in vals_base.iter().zip(vals_curr.iter()) {
571            if *a > 0.0 {
572                let pct = (b - a) / a * 100.0;
573                let sign = if pct >= 0.0 { "+" } else { "" };
574                print!("{:>14}", format!("{sign}{pct:.1}%"));
575            } else {
576                print!("{:>14}", "N/A");
577            }
578        }
579        println!();
580        println!();
581    }
582
583    Ok(())
584}
585
586/// Build a server listen address from the transport and port.
587fn server_listen_addr(transport: &ChannelTransport, port: u16) -> ChannelAddr {
588    match transport {
589        ChannelTransport::Tcp(_) => {
590            ChannelAddr::Tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
591        }
592        _ => ChannelAddr::any(transport.clone()),
593    }
594}
595
596#[tokio::main(flavor = "current_thread")]
597async fn main() -> anyhow::Result<()> {
598    let args = Cli::parse();
599
600    if let Some(path) = args.suite {
601        run_suite(&path, args.iterations).await
602    } else if let Some(paths) = args.diff {
603        run_diff(&paths[0], &paths[1])
604    } else if args.server {
605        let addr = server_listen_addr(&args.transport, args.port);
606        run_server(addr).await
607    } else if let Some(addr) = args.client {
608        run_client(addr, args.iterations, args.message_size).await
609    } else if args.duplex {
610        run_local_duplex(args.transport, args.iterations, args.message_size).await
611    } else {
612        match &args.transport {
613            ChannelTransport::Tcp(_) => {
614                run_local_subprocess("tcp", args.port, args.iterations, args.message_size)
615            }
616            _ => run_local_inprocess(args.transport, args.iterations, args.message_size).await,
617        }
618    }
619}