1use std::collections::HashMap;
10use std::fs;
11use std::os::unix::fs::PermissionsExt;
12use std::path::PathBuf;
13use std::sync::Arc;
14use std::sync::Mutex;
15
16use anyhow::Result;
17use anyhow::anyhow;
18use lazy_static::lazy_static;
19use rusqlite::Connection;
20use rusqlite::functions::FunctionFlags;
21use serde::Serialize;
22use serde_json::Value as JValue;
23use serde_rusqlite::*;
24use tracing::Event;
25use tracing::Subscriber;
26use tracing_subscriber::Layer;
27use tracing_subscriber::Registry;
28use tracing_subscriber::prelude::*;
29use tracing_subscriber::reload;
30
31pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
32
33lazy_static! {
34 static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
37 Mutex::new(None);
38}
39pub trait TableDef {
40 fn name(&self) -> &'static str;
41 fn columns(&self) -> &'static [&'static str];
42 fn create_table_stmt(&self) -> String {
43 let name = self.name();
44 let columns = self
45 .columns()
46 .iter()
47 .map(|col| format!("{col} TEXT "))
48 .collect::<Vec<String>>()
49 .join(",");
50 format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
51 }
52 fn insert_stmt(&self) -> String {
53 let name = self.name();
54 let columns = self.columns().join(", ");
55 let params = self
56 .columns()
57 .iter()
58 .map(|c| format!(":{c}"))
59 .collect::<Vec<String>>()
60 .join(", ");
61 format!("insert into {name} ({columns}) values ({params})")
62 }
63}
64
65impl TableDef for (&'static str, &'static [&'static str]) {
66 fn name(&self) -> &'static str {
67 self.0
68 }
69
70 fn columns(&self) -> &'static [&'static str] {
71 self.1
72 }
73}
74
75#[derive(Clone, Debug)]
76pub struct Table {
77 pub columns: &'static [&'static str],
78 pub create_table_stmt: String,
79 pub insert_stmt: String,
80}
81
82impl From<(&'static str, &'static [&'static str])> for Table {
83 fn from(value: (&'static str, &'static [&'static str])) -> Self {
84 Self {
85 columns: value.columns(),
86 create_table_stmt: value.create_table_stmt(),
87 insert_stmt: value.insert_stmt(),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum TableName {
94 ActorLifecycle,
95 Messages,
96 LogEvents,
97}
98
99impl TableName {
100 pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
101 pub const MESSAGES_STR: &'static str = "messages";
102 pub const LOG_EVENTS_STR: &'static str = "log_events";
103
104 pub fn as_str(&self) -> &'static str {
105 match self {
106 TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
107 TableName::Messages => Self::MESSAGES_STR,
108 TableName::LogEvents => Self::LOG_EVENTS_STR,
109 }
110 }
111
112 pub fn get_table(&self) -> &'static Table {
113 match self {
114 TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
115 TableName::Messages => &MESSAGES,
116 TableName::LogEvents => &LOG_EVENTS,
117 }
118 }
119}
120
121lazy_static! {
122 static ref ACTOR_LIFECYCLE: Table = (
123 TableName::ActorLifecycle.as_str(),
124 [
125 "actor_id",
126 "actor",
127 "name",
128 "supervised_actor",
129 "actor_status",
130 "module_path",
131 "line",
132 "file",
133 ]
134 .as_slice()
135 )
136 .into();
137 static ref MESSAGES: Table = (
138 TableName::Messages.as_str(),
139 [
140 "span_id",
141 "time_us",
142 "src",
143 "dest",
144 "payload",
145 "module_path",
146 "line",
147 "file",
148 ]
149 .as_slice()
150 )
151 .into();
152 static ref LOG_EVENTS: Table = (
153 TableName::LogEvents.as_str(),
154 [
155 "span_id",
156 "time_us",
157 "name",
158 "message",
159 "actor_id",
160 "level",
161 "line",
162 "file",
163 "module_path",
164 ]
165 .as_slice()
166 )
167 .into();
168 pub static ref ALL_TABLES: Vec<Table> = vec![
169 ACTOR_LIFECYCLE.clone(),
170 MESSAGES.clone(),
171 LOG_EVENTS.clone()
172 ];
173}
174
175pub struct SqliteLayer {
176 conn: Arc<Mutex<Connection>>,
177}
178use tracing::field::Visit;
179
180#[derive(Debug, Clone, Default, Serialize)]
181pub struct SqlVisitor(pub HashMap<String, JValue>);
182
183impl Visit for SqlVisitor {
184 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
185 self.0.insert(
186 field.name().to_string(),
187 JValue::String(format!("{:?}", value)),
188 );
189 }
190
191 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
192 self.0
193 .insert(field.name().to_string(), JValue::String(value.to_string()));
194 }
195
196 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
197 self.0
198 .insert(field.name().to_string(), JValue::Number(value.into()));
199 }
200
201 fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
202 let n = serde_json::Number::from_f64(value).unwrap();
203 self.0.insert(field.name().to_string(), JValue::Number(n));
204 }
205
206 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
207 self.0
208 .insert(field.name().to_string(), JValue::Number(value.into()));
209 }
210
211 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
212 self.0.insert(field.name().to_string(), JValue::Bool(value));
213 }
214}
215
216macro_rules! insert_event {
217 ($table:expr, $conn:ident, $event:ident) => {
218 let mut v: SqlVisitor = Default::default();
219 $event.record(&mut v);
220 let meta = $event.metadata();
221 v.0.insert(
222 "module_path".to_string(),
223 meta.module_path().map(String::from).into(),
224 );
225 v.0.insert("line".to_string(), meta.line().into());
226 v.0.insert("file".to_string(), meta.file().map(String::from).into());
227 $conn.prepare_cached(&$table.insert_stmt)?.execute(
228 serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
229 .to_slice()
230 .as_slice(),
231 )?;
232 };
233}
234
235pub fn insert_event_fields(conn: &Connection, table: &Table, fields: SqlVisitor) -> Result<()> {
238 conn.prepare_cached(&table.insert_stmt)?.execute(
239 serde_rusqlite::to_params_named_with_fields(fields, table.columns)?
240 .to_slice()
241 .as_slice(),
242 )?;
243 Ok(())
244}
245
246impl SqliteLayer {
247 pub fn new() -> Result<Self> {
248 let conn = Connection::open_in_memory()?;
249 Self::setup_connection(conn)
250 }
251
252 pub fn new_with_file(db_path: &str) -> Result<Self> {
253 let conn = Connection::open(db_path)?;
254 Self::setup_connection(conn)
255 }
256
257 fn setup_connection(conn: Connection) -> Result<Self> {
258 for table in ALL_TABLES.iter() {
259 conn.execute(&table.create_table_stmt, [])?;
260 }
261 conn.create_scalar_function(
262 "assert",
263 2,
264 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
265 move |ctx| {
266 let condition: bool = ctx.get(0)?;
267 let message: String = ctx.get(1)?;
268
269 if !condition {
270 return Err(rusqlite::Error::UserFunctionError(
271 anyhow!("assertion failed:{condition} {message}",).into(),
272 ));
273 }
274
275 Ok(condition)
276 },
277 )?;
278
279 Ok(Self {
280 conn: Arc::new(Mutex::new(conn)),
281 })
282 }
283
284 fn insert_event(&self, event: &Event<'_>) -> Result<()> {
285 let conn = self.conn.lock().unwrap();
286 match (event.metadata().target(), event.metadata().name()) {
287 (TableName::MESSAGES_STR, _) => {
288 insert_event!(TableName::Messages.get_table(), conn, event);
289 }
290 (TableName::ACTOR_LIFECYCLE_STR, _) => {
291 insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
292 }
293 _ => {
294 insert_event!(TableName::LogEvents.get_table(), conn, event);
295 }
296 }
297 Ok(())
298 }
299
300 pub fn connection(&self) -> Arc<Mutex<Connection>> {
301 self.conn.clone()
302 }
303}
304
305impl<S: Subscriber> Layer<S> for SqliteLayer {
306 fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
307 self.insert_event(event).unwrap();
308 }
309}
310
311#[allow(dead_code)]
312fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
313 let table_name_str = table_name.as_str();
314
315 let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
317 let column_info = stmt.query_map([], |row| {
318 row.get::<_, String>(1) })?;
320
321 let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
322
323 println!("=== {} ===", table_name_str.to_uppercase());
325 println!("{}", columns.join(" | "));
326 println!("{}", "-".repeat(columns.len() * 10));
327
328 let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
330 let rows = stmt.query_map([], |row| {
331 let mut values = Vec::new();
332 for (i, column) in columns.iter().enumerate() {
333 let value = if i == 0 && *column == "seq" {
335 match row.get::<_, Option<i64>>(i)? {
337 Some(v) => v.to_string(),
338 None => "NULL".to_string(),
339 }
340 } else {
341 match row.get::<_, Option<String>>(i)? {
343 Some(v) => v,
344 None => "NULL".to_string(),
345 }
346 };
347 values.push(value);
348 }
349 Ok(values.join(" | "))
350 })?;
351
352 for row in rows {
353 println!("{}", row?);
354 }
355 println!();
356 Ok(())
357}
358
359fn init_tracing_subscriber(layer: SqliteLayer) {
360 let handle = RELOAD_HANDLE.lock().unwrap();
361 if let Some(reload_handle) = handle.as_ref() {
362 let _ = reload_handle.reload(layer);
363 } else {
364 tracing_subscriber::registry().with(layer).init();
365 }
366}
367
368pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
372 let (layer, reload_handle) = reload::Layer::new(None);
373 let mut handle = RELOAD_HANDLE.lock().unwrap();
374 *handle = Some(reload_handle);
375 Ok(layer)
376}
377
378pub struct SqliteTracing {
380 db_path: Option<PathBuf>,
381 connection: Arc<Mutex<Connection>>,
382}
383
384impl SqliteTracing {
385 pub fn new() -> Result<Self> {
387 let temp_dir = std::env::temp_dir();
388 let file_name = format!("hyperactor_trace_{}.db", std::process::id());
389 let db_path = temp_dir.join(file_name);
390
391 let db_path_str = db_path.to_string_lossy();
392 let layer = SqliteLayer::new_with_file(&db_path_str)?;
393 let connection = layer.connection();
394
395 if let Ok(metadata) = fs::metadata(&db_path) {
398 let mut permissions = metadata.permissions();
399 permissions.set_mode(0o664); let _ = fs::set_permissions(&db_path, permissions);
401 }
402
403 init_tracing_subscriber(layer);
404
405 Ok(Self {
406 db_path: Some(db_path),
407 connection,
408 })
409 }
410
411 pub fn new_in_memory() -> Result<Self> {
413 let layer = SqliteLayer::new()?;
414 let connection = layer.connection();
415
416 init_tracing_subscriber(layer);
417
418 Ok(Self {
419 db_path: None,
420 connection,
421 })
422 }
423
424 pub fn db_path(&self) -> Option<&PathBuf> {
426 self.db_path.as_ref()
427 }
428
429 pub fn connection(&self) -> Arc<Mutex<Connection>> {
431 self.connection.clone()
432 }
433}
434
435impl Drop for SqliteTracing {
436 fn drop(&mut self) {
437 let handle = RELOAD_HANDLE.lock().unwrap();
439 if let Some(reload_handle) = handle.as_ref() {
440 let _ = reload_handle.reload(None);
441 }
442
443 if let Some(db_path) = &self.db_path {
445 if db_path.exists() {
446 let _ = fs::remove_file(db_path);
447 }
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use tracing::info;
455
456 use super::*;
457
458 #[test]
459 fn test_sqlite_tracing_with_file() -> Result<()> {
460 let tracing = SqliteTracing::new()?;
461 let conn = tracing.connection();
462
463 info!(target:"messages", test_field = "test_value", "Test msg");
464 info!(target:"log_events", test_field = "test_value", "Test event");
465
466 let count: i64 =
467 conn.lock()
468 .unwrap()
469 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
470 print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
471 assert!(count > 0);
472
473 assert!(tracing.db_path().is_some());
475 let db_path = tracing.db_path().unwrap();
476 assert!(db_path.exists());
477
478 Ok(())
479 }
480
481 #[test]
482 fn test_sqlite_tracing_in_memory() -> Result<()> {
483 let tracing = SqliteTracing::new_in_memory()?;
484 let conn = tracing.connection();
485
486 info!(target:"messages", test_field = "test_value", "Test event in memory");
487
488 let count: i64 =
489 conn.lock()
490 .unwrap()
491 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
492 print_table(&conn.lock().unwrap(), TableName::Messages)?;
493 assert!(count > 0);
494
495 assert!(tracing.db_path().is_none());
497
498 Ok(())
499 }
500
501 #[test]
502 fn test_sqlite_tracing_cleanup() -> Result<()> {
503 let db_path = {
504 let tracing = SqliteTracing::new()?;
505 let conn = tracing.connection();
506
507 info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
508
509 let count: i64 =
510 conn.lock()
511 .unwrap()
512 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
513 assert!(count > 0);
514
515 tracing.db_path().unwrap().clone()
516 }; assert!(!db_path.exists());
520
521 Ok(())
522 }
523
524 #[test]
525 fn test_sqlite_tracing_different_targets() -> Result<()> {
526 let tracing = SqliteTracing::new_in_memory()?;
527 let conn = tracing.connection();
528
529 info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
531 info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
532 info!(target:"log_events", test_field = "general_event", "General event");
533
534 let message_count: i64 =
536 conn.lock()
537 .unwrap()
538 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
539 assert_eq!(message_count, 1);
540
541 let lifecycle_count: i64 =
542 conn.lock()
543 .unwrap()
544 .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
545 assert_eq!(lifecycle_count, 1);
546
547 let events_count: i64 =
548 conn.lock()
549 .unwrap()
550 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
551 assert_eq!(events_count, 1);
552
553 Ok(())
554 }
555}