hyperactor_telemetry/
sqlite.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fs;
11use std::os::unix::fs::PermissionsExt;
12use std::path::PathBuf;
13use std::sync::Arc;
14use std::sync::Mutex;
15
16use anyhow::Result;
17use anyhow::anyhow;
18use lazy_static::lazy_static;
19use rusqlite::Connection;
20use rusqlite::functions::FunctionFlags;
21use serde::Serialize;
22use serde_json::Value as JValue;
23use serde_rusqlite::*;
24use tracing::Event;
25use tracing::Subscriber;
26use tracing_subscriber::Layer;
27use tracing_subscriber::Registry;
28use tracing_subscriber::prelude::*;
29use tracing_subscriber::reload;
30
31pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
32
33lazy_static! {
34    // Reload handle allows us to include a no-op layer during init, but load
35    // the layer dynamically during tests.
36    static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
37        Mutex::new(None);
38}
39pub trait TableDef {
40    fn name(&self) -> &'static str;
41    fn columns(&self) -> &'static [&'static str];
42    fn create_table_stmt(&self) -> String {
43        let name = self.name();
44        let columns = self
45            .columns()
46            .iter()
47            .map(|col| format!("{col} TEXT "))
48            .collect::<Vec<String>>()
49            .join(",");
50        format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
51    }
52    fn insert_stmt(&self) -> String {
53        let name = self.name();
54        let columns = self.columns().join(", ");
55        let params = self
56            .columns()
57            .iter()
58            .map(|c| format!(":{c}"))
59            .collect::<Vec<String>>()
60            .join(", ");
61        format!("insert into {name} ({columns}) values ({params})")
62    }
63}
64
65impl TableDef for (&'static str, &'static [&'static str]) {
66    fn name(&self) -> &'static str {
67        self.0
68    }
69
70    fn columns(&self) -> &'static [&'static str] {
71        self.1
72    }
73}
74
75#[derive(Clone, Debug)]
76pub struct Table {
77    pub columns: &'static [&'static str],
78    pub create_table_stmt: String,
79    pub insert_stmt: String,
80}
81
82impl From<(&'static str, &'static [&'static str])> for Table {
83    fn from(value: (&'static str, &'static [&'static str])) -> Self {
84        Self {
85            columns: value.columns(),
86            create_table_stmt: value.create_table_stmt(),
87            insert_stmt: value.insert_stmt(),
88        }
89    }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum TableName {
94    ActorLifecycle,
95    Messages,
96    LogEvents,
97}
98
99impl TableName {
100    pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
101    pub const MESSAGES_STR: &'static str = "messages";
102    pub const LOG_EVENTS_STR: &'static str = "log_events";
103
104    pub fn as_str(&self) -> &'static str {
105        match self {
106            TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
107            TableName::Messages => Self::MESSAGES_STR,
108            TableName::LogEvents => Self::LOG_EVENTS_STR,
109        }
110    }
111
112    pub fn get_table(&self) -> &'static Table {
113        match self {
114            TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
115            TableName::Messages => &MESSAGES,
116            TableName::LogEvents => &LOG_EVENTS,
117        }
118    }
119}
120
121lazy_static! {
122    static ref ACTOR_LIFECYCLE: Table = (
123        TableName::ActorLifecycle.as_str(),
124        [
125            "actor_id",
126            "actor",
127            "name",
128            "supervised_actor",
129            "actor_status",
130            "module_path",
131            "line",
132            "file",
133        ]
134        .as_slice()
135    )
136        .into();
137    static ref MESSAGES: Table = (
138        TableName::Messages.as_str(),
139        [
140            "span_id",
141            "time_us",
142            "src",
143            "dest",
144            "payload",
145            "module_path",
146            "line",
147            "file",
148        ]
149        .as_slice()
150    )
151        .into();
152    static ref LOG_EVENTS: Table = (
153        TableName::LogEvents.as_str(),
154        [
155            "span_id",
156            "time_us",
157            "name",
158            "message",
159            "actor_id",
160            "level",
161            "line",
162            "file",
163            "module_path",
164        ]
165        .as_slice()
166    )
167        .into();
168    static ref ALL_TABLES: Vec<Table> = vec![
169        ACTOR_LIFECYCLE.clone(),
170        MESSAGES.clone(),
171        LOG_EVENTS.clone()
172    ];
173}
174
175pub struct SqliteLayer {
176    conn: Arc<Mutex<Connection>>,
177}
178use tracing::field::Visit;
179
180#[derive(Debug, Clone, Default, Serialize)]
181struct SqlVisitor(HashMap<String, JValue>);
182
183impl Visit for SqlVisitor {
184    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
185        self.0.insert(
186            field.name().to_string(),
187            JValue::String(format!("{:?}", value)),
188        );
189    }
190
191    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
192        self.0
193            .insert(field.name().to_string(), JValue::String(value.to_string()));
194    }
195
196    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
197        self.0
198            .insert(field.name().to_string(), JValue::Number(value.into()));
199    }
200
201    fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
202        let n = serde_json::Number::from_f64(value).unwrap();
203        self.0.insert(field.name().to_string(), JValue::Number(n));
204    }
205
206    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
207        self.0
208            .insert(field.name().to_string(), JValue::Number(value.into()));
209    }
210
211    fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
212        self.0.insert(field.name().to_string(), JValue::Bool(value));
213    }
214}
215
216macro_rules! insert_event {
217    ($table:expr, $conn:ident, $event:ident) => {
218        let mut v: SqlVisitor = Default::default();
219        $event.record(&mut v);
220        let meta = $event.metadata();
221        v.0.insert(
222            "module_path".to_string(),
223            meta.module_path().map(String::from).into(),
224        );
225        v.0.insert("line".to_string(), meta.line().into());
226        v.0.insert("file".to_string(), meta.file().map(String::from).into());
227        $conn.prepare_cached(&$table.insert_stmt)?.execute(
228            serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
229                .to_slice()
230                .as_slice(),
231        )?;
232    };
233}
234
235impl SqliteLayer {
236    pub fn new() -> Result<Self> {
237        let conn = Connection::open_in_memory()?;
238        Self::setup_connection(conn)
239    }
240
241    pub fn new_with_file(db_path: &str) -> Result<Self> {
242        let conn = Connection::open(db_path)?;
243        Self::setup_connection(conn)
244    }
245
246    fn setup_connection(conn: Connection) -> Result<Self> {
247        for table in ALL_TABLES.iter() {
248            conn.execute(&table.create_table_stmt, [])?;
249        }
250        conn.create_scalar_function(
251            "assert",
252            2,
253            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
254            move |ctx| {
255                let condition: bool = ctx.get(0)?;
256                let message: String = ctx.get(1)?;
257
258                if !condition {
259                    return Err(rusqlite::Error::UserFunctionError(
260                        anyhow!("assertion failed:{condition} {message}",).into(),
261                    ));
262                }
263
264                Ok(condition)
265            },
266        )?;
267
268        Ok(Self {
269            conn: Arc::new(Mutex::new(conn)),
270        })
271    }
272
273    fn insert_event(&self, event: &Event<'_>) -> Result<()> {
274        let conn = self.conn.lock().unwrap();
275        match (event.metadata().target(), event.metadata().name()) {
276            (TableName::MESSAGES_STR, _) => {
277                insert_event!(TableName::Messages.get_table(), conn, event);
278            }
279            (TableName::ACTOR_LIFECYCLE_STR, _) => {
280                insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
281            }
282            _ => {
283                insert_event!(TableName::LogEvents.get_table(), conn, event);
284            }
285        }
286        Ok(())
287    }
288
289    pub fn connection(&self) -> Arc<Mutex<Connection>> {
290        self.conn.clone()
291    }
292}
293
294impl<S: Subscriber> Layer<S> for SqliteLayer {
295    fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
296        self.insert_event(event).unwrap();
297    }
298}
299
300#[allow(dead_code)]
301fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
302    let table_name_str = table_name.as_str();
303
304    // Get column names
305    let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
306    let column_info = stmt.query_map([], |row| {
307        row.get::<_, String>(1) // Column name is at index 1
308    })?;
309
310    let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
311
312    // Print header
313    println!("=== {} ===", table_name_str.to_uppercase());
314    println!("{}", columns.join(" | "));
315    println!("{}", "-".repeat(columns.len() * 10));
316
317    // Print rows
318    let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
319    let rows = stmt.query_map([], |row| {
320        let mut values = Vec::new();
321        for (i, column) in columns.iter().enumerate() {
322            // Handle different column types properly
323            let value = if i == 0 && *column == "seq" {
324                // First column is always the INTEGER seq column
325                match row.get::<_, Option<i64>>(i)? {
326                    Some(v) => v.to_string(),
327                    None => "NULL".to_string(),
328                }
329            } else {
330                // All other columns are TEXT
331                match row.get::<_, Option<String>>(i)? {
332                    Some(v) => v,
333                    None => "NULL".to_string(),
334                }
335            };
336            values.push(value);
337        }
338        Ok(values.join(" | "))
339    })?;
340
341    for row in rows {
342        println!("{}", row?);
343    }
344    println!();
345    Ok(())
346}
347
348fn init_tracing_subscriber(layer: SqliteLayer) {
349    let handle = RELOAD_HANDLE.lock().unwrap();
350    if let Some(reload_handle) = handle.as_ref() {
351        let _ = reload_handle.reload(layer);
352    } else {
353        tracing_subscriber::registry().with(layer).init();
354    }
355}
356
357// === API ===
358
359// Creates a new reload handler and no-op layer for initialization
360pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
361    let (layer, reload_handle) = reload::Layer::new(None);
362    let mut handle = RELOAD_HANDLE.lock().unwrap();
363    *handle = Some(reload_handle);
364    Ok(layer)
365}
366
367/// RAII guard for SQLite tracing database
368pub struct SqliteTracing {
369    db_path: Option<PathBuf>,
370    connection: Arc<Mutex<Connection>>,
371}
372
373impl SqliteTracing {
374    /// Create a new SqliteTracing with a temporary file
375    pub fn new() -> Result<Self> {
376        let temp_dir = std::env::temp_dir();
377        let file_name = format!("hyperactor_trace_{}.db", std::process::id());
378        let db_path = temp_dir.join(file_name);
379
380        let db_path_str = db_path.to_string_lossy();
381        let layer = SqliteLayer::new_with_file(&db_path_str)?;
382        let connection = layer.connection();
383
384        // Set file permissions to be readable and writable by owner and group
385        // This ensures the Python application can access the database file
386        if let Ok(metadata) = fs::metadata(&db_path) {
387            let mut permissions = metadata.permissions();
388            permissions.set_mode(0o664); // rw-rw-r--
389            let _ = fs::set_permissions(&db_path, permissions);
390        }
391
392        init_tracing_subscriber(layer);
393
394        Ok(Self {
395            db_path: Some(db_path),
396            connection,
397        })
398    }
399
400    /// Create a new SqliteTracing with in-memory database
401    pub fn new_in_memory() -> Result<Self> {
402        let layer = SqliteLayer::new()?;
403        let connection = layer.connection();
404
405        init_tracing_subscriber(layer);
406
407        Ok(Self {
408            db_path: None,
409            connection,
410        })
411    }
412
413    /// Get the path to the temporary database file (None for in-memory)
414    pub fn db_path(&self) -> Option<&PathBuf> {
415        self.db_path.as_ref()
416    }
417
418    /// Get a reference to the database connection
419    pub fn connection(&self) -> Arc<Mutex<Connection>> {
420        self.connection.clone()
421    }
422}
423
424impl Drop for SqliteTracing {
425    fn drop(&mut self) {
426        // Reset the layer to None
427        let handle = RELOAD_HANDLE.lock().unwrap();
428        if let Some(reload_handle) = handle.as_ref() {
429            let _ = reload_handle.reload(None);
430        }
431
432        // Delete the temporary file if it exists
433        if let Some(db_path) = &self.db_path {
434            if db_path.exists() {
435                let _ = fs::remove_file(db_path);
436            }
437        }
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use tracing::info;
444
445    use super::*;
446
447    #[test]
448    fn test_sqlite_tracing_with_file() -> Result<()> {
449        let tracing = SqliteTracing::new()?;
450        let conn = tracing.connection();
451
452        info!(target:"messages", test_field = "test_value", "Test msg");
453        info!(target:"log_events", test_field = "test_value", "Test event");
454
455        let count: i64 =
456            conn.lock()
457                .unwrap()
458                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
459        print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
460        assert!(count > 0);
461
462        // Verify we have a file path
463        assert!(tracing.db_path().is_some());
464        let db_path = tracing.db_path().unwrap();
465        assert!(db_path.exists());
466
467        Ok(())
468    }
469
470    #[test]
471    fn test_sqlite_tracing_in_memory() -> Result<()> {
472        let tracing = SqliteTracing::new_in_memory()?;
473        let conn = tracing.connection();
474
475        info!(target:"messages", test_field = "test_value", "Test event in memory");
476
477        let count: i64 =
478            conn.lock()
479                .unwrap()
480                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
481        print_table(&conn.lock().unwrap(), TableName::Messages)?;
482        assert!(count > 0);
483
484        // Verify we don't have a file path for in-memory
485        assert!(tracing.db_path().is_none());
486
487        Ok(())
488    }
489
490    #[test]
491    fn test_sqlite_tracing_cleanup() -> Result<()> {
492        let db_path = {
493            let tracing = SqliteTracing::new()?;
494            let conn = tracing.connection();
495
496            info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
497
498            let count: i64 =
499                conn.lock()
500                    .unwrap()
501                    .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
502            assert!(count > 0);
503
504            tracing.db_path().unwrap().clone()
505        }; // tracing goes out of scope here, triggering Drop
506
507        // File should be cleaned up after Drop
508        assert!(!db_path.exists());
509
510        Ok(())
511    }
512
513    #[test]
514    fn test_sqlite_tracing_different_targets() -> Result<()> {
515        let tracing = SqliteTracing::new_in_memory()?;
516        let conn = tracing.connection();
517
518        // Test different event targets
519        info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
520        info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
521        info!(target:"log_events", test_field = "general_event", "General event");
522
523        // Check that events went to the right tables
524        let message_count: i64 =
525            conn.lock()
526                .unwrap()
527                .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
528        assert_eq!(message_count, 1);
529
530        let lifecycle_count: i64 =
531            conn.lock()
532                .unwrap()
533                .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
534        assert_eq!(lifecycle_count, 1);
535
536        let events_count: i64 =
537            conn.lock()
538                .unwrap()
539                .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
540        assert_eq!(events_count, 1);
541
542        Ok(())
543    }
544}