1use std::collections::HashMap;
10use std::fmt;
11use std::path::Path;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::Context as TaskContext;
16use std::task::Poll;
17use std::time::Duration;
18use std::time::SystemTime;
19
20use anyhow::Result;
21use async_trait::async_trait;
22use chrono::DateTime;
23use chrono::Local;
24use hyperactor::Actor;
25use hyperactor::ActorRef;
26use hyperactor::Bind;
27use hyperactor::Context;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::Named;
32use hyperactor::OncePortRef;
33use hyperactor::RefClient;
34use hyperactor::Unbind;
35use hyperactor::channel;
36use hyperactor::channel::ChannelAddr;
37use hyperactor::channel::ChannelRx;
38use hyperactor::channel::ChannelTransport;
39use hyperactor::channel::ChannelTx;
40use hyperactor::channel::Rx;
41use hyperactor::channel::Tx;
42use hyperactor::channel::TxStatus;
43use hyperactor::clock::Clock;
44use hyperactor::clock::RealClock;
45use hyperactor::data::Serialized;
46use hyperactor_telemetry::env;
47use hyperactor_telemetry::log_file_path;
48use serde::Deserialize;
49use serde::Serialize;
50use tokio::io;
51use tokio::sync::Mutex;
52use tokio::sync::watch::Receiver;
53
54use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;
55
56mod line_prefixing_writer;
57use line_prefixing_writer::LinePrefixingWriter;
58
59pub(crate) const DEFAULT_AGGREGATE_WINDOW_SEC: u64 = 5;
60
61fn levenshtein_distance(left: &str, right: &str) -> usize {
63 let left_chars: Vec<char> = left.chars().collect();
64 let right_chars: Vec<char> = right.chars().collect();
65
66 let left_len = left_chars.len();
67 let right_len = right_chars.len();
68
69 if left_len == 0 {
71 return right_len;
72 }
73 if right_len == 0 {
74 return left_len;
75 }
76
77 let mut matrix = vec![vec![0; right_len + 1]; left_len + 1];
79
80 for (i, row) in matrix.iter_mut().enumerate().take(left_len + 1) {
82 row[0] = i;
83 }
84 for (j, cell) in matrix[0].iter_mut().enumerate().take(right_len + 1) {
85 *cell = j;
86 }
87
88 for i in 1..=left_len {
90 for j in 1..=right_len {
91 let cost = if left_chars[i - 1] == right_chars[j - 1] {
92 0
93 } else {
94 1
95 };
96
97 matrix[i][j] = std::cmp::min(
98 std::cmp::min(
99 matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
102 matrix[i - 1][j - 1] + cost, );
104 }
105 }
106
107 matrix[left_len][right_len]
109}
110
111fn normalized_edit_distance(left: &str, right: &str) -> f64 {
113 let distance = levenshtein_distance(left, right) as f64;
114 let max_len = std::cmp::max(left.len(), right.len()) as f64;
115
116 if max_len == 0.0 {
117 0.0 } else {
119 distance / max_len
120 }
121}
122
123#[derive(Debug, Clone)]
124struct LogLine {
126 content: String,
127 pub count: u64,
128}
129
130impl LogLine {
131 fn new(content: String) -> Self {
132 Self { content, count: 1 }
133 }
134}
135
136impl fmt::Display for LogLine {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 write!(
139 f,
140 "\x1b[33m[{} similar log lines]\x1b[0m {}",
141 self.count, self.content
142 )
143 }
144}
145
146#[derive(Debug, Clone)]
147struct Aggregator {
150 lines: Vec<LogLine>,
151 start_time: SystemTime,
152 similarity_threshold: f64, }
154
155impl Aggregator {
156 fn new() -> Self {
157 Self::new_with_threshold(0.15)
159 }
160
161 fn new_with_threshold(threshold: f64) -> Self {
162 Aggregator {
163 lines: vec![],
164 start_time: RealClock.system_time_now(),
165 similarity_threshold: threshold,
166 }
167 }
168
169 fn reset(&mut self) {
170 self.lines.clear();
171 self.start_time = RealClock.system_time_now();
172 }
173
174 fn add_line(&mut self, line: &str) -> anyhow::Result<()> {
175 let mut best_match_idx = None;
177 let mut best_similarity = f64::MAX;
178
179 for (idx, existing_line) in self.lines.iter().enumerate() {
180 let distance = normalized_edit_distance(&existing_line.content, line);
181
182 if distance < best_similarity && distance < self.similarity_threshold {
184 best_match_idx = Some(idx);
185 best_similarity = distance;
186 }
187 }
188
189 if let Some(idx) = best_match_idx {
191 self.lines[idx].count += 1;
192 } else {
193 self.lines.push(LogLine::new(line.to_string()));
195 }
196
197 Ok(())
198 }
199
200 fn is_empty(&self) -> bool {
201 self.lines.is_empty()
202 }
203}
204
205fn format_system_time(time: SystemTime) -> String {
207 let datetime: DateTime<Local> = time.into();
208 datetime.format("%Y-%m-%d %H:%M:%S").to_string()
209}
210
211impl fmt::Display for Aggregator {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 let start_time_str = format_system_time(self.start_time);
215
216 let current_time = RealClock.system_time_now();
218 let end_time_str = format_system_time(current_time);
219
220 writeln!(
222 f,
223 "\x1b[36m>>> Aggregated Logs ({}) >>>\x1b[0m",
224 start_time_str
225 )?;
226
227 for line in self.lines.iter() {
229 writeln!(f, "{}", line)?;
230 }
231 writeln!(
232 f,
233 "\x1b[36m<<< Aggregated Logs ({}) <<<\x1b[0m",
234 end_time_str
235 )?;
236 Ok(())
237 }
238}
239
240#[derive(
242 Debug,
243 Clone,
244 Serialize,
245 Deserialize,
246 Named,
247 Handler,
248 HandleClient,
249 RefClient
250)]
251pub enum LogMessage {
252 Log {
254 hostname: String,
256 pid: u32,
258 output_target: OutputTarget,
260 payload: Serialized,
262 },
263
264 Flush {
266 sync_version: Option<u64>,
269 },
270}
271
272#[derive(
274 Debug,
275 Clone,
276 Serialize,
277 Deserialize,
278 Named,
279 Handler,
280 HandleClient,
281 RefClient
282)]
283pub enum LogClientMessage {
284 SetAggregate {
285 aggregate_window_sec: Option<u64>,
287 },
288
289 StartSyncFlush {
291 expected_procs: usize,
293 reply: OncePortRef<()>,
295 version: OncePortRef<u64>,
297 },
298}
299
300#[async_trait]
302pub trait LogSender: Send + Sync {
303 fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()>;
305
306 fn flush(&mut self) -> anyhow::Result<()>;
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
313pub enum OutputTarget {
314 Stdout,
316 Stderr,
318}
319
320pub struct LocalLogSender {
322 hostname: String,
323 pid: u32,
324 tx: ChannelTx<LogMessage>,
325 status: Receiver<TxStatus>,
326}
327
328impl LocalLogSender {
329 fn new(log_channel: ChannelAddr, pid: u32) -> Result<Self, anyhow::Error> {
330 let tx = channel::dial::<LogMessage>(log_channel)?;
331 let status = tx.status().clone();
332
333 let hostname = hostname::get()
334 .unwrap_or_else(|_| "unknown_host".into())
335 .into_string()
336 .unwrap_or("unknown_host".to_string());
337 Ok(Self {
338 hostname,
339 pid,
340 tx,
341 status,
342 })
343 }
344}
345
346#[async_trait]
347impl LogSender for LocalLogSender {
348 fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
349 if TxStatus::Active == *self.status.borrow() {
350 self.tx.post(LogMessage::Log {
352 hostname: self.hostname.clone(),
353 pid: self.pid,
354 output_target: target,
355 payload: Serialized::serialize(&payload)?,
356 });
357 } else {
358 tracing::debug!(
359 "log sender {} is not active, skip sending log",
360 self.tx.addr()
361 )
362 }
363
364 Ok(())
365 }
366
367 fn flush(&mut self) -> anyhow::Result<()> {
368 if TxStatus::Active == *self.status.borrow() {
370 self.tx.post(LogMessage::Flush { sync_version: None });
372 } else {
373 tracing::debug!(
374 "log sender {} is not active, skip sending flush message",
375 self.tx.addr()
376 );
377 }
378 Ok(())
379 }
380}
381
382pub struct LogWriter<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> {
385 output_target: OutputTarget,
386 std_writer: S,
387 log_sender: T,
388}
389
390fn create_file_writer(
391 local_rank: usize,
392 output_target: OutputTarget,
393 env: env::Env,
394) -> Result<Box<dyn io::AsyncWrite + Send + Unpin + 'static>> {
395 let suffix = match output_target {
396 OutputTarget::Stderr => "stderr",
397 OutputTarget::Stdout => "stdout",
398 };
399 let (path, filename) = log_file_path(env)?;
400 let path = Path::new(&path);
401 let mut full_path = PathBuf::from(path);
402
403 let file_created_by_pid = std::process::id();
409
410 full_path.push(format!(
411 "{}_{}_{}.{}",
412 filename, file_created_by_pid, local_rank, suffix
413 ));
414 let file = std::fs::OpenOptions::new()
415 .create(true)
416 .append(true)
417 .open(full_path)?;
418 let tokio_file = tokio::fs::File::from_std(file);
419 Ok(Box::new(tokio_file))
421}
422
423fn get_local_log_destination(
424 local_rank: usize,
425 output_target: OutputTarget,
426) -> Result<Box<dyn io::AsyncWrite + Send + Unpin>> {
427 let env: env::Env = env::Env::current();
428 Ok(match env {
429 env::Env::Test | env::Env::Local => match output_target {
430 OutputTarget::Stdout => Box::new(LinePrefixingWriter::new(local_rank, io::stdout())),
431 OutputTarget::Stderr => Box::new(LinePrefixingWriter::new(local_rank, io::stderr())),
432 },
433 env::Env::MastEmulator | env::Env::Mast => {
434 create_file_writer(local_rank, output_target, env)?
435 }
436 })
437}
438
439pub fn create_log_writers(
450 local_rank: usize,
451 log_channel: ChannelAddr,
452 pid: u32,
453) -> Result<
454 (
455 Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
456 Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
457 ),
458 anyhow::Error,
459> {
460 let stdout_writer = LogWriter::with_default_writer(
462 local_rank,
463 OutputTarget::Stdout,
464 LocalLogSender::new(log_channel.clone(), pid)?,
465 )?;
466 let stderr_writer = LogWriter::with_default_writer(
467 local_rank,
468 OutputTarget::Stderr,
469 LocalLogSender::new(log_channel, pid)?,
470 )?;
471
472 Ok((Box::new(stdout_writer), Box::new(stderr_writer)))
473}
474
475impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> LogWriter<T, S> {
476 pub fn new(output_target: OutputTarget, std_writer: S, log_sender: T) -> Self {
484 Self {
485 output_target,
486 std_writer,
487 log_sender,
488 }
489 }
490}
491
492impl<T: LogSender + Unpin + 'static> LogWriter<T, Box<dyn io::AsyncWrite + Send + Unpin>> {
493 pub fn with_default_writer(
500 local_rank: usize,
501 output_target: OutputTarget,
502 log_sender: T,
503 ) -> Result<Self> {
504 let std_writer = get_local_log_destination(local_rank, output_target)?;
506
507 Ok(Self {
508 output_target,
509 std_writer,
510 log_sender,
511 })
512 }
513}
514
515impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> io::AsyncWrite
516 for LogWriter<T, S>
517{
518 fn poll_write(
519 self: Pin<&mut Self>,
520 cx: &mut TaskContext<'_>,
521 buf: &[u8],
522 ) -> Poll<Result<usize, io::Error>> {
523 let this = self.get_mut();
525
526 match Pin::new(&mut this.std_writer).poll_write(cx, buf) {
528 Poll::Ready(Ok(_)) => {
529 let output_target = this.output_target;
531 let data_to_send = buf.to_vec();
532
533 if let Err(e) = this.log_sender.send(output_target, data_to_send) {
536 tracing::error!("error sending log: {}", e);
537 }
538 Poll::Ready(Ok(buf.len()))
540 }
541 other => other, }
543 }
544
545 fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Result<(), io::Error>> {
546 let this = self.get_mut();
547
548 match Pin::new(&mut this.std_writer).poll_flush(cx) {
549 Poll::Ready(Ok(())) => {
550 if let Err(e) = this.log_sender.flush() {
551 tracing::error!("error sending flush: {}", e);
552 }
553 Poll::Ready(Ok(()))
554 }
555 other => other, }
557 }
558
559 fn poll_shutdown(
560 self: Pin<&mut Self>,
561 cx: &mut TaskContext<'_>,
562 ) -> Poll<Result<(), io::Error>> {
563 let this = self.get_mut();
564 Pin::new(&mut this.std_writer).poll_shutdown(cx)
565 }
566}
567
568#[derive(
570 Debug,
571 Clone,
572 Serialize,
573 Deserialize,
574 Named,
575 Handler,
576 HandleClient,
577 RefClient,
578 Bind,
579 Unbind
580)]
581pub enum LogForwardMessage {
582 Forward {},
584
585 SetMode { stream_to_client: bool },
587
588 ForceSyncFlush { version: u64 },
590}
591
592#[derive(Debug)]
594#[hyperactor::export(
595 spawn = true,
596 handlers = [LogForwardMessage {cast = true}],
597)]
598pub struct LogForwardActor {
599 rx: ChannelRx<LogMessage>,
600 flush_tx: Arc<Mutex<ChannelTx<LogMessage>>>,
601 next_flush_deadline: SystemTime,
602 logging_client_ref: ActorRef<LogClientActor>,
603 stream_to_client: bool,
604}
605
606#[async_trait]
607impl Actor for LogForwardActor {
608 type Params = ActorRef<LogClientActor>;
609
610 async fn new(logging_client_ref: Self::Params) -> Result<Self> {
611 let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) {
612 Ok(channel) => channel.parse()?,
613 Err(err) => {
614 tracing::debug!(
615 "log forwarder actor failed to read env var {}: {}",
616 BOOTSTRAP_LOG_CHANNEL,
617 err
618 );
619 ChannelAddr::any(ChannelTransport::Unix)
621 }
622 };
623 tracing::info!(
624 "log forwarder {} serve at {}",
625 std::process::id(),
626 log_channel
627 );
628
629 let rx = match channel::serve(log_channel.clone()) {
630 Ok((_, rx)) => rx,
631 Err(err) => {
632 tracing::error!(
635 "log forwarder actor failed to bootstrap on given channel {}: {}",
636 log_channel,
637 err
638 );
639 channel::serve(ChannelAddr::any(ChannelTransport::Unix))?.1
640 }
641 };
642
643 let flush_tx = Arc::new(Mutex::new(channel::dial::<LogMessage>(log_channel)?));
645 let now = RealClock.system_time_now();
646
647 Ok(Self {
648 rx,
649 flush_tx,
650 next_flush_deadline: now,
651 logging_client_ref,
652 stream_to_client: true,
653 })
654 }
655
656 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
657 this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
658
659 self.flush_tx
661 .lock()
662 .await
663 .send(LogMessage::Flush { sync_version: None })
664 .await?;
665 Ok(())
666 }
667}
668
669#[async_trait]
670#[hyperactor::forward(LogForwardMessage)]
671impl LogForwardMessageHandler for LogForwardActor {
672 async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
673 match self.rx.recv().await {
674 Ok(LogMessage::Flush { sync_version }) => {
675 let now = RealClock.system_time_now();
676 match sync_version {
677 None => {
678 let delay = Duration::from_secs(1);
680 if now >= self.next_flush_deadline {
681 self.next_flush_deadline = now + delay;
682 let flush_tx = self.flush_tx.clone();
683 tokio::spawn(async move {
684 RealClock.sleep(delay).await;
685 if let Err(e) = flush_tx
686 .lock()
687 .await
688 .send(LogMessage::Flush { sync_version: None })
689 .await
690 {
691 tracing::error!("failed to send flush message: {}", e);
692 }
693 });
694 }
695 }
696 version => {
697 self.logging_client_ref.flush(ctx, version).await?;
698 }
699 }
700 }
701 Ok(LogMessage::Log {
702 hostname,
703 pid,
704 output_target,
705 payload,
706 }) => {
707 if self.stream_to_client {
708 self.logging_client_ref
709 .log(ctx, hostname, pid, output_target, payload)
710 .await?;
711 }
712 }
713 Err(e) => {
714 return Err(e.into());
715 }
716 }
717
718 ctx.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
720
721 Ok(())
722 }
723
724 async fn set_mode(
725 &mut self,
726 _ctx: &Context<Self>,
727 stream_to_client: bool,
728 ) -> Result<(), anyhow::Error> {
729 self.stream_to_client = stream_to_client;
730 Ok(())
731 }
732
733 async fn force_sync_flush(
734 &mut self,
735 _cx: &Context<Self>,
736 version: u64,
737 ) -> Result<(), anyhow::Error> {
738 self.flush_tx
739 .lock()
740 .await
741 .send(LogMessage::Flush {
742 sync_version: Some(version),
743 })
744 .await
745 .map_err(anyhow::Error::from)
746 }
747}
748
749fn deserialize_message_lines(
751 serialized_message: &hyperactor::data::Serialized,
752) -> Result<Vec<String>> {
753 if let Ok(message_str) = serialized_message.deserialized::<String>() {
755 return Ok(message_str.lines().map(|s| s.to_string()).collect());
756 }
757
758 if let Ok(message_bytes) = serialized_message.deserialized::<Vec<u8>>() {
760 let message_str = String::from_utf8(message_bytes)?;
761 return Ok(message_str.lines().map(|s| s.to_string()).collect());
762 }
763
764 anyhow::bail!("failed to deserialize message as either String or Vec<u8>")
766}
767
768#[derive(Debug)]
770#[hyperactor::export(
771 spawn = true,
772 handlers = [LogMessage, LogClientMessage],
773)]
774pub struct LogClientActor {
775 aggregate_window_sec: Option<u64>,
776 aggregators: HashMap<OutputTarget, Aggregator>,
777 last_flush_time: SystemTime,
778 next_flush_deadline: Option<SystemTime>,
779
780 current_flush_version: u64,
782 current_flush_port: Option<OncePortRef<()>>,
783 current_unflushed_procs: usize,
784}
785
786impl LogClientActor {
787 fn print_aggregators(&mut self) {
788 for (output_target, aggregator) in self.aggregators.iter_mut() {
789 if aggregator.is_empty() {
790 continue;
791 }
792 match output_target {
793 OutputTarget::Stdout => {
794 println!("{}", aggregator);
795 }
796 OutputTarget::Stderr => {
797 eprintln!("{}", aggregator);
798 }
799 }
800
801 aggregator.reset();
803 }
804 }
805
806 fn print_log_line(hostname: &str, pid: u32, output_target: OutputTarget, line: String) {
807 let message = format!("[{} {}] {}", hostname, pid, line);
808
809 #[cfg(test)]
810 crate::logging::test_tap::push(&message);
811
812 match output_target {
813 OutputTarget::Stdout => println!("{}", message),
814 OutputTarget::Stderr => eprintln!("{}", message),
815 }
816 }
817
818 fn flush_internal(&mut self) {
819 self.print_aggregators();
820 self.last_flush_time = RealClock.system_time_now();
821 self.next_flush_deadline = None;
822 }
823}
824
825#[async_trait]
826impl Actor for LogClientActor {
827 type Params = ();
829
830 async fn new(_: ()) -> Result<Self, anyhow::Error> {
831 let mut aggregators = HashMap::new();
833 aggregators.insert(OutputTarget::Stderr, Aggregator::new());
834 aggregators.insert(OutputTarget::Stdout, Aggregator::new());
835
836 Ok(Self {
837 aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC),
838 aggregators,
839 last_flush_time: RealClock.system_time_now(),
840 next_flush_deadline: None,
841 current_flush_version: 0,
842 current_flush_port: None,
843 current_unflushed_procs: 0,
844 })
845 }
846}
847
848impl Drop for LogClientActor {
849 fn drop(&mut self) {
850 self.print_aggregators();
852 }
853}
854
855#[async_trait]
856#[hyperactor::forward(LogMessage)]
857impl LogMessageHandler for LogClientActor {
858 async fn log(
859 &mut self,
860 cx: &Context<Self>,
861 hostname: String,
862 pid: u32,
863 output_target: OutputTarget,
864 payload: Serialized,
865 ) -> Result<(), anyhow::Error> {
866 let message_lines = deserialize_message_lines(&payload)?;
868 let hostname = hostname.as_str();
869
870 match self.aggregate_window_sec {
871 None => {
872 for line in message_lines {
873 Self::print_log_line(hostname, pid, output_target, line);
874 }
875 self.last_flush_time = RealClock.system_time_now();
876 }
877 Some(window) => {
878 for line in message_lines {
879 if let Some(aggregator) = self.aggregators.get_mut(&output_target) {
880 if let Err(e) = aggregator.add_line(&line) {
881 tracing::error!("error adding log line: {}", e);
882 Self::print_log_line(hostname, pid, output_target, line);
884 }
885 } else {
886 tracing::error!("unknown output target: {:?}", output_target);
887 Self::print_log_line(hostname, pid, output_target, line);
889 }
890 }
891
892 let new_deadline = self.last_flush_time + Duration::from_secs(window);
893 let now = RealClock.system_time_now();
894 if new_deadline <= now {
895 self.flush_internal();
896 } else {
897 let delay = new_deadline.duration_since(now)?;
898 match self.next_flush_deadline {
899 None => {
900 self.next_flush_deadline = Some(new_deadline);
901 cx.self_message_with_delay(
902 LogMessage::Flush { sync_version: None },
903 delay,
904 )?;
905 }
906 Some(deadline) => {
907 if new_deadline < deadline {
909 self.next_flush_deadline = Some(new_deadline);
911 cx.self_message_with_delay(
912 LogMessage::Flush { sync_version: None },
913 delay,
914 )?;
915 }
916 }
917 }
918 }
919 }
920 }
921
922 Ok(())
923 }
924
925 async fn flush(
926 &mut self,
927 cx: &Context<Self>,
928 sync_version: Option<u64>,
929 ) -> Result<(), anyhow::Error> {
930 match sync_version {
931 None => {
932 self.flush_internal();
933 }
934 Some(version) => {
935 if version != self.current_flush_version {
936 tracing::error!(
937 "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully",
938 version,
939 self.current_flush_version
940 );
941 return Ok(());
942 }
943
944 if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() {
945 anyhow::bail!("found no ongoing flush request");
947 }
948 self.current_unflushed_procs -= 1;
949
950 tracing::debug!(
951 "ack sync flush: version {}; remaining procs: {}",
952 self.current_flush_version,
953 self.current_unflushed_procs
954 );
955
956 if self.current_unflushed_procs == 0 {
957 self.flush_internal();
958 let reply = self.current_flush_port.take().unwrap();
959 self.current_flush_port = None;
960 reply.send(cx, ()).map_err(anyhow::Error::from)?;
961 }
962 }
963 }
964
965 Ok(())
966 }
967}
968
969#[async_trait]
970#[hyperactor::forward(LogClientMessage)]
971impl LogClientMessageHandler for LogClientActor {
972 async fn set_aggregate(
973 &mut self,
974 _cx: &Context<Self>,
975 aggregate_window_sec: Option<u64>,
976 ) -> Result<(), anyhow::Error> {
977 if self.aggregate_window_sec.is_some() && aggregate_window_sec.is_none() {
978 self.print_aggregators();
980 }
981 self.aggregate_window_sec = aggregate_window_sec;
982 Ok(())
983 }
984
985 async fn start_sync_flush(
986 &mut self,
987 cx: &Context<Self>,
988 expected_procs_flushed: usize,
989 reply: OncePortRef<()>,
990 version: OncePortRef<u64>,
991 ) -> Result<(), anyhow::Error> {
992 if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() {
993 tracing::warn!(
994 "found unfinished ongoing flush: version {}; {} unflushed procs",
995 self.current_flush_version,
996 self.current_unflushed_procs,
997 );
998 }
999
1000 self.current_flush_version += 1;
1001 tracing::debug!(
1002 "start sync flush with version {}",
1003 self.current_flush_version
1004 );
1005 self.current_flush_port = Some(reply.clone());
1006 self.current_unflushed_procs = expected_procs_flushed;
1007 version
1008 .send(cx, self.current_flush_version)
1009 .map_err(anyhow::Error::from)?;
1010 Ok(())
1011 }
1012}
1013
1014#[cfg(test)]
1015pub mod test_tap {
1016 use std::sync::Mutex;
1017 use std::sync::OnceLock;
1018
1019 use tokio::sync::mpsc::UnboundedReceiver;
1020 use tokio::sync::mpsc::UnboundedSender;
1021
1022 static TAP: OnceLock<UnboundedSender<String>> = OnceLock::new();
1023 static RX: OnceLock<Mutex<UnboundedReceiver<String>>> = OnceLock::new();
1024
1025 pub fn install(tx: UnboundedSender<String>) {
1027 let _ = TAP.set(tx);
1028 }
1029
1030 pub fn set_receiver(rx: UnboundedReceiver<String>) {
1032 let _ = RX.set(Mutex::new(rx));
1033 }
1034
1035 pub fn push(s: &str) {
1037 if let Some(tx) = TAP.get() {
1038 let _ = tx.send(s.to_string());
1039 }
1040 }
1041
1042 pub fn drain() -> Vec<String> {
1044 let mut out = Vec::new();
1045 if let Some(rx) = RX.get() {
1046 let mut rx = rx.lock().unwrap();
1047 while let Ok(line) = rx.try_recv() {
1048 out.push(line);
1049 }
1050 }
1051 out
1052 }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use std::sync::Arc;
1058 use std::sync::Mutex;
1059
1060 use hyperactor::channel;
1061 use hyperactor::channel::ChannelAddr;
1062 use hyperactor::channel::ChannelTx;
1063 use hyperactor::channel::Tx;
1064 use hyperactor::id;
1065 use hyperactor::mailbox::BoxedMailboxSender;
1066 use hyperactor::mailbox::DialMailboxRouter;
1067 use hyperactor::mailbox::MailboxServer;
1068 use hyperactor::proc::Proc;
1069 use tokio::io::AsyncWriteExt;
1070 use tokio::sync::mpsc;
1071
1072 use super::*;
1073
1074 #[tokio::test]
1075 async fn test_forwarding_log_to_client() {
1076 let router = DialMailboxRouter::new();
1078 let (proc_addr, client_rx) =
1079 channel::serve(ChannelAddr::any(ChannelTransport::Unix)).unwrap();
1080 let proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone()));
1081 proc.clone().serve(client_rx);
1082 router.bind(id!(client[0]).into(), proc_addr.clone());
1083 let (client, _handle) = proc.instance("client").unwrap();
1084
1085 let log_channel = ChannelAddr::any(ChannelTransport::Unix);
1087 unsafe {
1089 std::env::set_var(BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
1090 }
1091 let log_client: ActorRef<LogClientActor> =
1092 proc.spawn("log_client", ()).await.unwrap().bind();
1093 let log_forwarder: ActorRef<LogForwardActor> = proc
1094 .spawn("log_forwarder", log_client)
1095 .await
1096 .unwrap()
1097 .bind();
1098
1099 let tx: ChannelTx<LogMessage> = channel::dial(log_channel).unwrap();
1101 tx.post(LogMessage::Log {
1102 hostname: "my_host".into(),
1103 pid: 1,
1104 output_target: OutputTarget::Stderr,
1105 payload: Serialized::serialize(&"will not stream".to_string()).unwrap(),
1106 });
1107
1108 log_forwarder.set_mode(&client, true).await.unwrap();
1110 tx.post(LogMessage::Log {
1111 hostname: "my_host".into(),
1112 pid: 1,
1113 output_target: OutputTarget::Stderr,
1114 payload: Serialized::serialize(&"will stream".to_string()).unwrap(),
1115 });
1116
1117 }
1119
1120 #[test]
1121 fn test_deserialize_message_lines_string() {
1122 let message = "Line 1\nLine 2\nLine 3".to_string();
1124 let serialized = Serialized::serialize(&message).unwrap();
1125
1126 let result = deserialize_message_lines(&serialized).unwrap();
1127
1128 assert_eq!(result, vec!["Line 1", "Line 2", "Line 3"]);
1129
1130 let message_bytes = "Hello\nWorld\nUTF-8 \u{1F980}".as_bytes().to_vec();
1132 let serialized = Serialized::serialize(&message_bytes).unwrap();
1133
1134 let result = deserialize_message_lines(&serialized).unwrap();
1135
1136 assert_eq!(result, vec!["Hello", "World", "UTF-8 \u{1F980}"]);
1137
1138 let message = "Single line message".to_string();
1140 let serialized = Serialized::serialize(&message).unwrap();
1141
1142 let result = deserialize_message_lines(&serialized).unwrap();
1143
1144 assert_eq!(result, vec!["Single line message"]);
1145
1146 let message = "\n\n".to_string();
1148 let serialized = Serialized::serialize(&message).unwrap();
1149
1150 let result = deserialize_message_lines(&serialized).unwrap();
1151
1152 assert_eq!(result, vec!["", ""]);
1153
1154 let invalid_utf8_bytes = vec![0xFF, 0xFE, 0xFD]; let serialized = Serialized::serialize_as::<Vec<u8>, _>(&invalid_utf8_bytes).unwrap();
1157
1158 let result = deserialize_message_lines(&serialized);
1159
1160 assert!(result.is_err());
1161 let message = result.unwrap_err().to_string();
1162 assert!(message.contains("invalid utf-8"), "{}", message);
1163 }
1164
1165 struct MockWriter {
1167 data: Arc<Mutex<Vec<u8>>>,
1168 }
1169
1170 impl MockWriter {
1171 fn new() -> (Self, Arc<Mutex<Vec<u8>>>) {
1172 let data = Arc::new(Mutex::new(Vec::new()));
1173 (Self { data: data.clone() }, data)
1174 }
1175 }
1176
1177 impl io::AsyncWrite for MockWriter {
1178 fn poll_write(
1179 self: Pin<&mut Self>,
1180 _cx: &mut TaskContext<'_>,
1181 buf: &[u8],
1182 ) -> Poll<Result<usize, io::Error>> {
1183 let mut data = self.data.lock().unwrap();
1184 data.extend_from_slice(buf);
1185 Poll::Ready(Ok(buf.len()))
1186 }
1187
1188 fn poll_flush(
1189 self: Pin<&mut Self>,
1190 _cx: &mut TaskContext<'_>,
1191 ) -> Poll<Result<(), io::Error>> {
1192 Poll::Ready(Ok(()))
1193 }
1194
1195 fn poll_shutdown(
1196 self: Pin<&mut Self>,
1197 _cx: &mut TaskContext<'_>,
1198 ) -> Poll<Result<(), io::Error>> {
1199 Poll::Ready(Ok(()))
1200 }
1201 }
1202
1203 struct MockLogSender {
1205 log_sender: mpsc::UnboundedSender<(OutputTarget, String)>, flush_called: Arc<Mutex<bool>>, }
1208
1209 impl MockLogSender {
1210 fn new(log_sender: mpsc::UnboundedSender<(OutputTarget, String)>) -> Self {
1211 Self {
1212 log_sender,
1213 flush_called: Arc::new(Mutex::new(false)),
1214 }
1215 }
1216 }
1217
1218 #[async_trait]
1219 impl LogSender for MockLogSender {
1220 fn send(&mut self, output_target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
1221 let line = match std::str::from_utf8(&payload) {
1223 Ok(s) => s.to_string(),
1224 Err(_) => String::from_utf8_lossy(&payload).to_string(),
1225 };
1226
1227 self.log_sender
1228 .send((output_target, line))
1229 .map_err(|e| anyhow::anyhow!("Failed to send log in test: {}", e))
1230 }
1231
1232 fn flush(&mut self) -> anyhow::Result<()> {
1233 let mut flush_called = self.flush_called.lock().unwrap();
1235 *flush_called = true;
1236
1237 Ok(())
1240 }
1241 }
1242
1243 #[tokio::test]
1244 async fn test_log_writer_direct_forwarding() {
1245 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1247
1248 let mock_log_sender = MockLogSender::new(log_sender);
1250
1251 let (mock_writer, _) = MockWriter::new();
1253 let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1254
1255 let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1257
1258 writer.write_all(b"Hello, world!").await.unwrap();
1260 writer.flush().await.unwrap();
1261
1262 let (output_target, content) = log_receiver.recv().await.unwrap();
1264 assert_eq!(output_target, OutputTarget::Stdout);
1265 assert_eq!(content, "Hello, world!");
1266
1267 writer.write_all(b"\nNext line").await.unwrap();
1269 writer.flush().await.unwrap();
1270
1271 let (output_target, content) = log_receiver.recv().await.unwrap();
1273 assert_eq!(output_target, OutputTarget::Stdout);
1274 assert_eq!(content, "\nNext line");
1275 }
1276
1277 #[tokio::test]
1278 async fn test_log_writer_stdout_stderr() {
1279 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1281
1282 let stdout_sender = MockLogSender::new(log_sender.clone());
1284 let stderr_sender = MockLogSender::new(log_sender);
1285
1286 let (stdout_mock_writer, _) = MockWriter::new();
1288 let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1289
1290 let (stderr_mock_writer, _) = MockWriter::new();
1291 let stderr_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stderr_mock_writer);
1292
1293 let mut stdout_writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, stdout_sender);
1295 let mut stderr_writer = LogWriter::new(OutputTarget::Stderr, stderr_writer, stderr_sender);
1296
1297 stdout_writer.write_all(b"Stdout data").await.unwrap();
1299 stdout_writer.flush().await.unwrap();
1300
1301 stderr_writer.write_all(b"Stderr data").await.unwrap();
1302 stderr_writer.flush().await.unwrap();
1303
1304 let mut received_stdout = false;
1307 let mut received_stderr = false;
1308
1309 for _ in 0..2 {
1310 let (output_target, content) = log_receiver.recv().await.unwrap();
1311 match output_target {
1312 OutputTarget::Stdout => {
1313 assert_eq!(content, "Stdout data");
1314 received_stdout = true;
1315 }
1316 OutputTarget::Stderr => {
1317 assert_eq!(content, "Stderr data");
1318 received_stderr = true;
1319 }
1320 }
1321 }
1322
1323 assert!(received_stdout);
1324 assert!(received_stderr);
1325 }
1326
1327 #[tokio::test]
1328 async fn test_log_writer_binary_data() {
1329 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1331
1332 let mock_log_sender = MockLogSender::new(log_sender);
1334
1335 let (mock_writer, _) = MockWriter::new();
1337 let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1338
1339 let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1341
1342 let binary_data = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F, 0xFF, 0xFE, 0x00];
1344 writer.write_all(&binary_data).await.unwrap();
1345 writer.flush().await.unwrap();
1346
1347 let (output_target, content) = log_receiver.recv().await.unwrap();
1349 assert_eq!(output_target, OutputTarget::Stdout);
1350 assert!(content.starts_with("Hello"));
1352 }
1354
1355 #[tokio::test]
1356 async fn test_log_writer_poll_flush() {
1357 let (log_sender, _log_receiver) = mpsc::unbounded_channel();
1359
1360 let mock_log_sender = MockLogSender::new(log_sender);
1362 let log_sender_flush_tracker = mock_log_sender.flush_called.clone();
1363
1364 let (stdout_mock_writer, _) = MockWriter::new();
1366 let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1367
1368 let mut writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, mock_log_sender);
1370
1371 writer.flush().await.unwrap();
1373
1374 assert!(
1376 *log_sender_flush_tracker.lock().unwrap(),
1377 "LogSender's flush was not called"
1378 );
1379 }
1380
1381 #[test]
1382 fn test_string_similarity() {
1383 assert_eq!(normalized_edit_distance("hello", "hello"), 0.0);
1385
1386 assert_eq!(normalized_edit_distance("hello", "i'mdiff"), 1.0);
1388
1389 assert!(normalized_edit_distance("hello", "helo") < 0.5);
1391 assert!(normalized_edit_distance("hello", "hello!") < 0.5);
1392
1393 assert_eq!(normalized_edit_distance("", ""), 0.0);
1395 assert_eq!(normalized_edit_distance("hello", ""), 1.0);
1396 }
1397
1398 #[test]
1399 fn test_add_line_to_empty_aggregator() {
1400 let mut aggregator = Aggregator::new();
1401 let result = aggregator.add_line("ERROR 404 not found");
1402
1403 assert!(result.is_ok());
1404 assert_eq!(aggregator.lines.len(), 1);
1405 assert_eq!(aggregator.lines[0].content, "ERROR 404 not found");
1406 assert_eq!(aggregator.lines[0].count, 1);
1407 }
1408
1409 #[test]
1410 fn test_add_line_merges_with_similar_line() {
1411 let mut aggregator = Aggregator::new_with_threshold(0.2);
1412
1413 aggregator.add_line("ERROR 404 timeout").unwrap();
1415 assert_eq!(aggregator.lines.len(), 1);
1416
1417 aggregator.add_line("ERROR 500 timeout").unwrap();
1419 assert_eq!(aggregator.lines.len(), 1); assert_eq!(aggregator.lines[0].count, 2);
1421
1422 aggregator
1424 .add_line("WARNING database connection failed")
1425 .unwrap();
1426 assert_eq!(aggregator.lines.len(), 2); aggregator
1430 .add_line("WARNING database connection timed out")
1431 .unwrap();
1432 assert_eq!(aggregator.lines.len(), 2); assert_eq!(aggregator.lines[1].count, 2); }
1435
1436 #[test]
1437 fn test_aggregation_of_similar_log_lines() {
1438 let mut aggregator = Aggregator::new_with_threshold(0.2);
1439
1440 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/i/ims-ansi, falling back to search/replace to update prefixes for this file.").unwrap();
1442 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,351 conda-unpack-fb:292] Found invalid offsets for lib/pkgconfig/ncursesw.pc, falling back to search/replace to update prefixes for this file.").unwrap();
1443 aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/k/kt7, falling back to search/replace to update prefixes for this file.").unwrap();
1444
1445 assert_eq!(aggregator.lines.len(), 1);
1447
1448 assert_eq!(aggregator.lines[0].count, 3);
1450 }
1451}