hyperactor_mesh/
logging.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fmt;
11use std::path::Path;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::Context as TaskContext;
16use std::task::Poll;
17use std::time::Duration;
18use std::time::SystemTime;
19
20use anyhow::Result;
21use async_trait::async_trait;
22use chrono::DateTime;
23use chrono::Local;
24use hyperactor::Actor;
25use hyperactor::ActorRef;
26use hyperactor::Bind;
27use hyperactor::Context;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::Named;
32use hyperactor::OncePortRef;
33use hyperactor::RefClient;
34use hyperactor::Unbind;
35use hyperactor::channel;
36use hyperactor::channel::ChannelAddr;
37use hyperactor::channel::ChannelRx;
38use hyperactor::channel::ChannelTransport;
39use hyperactor::channel::ChannelTx;
40use hyperactor::channel::Rx;
41use hyperactor::channel::Tx;
42use hyperactor::channel::TxStatus;
43use hyperactor::clock::Clock;
44use hyperactor::clock::RealClock;
45use hyperactor::data::Serialized;
46use hyperactor_telemetry::env;
47use hyperactor_telemetry::log_file_path;
48use serde::Deserialize;
49use serde::Serialize;
50use tokio::io;
51use tokio::sync::Mutex;
52use tokio::sync::watch::Receiver;
53
54use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL;
55
56mod line_prefixing_writer;
57use line_prefixing_writer::LinePrefixingWriter;
58
59const DEFAULT_AGGREGATE_WINDOW_SEC: u64 = 5;
60
61/// Calculate the Levenshtein distance between two strings
62fn levenshtein_distance(left: &str, right: &str) -> usize {
63    let left_chars: Vec<char> = left.chars().collect();
64    let right_chars: Vec<char> = right.chars().collect();
65
66    let left_len = left_chars.len();
67    let right_len = right_chars.len();
68
69    // Handle edge cases
70    if left_len == 0 {
71        return right_len;
72    }
73    if right_len == 0 {
74        return left_len;
75    }
76
77    // Create a matrix of size (len_s1+1) x (len_s2+1)
78    let mut matrix = vec![vec![0; right_len + 1]; left_len + 1];
79
80    // Initialize the first row and column
81    for (i, row) in matrix.iter_mut().enumerate().take(left_len + 1) {
82        row[0] = i;
83    }
84    for (j, cell) in matrix[0].iter_mut().enumerate().take(right_len + 1) {
85        *cell = j;
86    }
87
88    // Fill the matrix
89    for i in 1..=left_len {
90        for j in 1..=right_len {
91            let cost = if left_chars[i - 1] == right_chars[j - 1] {
92                0
93            } else {
94                1
95            };
96
97            matrix[i][j] = std::cmp::min(
98                std::cmp::min(
99                    matrix[i - 1][j] + 1, // deletion
100                    matrix[i][j - 1] + 1, // insertion
101                ),
102                matrix[i - 1][j - 1] + cost, // substitution
103            );
104        }
105    }
106
107    // Return the bottom-right cell
108    matrix[left_len][right_len]
109}
110
111/// Calculate the normalized edit distance between two strings (0.0 to 1.0)
112fn normalized_edit_distance(left: &str, right: &str) -> f64 {
113    let distance = levenshtein_distance(left, right) as f64;
114    let max_len = std::cmp::max(left.len(), right.len()) as f64;
115
116    if max_len == 0.0 {
117        0.0 // Both strings are empty, so they're identical
118    } else {
119        distance / max_len
120    }
121}
122
123#[derive(Debug, Clone)]
124/// LogLine represents a single log line with its content and count
125struct LogLine {
126    content: String,
127    pub count: u64,
128}
129
130impl LogLine {
131    fn new(content: String) -> Self {
132        Self { content, count: 1 }
133    }
134}
135
136impl fmt::Display for LogLine {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        write!(
139            f,
140            "\x1b[33m[{} similar log lines]\x1b[0m {}",
141            self.count, self.content
142        )
143    }
144}
145
146#[derive(Debug, Clone)]
147/// Aggregator is a struct that holds a list of LogLines and a start time.
148/// It can aggregate new log lines to existing ones if they are "similar" based on edit distance.
149struct Aggregator {
150    lines: Vec<LogLine>,
151    start_time: SystemTime,
152    similarity_threshold: f64, // Threshold for considering two strings similar (0.0 to 1.0)
153}
154
155impl Aggregator {
156    fn new() -> Self {
157        // Default threshold: strings with normalized edit distance < 0.15 are considered similar
158        Self::new_with_threshold(0.15)
159    }
160
161    fn new_with_threshold(threshold: f64) -> Self {
162        Aggregator {
163            lines: vec![],
164            start_time: RealClock.system_time_now(),
165            similarity_threshold: threshold,
166        }
167    }
168
169    fn reset(&mut self) {
170        self.lines.clear();
171        self.start_time = RealClock.system_time_now();
172    }
173
174    fn add_line(&mut self, line: &str) -> anyhow::Result<()> {
175        // Find the most similar existing line
176        let mut best_match_idx = None;
177        let mut best_similarity = f64::MAX;
178
179        for (idx, existing_line) in self.lines.iter().enumerate() {
180            let distance = normalized_edit_distance(&existing_line.content, line);
181
182            // If this line is more similar than our current best match
183            if distance < best_similarity && distance < self.similarity_threshold {
184                best_match_idx = Some(idx);
185                best_similarity = distance;
186            }
187        }
188
189        // If we found a similar enough line, increment its count
190        if let Some(idx) = best_match_idx {
191            self.lines[idx].count += 1;
192        } else {
193            // Otherwise, add a new line
194            self.lines.push(LogLine::new(line.to_string()));
195        }
196
197        Ok(())
198    }
199
200    fn is_empty(&self) -> bool {
201        self.lines.is_empty()
202    }
203}
204
205// Helper function to format SystemTime
206fn format_system_time(time: SystemTime) -> String {
207    let datetime: DateTime<Local> = time.into();
208    datetime.format("%Y-%m-%d %H:%M:%S").to_string()
209}
210
211impl fmt::Display for Aggregator {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        // Format the start time
214        let start_time_str = format_system_time(self.start_time);
215
216        // Get and format the current time
217        let current_time = RealClock.system_time_now();
218        let end_time_str = format_system_time(current_time);
219
220        // Write the header with formatted time window
221        writeln!(
222            f,
223            "\x1b[36m>>> Aggregated Logs ({}) >>>\x1b[0m",
224            start_time_str
225        )?;
226
227        // Write each log line
228        for line in self.lines.iter() {
229            writeln!(f, "{}", line)?;
230        }
231        writeln!(
232            f,
233            "\x1b[36m<<< Aggregated Logs ({}) <<<\x1b[0m",
234            end_time_str
235        )?;
236        Ok(())
237    }
238}
239
240/// Messages that can be sent to the LogClientActor remotely.
241#[derive(
242    Debug,
243    Clone,
244    Serialize,
245    Deserialize,
246    Named,
247    Handler,
248    HandleClient,
249    RefClient
250)]
251pub enum LogMessage {
252    /// Log details
253    Log {
254        /// The hostname of the process that generated the log
255        hostname: String,
256        /// The pid of the process that generated the log
257        pid: u32,
258        /// The target output stream (stdout or stderr)
259        output_target: OutputTarget,
260        /// The log payload as bytes
261        payload: Serialized,
262    },
263
264    /// Flush the log
265    Flush {
266        /// Indicate if the current flush is synced or non-synced.
267        /// If synced, a version number is available. Otherwise, none.
268        sync_version: Option<u64>,
269    },
270}
271
272/// Messages that can be sent to the LogClient locally.
273#[derive(
274    Debug,
275    Clone,
276    Serialize,
277    Deserialize,
278    Named,
279    Handler,
280    HandleClient,
281    RefClient
282)]
283pub enum LogClientMessage {
284    SetAggregate {
285        /// The time window in seconds to aggregate logs. If None, aggregation is disabled.
286        aggregate_window_sec: Option<u64>,
287    },
288
289    /// Synchronously flush all the logs from all the procs. This is for client to call.
290    StartSyncFlush {
291        /// Expect these many procs to ack the flush message.
292        expected_procs: usize,
293        /// Return once we have received the acks from all the procs
294        reply: OncePortRef<()>,
295        /// Return to the caller the current flush version
296        version: OncePortRef<u64>,
297    },
298}
299
300/// Trait for sending logs
301#[async_trait]
302pub trait LogSender: Send + Sync {
303    /// Send a log payload in bytes
304    fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()>;
305
306    /// Flush the log channel, ensuring all messages are delivered
307    /// Returns when the flush message has been acknowledged
308    fn flush(&mut self) -> anyhow::Result<()>;
309}
310
311/// Represents the target output stream (stdout or stderr)
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
313pub enum OutputTarget {
314    /// Standard output stream
315    Stdout,
316    /// Standard error stream
317    Stderr,
318}
319
320/// Write the log to a local unix channel so some actors can listen to it and stream the log back.
321pub struct LocalLogSender {
322    hostname: String,
323    pid: u32,
324    tx: ChannelTx<LogMessage>,
325    status: Receiver<TxStatus>,
326}
327
328impl LocalLogSender {
329    fn new(log_channel: ChannelAddr, pid: u32) -> Result<Self, anyhow::Error> {
330        let tx = channel::dial::<LogMessage>(log_channel)?;
331        let status = tx.status().clone();
332
333        let hostname = hostname::get()
334            .unwrap_or_else(|_| "unknown_host".into())
335            .into_string()
336            .unwrap_or("unknown_host".to_string());
337        Ok(Self {
338            hostname,
339            pid,
340            tx,
341            status,
342        })
343    }
344}
345
346#[async_trait]
347impl LogSender for LocalLogSender {
348    fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
349        if TxStatus::Active == *self.status.borrow() {
350            // Do not use tx.send, it will block the allocator as the child process state is unknown.
351            self.tx.post(LogMessage::Log {
352                hostname: self.hostname.clone(),
353                pid: self.pid,
354                output_target: target,
355                payload: Serialized::serialize_anon(&payload)?,
356            });
357        } else {
358            tracing::debug!(
359                "log sender {} is not active, skip sending log",
360                self.tx.addr()
361            )
362        }
363
364        Ok(())
365    }
366
367    fn flush(&mut self) -> anyhow::Result<()> {
368        // send will make sure message is delivered
369        if TxStatus::Active == *self.status.borrow() {
370            // Do not use tx.send, it will block the allocator as the child process state is unknown.
371            self.tx.post(LogMessage::Flush { sync_version: None });
372        } else {
373            tracing::debug!(
374                "log sender {} is not active, skip sending flush message",
375                self.tx.addr()
376            );
377        }
378        Ok(())
379    }
380}
381
382/// A custom writer that tees to both stdout/stderr.
383/// It captures output lines and sends them to the child process.
384pub struct LogWriter<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> {
385    output_target: OutputTarget,
386    std_writer: S,
387    log_sender: T,
388}
389
390fn create_file_writer(
391    local_rank: usize,
392    output_target: OutputTarget,
393    env: env::Env,
394) -> Result<Box<dyn io::AsyncWrite + Send + Unpin + 'static>> {
395    let suffix = match output_target {
396        OutputTarget::Stderr => "stderr",
397        OutputTarget::Stdout => "stdout",
398    };
399    let (path, filename) = log_file_path(env)?;
400    let path = Path::new(&path);
401    let mut full_path = PathBuf::from(path);
402
403    // This is the PID of the "owner" of the proc mesh, the proc mesh
404    // this proc "belongs" to. In other words,the PID of the process
405    // that invokes `cmd.spawn()` (where `cmd: &mut
406    // tokio::process::Command`) to start the process that will host
407    // the proc that this file writer relates to.
408    let file_created_by_pid = std::process::id();
409
410    full_path.push(format!(
411        "{}_{}_{}.{}",
412        filename, file_created_by_pid, local_rank, suffix
413    ));
414    let file = std::fs::OpenOptions::new()
415        .create(true)
416        .append(true)
417        .open(full_path)?;
418    let tokio_file = tokio::fs::File::from_std(file);
419    // TODO: should we buffer this?
420    Ok(Box::new(tokio_file))
421}
422
423fn get_local_log_destination(
424    local_rank: usize,
425    output_target: OutputTarget,
426) -> Result<Box<dyn io::AsyncWrite + Send + Unpin>> {
427    let env: env::Env = env::Env::current();
428    Ok(match env {
429        env::Env::Test => match output_target {
430            OutputTarget::Stdout => Box::new(LinePrefixingWriter::new(local_rank, io::stdout())),
431            OutputTarget::Stderr => Box::new(LinePrefixingWriter::new(local_rank, io::stderr())),
432        },
433        env::Env::Local | env::Env::MastEmulator | env::Env::Mast => {
434            create_file_writer(local_rank, output_target, env)?
435        }
436    })
437}
438
439/// Helper function to create stdout and stderr LogWriter instances
440///
441/// # Arguments
442///
443/// * `log_channel` - The unix channel for the writer to stream logs to
444/// * `pid` - The process ID of the process
445///
446/// # Returns
447///
448/// A tuple of boxed writers for stdout and stderr
449pub fn create_log_writers(
450    local_rank: usize,
451    log_channel: ChannelAddr,
452    pid: u32,
453) -> Result<
454    (
455        Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
456        Box<dyn io::AsyncWrite + Send + Unpin + 'static>,
457    ),
458    anyhow::Error,
459> {
460    // Create LogWriter instances for stdout and stderr using the shared log sender
461    let stdout_writer = LogWriter::with_default_writer(
462        local_rank,
463        OutputTarget::Stdout,
464        LocalLogSender::new(log_channel.clone(), pid)?,
465    )?;
466    let stderr_writer = LogWriter::with_default_writer(
467        local_rank,
468        OutputTarget::Stderr,
469        LocalLogSender::new(log_channel, pid)?,
470    )?;
471
472    Ok((Box::new(stdout_writer), Box::new(stderr_writer)))
473}
474
475impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> LogWriter<T, S> {
476    /// Creates a new LogWriter.
477    ///
478    /// # Arguments
479    ///
480    /// * `output_target` - The target output stream (stdout or stderr)
481    /// * `std_writer` - The writer to use for stdout/stderr
482    /// * `log_sender` - The log sender to use for sending logs
483    pub fn new(output_target: OutputTarget, std_writer: S, log_sender: T) -> Self {
484        Self {
485            output_target,
486            std_writer,
487            log_sender,
488        }
489    }
490}
491
492impl<T: LogSender + Unpin + 'static> LogWriter<T, Box<dyn io::AsyncWrite + Send + Unpin>> {
493    /// Creates a new LogWriter with the default stdout/stderr writer.
494    ///
495    /// # Arguments
496    ///
497    /// * `output_target` - The target output stream (stdout or stderr)
498    /// * `log_sender` - The log sender to use for sending logs
499    pub fn with_default_writer(
500        local_rank: usize,
501        output_target: OutputTarget,
502        log_sender: T,
503    ) -> Result<Self> {
504        // Use a default writer based on the output target
505        let std_writer = get_local_log_destination(local_rank, output_target)?;
506
507        Ok(Self {
508            output_target,
509            std_writer,
510            log_sender,
511        })
512    }
513}
514
515impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static> io::AsyncWrite
516    for LogWriter<T, S>
517{
518    fn poll_write(
519        self: Pin<&mut Self>,
520        cx: &mut TaskContext<'_>,
521        buf: &[u8],
522    ) -> Poll<Result<usize, io::Error>> {
523        // Get a mutable reference to the std_writer field
524        let this = self.get_mut();
525
526        // First, write to stdout/stderr
527        match Pin::new(&mut this.std_writer).poll_write(cx, buf) {
528            Poll::Ready(Ok(_)) => {
529                // Forward the buffer directly to the log sender without parsing
530                let output_target = this.output_target.clone();
531                let data_to_send = buf.to_vec();
532
533                // Use the log sender directly without cloning
534                // Since LogSender::send takes &self, we don't need to clone it
535                if let Err(e) = this.log_sender.send(output_target, data_to_send) {
536                    tracing::error!("error sending log: {}", e);
537                }
538                // Return success with the full buffer size
539                Poll::Ready(Ok(buf.len()))
540            }
541            other => other, // Propagate any errors or Pending state
542        }
543    }
544
545    fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Result<(), io::Error>> {
546        let this = self.get_mut();
547
548        match Pin::new(&mut this.std_writer).poll_flush(cx) {
549            Poll::Ready(Ok(())) => {
550                if let Err(e) = this.log_sender.flush() {
551                    tracing::error!("error sending flush: {}", e);
552                }
553                Poll::Ready(Ok(()))
554            }
555            other => other, // Propagate any errors or Pending state from the std_writer flush
556        }
557    }
558
559    fn poll_shutdown(
560        self: Pin<&mut Self>,
561        cx: &mut TaskContext<'_>,
562    ) -> Poll<Result<(), io::Error>> {
563        let this = self.get_mut();
564        Pin::new(&mut this.std_writer).poll_shutdown(cx)
565    }
566}
567
568/// Messages that can be sent to the LogWriterActor
569#[derive(
570    Debug,
571    Clone,
572    Serialize,
573    Deserialize,
574    Named,
575    Handler,
576    HandleClient,
577    RefClient,
578    Bind,
579    Unbind
580)]
581pub enum LogForwardMessage {
582    /// Receive the log from the parent process and forward ti to the client.
583    Forward {},
584
585    /// If to stream the log back to the client.
586    SetMode { stream_to_client: bool },
587
588    /// Flush the log with a version number.
589    ForceSyncFlush { version: u64 },
590}
591
592/// A log forwarder that receives the log from its parent process and forward it back to the client
593#[derive(Debug)]
594#[hyperactor::export(
595    spawn = true,
596    handlers = [LogForwardMessage {cast = true}],
597)]
598pub struct LogForwardActor {
599    rx: ChannelRx<LogMessage>,
600    flush_tx: Arc<Mutex<ChannelTx<LogMessage>>>,
601    next_flush_deadline: SystemTime,
602    logging_client_ref: ActorRef<LogClientActor>,
603    stream_to_client: bool,
604}
605
606#[async_trait]
607impl Actor for LogForwardActor {
608    type Params = ActorRef<LogClientActor>;
609
610    async fn new(logging_client_ref: Self::Params) -> Result<Self> {
611        let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) {
612            Ok(channel) => channel.parse()?,
613            Err(err) => {
614                tracing::debug!(
615                    "log forwarder actor failed to read env var {}: {}",
616                    BOOTSTRAP_LOG_CHANNEL,
617                    err
618                );
619                // TODO: an empty channel to serve
620                ChannelAddr::any(ChannelTransport::Unix)
621            }
622        };
623        tracing::info!(
624            "log forwarder {} serve at {}",
625            std::process::id(),
626            log_channel
627        );
628
629        let rx = match channel::serve(log_channel.clone()).await {
630            Ok((_, rx)) => rx,
631            Err(err) => {
632                // This can happen if we are not spanwed on a separate process like local.
633                // For local mesh, log streaming anyway is not needed.
634                tracing::error!(
635                    "log forwarder actor failed to bootstrap on given channel {}: {}",
636                    log_channel,
637                    err
638                );
639                channel::serve(ChannelAddr::any(ChannelTransport::Unix))
640                    .await?
641                    .1
642            }
643        };
644
645        // Dial the same channel to send flush message to drain the log queue.
646        let flush_tx = Arc::new(Mutex::new(channel::dial::<LogMessage>(log_channel)?));
647        let now = RealClock.system_time_now();
648
649        Ok(Self {
650            rx,
651            flush_tx,
652            next_flush_deadline: now,
653            logging_client_ref,
654            stream_to_client: true,
655        })
656    }
657
658    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
659        this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
660
661        // Make sure we start the flush loop periodically so the log channel will not deadlock.
662        self.flush_tx
663            .lock()
664            .await
665            .send(LogMessage::Flush { sync_version: None })
666            .await?;
667        Ok(())
668    }
669}
670
671#[async_trait]
672#[hyperactor::forward(LogForwardMessage)]
673impl LogForwardMessageHandler for LogForwardActor {
674    async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
675        match self.rx.recv().await {
676            Ok(LogMessage::Flush { sync_version }) => {
677                let now = RealClock.system_time_now();
678                match sync_version {
679                    None => {
680                        // Schedule another flush to keep the log channel from deadlocking.
681                        let delay = Duration::from_secs(1);
682                        if now >= self.next_flush_deadline {
683                            self.next_flush_deadline = now + delay;
684                            let flush_tx = self.flush_tx.clone();
685                            tokio::spawn(async move {
686                                RealClock.sleep(delay).await;
687                                if let Err(e) = flush_tx
688                                    .lock()
689                                    .await
690                                    .send(LogMessage::Flush { sync_version: None })
691                                    .await
692                                {
693                                    tracing::error!("failed to send flush message: {}", e);
694                                }
695                            });
696                        }
697                    }
698                    version => {
699                        self.logging_client_ref.flush(ctx, version).await?;
700                    }
701                }
702            }
703            Ok(LogMessage::Log {
704                hostname,
705                pid,
706                output_target,
707                payload,
708            }) => {
709                if self.stream_to_client {
710                    self.logging_client_ref
711                        .log(ctx, hostname, pid, output_target, payload)
712                        .await?;
713                }
714            }
715            Err(e) => {
716                return Err(e.into());
717            }
718        }
719
720        // This is not ideal as we are using raw tx/rx.
721        ctx.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?;
722
723        Ok(())
724    }
725
726    async fn set_mode(
727        &mut self,
728        _ctx: &Context<Self>,
729        stream_to_client: bool,
730    ) -> Result<(), anyhow::Error> {
731        self.stream_to_client = stream_to_client;
732        Ok(())
733    }
734
735    async fn force_sync_flush(
736        &mut self,
737        _cx: &Context<Self>,
738        version: u64,
739    ) -> Result<(), anyhow::Error> {
740        self.flush_tx
741            .lock()
742            .await
743            .send(LogMessage::Flush {
744                sync_version: Some(version),
745            })
746            .await
747            .map_err(anyhow::Error::from)
748    }
749}
750
751/// Deserialize a serialized message and split it into UTF-8 lines
752fn deserialize_message_lines(
753    serialized_message: &hyperactor::data::Serialized,
754) -> Result<Vec<String>> {
755    // Try to deserialize as String first
756    if let Ok(message_str) = serialized_message.deserialized::<String>() {
757        return Ok(message_str.lines().map(|s| s.to_string()).collect());
758    }
759
760    // If that fails, try to deserialize as Vec<u8> and convert to UTF-8
761    if let Ok(message_bytes) = serialized_message.deserialized::<Vec<u8>>() {
762        let message_str = String::from_utf8(message_bytes)?;
763        return Ok(message_str.lines().map(|s| s.to_string()).collect());
764    }
765
766    // If both fail, return an error
767    anyhow::bail!("Failed to deserialize message as either String or Vec<u8>")
768}
769
770/// A client to receive logs from remote processes
771#[derive(Debug)]
772#[hyperactor::export(
773    spawn = true,
774    handlers = [LogMessage, LogClientMessage],
775)]
776pub struct LogClientActor {
777    aggregate_window_sec: Option<u64>,
778    aggregators: HashMap<OutputTarget, Aggregator>,
779    last_flush_time: SystemTime,
780    next_flush_deadline: Option<SystemTime>,
781
782    // For flush sync barrier
783    current_flush_version: u64,
784    current_flush_port: Option<OncePortRef<()>>,
785    current_unflushed_procs: usize,
786}
787
788impl LogClientActor {
789    fn print_aggregators(&mut self) {
790        for (output_target, aggregator) in self.aggregators.iter_mut() {
791            if aggregator.is_empty() {
792                continue;
793            }
794            match output_target {
795                OutputTarget::Stdout => {
796                    println!("{}", aggregator);
797                }
798                OutputTarget::Stderr => {
799                    eprintln!("{}", aggregator);
800                }
801            }
802
803            // Reset the aggregator
804            aggregator.reset();
805        }
806    }
807
808    fn print_log_line(hostname: &str, pid: u32, output_target: OutputTarget, line: String) {
809        let message = format!("[{} {}] {}", hostname, pid, line);
810        match output_target {
811            OutputTarget::Stdout => println!("{}", message),
812            OutputTarget::Stderr => eprintln!("{}", message),
813        }
814    }
815
816    fn flush_internal(&mut self) {
817        self.print_aggregators();
818        self.last_flush_time = RealClock.system_time_now();
819        self.next_flush_deadline = None;
820    }
821}
822
823#[async_trait]
824impl Actor for LogClientActor {
825    /// The aggregation window in seconds.
826    type Params = ();
827
828    async fn new(_: ()) -> Result<Self, anyhow::Error> {
829        // Initialize aggregators
830        let mut aggregators = HashMap::new();
831        aggregators.insert(OutputTarget::Stderr, Aggregator::new());
832        aggregators.insert(OutputTarget::Stdout, Aggregator::new());
833
834        Ok(Self {
835            aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC),
836            aggregators,
837            last_flush_time: RealClock.system_time_now(),
838            next_flush_deadline: None,
839            current_flush_version: 0,
840            current_flush_port: None,
841            current_unflushed_procs: 0,
842        })
843    }
844}
845
846impl Drop for LogClientActor {
847    fn drop(&mut self) {
848        // Flush the remaining logs before shutting down
849        self.print_aggregators();
850    }
851}
852
853#[async_trait]
854#[hyperactor::forward(LogMessage)]
855impl LogMessageHandler for LogClientActor {
856    async fn log(
857        &mut self,
858        cx: &Context<Self>,
859        hostname: String,
860        pid: u32,
861        output_target: OutputTarget,
862        payload: Serialized,
863    ) -> Result<(), anyhow::Error> {
864        // Deserialize the message and process line by line with UTF-8
865        let message_lines = deserialize_message_lines(&payload)?;
866        let hostname = hostname.as_str();
867
868        match self.aggregate_window_sec {
869            None => {
870                for line in message_lines {
871                    Self::print_log_line(hostname, pid, output_target, line);
872                }
873                self.last_flush_time = RealClock.system_time_now();
874            }
875            Some(window) => {
876                for line in message_lines {
877                    if let Some(aggregator) = self.aggregators.get_mut(&output_target) {
878                        if let Err(e) = aggregator.add_line(&line) {
879                            tracing::error!("error adding log line: {}", e);
880                            // For the sake of completeness, flush the log lines.
881                            Self::print_log_line(hostname, pid, output_target, line);
882                        }
883                    } else {
884                        tracing::error!("unknown output target: {:?}", output_target);
885                        // For the sake of completeness, flush the log lines.
886                        Self::print_log_line(hostname, pid, output_target, line);
887                    }
888                }
889
890                let new_deadline = self.last_flush_time + Duration::from_secs(window);
891                let now = RealClock.system_time_now();
892                if new_deadline <= now {
893                    self.flush_internal();
894                } else {
895                    let delay = new_deadline.duration_since(now)?;
896                    match self.next_flush_deadline {
897                        None => {
898                            self.next_flush_deadline = Some(new_deadline);
899                            cx.self_message_with_delay(
900                                LogMessage::Flush { sync_version: None },
901                                delay,
902                            )?;
903                        }
904                        Some(deadline) => {
905                            // Some early log lines have alrady triggered the flush.
906                            if new_deadline < deadline {
907                                // This can happen if the user has adjusted the aggregation window.
908                                self.next_flush_deadline = Some(new_deadline);
909                                cx.self_message_with_delay(
910                                    LogMessage::Flush { sync_version: None },
911                                    delay,
912                                )?;
913                            }
914                        }
915                    }
916                }
917            }
918        }
919
920        Ok(())
921    }
922
923    async fn flush(
924        &mut self,
925        cx: &Context<Self>,
926        sync_version: Option<u64>,
927    ) -> Result<(), anyhow::Error> {
928        match sync_version {
929            None => {
930                self.flush_internal();
931            }
932            Some(version) => {
933                if version != self.current_flush_version {
934                    tracing::error!(
935                        "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully",
936                        version,
937                        self.current_flush_version
938                    );
939                    return Ok(());
940                }
941
942                if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() {
943                    // This is a serious issue; it's better to error out.
944                    anyhow::bail!("found no ongoing flush request");
945                }
946                self.current_unflushed_procs -= 1;
947
948                tracing::debug!(
949                    "ack sync flush: version {}; remaining procs: {}",
950                    self.current_flush_version,
951                    self.current_unflushed_procs
952                );
953
954                if self.current_unflushed_procs == 0 {
955                    self.flush_internal();
956                    let reply = self.current_flush_port.take().unwrap();
957                    self.current_flush_port = None;
958                    reply.send(cx, ()).map_err(anyhow::Error::from)?;
959                }
960            }
961        }
962
963        Ok(())
964    }
965}
966
967#[async_trait]
968#[hyperactor::forward(LogClientMessage)]
969impl LogClientMessageHandler for LogClientActor {
970    async fn set_aggregate(
971        &mut self,
972        _cx: &Context<Self>,
973        aggregate_window_sec: Option<u64>,
974    ) -> Result<(), anyhow::Error> {
975        if self.aggregate_window_sec.is_some() && aggregate_window_sec.is_none() {
976            // Make sure we flush whatever in the aggregators before disabling aggregation.
977            self.print_aggregators();
978        }
979        self.aggregate_window_sec = aggregate_window_sec;
980        Ok(())
981    }
982
983    async fn start_sync_flush(
984        &mut self,
985        cx: &Context<Self>,
986        expected_procs_flushed: usize,
987        reply: OncePortRef<()>,
988        version: OncePortRef<u64>,
989    ) -> Result<(), anyhow::Error> {
990        if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() {
991            tracing::warn!(
992                "found unfinished ongoing flush: version {}; {} unflushed procs",
993                self.current_flush_version,
994                self.current_unflushed_procs,
995            );
996        }
997
998        self.current_flush_version += 1;
999        tracing::debug!(
1000            "start sync flush with version {}",
1001            self.current_flush_version
1002        );
1003        self.current_flush_port = Some(reply.clone());
1004        self.current_unflushed_procs = expected_procs_flushed;
1005        version
1006            .send(cx, self.current_flush_version)
1007            .map_err(anyhow::Error::from)?;
1008        Ok(())
1009    }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use std::sync::Arc;
1015    use std::sync::Mutex;
1016
1017    use hyperactor::channel;
1018    use hyperactor::channel::ChannelAddr;
1019    use hyperactor::channel::ChannelTx;
1020    use hyperactor::channel::Tx;
1021    use hyperactor::id;
1022    use hyperactor::mailbox::BoxedMailboxSender;
1023    use hyperactor::mailbox::DialMailboxRouter;
1024    use hyperactor::mailbox::MailboxServer;
1025    use hyperactor::proc::Proc;
1026    use tokio::io::AsyncWriteExt;
1027    use tokio::sync::mpsc;
1028
1029    use super::*;
1030
1031    #[tokio::test]
1032    async fn test_forwarding_log_to_client() {
1033        // Setup the basics
1034        let router = DialMailboxRouter::new();
1035        let (proc_addr, client_rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
1036            .await
1037            .unwrap();
1038        let proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone()));
1039        proc.clone().serve(client_rx);
1040        router.bind(id!(client[0]).into(), proc_addr.clone());
1041        let client = proc.attach("client").unwrap();
1042
1043        // Spin up both the forwarder and the client
1044        let log_channel = ChannelAddr::any(ChannelTransport::Unix);
1045        // SAFETY: Unit test
1046        unsafe {
1047            std::env::set_var(BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
1048        }
1049        let log_client: ActorRef<LogClientActor> =
1050            proc.spawn("log_client", ()).await.unwrap().bind();
1051        let log_forwarder: ActorRef<LogForwardActor> = proc
1052            .spawn("log_forwarder", log_client)
1053            .await
1054            .unwrap()
1055            .bind();
1056
1057        // Write some logs that will not be streamed
1058        let tx: ChannelTx<LogMessage> = channel::dial(log_channel).unwrap();
1059        tx.post(LogMessage::Log {
1060            hostname: "my_host".into(),
1061            pid: 1,
1062            output_target: OutputTarget::Stderr,
1063            payload: Serialized::serialize_anon(&"will not stream".to_string()).unwrap(),
1064        });
1065
1066        // Turn on streaming
1067        log_forwarder.set_mode(&client, true).await.unwrap();
1068        tx.post(LogMessage::Log {
1069            hostname: "my_host".into(),
1070            pid: 1,
1071            output_target: OutputTarget::Stderr,
1072            payload: Serialized::serialize_anon(&"will stream".to_string()).unwrap(),
1073        });
1074
1075        // TODO: it is hard to test out anything meaningful here as the client flushes to stdout.
1076    }
1077
1078    #[test]
1079    fn test_deserialize_message_lines_string() {
1080        // Test deserializing a String message with multiple lines
1081        let message = "Line 1\nLine 2\nLine 3".to_string();
1082        let serialized = Serialized::serialize_anon(&message).unwrap();
1083
1084        let result = deserialize_message_lines(&serialized).unwrap();
1085
1086        assert_eq!(result, vec!["Line 1", "Line 2", "Line 3"]);
1087
1088        // Test deserializing a Vec<u8> message with UTF-8 content
1089        let message_bytes = "Hello\nWorld\nUTF-8 \u{1F980}".as_bytes().to_vec();
1090        let serialized = Serialized::serialize_anon(&message_bytes).unwrap();
1091
1092        let result = deserialize_message_lines(&serialized).unwrap();
1093
1094        assert_eq!(result, vec!["Hello", "World", "UTF-8 \u{1F980}"]);
1095
1096        // Test deserializing a single line message
1097        let message = "Single line message".to_string();
1098        let serialized = Serialized::serialize_anon(&message).unwrap();
1099
1100        let result = deserialize_message_lines(&serialized).unwrap();
1101
1102        assert_eq!(result, vec!["Single line message"]);
1103
1104        // Test deserializing an empty lines
1105        let message = "\n\n".to_string();
1106        let serialized = Serialized::serialize_anon(&message).unwrap();
1107
1108        let result = deserialize_message_lines(&serialized).unwrap();
1109
1110        assert_eq!(result, vec!["", ""]);
1111
1112        // Test error handling for invalid UTF-8 bytes
1113        let invalid_utf8_bytes = vec![0xFF, 0xFE, 0xFD]; // Invalid UTF-8 sequence
1114        let serialized = Serialized::serialize_anon(&invalid_utf8_bytes).unwrap();
1115
1116        let result = deserialize_message_lines(&serialized);
1117
1118        assert!(result.is_err());
1119        assert!(result.unwrap_err().to_string().contains("invalid utf-8"));
1120    }
1121
1122    // Mock implementation of AsyncWrite that captures written data
1123    struct MockWriter {
1124        data: Arc<Mutex<Vec<u8>>>,
1125    }
1126
1127    impl MockWriter {
1128        fn new() -> (Self, Arc<Mutex<Vec<u8>>>) {
1129            let data = Arc::new(Mutex::new(Vec::new()));
1130            (Self { data: data.clone() }, data)
1131        }
1132    }
1133
1134    impl io::AsyncWrite for MockWriter {
1135        fn poll_write(
1136            self: Pin<&mut Self>,
1137            _cx: &mut TaskContext<'_>,
1138            buf: &[u8],
1139        ) -> Poll<Result<usize, io::Error>> {
1140            let mut data = self.data.lock().unwrap();
1141            data.extend_from_slice(buf);
1142            Poll::Ready(Ok(buf.len()))
1143        }
1144
1145        fn poll_flush(
1146            self: Pin<&mut Self>,
1147            _cx: &mut TaskContext<'_>,
1148        ) -> Poll<Result<(), io::Error>> {
1149            Poll::Ready(Ok(()))
1150        }
1151
1152        fn poll_shutdown(
1153            self: Pin<&mut Self>,
1154            _cx: &mut TaskContext<'_>,
1155        ) -> Poll<Result<(), io::Error>> {
1156            Poll::Ready(Ok(()))
1157        }
1158    }
1159
1160    // Mock implementation of LogSender for testing
1161    struct MockLogSender {
1162        log_sender: mpsc::UnboundedSender<(OutputTarget, String)>, // (output_target, content)
1163        flush_called: Arc<Mutex<bool>>,                            // Track if flush was called
1164    }
1165
1166    impl MockLogSender {
1167        fn new(log_sender: mpsc::UnboundedSender<(OutputTarget, String)>) -> Self {
1168            Self {
1169                log_sender,
1170                flush_called: Arc::new(Mutex::new(false)),
1171            }
1172        }
1173    }
1174
1175    #[async_trait]
1176    impl LogSender for MockLogSender {
1177        fn send(&mut self, output_target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
1178            // For testing purposes, convert to string if it's valid UTF-8
1179            let line = match std::str::from_utf8(&payload) {
1180                Ok(s) => s.to_string(),
1181                Err(_) => String::from_utf8_lossy(&payload).to_string(),
1182            };
1183
1184            self.log_sender
1185                .send((output_target, line))
1186                .map_err(|e| anyhow::anyhow!("Failed to send log in test: {}", e))
1187        }
1188
1189        fn flush(&mut self) -> anyhow::Result<()> {
1190            // Mark that flush was called
1191            let mut flush_called = self.flush_called.lock().unwrap();
1192            *flush_called = true;
1193
1194            // For testing purposes, just return Ok
1195            // In a real implementation, this would wait for all messages to be delivered
1196            Ok(())
1197        }
1198    }
1199
1200    #[tokio::test]
1201    async fn test_log_writer_direct_forwarding() {
1202        // Create a channel to receive logs
1203        let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1204
1205        // Create a mock log sender
1206        let mock_log_sender = MockLogSender::new(log_sender);
1207
1208        // Create a mock writer for stdout
1209        let (mock_writer, _) = MockWriter::new();
1210        let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1211
1212        // Create a log writer with the mock sender
1213        let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1214
1215        // Write some data
1216        writer.write_all(b"Hello, world!").await.unwrap();
1217        writer.flush().await.unwrap();
1218
1219        // Check that the log was sent as is
1220        let (output_target, content) = log_receiver.recv().await.unwrap();
1221        assert_eq!(output_target, OutputTarget::Stdout);
1222        assert_eq!(content, "Hello, world!");
1223
1224        // Write more data
1225        writer.write_all(b"\nNext line").await.unwrap();
1226        writer.flush().await.unwrap();
1227
1228        // Check that the second chunk was sent as is
1229        let (output_target, content) = log_receiver.recv().await.unwrap();
1230        assert_eq!(output_target, OutputTarget::Stdout);
1231        assert_eq!(content, "\nNext line");
1232    }
1233
1234    #[tokio::test]
1235    async fn test_log_writer_stdout_stderr() {
1236        // Create a channel to receive logs
1237        let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1238
1239        // Create mock log senders for stdout and stderr
1240        let stdout_sender = MockLogSender::new(log_sender.clone());
1241        let stderr_sender = MockLogSender::new(log_sender);
1242
1243        // Create mock writers for stdout and stderr
1244        let (stdout_mock_writer, _) = MockWriter::new();
1245        let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1246
1247        let (stderr_mock_writer, _) = MockWriter::new();
1248        let stderr_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stderr_mock_writer);
1249
1250        // Create log writers with the mock senders
1251        let mut stdout_writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, stdout_sender);
1252        let mut stderr_writer = LogWriter::new(OutputTarget::Stderr, stderr_writer, stderr_sender);
1253
1254        // Write to stdout and stderr
1255        stdout_writer.write_all(b"Stdout data").await.unwrap();
1256        stdout_writer.flush().await.unwrap();
1257
1258        stderr_writer.write_all(b"Stderr data").await.unwrap();
1259        stderr_writer.flush().await.unwrap();
1260
1261        // Check that logs were sent with correct output targets
1262        // Note: We can't guarantee the order of reception since they're sent from different tasks
1263        let mut received_stdout = false;
1264        let mut received_stderr = false;
1265
1266        for _ in 0..2 {
1267            let (output_target, content) = log_receiver.recv().await.unwrap();
1268            match output_target {
1269                OutputTarget::Stdout => {
1270                    assert_eq!(content, "Stdout data");
1271                    received_stdout = true;
1272                }
1273                OutputTarget::Stderr => {
1274                    assert_eq!(content, "Stderr data");
1275                    received_stderr = true;
1276                }
1277            }
1278        }
1279
1280        assert!(received_stdout);
1281        assert!(received_stderr);
1282    }
1283
1284    #[tokio::test]
1285    async fn test_log_writer_binary_data() {
1286        // Create a channel to receive logs
1287        let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1288
1289        // Create a mock log sender
1290        let mock_log_sender = MockLogSender::new(log_sender);
1291
1292        // Create a mock writer for stdout
1293        let (mock_writer, _) = MockWriter::new();
1294        let std_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(mock_writer);
1295
1296        // Create a log writer with the mock sender
1297        let mut writer = LogWriter::new(OutputTarget::Stdout, std_writer, mock_log_sender);
1298
1299        // Write binary data (including non-UTF8 bytes)
1300        let binary_data = vec![0x48, 0x65, 0x6C, 0x6C, 0x6F, 0xFF, 0xFE, 0x00];
1301        writer.write_all(&binary_data).await.unwrap();
1302        writer.flush().await.unwrap();
1303
1304        // Check that the log was sent and converted to string (with lossy UTF-8 conversion in MockLogSender)
1305        let (output_target, content) = log_receiver.recv().await.unwrap();
1306        assert_eq!(output_target, OutputTarget::Stdout);
1307        // The content should be "Hello" followed by replacement characters for invalid bytes
1308        assert!(content.starts_with("Hello"));
1309        // The rest of the content will be replacement characters, but we don't care about the exact representation
1310    }
1311
1312    #[tokio::test]
1313    async fn test_log_writer_poll_flush() {
1314        // Create a channel to receive logs
1315        let (log_sender, _log_receiver) = mpsc::unbounded_channel();
1316
1317        // Create a mock log sender that tracks flush calls
1318        let mock_log_sender = MockLogSender::new(log_sender);
1319        let log_sender_flush_tracker = mock_log_sender.flush_called.clone();
1320
1321        // Create mock writers for stdout and stderr
1322        let (stdout_mock_writer, _) = MockWriter::new();
1323        let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1324
1325        // Create a log writer with the mocks
1326        let mut writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, mock_log_sender);
1327
1328        // Call flush on the writer
1329        writer.flush().await.unwrap();
1330
1331        // Verify that log sender's flush were called
1332        assert!(
1333            *log_sender_flush_tracker.lock().unwrap(),
1334            "LogSender's flush was not called"
1335        );
1336    }
1337
1338    #[test]
1339    fn test_string_similarity() {
1340        // Test exact match
1341        assert_eq!(normalized_edit_distance("hello", "hello"), 0.0);
1342
1343        // Test completely different strings
1344        assert_eq!(normalized_edit_distance("hello", "i'mdiff"), 1.0);
1345
1346        // Test similar strings
1347        assert!(normalized_edit_distance("hello", "helo") < 0.5);
1348        assert!(normalized_edit_distance("hello", "hello!") < 0.5);
1349
1350        // Test empty strings
1351        assert_eq!(normalized_edit_distance("", ""), 0.0);
1352        assert_eq!(normalized_edit_distance("hello", ""), 1.0);
1353    }
1354
1355    #[test]
1356    fn test_add_line_to_empty_aggregator() {
1357        let mut aggregator = Aggregator::new();
1358        let result = aggregator.add_line("ERROR 404 not found");
1359
1360        assert!(result.is_ok());
1361        assert_eq!(aggregator.lines.len(), 1);
1362        assert_eq!(aggregator.lines[0].content, "ERROR 404 not found");
1363        assert_eq!(aggregator.lines[0].count, 1);
1364    }
1365
1366    #[test]
1367    fn test_add_line_merges_with_similar_line() {
1368        let mut aggregator = Aggregator::new_with_threshold(0.2);
1369
1370        // Add first line
1371        aggregator.add_line("ERROR 404 timeout").unwrap();
1372        assert_eq!(aggregator.lines.len(), 1);
1373
1374        // Add second line that should merge (similar enough)
1375        aggregator.add_line("ERROR 500 timeout").unwrap();
1376        assert_eq!(aggregator.lines.len(), 1); // Should still be 1 line after merge
1377        assert_eq!(aggregator.lines[0].count, 2);
1378
1379        // Add third line that's too different
1380        aggregator
1381            .add_line("WARNING database connection failed")
1382            .unwrap();
1383        assert_eq!(aggregator.lines.len(), 2); // Should be 2 lines now
1384
1385        // Add fourth line similar to third
1386        aggregator
1387            .add_line("WARNING database connection timed out")
1388            .unwrap();
1389        assert_eq!(aggregator.lines.len(), 2); // Should still be 2 lines
1390        assert_eq!(aggregator.lines[1].count, 2); // Second group has 2 lines
1391    }
1392
1393    #[test]
1394    fn test_aggregation_of_similar_log_lines() {
1395        let mut aggregator = Aggregator::new_with_threshold(0.2);
1396
1397        // Add the provided log lines with small differences
1398        aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/i/ims-ansi, falling back to search/replace to update prefixes for this file.").unwrap();
1399        aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,351 conda-unpack-fb:292] Found invalid offsets for lib/pkgconfig/ncursesw.pc, falling back to search/replace to update prefixes for this file.").unwrap();
1400        aggregator.add_line("[1 similar log lines] WARNING <<2025, 2025>> -07-30 <<0, 0>> :41:45,366 conda-unpack-fb:292] Found invalid offsets for share/terminfo/k/kt7, falling back to search/replace to update prefixes for this file.").unwrap();
1401
1402        // Check that we have only one aggregated line due to similarity
1403        assert_eq!(aggregator.lines.len(), 1);
1404
1405        // Check that the count is 3
1406        assert_eq!(aggregator.lines[0].count, 3);
1407    }
1408}