1use std::collections::HashMap;
10use std::fmt;
11use std::path::Path;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::Context as TaskContext;
16use std::task::Poll;
17use std::time::Duration;
18use std::time::SystemTime;
19
20use anyhow::Result;
21use async_trait::async_trait;
22use chrono::DateTime;
23use chrono::Local;
24use hyperactor::Actor;
25use hyperactor::ActorRef;
26use hyperactor::Bind;
27use hyperactor::Context;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::Named;
32use hyperactor::OncePortRef;
33use hyperactor::RefClient;
34use hyperactor::Unbind;
35use hyperactor::channel;
36use hyperactor::channel::ChannelAddr;
37use hyperactor::channel::ChannelRx;
38use hyperactor::channel::ChannelTransport;
39use hyperactor::channel::ChannelTx;
40use hyperactor::channel::Rx;
41use hyperactor::channel::Tx;
42use hyperactor::channel::TxStatus;
43use hyperactor::clock::Clock;
44use hyperactor::clock::RealClock;
45use hyperactor::data::Serialized;
46use hyperactor_telemetry::env;
47use hyperactor_telemetry::log_file_path;
48use serde::Deserialize;
49use serde::Serialize;
50use tokio::io;
51use tokio::sync::Mutex;
52use tokio::sync::watch::Receiver;
53
54use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;
55
56mod line_prefixing_writer;
57use line_prefixing_writer::LinePrefixingWriter;
58
59const DEFAULT_AGGREGATE_WINDOW_SEC: u64 = 5;
60
61fn levenshtein_distance(left: &str, right: &str) -> usize {
63 let left_chars: Vec<char> = left.chars().collect();
64 let right_chars: Vec<char> = right.chars().collect();
65
66 let left_len = left_chars.len();
67 let right_len = right_chars.len();
68
69 if left_len == 0 {
71 return right_len;
72 }
73 if right_len == 0 {
74 return left_len;
75 }
76
77 let mut matrix = vec![vec![0; right_len + 1]; left_len + 1];
79
80 for (i, row) in matrix.iter_mut().enumerate().take(left_len + 1) {
82 row[0] = i;
83 }
84 for (j, cell) in matrix[0].iter_mut().enumerate().take(right_len + 1) {
85 *cell = j;
86 }
87
88 for i in 1..=left_len {
90 for j in 1..=right_len {
91 let cost = if left_chars[i - 1] == right_chars[j - 1] {
92 0
93 } else {
94 1
95 };
96
97 matrix[i][j] = std::cmp::min(
98 std::cmp::min(
99 matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
102 matrix[i - 1][j - 1] + cost, );
104 }
105 }
106
107 matrix[left_len][right_len]
109}
110
111fn normalized_edit_distance(left: &str, right: &str) -> f64 {
113 let distance = levenshtein_distance(left, right) as f64;
114 let max_len = std::cmp::max(left.len(), right.len()) as f64;
115
116 if max_len == 0.0 {
117 0.0 } else {
119 distance / max_len
120 }
121}
122
123#[derive(Debug, Clone)]
124struct LogLine {
126 content: String,
127 pub count: u64,
128}
129
130impl LogLine {
131 fn new(content: String) -> Self {
132 Self { content, count: 1 }
133 }
134}
135
136impl fmt::Display for LogLine {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 write!(
139 f,
140 "\x1b[33m[{} similar log lines]\x1b[0m {}",
141 self.count, self.content
142 )
143 }
144}
145
146#[derive(Debug, Clone)]
147struct Aggregator {
150 lines: Vec<LogLine>,
151 start_time: SystemTime,
152 similarity_threshold: f64, }
154
155impl Aggregator {
156 fn new() -> Self {
157 Self::new_with_threshold(0.15)
159 }
160
161 fn new_with_threshold(threshold: f64) -> Self {
162 Aggregator {
163 lines: vec![],
164 start_time: RealClock.system_time_now(),
165 similarity_threshold: threshold,
166 }
167 }
168
169 fn reset(&mut self) {
170 self.lines.clear();
171 self.start_time = RealClock.system_time_now();
172 }
173
174 fn add_line(&mut self, line: &str) -> anyhow::Result<()> {
175 let mut best_match_idx = None;
177 let mut best_similarity = f64::MAX;
178
179 for (idx, existing_line) in self.lines.iter().enumerate() {
180 let distance = normalized_edit_distance(&existing_line.content, line);
181
182 if distance < best_similarity && distance < self.similarity_threshold {
184 best_match_idx = Some(idx);
185 best_similarity = distance;
186 }
187 }
188
189 if let Some(idx) = best_match_idx {
191 self.lines[idx].count += 1;
192 } else {
193 self.lines.push(LogLine::new(line.to_string()));
195 }
196
197 Ok(())
198 }
199
200 fn is_empty(&self) -> bool {
201 self.lines.is_empty()
202 }
203}
204
205fn format_system_time(time: SystemTime) -> String {
207 let datetime: DateTime<Local> = time.into();
208 datetime.format("%Y-%m-%d %H:%M:%S").to_string()
209}
210
211impl fmt::Display for Aggregator {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 let start_time_str = format_system_time(self.start_time);
215
216 let current_time = RealClock.system_time_now();
218 let end_time_str = format_system_time(current_time);
219
220 writeln!(
222 f,
223 "\x1b[36m>>> Aggregated Logs ({}) >>>\x1b[0m",
224 start_time_str
225 )?;
226
227 for line in self.lines.iter() {
229 writeln!(f, "{}", line)?;
230 }
231 writeln!(
232 f,
233 "\x1b[36m<<< Aggregated Logs ({}) <<<\x1b[0m",
234 end_time_str
235 )?;
236 Ok(())
237 }
238}
239
240#[derive(
242 Debug,
243 Clone,
244 Serialize,
245 Deserialize,
246 Named,
247 Handler,
248 HandleClient,
249 RefClient
250)]
251pub enum LogMessage {
252 Log {
254 hostname: String,
256 pid: u32,
258 output_target: OutputTarget,
260 payload: Serialized,
262 },
263
264 Flush {
266 sync_version: Option<u64>,
269 },
270}
271
272#[derive(
274 Debug,
275 Clone,
276 Serialize,
277 Deserialize,
278 Named,
279 Handler,
280 HandleClient,
281 RefClient
282)]
283pub enum LogClientMessage {
284 SetAggregate {
285 aggregate_window_sec: Option<u64>,
287 },
288
289 StartSyncFlush {
291 expected_procs: usize,
293 reply: OncePortRef<()>,
295 version: OncePortRef<u64>,
297 },
298}
299
300#[async_trait]
302pub trait LogSender: Send + Sync {
303 fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()>;
305
306 fn flush(&mut self) -> anyhow::Result<()>;
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
313pub enum OutputTarget {
314 Stdout,
316 Stderr,
318}
319
320pub struct LocalLogSender {
322 hostname: String,
323 pid: u32,
324 tx: ChannelTx<LogMessage>,
325 status: Receiver<TxStatus>,
326}
327
328impl LocalLogSender {
329 fn new(log_channel: ChannelAddr, pid: u32) -> Result<Self, anyhow::Error> {
330 let tx = channel::dial::<LogMessage>(log_channel)?;
331 let status = tx.status().clone();
332
333 let hostname = hostname::get()
334 .unwrap_or_else(|_| "unknown_host".into())
335 .into_string()
336 .unwrap_or("unknown_host".to_string());
337 Ok(Self {
338 hostname,
339 pid,
340 tx,
341 status,
342 })
343 }
344}
345
346#[async_trait]
347impl LogSender for LocalLogSender {
348 fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
349 if TxStatus::Active == *self.status.borrow() {
350 self.tx.post(LogMessage::Log {
352 hostname: self.hostname.clone(),
353 pid: self.pid,
354 output_target: target,
355 payload: Serialized::serialize_anon(&payload)?,
356 });
357 } else {
358 tracing::debug!(
359 "log sender {} is not active, skip sending log",
360 self.tx.addr()
361 )
362 }
363
364 Ok(())
365 }
366
367 fn flush(&mut self) -> anyhow::Result<()> {
368 if TxStatus::Active == *self.status.borrow() {
370 self.tx.post(LogMessage::Flush { sync_version: None });
372 } else {
373 tracing::debug!(
374 "log sender {} is not active, skip sending flush message",
375 self.tx.addr()
376 );
377 }
378 Ok(())
379 }
380}
381
382pub struct LogWriter<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> {
385 output_target: OutputTarget,
386 std_writer: S,
387 log_sender: T,
388}
389
390fn create_file_writer(
391 local_rank: usize,
392 output_target: OutputTarget,
393 env: env::Env,
394) -> Result<Box<dyn io::AsyncWrite + Send + Unpin + 'static>> {
395 let suffix = match output_target {
396 OutputTarget::Stderr => "stderr",
397 OutputTarget::Stdout => "stdout",
398 };
399 let (path, filename) = log_file_path(env)?;
400 let path = Path::new(&path);
401 let mut full_path = PathBuf::from(path);
402
403 let file_created_by_pid = std::process::id();
409
410 full_path.push(format!(
411 "{}_{}_{}.{}",
412 filename, file_created_by_pid, local_rank, suffix
413 ));
414 let file = std::fs::OpenOptions::new()
415 .create(true)
416 .append(true)
417 .open(full_path)?;
418 let tokio_file = tokio::fs::File::from_std(file);
419 Ok(Box::new(tokio_file))
421}
422
423fn get_local_log_destination(
424 local_rank: usize,
425 output_target: OutputTarget,
426) -> Result<Box<dyn io::AsyncWrite + Send + Unpin>> {
427 let env: env::Env = env::Env::current();
428 Ok(match env {
429 env::Env::Test => match output_target {
430 OutputTarget::Stdout => Box::new(LinePrefixingWriter::new(local_rank, io::stdout())),
431 OutputTarget::Stderr => Box::new(LinePrefixingWriter::new(local_rank, io::stderr())),
432 },
433 env::Env::Local | env::Env::MastEmulator | env::Env::Mast => {
434 create_file_writer(local_rank, output_target, env)?
435 }
436 })
437}
438
439pub fn create_log_writers(
450 local_rank: usize,
451 log_channel: ChannelAddr,
452 pid: u32,
453) -> Result<
454 (
455 Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
456 Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
457 ),
458 anyhow::Error,
459> {
460 let stdout_writer = LogWriter::with_default_writer(
462 local_rank,
463 OutputTarget::Stdout,
464 LocalLogSender::new(log_channel.clone(), pid)?,
465 )?;
466 let stderr_writer = LogWriter::with_default_writer(
467 local_rank,
468 OutputTarget::Stderr,
469 LocalLogSender::new(log_channel, pid)?,
470 )?;
471
472 Ok((Box::new(stdout_writer), Box::new(stderr_writer)))
473}
474
475impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> LogWriter<T, S> {
476 pub fn new(output_target: OutputTarget, std_writer: S, log_sender: T) -> Self {
484 Self {
485 output_target,
486 std_writer,
487 log_sender,
488 }
489 }
490}
491
492impl<T: LogSender + Unpin + 'static> LogWriter<T, Box<dyn io::AsyncWrite + Send + Unpin>> {
493 pub fn with_default_writer(
500 local_rank: usize,
501 output_target: OutputTarget,
502 log_sender: T,
503 ) -> Result<Self> {
504 let std_writer = get_local_log_destination(local_rank, output_target)?;
506
507 Ok(Self {
508 output_target,
509 std_writer,
510 log_sender,
511 })
512 }
513}
514
515impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> io::AsyncWrite
516 for LogWriter<T, S>
517{
518 fn poll_write(
519 self: Pin<&mut Self>,
520 cx: &mut TaskContext<'_>,
521 buf: &[u8],
522 ) -> Poll<Result<usize, io::Error>> {
523 let this = self.get_mut();
525
526 match Pin::new(&mut this.std_writer).poll_write(cx, buf) {
528 Poll::Ready(Ok(_)) => {
529 let output_target = this.output_target.clone();
531 let data_to_send = buf.to_vec();
532
533 if let Err(e) = this.log_sender.send(output_target, data_to_send) {
536 tracing::error!("error sending log: {}", e);
537 }
538 Poll::Ready(Ok(buf.len()))
540 }
541 other => other, }
543 }
544
545 fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Result<(), io::Error>> {
546 let this = self.get_mut();
547
548 match Pin::new(&mut this.std_writer).poll_flush(cx) {
549 Poll::Ready(Ok(())) => {
550 if let Err(e) = this.log_sender.flush() {
551 tracing::error!("error sending flush: {}", e);
552 }
553 Poll::Ready(Ok(()))
554 }
555 other => other, }
557 }
558
559 fn poll_shutdown(
560 self: Pin<&mut Self>,
561 cx: &mut TaskContext<'_>,
562 ) -> Poll<Result<(), io::Error>> {
563 let this = self.get_mut();
564 Pin::new(&mut this.std_writer).poll_shutdown(cx)
565 }
566}
567
568#[derive(
570 Debug,
571 Clone,
572 Serialize,
573 Deserialize,
574 Named,
575 Handler,
576 HandleClient,
577 RefClient,
578 Bind,
579 Unbind
580)]
581pub enum LogForwardMessage {
582 Forward {},
584
585 SetMode { stream_to_client: bool },
587
588 ForceSyncFlush { version: u64 },
590}
591
592#[derive(Debug)]
594#[hyperactor::export(
595 spawn = true,
596 handlers = [LogForwardMessage {cast = true}],
597)]
598pub struct LogForwardActor {
599 rx: ChannelRx<LogMessage>,
600 flush_tx: Arc<Mutex<ChannelTx<LogMessage>>>,
601 next_flush_deadline: SystemTime,
602 logging_client_ref: ActorRef<LogClientActor>,
603 stream_to_client: bool,
604}
605
606#[async_trait]
607impl Actor for LogForwardActor {
608 type Params = ActorRef<LogClientActor>;
609
610 async fn new(logging_client_ref: Self::Params) -> Result<Self> {
611 let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) {
612 Ok(channel) => channel.parse()?,
613 Err(err) => {
614 tracing::debug!(
615 "log forwarder actor failed to read env var {}: {}",
616 BOOTSTRAP_LOG_CHANNEL,
617 err
618 );
619 ChannelAddr::any(ChannelTransport::Unix)
621 }
622 };
623 tracing::info!(
624 "log forwarder {} serve at {}",
625 std::process::id(),
626 log_channel
627 );
628
629 let rx = match channel::serve(log_channel.clone()).await {
630 Ok((_, rx)) => rx,
631 Err(err) => {
632 tracing::error!(
635 "log forwarder actor failed to bootstrap on given channel {}: {}",
636 log_channel,
637 err
638 );
639 channel::serve(ChannelAddr::any(ChannelTransport::Unix))
640 .await?
641 .1
642 }
643 };
644
645 let flush_tx = Arc::new(Mutex::new(channel::dial::<LogMessage>(log_channel)?));
647 let now = RealClock.system_time_now();
648
649 Ok(Self {
650 rx,
651 flush_tx,
652 next_flush_deadline: now,
653 logging_client_ref,
654 stream_to_client: true,
655 })
656 }
657
658 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
659 this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
660
661 self.flush_tx
663 .lock()
664 .await
665 .send(LogMessage::Flush { sync_version: None })
666 .await?;
667 Ok(())
668 }
669}
670
671#[async_trait]
672#[hyperactor::forward(LogForwardMessage)]
673impl LogForwardMessageHandler for LogForwardActor {
674 async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
675 match self.rx.recv().await {
676 Ok(LogMessage::Flush { sync_version }) => {
677 let now = RealClock.system_time_now();
678 match sync_version {
679 None => {
680 let delay = Duration::from_secs(1);
682 if now >= self.next_flush_deadline {
683 self.next_flush_deadline = now + delay;
684 let flush_tx = self.flush_tx.clone();
685 tokio::spawn(async move {
686 RealClock.sleep(delay).await;
687 if let Err(e) = flush_tx
688 .lock()
689 .await
690 .send(LogMessage::Flush { sync_version: None })
691 .await
692 {
693 tracing::error!("failed to send flush message: {}", e);
694 }
695 });
696 }
697 }
698 version => {
699 self.logging_client_ref.flush(ctx, version).await?;
700 }
701 }
702 }
703 Ok(LogMessage::Log {
704 hostname,
705 pid,
706 output_target,
707 payload,
708 }) => {
709 if self.stream_to_client {
710 self.logging_client_ref
711 .log(ctx, hostname, pid, output_target, payload)
712 .await?;
713 }
714 }
715 Err(e) => {
716 return Err(e.into());
717 }
718 }
719
720 ctx.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
722
723 Ok(())
724 }
725
726 async fn set_mode(
727 &mut self,
728 _ctx: &Context<Self>,
729 stream_to_client: bool,
730 ) -> Result<(), anyhow::Error> {
731 self.stream_to_client = stream_to_client;
732 Ok(())
733 }
734
735 async fn force_sync_flush(
736 &mut self,
737 _cx: &Context<Self>,
738 version: u64,
739 ) -> Result<(), anyhow::Error> {
740 self.flush_tx
741 .lock()
742 .await
743 .send(LogMessage::Flush {
744 sync_version: Some(version),
745 })
746 .await
747 .map_err(anyhow::Error::from)
748 }
749}
750
751fn deserialize_message_lines(
753 serialized_message: &hyperactor::data::Serialized,
754) -> Result<Vec<String>> {
755 if let Ok(message_str) = serialized_message.deserialized::<String>() {
757 return Ok(message_str.lines().map(|s| s.to_string()).collect());
758 }
759
760 if let Ok(message_bytes) = serialized_message.deserialized::<Vec<u8>>() {
762 let message_str = String::from_utf8(message_bytes)?;
763 return Ok(message_str.lines().map(|s| s.to_string()).collect());
764 }
765
766 anyhow::bail!("Failed to deserialize message as either String or Vec<u8>")
768}
769
770#[derive(Debug)]
772#[hyperactor::export(
773 spawn = true,
774 handlers = [LogMessage, LogClientMessage],
775)]
776pub struct LogClientActor {
777 aggregate_window_sec: Option<u64>,
778 aggregators: HashMap<OutputTarget, Aggregator>,
779 last_flush_time: SystemTime,
780 next_flush_deadline: Option<SystemTime>,
781
782 current_flush_version: u64,
784 current_flush_port: Option<OncePortRef<()>>,
785 current_unflushed_procs: usize,
786}
787
788impl LogClientActor {
789 fn print_aggregators(&mut self) {
790 for (output_target, aggregator) in self.aggregators.iter_mut() {
791 if aggregator.is_empty() {
792 continue;
793 }
794 match output_target {
795 OutputTarget::Stdout => {
796 println!("{}", aggregator);
797 }
798 OutputTarget::Stderr => {
799 eprintln!("{}", aggregator);
800 }
801 }
802
803 aggregator.reset();
805 }
806 }
807
808 fn print_log_line(hostname: &str, pid: u32, output_target: OutputTarget, line: String) {
809 let message = format!("[{} {}] {}", hostname, pid, line);
810 match output_target {
811 OutputTarget::Stdout => println!("{}", message),
812 OutputTarget::Stderr => eprintln!("{}", message),
813 }
814 }
815
816 fn flush_internal(&mut self) {
817 self.print_aggregators();
818 self.last_flush_time = RealClock.system_time_now();
819 self.next_flush_deadline = None;
820 }
821}
822
823#[async_trait]
824impl Actor for LogClientActor {
825 type Params = ();
827
828 async fn new(_: ()) -> Result<Self, anyhow::Error> {
829 let mut aggregators = HashMap::new();
831 aggregators.insert(OutputTarget::Stderr, Aggregator::new());
832 aggregators.insert(OutputTarget::Stdout, Aggregator::new());
833
834 Ok(Self {
835 aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC),
836 aggregators,
837 last_flush_time: RealClock.system_time_now(),
838 next_flush_deadline: None,
839 current_flush_version: 0,
840 current_flush_port: None,
841 current_unflushed_procs: 0,
842 })
843 }
844}
845
846impl Drop for LogClientActor {
847 fn drop(&mut self) {
848 self.print_aggregators();
850 }
851}
852
853#[async_trait]
854#[hyperactor::forward(LogMessage)]
855impl LogMessageHandler for LogClientActor {
856 async fn log(
857 &mut self,
858 cx: &Context<Self>,
859 hostname: String,
860 pid: u32,
861 output_target: OutputTarget,
862 payload: Serialized,
863 ) -> Result<(), anyhow::Error> {
864 let message_lines = deserialize_message_lines(&payload)?;
866 let hostname = hostname.as_str();
867
868 match self.aggregate_window_sec {
869 None => {
870 for line in message_lines {
871 Self::print_log_line(hostname, pid, output_target, line);
872 }
873 self.last_flush_time = RealClock.system_time_now();
874 }
875 Some(window) => {
876 for line in message_lines {
877 if let Some(aggregator) = self.aggregators.get_mut(&output_target) {
878 if let Err(e) = aggregator.add_line(&line) {
879 tracing::error!("error adding log line: {}", e);
880 Self::print_log_line(hostname, pid, output_target, line);
882 }
883 } else {
884 tracing::error!("unknown output target: {:?}", output_target);
885 Self::print_log_line(hostname, pid, output_target, line);
887 }
888 }
889
890 let new_deadline = self.last_flush_time + Duration::from_secs(window);
891 let now = RealClock.system_time_now();
892 if new_deadline <= now {
893 self.flush_internal();
894 } else {
895 let delay = new_deadline.duration_since(now)?;
896 match self.next_flush_deadline {
897 None => {
898 self.next_flush_deadline = Some(new_deadline);
899 cx.self_message_with_delay(
900 LogMessage::Flush { sync_version: None },
901 delay,
902 )?;
903 }
904 Some(deadline) => {
905 if new_deadline < deadline {
907 self.next_flush_deadline = Some(new_deadline);
909 cx.self_message_with_delay(
910 LogMessage::Flush { sync_version: None },
911 delay,
912 )?;
913 }
914 }
915 }
916 }
917 }
918 }
919
920 Ok(())
921 }
922
923 async fn flush(
924 &mut self,
925 cx: &Context<Self>,
926 sync_version: Option<u64>,
927 ) -> Result<(), anyhow::Error> {
928 match sync_version {
929 None => {
930 self.flush_internal();
931 }
932 Some(version) => {
933 if version != self.current_flush_version {
934 tracing::error!(
935 "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully",
936 version,
937 self.current_flush_version
938 );
939 return Ok(());
940 }
941
942 if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() {
943 anyhow::bail!("found no ongoing flush request");
945 }
946 self.current_unflushed_procs -= 1;
947
948 tracing::debug!(
949 "ack sync flush: version {}; remaining procs: {}",
950 self.current_flush_version,
951 self.current_unflushed_procs
952 );
953
954 if self.current_unflushed_procs == 0 {
955 self.flush_internal();
956 let reply = self.current_flush_port.take().unwrap();
957 self.current_flush_port = None;
958 reply.send(cx, ()).map_err(anyhow::Error::from)?;
959 }
960 }
961 }
962
963 Ok(())
964 }
965}
966
967#[async_trait]
968#[hyperactor::forward(LogClientMessage)]
969impl LogClientMessageHandler for LogClientActor {
970 async fn set_aggregate(
971 &mut self,
972 _cx: &Context<Self>,
973 aggregate_window_sec: Option<u64>,
974 ) -> Result<(), anyhow::Error> {
975 if self.aggregate_window_sec.is_some() && aggregate_window_sec.is_none() {
976 self.print_aggregators();
978 }
979 self.aggregate_window_sec = aggregate_window_sec;
980 Ok(())
981 }
982
983 async fn start_sync_flush(
984 &mut self,
985 cx: &Context<Self>,
986 expected_procs_flushed: usize,
987 reply: OncePortRef<()>,
988 version: OncePortRef<u64>,
989 ) -> Result<(), anyhow::Error> {
990 if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() {
991 tracing::warn!(
992 "found unfinished ongoing flush: version {}; {} unflushed procs",
993 self.current_flush_version,
994 self.current_unflushed_procs,
995 );
996 }
997
998 self.current_flush_version += 1;
999 tracing::debug!(
1000 "start sync flush with version {}",
1001 self.current_flush_version
1002 );
1003 self.current_flush_port = Some(reply.clone());
1004 self.current_unflushed_procs = expected_procs_flushed;
1005 version
1006 .send(cx, self.current_flush_version)
1007 .map_err(anyhow::Error::from)?;
1008 Ok(())
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use std::sync::Arc;
1015 use std::sync::Mutex;
1016
1017 use hyperactor::channel;
1018 use hyperactor::channel::ChannelAddr;
1019 use hyperactor::channel::ChannelTx;
1020 use hyperactor::channel::Tx;
1021 use hyperactor::id;
1022 use hyperactor::mailbox::BoxedMailboxSender;
1023 use hyperactor::mailbox::DialMailboxRouter;
1024 use hyperactor::mailbox::MailboxServer;
1025 use hyperactor::proc::Proc;
1026 use tokio::io::AsyncWriteExt;
1027 use tokio::sync::mpsc;
1028
1029 use super::*;
1030
1031 #[tokio::test]
1032 async fn test_forwarding_log_to_client() {
1033 let router = DialMailboxRouter::new();
1035 let (proc_addr, client_rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
1036 .await
1037 .unwrap();
1038 let proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone()));
1039 proc.clone().serve(client_rx);
1040 router.bind(id!(client[0]).into(), proc_addr.clone());
1041 let client = proc.attach("client").unwrap();
1042
1043 let log_channel = ChannelAddr::any(ChannelTransport::Unix);
1045 unsafe {
1047 std::env::set_var(BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
1048 }
1049 let log_client: ActorRef<LogClientActor> =
1050 proc.spawn("log_client", ()).await.unwrap().bind();
1051 let log_forwarder: ActorRef<LogForwardActor> = proc
1052 .spawn("log_forwarder", log_client)
1053 .await
1054 .unwrap()
1055 .bind();
1056
1057 let tx: ChannelTx<LogMessage> = channel::dial(log_channel).unwrap();
1059 tx.post(LogMessage::Log {
1060 hostname: "my_host".into(),
1061 pid: 1,
1062 output_target: OutputTarget::Stderr,
1063 payload: Serialized::serialize_anon(&"will not stream".to_string()).unwrap(),
1064 });
1065
1066 log_forwarder.set_mode(&client, true).await.unwrap();
1068 tx.post(LogMessage::Log {
1069 hostname: "my_host".into(),
1070 pid: 1,
1071 output_target: OutputTarget::Stderr,
1072 payload: Serialized::serialize_anon(&"will stream".to_string()).unwrap(),
1073 });
1074
1075 }
1077
1078 #[test]
1079 fn test_deserialize_message_lines_string() {
1080 let message = "Line 1\nLine 2\nLine 3".to_string();
1082 let serialized = Serialized::serialize_anon(&message).unwrap();
1083
1084 let result = deserialize_message_lines(&serialized).unwrap();
1085
1086 assert_eq!(result, vec!["Line 1", "Line 2", "Line 3"]);
1087
1088 let message_bytes = "Hello\nWorld\nUTF-8 \u{1F980}".as_bytes().to_vec();
1090 let serialized = Serialized::serialize_anon(&message_bytes).unwrap();
1091
1092 let result = deserialize_message_lines(&serialized).unwrap();
1093
1094 assert_eq!(result, vec!["Hello", "World", "UTF-8 \u{1F980}"]);
1095
1096 let message = "Single line message".to_string();
1098 let serialized = Serialized::serialize_anon(&message).unwrap();
1099
1100 let result = deserialize_message_lines(&serialized).unwrap();
1101
1102 assert_eq!(result, vec!["Single line message"]);
1103
1104 let message = "\n\n".to_string();
1106 let serialized = Serialized::serialize_anon(&message).unwrap();
1107
1108 let result = deserialize_message_lines(&serialized).unwrap();
1109
1110 assert_eq!(result, vec!["", ""]);
1111
1112 let invalid_utf8_bytes = vec![0xFF, 0xFE, 0xFD]; let serialized = Serialized::serialize_anon(&invalid_utf8_bytes).unwrap();
1115
1116 let result = deserialize_message_lines(&serialized);
1117
1118 assert!(result.is_err());
1119 assert!(result.unwrap_err().to_string().contains("invalid utf-8"));
1120 }
1121
1122 struct MockWriter {
1124 data: Arc<Mutex<Vec<u8>>>,
1125 }
1126
1127 impl MockWriter {
1128 fn new() -> (Self, Arc<Mutex<Vec<u8>>>) {
1129 let data = Arc::new(Mutex::new(Vec::new()));
1130 (Self { data: data.clone() }, data)
1131 }
1132 }
1133
1134 impl io::AsyncWrite for MockWriter {
1135 fn poll_write(
1136 self: Pin<&mut Self>,
1137 _cx: &mut TaskContext<'_>,
1138 buf: &[u8],
1139 ) -> Poll<Result<usize, io::Error>> {
1140 let mut data = self.data.lock().unwrap();
1141 data.extend_from_slice(buf);
1142 Poll::Ready(Ok(buf.len()))
1143 }
1144
1145 fn poll_flush(
1146 self: Pin<&mut Self>,
1147 _cx: &mut TaskContext<'_>,
1148 ) -> Poll<Result<(), io::Error>> {
1149 Poll::Ready(Ok(()))
1150 }
1151
1152 fn poll_shutdown(
1153 self: Pin<&mut Self>,
1154 _cx: &mut TaskContext<'_>,
1155 ) -> Poll<Result<(), io::Error>> {
1156 Poll::Ready(Ok(()))
1157 }
1158 }
1159
1160 struct MockLogSender {
1162 log_sender: mpsc::UnboundedSender<(OutputTarget, String)>, flush_called: Arc<Mutex<bool>>, }
1165
1166 impl MockLogSender {
1167 fn new(log_sender: mpsc::UnboundedSender<(OutputTarget, String)>) -> Self {
1168 Self {
1169 log_sender,
1170 flush_called: Arc::new(Mutex::new(false)),
1171 }
1172 }
1173 }
1174
1175 #[async_trait]
1176 impl LogSender for MockLogSender {
1177 fn send(&mut self, output_target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
1178 let line = match std::str::from_utf8(&payload) {
1180 Ok(s) => s.to_string(),
1181 Err(_) => String::from_utf8_lossy(&payload).to_string(),
1182 };
1183
1184 self.log_sender
1185 .send((output_target, line))
1186 .map_err(|e| anyhow::anyhow!("Failed to send log in test: {}", e))
1187 }
1188
1189 fn flush(&mut self) -> anyhow::Result<()> {
1190 let mut flush_called = self.flush_called.lock().unwrap();
1192 *flush_called = true;
1193
1194 Ok(())
1197 }
1198 }
1199
1200 #[tokio::test]
1201 async fn test_log_writer_direct_forwarding() {
1202 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1204
1205 let mock_log_sender = MockLogSender::new(log_sender);
1207
1208 let (mock_writer, _) = MockWriter::new();
1210 let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1211
1212 let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1214
1215 writer.write_all(b"Hello, world!").await.unwrap();
1217 writer.flush().await.unwrap();
1218
1219 let (output_target, content) = log_receiver.recv().await.unwrap();
1221 assert_eq!(output_target, OutputTarget::Stdout);
1222 assert_eq!(content, "Hello, world!");
1223
1224 writer.write_all(b"\nNext line").await.unwrap();
1226 writer.flush().await.unwrap();
1227
1228 let (output_target, content) = log_receiver.recv().await.unwrap();
1230 assert_eq!(output_target, OutputTarget::Stdout);
1231 assert_eq!(content, "\nNext line");
1232 }
1233
1234 #[tokio::test]
1235 async fn test_log_writer_stdout_stderr() {
1236 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1238
1239 let stdout_sender = MockLogSender::new(log_sender.clone());
1241 let stderr_sender = MockLogSender::new(log_sender);
1242
1243 let (stdout_mock_writer, _) = MockWriter::new();
1245 let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1246
1247 let (stderr_mock_writer, _) = MockWriter::new();
1248 let stderr_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stderr_mock_writer);
1249
1250 let mut stdout_writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, stdout_sender);
1252 let mut stderr_writer = LogWriter::new(OutputTarget::Stderr, stderr_writer, stderr_sender);
1253
1254 stdout_writer.write_all(b"Stdout data").await.unwrap();
1256 stdout_writer.flush().await.unwrap();
1257
1258 stderr_writer.write_all(b"Stderr data").await.unwrap();
1259 stderr_writer.flush().await.unwrap();
1260
1261 let mut received_stdout = false;
1264 let mut received_stderr = false;
1265
1266 for _ in 0..2 {
1267 let (output_target, content) = log_receiver.recv().await.unwrap();
1268 match output_target {
1269 OutputTarget::Stdout => {
1270 assert_eq!(content, "Stdout data");
1271 received_stdout = true;
1272 }
1273 OutputTarget::Stderr => {
1274 assert_eq!(content, "Stderr data");
1275 received_stderr = true;
1276 }
1277 }
1278 }
1279
1280 assert!(received_stdout);
1281 assert!(received_stderr);
1282 }
1283
1284 #[tokio::test]
1285 async fn test_log_writer_binary_data() {
1286 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1288
1289 let mock_log_sender = MockLogSender::new(log_sender);
1291
1292 let (mock_writer, _) = MockWriter::new();
1294 let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1295
1296 let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1298
1299 let binary_data = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F, 0xFF, 0xFE, 0x00];
1301 writer.write_all(&binary_data).await.unwrap();
1302 writer.flush().await.unwrap();
1303
1304 let (output_target, content) = log_receiver.recv().await.unwrap();
1306 assert_eq!(output_target, OutputTarget::Stdout);
1307 assert!(content.starts_with("Hello"));
1309 }
1311
1312 #[tokio::test]
1313 async fn test_log_writer_poll_flush() {
1314 let (log_sender, _log_receiver) = mpsc::unbounded_channel();
1316
1317 let mock_log_sender = MockLogSender::new(log_sender);
1319 let log_sender_flush_tracker = mock_log_sender.flush_called.clone();
1320
1321 let (stdout_mock_writer, _) = MockWriter::new();
1323 let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1324
1325 let mut writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, mock_log_sender);
1327
1328 writer.flush().await.unwrap();
1330
1331 assert!(
1333 *log_sender_flush_tracker.lock().unwrap(),
1334 "LogSender's flush was not called"
1335 );
1336 }
1337
1338 #[test]
1339 fn test_string_similarity() {
1340 assert_eq!(normalized_edit_distance("hello", "hello"), 0.0);
1342
1343 assert_eq!(normalized_edit_distance("hello", "i'mdiff"), 1.0);
1345
1346 assert!(normalized_edit_distance("hello", "helo") < 0.5);
1348 assert!(normalized_edit_distance("hello", "hello!") < 0.5);
1349
1350 assert_eq!(normalized_edit_distance("", ""), 0.0);
1352 assert_eq!(normalized_edit_distance("hello", ""), 1.0);
1353 }
1354
1355 #[test]
1356 fn test_add_line_to_empty_aggregator() {
1357 let mut aggregator = Aggregator::new();
1358 let result = aggregator.add_line("ERROR 404 not found");
1359
1360 assert!(result.is_ok());
1361 assert_eq!(aggregator.lines.len(), 1);
1362 assert_eq!(aggregator.lines[0].content, "ERROR 404 not found");
1363 assert_eq!(aggregator.lines[0].count, 1);
1364 }
1365
1366 #[test]
1367 fn test_add_line_merges_with_similar_line() {
1368 let mut aggregator = Aggregator::new_with_threshold(0.2);
1369
1370 aggregator.add_line("ERROR 404 timeout").unwrap();
1372 assert_eq!(aggregator.lines.len(), 1);
1373
1374 aggregator.add_line("ERROR 500 timeout").unwrap();
1376 assert_eq!(aggregator.lines.len(), 1); assert_eq!(aggregator.lines[0].count, 2);
1378
1379 aggregator
1381 .add_line("WARNING database connection failed")
1382 .unwrap();
1383 assert_eq!(aggregator.lines.len(), 2); aggregator
1387 .add_line("WARNING database connection timed out")
1388 .unwrap();
1389 assert_eq!(aggregator.lines.len(), 2); assert_eq!(aggregator.lines[1].count, 2); }
1392
1393 #[test]
1394 fn test_aggregation_of_similar_log_lines() {
1395 let mut aggregator = Aggregator::new_with_threshold(0.2);
1396
1397 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/i/ims-ansi, falling back to search/replace to update prefixes for this file.").unwrap();
1399 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,351 conda-unpack-fb:292] Found invalid offsets for lib/pkgconfig/ncursesw.pc, falling back to search/replace to update prefixes for this file.").unwrap();
1400 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/k/kt7, falling back to search/replace to update prefixes for this file.").unwrap();
1401
1402 assert_eq!(aggregator.lines.len(), 1);
1404
1405 assert_eq!(aggregator.lines[0].count, 3);
1407 }
1408}