hyperactor_telemetry/
sqlite.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fs;
11use std::os::unix::fs::PermissionsExt;
12use std::path::PathBuf;
13use std::sync::Arc;
14use std::sync::Mutex;
15
16use anyhow::Result;
17use anyhow::anyhow;
18use lazy_static::lazy_static;
19use rusqlite::Connection;
20use rusqlite::functions::FunctionFlags;
21use serde::Serialize;
22use serde_json::Value as JValue;
23use serde_rusqlite::*;
24use tracing::Event;
25use tracing::Subscriber;
26use tracing_subscriber::Layer;
27use tracing_subscriber::Registry;
28use tracing_subscriber::prelude::*;
29use tracing_subscriber::reload;
30
31pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
32
33lazy_static! {
34    // Reload handle allows us to include a no-op layer during init, but load
35    // the layer dynamically during tests.
36    static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
37        Mutex::new(None);
38}
39pub trait TableDef {
40    fn name(&self) -> &'static str;
41    fn columns(&self) -> &'static [&'static str];
42    fn create_table_stmt(&self) -> String {
43        let name = self.name();
44        let columns = self
45            .columns()
46            .iter()
47            .map(|col| format!("{col} TEXT "))
48            .collect::<Vec<String>>()
49            .join(",");
50        format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
51    }
52    fn insert_stmt(&self) -> String {
53        let name = self.name();
54        let columns = self.columns().join(", ");
55        let params = self
56            .columns()
57            .iter()
58            .map(|c| format!(":{c}"))
59            .collect::<Vec<String>>()
60            .join(", ");
61        format!("insert into {name} ({columns}) values ({params})")
62    }
63}
64
65impl TableDef for (&'static str, &'static [&'static str]) {
66    fn name(&self) -> &'static str {
67        self.0
68    }
69
70    fn columns(&self) -> &'static [&'static str] {
71        self.1
72    }
73}
74
75#[derive(Clone, Debug)]
76pub struct Table {
77    pub columns: &'static [&'static str],
78    pub create_table_stmt: String,
79    pub insert_stmt: String,
80}
81
82impl From<(&'static str, &'static [&'static str])> for Table {
83    fn from(value: (&'static str, &'static [&'static str])) -> Self {
84        Self {
85            columns: value.columns(),
86            create_table_stmt: value.create_table_stmt(),
87            insert_stmt: value.insert_stmt(),
88        }
89    }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum TableName {
94    ActorLifecycle,
95    Messages,
96    LogEvents,
97}
98
99impl TableName {
100    pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
101    pub const MESSAGES_STR: &'static str = "messages";
102    pub const LOG_EVENTS_STR: &'static str = "log_events";
103
104    pub fn as_str(&self) -> &'static str {
105        match self {
106            TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
107            TableName::Messages => Self::MESSAGES_STR,
108            TableName::LogEvents => Self::LOG_EVENTS_STR,
109        }
110    }
111
112    pub fn get_table(&self) -> &'static Table {
113        match self {
114            TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
115            TableName::Messages => &MESSAGES,
116            TableName::LogEvents => &LOG_EVENTS,
117        }
118    }
119}
120
121lazy_static! {
122    static ref ACTOR_LIFECYCLE: Table = (
123        TableName::ActorLifecycle.as_str(),
124        [
125            "actor_id",
126            "actor",
127            "name",
128            "supervised_actor",
129            "actor_status",
130            "module_path",
131            "line",
132            "file",
133        ]
134        .as_slice()
135    )
136        .into();
137    static ref MESSAGES: Table = (
138        TableName::Messages.as_str(),
139        [
140            "span_id",
141            "time_us",
142            "src",
143            "dest",
144            "payload",
145            "module_path",
146            "line",
147            "file",
148        ]
149        .as_slice()
150    )
151        .into();
152    static ref LOG_EVENTS: Table = (
153        TableName::LogEvents.as_str(),
154        [
155            "span_id",
156            "time_us",
157            "name",
158            "message",
159            "actor_id",
160            "level",
161            "line",
162            "file",
163            "module_path",
164        ]
165        .as_slice()
166    )
167        .into();
168    pub static ref ALL_TABLES: Vec<Table> = vec![
169        ACTOR_LIFECYCLE.clone(),
170        MESSAGES.clone(),
171        LOG_EVENTS.clone()
172    ];
173}
174
175pub struct SqliteLayer {
176    conn: Arc<Mutex<Connection>>,
177}
178use tracing::field::Visit;
179
180#[derive(Debug, Clone, Default, Serialize)]
181pub struct SqlVisitor(pub HashMap<String, JValue>);
182
183impl Visit for SqlVisitor {
184    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
185        self.0.insert(
186            field.name().to_string(),
187            JValue::String(format!("{:?}", value)),
188        );
189    }
190
191    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
192        self.0
193            .insert(field.name().to_string(), JValue::String(value.to_string()));
194    }
195
196    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
197        self.0
198            .insert(field.name().to_string(), JValue::Number(value.into()));
199    }
200
201    fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
202        let n = serde_json::Number::from_f64(value).unwrap();
203        self.0.insert(field.name().to_string(), JValue::Number(n));
204    }
205
206    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
207        self.0
208            .insert(field.name().to_string(), JValue::Number(value.into()));
209    }
210
211    fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
212        self.0.insert(field.name().to_string(), JValue::Bool(value));
213    }
214}
215
216macro_rules! insert_event {
217    ($table:expr, $conn:ident, $event:ident) => {
218        let mut v: SqlVisitor = Default::default();
219        $event.record(&mut v);
220        let meta = $event.metadata();
221        v.0.insert(
222            "module_path".to_string(),
223            meta.module_path().map(String::from).into(),
224        );
225        v.0.insert("line".to_string(), meta.line().into());
226        v.0.insert("file".to_string(), meta.file().map(String::from).into());
227        $conn.prepare_cached(&$table.insert_stmt)?.execute(
228            serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
229                .to_slice()
230                .as_slice(),
231        )?;
232    };
233}
234
235/// Public helper to insert event fields into database using the same logic as the old implementation.
236/// This is used by the unified SqliteExporter to ensure identical behavior.
237pub fn insert_event_fields(conn: &Connection, table: &Table, fields: SqlVisitor) -> Result<()> {
238    conn.prepare_cached(&table.insert_stmt)?.execute(
239        serde_rusqlite::to_params_named_with_fields(fields, table.columns)?
240            .to_slice()
241            .as_slice(),
242    )?;
243    Ok(())
244}
245
246impl SqliteLayer {
247    pub fn new() -> Result<Self> {
248        let conn = Connection::open_in_memory()?;
249        Self::setup_connection(conn)
250    }
251
252    pub fn new_with_file(db_path: &str) -> Result<Self> {
253        let conn = Connection::open(db_path)?;
254        Self::setup_connection(conn)
255    }
256
257    fn setup_connection(conn: Connection) -> Result<Self> {
258        for table in ALL_TABLES.iter() {
259            conn.execute(&table.create_table_stmt, [])?;
260        }
261        conn.create_scalar_function(
262            "assert",
263            2,
264            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
265            move |ctx| {
266                let condition: bool = ctx.get(0)?;
267                let message: String = ctx.get(1)?;
268
269                if !condition {
270                    return Err(rusqlite::Error::UserFunctionError(
271                        anyhow!("assertion failed:{condition} {message}",).into(),
272                    ));
273                }
274
275                Ok(condition)
276            },
277        )?;
278
279        Ok(Self {
280            conn: Arc::new(Mutex::new(conn)),
281        })
282    }
283
284    fn insert_event(&self, event: &Event<'_>) -> Result<()> {
285        let conn = self.conn.lock().unwrap();
286        match (event.metadata().target(), event.metadata().name()) {
287            (TableName::MESSAGES_STR, _) => {
288                insert_event!(TableName::Messages.get_table(), conn, event);
289            }
290            (TableName::ACTOR_LIFECYCLE_STR, _) => {
291                insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
292            }
293            _ => {
294                insert_event!(TableName::LogEvents.get_table(), conn, event);
295            }
296        }
297        Ok(())
298    }
299
300    pub fn connection(&self) -> Arc<Mutex<Connection>> {
301        self.conn.clone()
302    }
303}
304
305impl<S: Subscriber> Layer<S> for SqliteLayer {
306    fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
307        self.insert_event(event).unwrap();
308    }
309}
310
311#[allow(dead_code)]
312fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
313    let table_name_str = table_name.as_str();
314
315    // Get column names
316    let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
317    let column_info = stmt.query_map([], |row| {
318        row.get::<_, String>(1) // Column name is at index 1
319    })?;
320
321    let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
322
323    // Print header
324    println!("=== {} ===", table_name_str.to_uppercase());
325    println!("{}", columns.join(" | "));
326    println!("{}", "-".repeat(columns.len() * 10));
327
328    // Print rows
329    let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
330    let rows = stmt.query_map([], |row| {
331        let mut values = Vec::new();
332        for (i, column) in columns.iter().enumerate() {
333            // Handle different column types properly
334            let value = if i == 0 && *column == "seq" {
335                // First column is always the INTEGER seq column
336                match row.get::<_, Option<i64>>(i)? {
337                    Some(v) => v.to_string(),
338                    None => "NULL".to_string(),
339                }
340            } else {
341                // All other columns are TEXT
342                match row.get::<_, Option<String>>(i)? {
343                    Some(v) => v,
344                    None => "NULL".to_string(),
345                }
346            };
347            values.push(value);
348        }
349        Ok(values.join(" | "))
350    })?;
351
352    for row in rows {
353        println!("{}", row?);
354    }
355    println!();
356    Ok(())
357}
358
359fn init_tracing_subscriber(layer: SqliteLayer) {
360    let handle = RELOAD_HANDLE.lock().unwrap();
361    if let Some(reload_handle) = handle.as_ref() {
362        let _ = reload_handle.reload(layer);
363    } else {
364        tracing_subscriber::registry().with(layer).init();
365    }
366}
367
368// === API ===
369
370// Creates a new reload handler and no-op layer for initialization
371pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
372    let (layer, reload_handle) = reload::Layer::new(None);
373    let mut handle = RELOAD_HANDLE.lock().unwrap();
374    *handle = Some(reload_handle);
375    Ok(layer)
376}
377
378/// RAII guard for SQLite tracing database
379pub struct SqliteTracing {
380    db_path: Option<PathBuf>,
381    connection: Arc<Mutex<Connection>>,
382}
383
384impl SqliteTracing {
385    /// Create a new SqliteTracing with a temporary file
386    pub fn new() -> Result<Self> {
387        let temp_dir = std::env::temp_dir();
388        let file_name = format!("hyperactor_trace_{}.db", std::process::id());
389        let db_path = temp_dir.join(file_name);
390
391        let db_path_str = db_path.to_string_lossy();
392        let layer = SqliteLayer::new_with_file(&db_path_str)?;
393        let connection = layer.connection();
394
395        // Set file permissions to be readable and writable by owner and group
396        // This ensures the Python application can access the database file
397        if let Ok(metadata) = fs::metadata(&db_path) {
398            let mut permissions = metadata.permissions();
399            permissions.set_mode(0o664); // rw-rw-r--
400            let _ = fs::set_permissions(&db_path, permissions);
401        }
402
403        init_tracing_subscriber(layer);
404
405        Ok(Self {
406            db_path: Some(db_path),
407            connection,
408        })
409    }
410
411    /// Create a new SqliteTracing with in-memory database
412    pub fn new_in_memory() -> Result<Self> {
413        let layer = SqliteLayer::new()?;
414        let connection = layer.connection();
415
416        init_tracing_subscriber(layer);
417
418        Ok(Self {
419            db_path: None,
420            connection,
421        })
422    }
423
424    /// Get the path to the temporary database file (None for in-memory)
425    pub fn db_path(&self) -> Option<&PathBuf> {
426        self.db_path.as_ref()
427    }
428
429    /// Get a reference to the database connection
430    pub fn connection(&self) -> Arc<Mutex<Connection>> {
431        self.connection.clone()
432    }
433}
434
435impl Drop for SqliteTracing {
436    fn drop(&mut self) {
437        // Reset the layer to None
438        let handle = RELOAD_HANDLE.lock().unwrap();
439        if let Some(reload_handle) = handle.as_ref() {
440            let _ = reload_handle.reload(None);
441        }
442
443        // Delete the temporary file if it exists
444        if let Some(db_path) = &self.db_path {
445            if db_path.exists() {
446                let _ = fs::remove_file(db_path);
447            }
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use tracing::info;
455
456    use super::*;
457
458    #[test]
459    fn test_sqlite_tracing_with_file() -> Result<()> {
460        let tracing = SqliteTracing::new()?;
461        let conn = tracing.connection();
462
463        info!(target:"messages", test_field = "test_value", "Test msg");
464        info!(target:"log_events", test_field = "test_value", "Test event");
465
466        let count: i64 =
467            conn.lock()
468                .unwrap()
469                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
470        print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
471        assert!(count > 0);
472
473        // Verify we have a file path
474        assert!(tracing.db_path().is_some());
475        let db_path = tracing.db_path().unwrap();
476        assert!(db_path.exists());
477
478        Ok(())
479    }
480
481    #[test]
482    fn test_sqlite_tracing_in_memory() -> Result<()> {
483        let tracing = SqliteTracing::new_in_memory()?;
484        let conn = tracing.connection();
485
486        info!(target:"messages", test_field = "test_value", "Test event in memory");
487
488        let count: i64 =
489            conn.lock()
490                .unwrap()
491                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
492        print_table(&conn.lock().unwrap(), TableName::Messages)?;
493        assert!(count > 0);
494
495        // Verify we don't have a file path for in-memory
496        assert!(tracing.db_path().is_none());
497
498        Ok(())
499    }
500
501    #[test]
502    fn test_sqlite_tracing_cleanup() -> Result<()> {
503        let db_path = {
504            let tracing = SqliteTracing::new()?;
505            let conn = tracing.connection();
506
507            info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
508
509            let count: i64 =
510                conn.lock()
511                    .unwrap()
512                    .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
513            assert!(count > 0);
514
515            tracing.db_path().unwrap().clone()
516        }; // tracing goes out of scope here, triggering Drop
517
518        // File should be cleaned up after Drop
519        assert!(!db_path.exists());
520
521        Ok(())
522    }
523
524    #[test]
525    fn test_sqlite_tracing_different_targets() -> Result<()> {
526        let tracing = SqliteTracing::new_in_memory()?;
527        let conn = tracing.connection();
528
529        // Test different event targets
530        info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
531        info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
532        info!(target:"log_events", test_field = "general_event", "General event");
533
534        // Check that events went to the right tables
535        let message_count: i64 =
536            conn.lock()
537                .unwrap()
538                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
539        assert_eq!(message_count, 1);
540
541        let lifecycle_count: i64 =
542            conn.lock()
543                .unwrap()
544                .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
545        assert_eq!(lifecycle_count, 1);
546
547        let events_count: i64 =
548            conn.lock()
549                .unwrap()
550                .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
551        assert_eq!(events_count, 1);
552
553        Ok(())
554    }
555}