1use std::io::Write;
22use std::net::IpAddr;
23use std::net::Ipv6Addr;
24use std::net::SocketAddr;
25use std::path::PathBuf;
26use std::process::Command;
27use std::time::Duration;
28use std::time::Instant;
29
30use clap::Parser;
31use enum_as_inner::EnumAsInner;
32use hyperactor::channel;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::channel::ChannelRx;
35use hyperactor::channel::ChannelTransport;
36use hyperactor::channel::Rx;
37use hyperactor::channel::TcpMode;
38use hyperactor::channel::Tx;
39use hyperactor::channel::duplex;
40use serde::Deserialize;
41use serde::Serialize;
42use typeuri::Named;
43
44#[derive(Clone, Debug, Named, Serialize, Deserialize, EnumAsInner)]
45enum Message {
46 Hello(ChannelAddr),
48 Echo(serde_multipart::Part),
50}
51
52impl Message {
53 fn payload_len(&self) -> usize {
54 match self {
55 Message::Hello(_) => 0,
56 Message::Echo(part) => part.len(),
57 }
58 }
59}
60
61#[derive(Parser)]
62#[command(about = "Hyperactor channel ping-pong benchmark")]
63struct Cli {
64 #[arg(long)]
66 server: bool,
67
68 #[arg(long)]
70 client: Option<ChannelAddr>,
71
72 #[arg(long, default_value = "tcp")]
74 transport: ChannelTransport,
75
76 #[arg(long, default_value_t = 1000)]
78 iterations: usize,
79
80 #[arg(long, default_value_t = 5555)]
82 port: u16,
83
84 #[arg(long, default_value_t = 100)]
86 message_size: usize,
87
88 #[arg(long)]
90 duplex: bool,
91
92 #[arg(long)]
94 suite: Option<PathBuf>,
95
96 #[arg(long, num_args = 2, value_names = ["BASELINE", "CURRENT"])]
98 diff: Option<Vec<PathBuf>>,
99}
100
101async fn run_server(addr: ChannelAddr) -> anyhow::Result<()> {
102 let (listen_addr, mut rx) = channel::serve::<Message>(addr)?;
103 println!("Server listening on {listen_addr}");
104
105 let client_addr = rx
107 .recv()
108 .await?
109 .into_hello()
110 .map_err(|_| anyhow::anyhow!("expected Hello"))?;
111 let client_tx = channel::dial(client_addr)?;
112
113 loop {
115 let msg = rx.recv().await?;
116 client_tx.post(msg);
117 }
118}
119
120async fn run_client(
121 server_addr: ChannelAddr,
122 num_iterations: usize,
123 message_size: usize,
124) -> anyhow::Result<()> {
125 let server_tx = channel::dial::<Message>(server_addr.clone())?;
126
127 let (client_addr, mut client_rx) =
129 channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone()))?;
130
131 println!("Client connected to {server_addr}");
132
133 server_tx.post(Message::Hello(client_addr));
135
136 let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
137 let message_bytes = message.payload_len();
138
139 for _ in 0..10 {
141 server_tx.post(message.clone());
142 client_rx.recv().await?;
143 }
144
145 println!("Payload size: {message_size} bytes");
146 println!("Starting {num_iterations} ping-pong iterations...");
147
148 let mut latencies = Vec::with_capacity(num_iterations);
149 let mut total_bytes_sent = 0usize;
150 let mut total_bytes_received = 0usize;
151
152 let total_start = Instant::now();
153
154 for i in 0..num_iterations {
155 let start = Instant::now();
156 server_tx.post(message.clone()); total_bytes_sent += message_bytes;
158
159 let response = client_rx.recv().await?;
160 total_bytes_received += response.payload_len();
161
162 latencies.push(start.elapsed());
163
164 if (i + 1) % 100 == 0 {
165 println!("Completed {}/{num_iterations} iterations", i + 1);
166 }
167 }
168
169 let total_elapsed = total_start.elapsed();
170
171 let avg_ms = latencies.iter().sum::<Duration>().as_secs_f64() * 1000.0 / latencies.len() as f64;
172 let min_ms = latencies.iter().min().unwrap().as_secs_f64() * 1000.0;
173 let max_ms = latencies.iter().max().unwrap().as_secs_f64() * 1000.0;
174
175 let total_bytes = total_bytes_sent + total_bytes_received;
176 let total_secs = total_elapsed.as_secs_f64();
177 let bw_bps = total_bytes as f64 / total_secs;
178 let bw_mbps = bw_bps * 8.0 / (1024.0 * 1024.0);
179
180 println!();
181 println!("Results:");
182 println!("Average latency: {avg_ms:.3} ms");
183 println!("Min latency: {min_ms:.3} ms");
184 println!("Max latency: {max_ms:.3} ms");
185 println!("Total iterations: {}", latencies.len());
186 println!("Total time: {total_secs:.3} seconds");
187 println!("Bytes sent: {total_bytes_sent} bytes");
188 println!("Bytes received: {total_bytes_received} bytes");
189 println!("Total bytes transferred: {total_bytes} bytes");
190 println!("Bandwidth: {bw_bps:.0} bytes/sec ({bw_mbps:.2} Mbps)");
191
192 Ok(())
193}
194
195fn run_local_subprocess(
197 transport: &str,
198 port: u16,
199 iterations: usize,
200 message_size: usize,
201) -> anyhow::Result<()> {
202 println!("Running local benchmark (subprocesses)...");
203 println!("Using port {port}");
204
205 let exe = std::env::current_exe()?;
206
207 let mut server = Command::new(&exe)
209 .args([
210 "--transport",
211 transport,
212 "--server",
213 "--port",
214 &port.to_string(),
215 ])
216 .spawn()?;
217
218 std::thread::sleep(Duration::from_millis(200));
220
221 let addr = format!("tcp:[::1]:{port}");
223 let mut client = Command::new(&exe)
224 .args([
225 "--transport",
226 transport,
227 "--client",
228 &addr,
229 "--iterations",
230 &iterations.to_string(),
231 "--message-size",
232 &message_size.to_string(),
233 ])
234 .spawn()?;
235
236 client.wait()?;
238
239 let _ = server.kill();
241 let _ = server.wait();
242
243 Ok(())
244}
245
246async fn run_local_inprocess(
248 transport: ChannelTransport,
249 iterations: usize,
250 message_size: usize,
251) -> anyhow::Result<()> {
252 println!("Running local benchmark (in-process, {transport})...");
253
254 let (server_addr, server_rx) = channel::serve::<Message>(ChannelAddr::any(transport))?;
255 let _server = tokio::spawn(async move {
256 let _ = run_server_loop(server_rx).await;
257 });
258
259 run_client(server_addr, iterations, message_size).await
260}
261
262async fn run_server_loop(mut rx: ChannelRx<Message>) -> anyhow::Result<()> {
264 let client_addr = rx
265 .recv()
266 .await?
267 .into_hello()
268 .map_err(|_| anyhow::anyhow!("expected Hello"))?;
269 let client_tx = channel::dial(client_addr)?;
270 loop {
271 let msg = rx.recv().await?;
272 client_tx.post(msg);
273 }
274}
275
276async fn run_local_duplex(
278 transport: ChannelTransport,
279 iterations: usize,
280 message_size: usize,
281) -> anyhow::Result<()> {
282 println!("Running duplex benchmark (in-process, {transport})...");
283
284 let mut server = duplex::serve::<Message, Message>(ChannelAddr::any(transport))?;
285 let server_addr = server.addr().clone();
286
287 let server_handle = tokio::spawn(async move {
289 let (mut rx, tx) = server.accept().await.unwrap();
290 while let Ok(msg) = rx.recv().await {
291 tx.post(msg);
292 }
293 });
294
295 let (client_tx, mut client_rx) = duplex::dial::<Message, Message>(server_addr.clone()).unwrap();
296
297 println!("Client connected to {server_addr} (duplex)");
298
299 let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
300 let message_bytes = message.payload_len();
301
302 for _ in 0..10 {
304 client_tx.post(message.clone());
305 client_rx.recv().await?;
306 }
307
308 println!("Payload size: {message_size} bytes");
309 println!("Starting {iterations} ping-pong iterations...");
310
311 let mut latencies = Vec::with_capacity(iterations);
312 let mut total_bytes_sent = 0usize;
313 let mut total_bytes_received = 0usize;
314
315 let total_start = Instant::now();
316
317 for i in 0..iterations {
318 let start = Instant::now();
319 client_tx.post(message.clone());
320 total_bytes_sent += message_bytes;
321
322 let response = client_rx.recv().await?;
323 total_bytes_received += response.payload_len();
324
325 latencies.push(start.elapsed());
326
327 if (i + 1) % 100 == 0 {
328 println!("Completed {}/{iterations} iterations", i + 1);
329 }
330 }
331
332 let total_elapsed = total_start.elapsed();
333
334 let avg_ms = latencies.iter().sum::<Duration>().as_secs_f64() * 1000.0 / latencies.len() as f64;
335 let min_ms = latencies.iter().min().unwrap().as_secs_f64() * 1000.0;
336 let max_ms = latencies.iter().max().unwrap().as_secs_f64() * 1000.0;
337
338 let total_bytes = total_bytes_sent + total_bytes_received;
339 let total_secs = total_elapsed.as_secs_f64();
340 let bw_bps = total_bytes as f64 / total_secs;
341 let bw_mbps = bw_bps * 8.0 / (1024.0 * 1024.0);
342
343 println!();
344 println!("Results:");
345 println!("Average latency: {avg_ms:.3} ms");
346 println!("Min latency: {min_ms:.3} ms");
347 println!("Max latency: {max_ms:.3} ms");
348 println!("Total iterations: {}", latencies.len());
349 println!("Total time: {total_secs:.3} seconds");
350 println!("Bytes sent: {total_bytes_sent} bytes");
351 println!("Bytes received: {total_bytes_received} bytes");
352 println!("Total bytes transferred: {total_bytes} bytes");
353 println!("Bandwidth: {bw_bps:.0} bytes/sec ({bw_mbps:.2} Mbps)");
354
355 server_handle.abort();
356 Ok(())
357}
358
359async fn bench_ping_pong_duplex(
361 transport: ChannelTransport,
362 num_iterations: usize,
363 message_size: usize,
364) -> anyhow::Result<Duration> {
365 let mut server = duplex::serve::<Message, Message>(ChannelAddr::any(transport))?;
366 let server_addr = server.addr().clone();
367
368 let server_handle = tokio::spawn(async move {
369 let (mut rx, tx) = server.accept().await.unwrap();
370 while let Ok(msg) = rx.recv().await {
371 tx.post(msg);
372 }
373 });
374
375 let (client_tx, mut client_rx) = duplex::dial::<Message, Message>(server_addr).unwrap();
376
377 let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
378
379 for _ in 0..10 {
381 client_tx.post(message.clone());
382 client_rx.recv().await?;
383 }
384
385 let start = Instant::now();
386 for _ in 0..num_iterations {
387 client_tx.post(message.clone());
388 client_rx.recv().await?;
389 }
390 let elapsed = start.elapsed();
391
392 server_handle.abort();
393 Ok(elapsed)
394}
395
396const SUITE_SIZES: &[usize] = &[100, 1_000, 10_000, 100_000, 1_000_000];
397
398async fn bench_ping_pong(
400 transport: ChannelTransport,
401 num_iterations: usize,
402 message_size: usize,
403) -> anyhow::Result<Duration> {
404 let (server_addr, server_rx) = channel::serve::<Message>(ChannelAddr::any(transport))?;
405 let server_handle = tokio::spawn(async move {
406 let _ = run_server_loop(server_rx).await;
407 });
408
409 let server_tx = channel::dial::<Message>(server_addr)?;
410 let (client_addr, mut client_rx) =
411 channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone()))?;
412 server_tx.post(Message::Hello(client_addr));
413
414 let message = Message::Echo(serde_multipart::Part::from(vec![0u8; message_size]));
415
416 for _ in 0..10 {
418 server_tx.post(message.clone());
419 client_rx.recv().await?;
420 }
421
422 let start = Instant::now();
423 for _ in 0..num_iterations {
424 server_tx.post(message.clone());
425 client_rx.recv().await?;
426 }
427 let elapsed = start.elapsed();
428
429 server_handle.abort();
430 Ok(elapsed)
431}
432
433struct BenchEntry {
435 name: &'static str,
436 transport: ChannelTransport,
437 use_duplex: bool,
438}
439
440async fn run_suite(output: &std::path::Path, iterations: usize) -> anyhow::Result<()> {
442 let entries = vec![
443 BenchEntry {
444 name: "local",
445 transport: ChannelTransport::Local,
446 use_duplex: false,
447 },
448 BenchEntry {
449 name: "unix",
450 transport: ChannelTransport::Unix,
451 use_duplex: false,
452 },
453 BenchEntry {
454 name: "tcp",
455 transport: ChannelTransport::Tcp(TcpMode::Hostname),
456 use_duplex: false,
457 },
458 BenchEntry {
459 name: "duplex-unix",
460 transport: ChannelTransport::Unix,
461 use_duplex: true,
462 },
463 BenchEntry {
464 name: "duplex-tcp",
465 transport: ChannelTransport::Tcp(TcpMode::Hostname),
466 use_duplex: true,
467 },
468 ];
469
470 let mut file = std::fs::File::create(output)?;
471
472 let headers: Vec<String> = SUITE_SIZES.iter().map(|s| s.to_string()).collect();
474 writeln!(file, "transport,{}", headers.join(","))?;
475
476 print!("{:<14}", "transport");
478 for size in SUITE_SIZES {
479 print!("{:>14}", size);
480 }
481 println!();
482
483 for entry in &entries {
484 print!("{:<14}", entry.name);
485 let mut times_ms = Vec::new();
486 for &size in SUITE_SIZES {
487 let dur = if entry.use_duplex {
488 bench_ping_pong_duplex(entry.transport.clone(), iterations, size).await?
489 } else {
490 bench_ping_pong(entry.transport.clone(), iterations, size).await?
491 };
492 let ms = dur.as_secs_f64() * 1000.0;
493 times_ms.push(ms);
494 print!("{:>12.3}ms", ms);
495 }
496 println!();
497
498 let values: Vec<String> = times_ms.iter().map(|t| format!("{t:.3}")).collect();
499 writeln!(file, "{},{}", entry.name, values.join(","))?;
500 }
501
502 eprintln!("\nResults written to {}", output.display());
503 Ok(())
504}
505
506fn run_diff(baseline: &std::path::Path, current: &std::path::Path) -> anyhow::Result<()> {
508 let parse_csv =
509 |path: &std::path::Path| -> anyhow::Result<(Vec<String>, Vec<(String, Vec<f64>)>)> {
510 let content = std::fs::read_to_string(path)?;
511 let mut lines = content.lines();
512 let header = lines.next().ok_or_else(|| anyhow::anyhow!("empty CSV"))?;
513 let sizes: Vec<String> = header
514 .split(',')
515 .skip(1)
516 .map(|s| s.trim().to_string())
517 .collect();
518 let mut rows = Vec::new();
519 for line in lines {
520 if line.is_empty() {
521 continue;
522 }
523 let parts: Vec<&str> = line.split(',').collect();
524 let name = parts[0].trim().to_string();
525 let values: Vec<f64> = parts[1..]
526 .iter()
527 .map(|s| s.trim().parse().unwrap_or(f64::NAN))
528 .collect();
529 rows.push((name, values));
530 }
531 Ok((sizes, rows))
532 };
533
534 let (sizes_a, rows_a) = parse_csv(baseline)?;
535 let (sizes_b, rows_b) = parse_csv(current)?;
536
537 if sizes_a != sizes_b {
538 anyhow::bail!("CSV files have different message size columns");
539 }
540
541 println!("baseline: {}", baseline.display());
542 println!("current: {}", current.display());
543 println!();
544
545 print!("{:<16}", "");
547 for size in &sizes_a {
548 print!("{:>14}", size);
549 }
550 println!();
551
552 for (name, vals_base) in &rows_a {
553 let Some((_, vals_curr)) = rows_b.iter().find(|(n, _)| n == name) else {
554 continue;
555 };
556
557 print!("{:<16}", format!("{name} (base)"));
558 for v in vals_base {
559 print!("{:>12.3}ms", v);
560 }
561 println!();
562
563 print!("{:<16}", format!("{name} (curr)"));
564 for v in vals_curr {
565 print!("{:>12.3}ms", v);
566 }
567 println!();
568
569 print!("{:<16}", format!("{name} (diff)"));
570 for (a, b) in vals_base.iter().zip(vals_curr.iter()) {
571 if *a > 0.0 {
572 let pct = (b - a) / a * 100.0;
573 let sign = if pct >= 0.0 { "+" } else { "" };
574 print!("{:>14}", format!("{sign}{pct:.1}%"));
575 } else {
576 print!("{:>14}", "N/A");
577 }
578 }
579 println!();
580 println!();
581 }
582
583 Ok(())
584}
585
586fn server_listen_addr(transport: &ChannelTransport, port: u16) -> ChannelAddr {
588 match transport {
589 ChannelTransport::Tcp(_) => {
590 ChannelAddr::Tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
591 }
592 _ => ChannelAddr::any(transport.clone()),
593 }
594}
595
596#[tokio::main(flavor = "current_thread")]
597async fn main() -> anyhow::Result<()> {
598 let args = Cli::parse();
599
600 if let Some(path) = args.suite {
601 run_suite(&path, args.iterations).await
602 } else if let Some(paths) = args.diff {
603 run_diff(&paths[0], &paths[1])
604 } else if args.server {
605 let addr = server_listen_addr(&args.transport, args.port);
606 run_server(addr).await
607 } else if let Some(addr) = args.client {
608 run_client(addr, args.iterations, args.message_size).await
609 } else if args.duplex {
610 run_local_duplex(args.transport, args.iterations, args.message_size).await
611 } else {
612 match &args.transport {
613 ChannelTransport::Tcp(_) => {
614 run_local_subprocess("tcp", args.port, args.iterations, args.message_size)
615 }
616 _ => run_local_inprocess(args.transport, args.iterations, args.message_size).await,
617 }
618 }
619}