hyperactor_telemetry/
sqlite.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::Arc;
13use std::sync::Mutex;
14
15use anyhow::Result;
16use anyhow::anyhow;
17use lazy_static::lazy_static;
18use rusqlite::Connection;
19use rusqlite::functions::FunctionFlags;
20use serde::Serialize;
21use serde_json::Value as JValue;
22use serde_rusqlite::*;
23use tracing::Event;
24use tracing::Subscriber;
25use tracing_subscriber::Layer;
26use tracing_subscriber::Registry;
27use tracing_subscriber::prelude::*;
28use tracing_subscriber::reload;
29
30pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
31
32lazy_static! {
33    // Reload handle allows us to include a no-op layer during init, but load
34    // the layer dynamically during tests.
35    static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
36        Mutex::new(None);
37}
38pub trait TableDef {
39    fn name(&self) -> &'static str;
40    fn columns(&self) -> &'static [&'static str];
41    fn create_table_stmt(&self) -> String {
42        let name = self.name();
43        let columns = self
44            .columns()
45            .iter()
46            .map(|col| format!("{col} TEXT "))
47            .collect::<Vec<String>>()
48            .join(",");
49        format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
50    }
51    fn insert_stmt(&self) -> String {
52        let name = self.name();
53        let columns = self.columns().join(", ");
54        let params = self
55            .columns()
56            .iter()
57            .map(|c| format!(":{c}"))
58            .collect::<Vec<String>>()
59            .join(", ");
60        format!("insert into {name} ({columns}) values ({params})")
61    }
62}
63
64impl TableDef for (&'static str, &'static [&'static str]) {
65    fn name(&self) -> &'static str {
66        self.0
67    }
68
69    fn columns(&self) -> &'static [&'static str] {
70        self.1
71    }
72}
73
74#[derive(Clone, Debug)]
75pub struct Table {
76    pub columns: &'static [&'static str],
77    pub create_table_stmt: String,
78    pub insert_stmt: String,
79}
80
81impl From<(&'static str, &'static [&'static str])> for Table {
82    fn from(value: (&'static str, &'static [&'static str])) -> Self {
83        Self {
84            columns: value.columns(),
85            create_table_stmt: value.create_table_stmt(),
86            insert_stmt: value.insert_stmt(),
87        }
88    }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum TableName {
93    ActorLifecycle,
94    Messages,
95    LogEvents,
96}
97
98impl TableName {
99    pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
100    pub const MESSAGES_STR: &'static str = "messages";
101    pub const LOG_EVENTS_STR: &'static str = "log_events";
102
103    pub fn as_str(&self) -> &'static str {
104        match self {
105            TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
106            TableName::Messages => Self::MESSAGES_STR,
107            TableName::LogEvents => Self::LOG_EVENTS_STR,
108        }
109    }
110
111    pub fn get_table(&self) -> &'static Table {
112        match self {
113            TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
114            TableName::Messages => &MESSAGES,
115            TableName::LogEvents => &LOG_EVENTS,
116        }
117    }
118}
119
120lazy_static! {
121    static ref ACTOR_LIFECYCLE: Table = (
122        TableName::ActorLifecycle.as_str(),
123        [
124            "actor_id",
125            "actor",
126            "name",
127            "supervised_actor",
128            "actor_status",
129            "module_path",
130            "line",
131            "file",
132        ]
133        .as_slice()
134    )
135        .into();
136    static ref MESSAGES: Table = (
137        TableName::Messages.as_str(),
138        [
139            "span_id",
140            "time_us",
141            "src",
142            "dest",
143            "payload",
144            "module_path",
145            "line",
146            "file",
147        ]
148        .as_slice()
149    )
150        .into();
151    static ref LOG_EVENTS: Table = (
152        TableName::LogEvents.as_str(),
153        [
154            "span_id",
155            "time_us",
156            "name",
157            "message",
158            "actor_id",
159            "level",
160            "line",
161            "file",
162            "module_path",
163        ]
164        .as_slice()
165    )
166        .into();
167    static ref ALL_TABLES: Vec<Table> = vec![
168        ACTOR_LIFECYCLE.clone(),
169        MESSAGES.clone(),
170        LOG_EVENTS.clone()
171    ];
172}
173
174pub struct SqliteLayer {
175    conn: Arc<Mutex<Connection>>,
176}
177use tracing::field::Visit;
178
179#[derive(Debug, Clone, Default, Serialize)]
180struct SqlVisitor(HashMap<String, JValue>);
181
182impl Visit for SqlVisitor {
183    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
184        self.0.insert(
185            field.name().to_string(),
186            JValue::String(format!("{:?}", value)),
187        );
188    }
189
190    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
191        self.0
192            .insert(field.name().to_string(), JValue::String(value.to_string()));
193    }
194
195    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
196        self.0
197            .insert(field.name().to_string(), JValue::Number(value.into()));
198    }
199
200    fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
201        let n = serde_json::Number::from_f64(value).unwrap();
202        self.0.insert(field.name().to_string(), JValue::Number(n));
203    }
204
205    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
206        self.0
207            .insert(field.name().to_string(), JValue::Number(value.into()));
208    }
209
210    fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
211        self.0.insert(field.name().to_string(), JValue::Bool(value));
212    }
213}
214
215macro_rules! insert_event {
216    ($table:expr, $conn:ident, $event:ident) => {
217        let mut v: SqlVisitor = Default::default();
218        $event.record(&mut v);
219        let meta = $event.metadata();
220        v.0.insert(
221            "module_path".to_string(),
222            meta.module_path().map(String::from).into(),
223        );
224        v.0.insert("line".to_string(), meta.line().into());
225        v.0.insert("file".to_string(), meta.file().map(String::from).into());
226        $conn.prepare_cached(&$table.insert_stmt)?.execute(
227            serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
228                .to_slice()
229                .as_slice(),
230        )?;
231    };
232}
233
234impl SqliteLayer {
235    pub fn new() -> Result<Self> {
236        let conn = Connection::open_in_memory()?;
237        Self::setup_connection(conn)
238    }
239
240    pub fn new_with_file(db_path: &str) -> Result<Self> {
241        let conn = Connection::open(db_path)?;
242        Self::setup_connection(conn)
243    }
244
245    fn setup_connection(conn: Connection) -> Result<Self> {
246        for table in ALL_TABLES.iter() {
247            conn.execute(&table.create_table_stmt, [])?;
248        }
249        conn.create_scalar_function(
250            "assert",
251            2,
252            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
253            move |ctx| {
254                let condition: bool = ctx.get(0)?;
255                let message: String = ctx.get(1)?;
256
257                if !condition {
258                    return Err(rusqlite::Error::UserFunctionError(
259                        anyhow!("assertion failed:{condition} {message}",).into(),
260                    ));
261                }
262
263                Ok(condition)
264            },
265        )?;
266
267        Ok(Self {
268            conn: Arc::new(Mutex::new(conn)),
269        })
270    }
271
272    fn insert_event(&self, event: &Event<'_>) -> Result<()> {
273        let conn = self.conn.lock().unwrap();
274        match (event.metadata().target(), event.metadata().name()) {
275            (TableName::MESSAGES_STR, _) => {
276                insert_event!(TableName::Messages.get_table(), conn, event);
277            }
278            (TableName::ACTOR_LIFECYCLE_STR, _) => {
279                insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
280            }
281            _ => {
282                insert_event!(TableName::LogEvents.get_table(), conn, event);
283            }
284        }
285        Ok(())
286    }
287
288    pub fn connection(&self) -> Arc<Mutex<Connection>> {
289        self.conn.clone()
290    }
291}
292
293impl<S: Subscriber> Layer<S> for SqliteLayer {
294    fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
295        self.insert_event(event).unwrap();
296    }
297}
298
299#[allow(dead_code)]
300fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
301    let table_name_str = table_name.as_str();
302
303    // Get column names
304    let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
305    let column_info = stmt.query_map([], |row| {
306        row.get::<_, String>(1) // Column name is at index 1
307    })?;
308
309    let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
310
311    // Print header
312    println!("=== {} ===", table_name_str.to_uppercase());
313    println!("{}", columns.join(" | "));
314    println!("{}", "-".repeat(columns.len() * 10));
315
316    // Print rows
317    let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
318    let rows = stmt.query_map([], |row| {
319        let mut values = Vec::new();
320        for (i, column) in columns.iter().enumerate() {
321            // Handle different column types properly
322            let value = if i == 0 && *column == "seq" {
323                // First column is always the INTEGER seq column
324                match row.get::<_, Option<i64>>(i)? {
325                    Some(v) => v.to_string(),
326                    None => "NULL".to_string(),
327                }
328            } else {
329                // All other columns are TEXT
330                match row.get::<_, Option<String>>(i)? {
331                    Some(v) => v,
332                    None => "NULL".to_string(),
333                }
334            };
335            values.push(value);
336        }
337        Ok(values.join(" | "))
338    })?;
339
340    for row in rows {
341        println!("{}", row?);
342    }
343    println!();
344    Ok(())
345}
346
347fn init_tracing_subscriber(layer: SqliteLayer) {
348    let handle = RELOAD_HANDLE.lock().unwrap();
349    if let Some(reload_handle) = handle.as_ref() {
350        let _ = reload_handle.reload(layer);
351    } else {
352        tracing_subscriber::registry().with(layer).init();
353    }
354}
355
356// === API ===
357
358// Creates a new reload handler and no-op layer for initialization
359pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
360    let (layer, reload_handle) = reload::Layer::new(None);
361    let mut handle = RELOAD_HANDLE.lock().unwrap();
362    *handle = Some(reload_handle);
363    Ok(layer)
364}
365
366/// RAII guard for SQLite tracing database
367pub struct SqliteTracing {
368    db_path: Option<PathBuf>,
369    connection: Arc<Mutex<Connection>>,
370}
371
372impl SqliteTracing {
373    /// Create a new SqliteTracing with a temporary file
374    pub fn new() -> Result<Self> {
375        let temp_dir = std::env::temp_dir();
376        let file_name = format!("hyperactor_trace_{}.db", std::process::id());
377        let db_path = temp_dir.join(file_name);
378
379        let db_path_str = db_path.to_string_lossy();
380        let layer = SqliteLayer::new_with_file(&db_path_str)?;
381        let connection = layer.connection();
382
383        init_tracing_subscriber(layer);
384
385        Ok(Self {
386            db_path: Some(db_path),
387            connection,
388        })
389    }
390
391    /// Create a new SqliteTracing with in-memory database
392    pub fn new_in_memory() -> Result<Self> {
393        let layer = SqliteLayer::new()?;
394        let connection = layer.connection();
395
396        init_tracing_subscriber(layer);
397
398        Ok(Self {
399            db_path: None,
400            connection,
401        })
402    }
403
404    /// Get the path to the temporary database file (None for in-memory)
405    pub fn db_path(&self) -> Option<&PathBuf> {
406        self.db_path.as_ref()
407    }
408
409    /// Get a reference to the database connection
410    pub fn connection(&self) -> Arc<Mutex<Connection>> {
411        self.connection.clone()
412    }
413}
414
415impl Drop for SqliteTracing {
416    fn drop(&mut self) {
417        // Reset the layer to None
418        let handle = RELOAD_HANDLE.lock().unwrap();
419        if let Some(reload_handle) = handle.as_ref() {
420            let _ = reload_handle.reload(None);
421        }
422
423        // Delete the temporary file if it exists
424        if let Some(db_path) = &self.db_path {
425            if db_path.exists() {
426                let _ = fs::remove_file(db_path);
427            }
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use tracing::info;
435
436    use super::*;
437
438    #[test]
439    fn test_sqlite_tracing_with_file() -> Result<()> {
440        let tracing = SqliteTracing::new()?;
441        let conn = tracing.connection();
442
443        info!(target:"messages", test_field = "test_value", "Test msg");
444        info!(target:"log_events", test_field = "test_value", "Test event");
445
446        let count: i64 =
447            conn.lock()
448                .unwrap()
449                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
450        print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
451        assert!(count > 0);
452
453        // Verify we have a file path
454        assert!(tracing.db_path().is_some());
455        let db_path = tracing.db_path().unwrap();
456        assert!(db_path.exists());
457
458        Ok(())
459    }
460
461    #[test]
462    fn test_sqlite_tracing_in_memory() -> Result<()> {
463        let tracing = SqliteTracing::new_in_memory()?;
464        let conn = tracing.connection();
465
466        info!(target:"messages", test_field = "test_value", "Test event in memory");
467
468        let count: i64 =
469            conn.lock()
470                .unwrap()
471                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
472        print_table(&conn.lock().unwrap(), TableName::Messages)?;
473        assert!(count > 0);
474
475        // Verify we don't have a file path for in-memory
476        assert!(tracing.db_path().is_none());
477
478        Ok(())
479    }
480
481    #[test]
482    fn test_sqlite_tracing_cleanup() -> Result<()> {
483        let db_path = {
484            let tracing = SqliteTracing::new()?;
485            let conn = tracing.connection();
486
487            info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
488
489            let count: i64 =
490                conn.lock()
491                    .unwrap()
492                    .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
493            assert!(count > 0);
494
495            tracing.db_path().unwrap().clone()
496        }; // tracing goes out of scope here, triggering Drop
497
498        // File should be cleaned up after Drop
499        assert!(!db_path.exists());
500
501        Ok(())
502    }
503
504    #[test]
505    fn test_sqlite_tracing_different_targets() -> Result<()> {
506        let tracing = SqliteTracing::new_in_memory()?;
507        let conn = tracing.connection();
508
509        // Test different event targets
510        info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
511        info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
512        info!(target:"log_events", test_field = "general_event", "General event");
513
514        // Check that events went to the right tables
515        let message_count: i64 =
516            conn.lock()
517                .unwrap()
518                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
519        assert_eq!(message_count, 1);
520
521        let lifecycle_count: i64 =
522            conn.lock()
523                .unwrap()
524                .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
525        assert_eq!(lifecycle_count, 1);
526
527        let events_count: i64 =
528            conn.lock()
529                .unwrap()
530                .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
531        assert_eq!(events_count, 1);
532
533        Ok(())
534    }
535}