1use std::collections::HashMap;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::Arc;
13use std::sync::Mutex;
14
15use anyhow::Result;
16use anyhow::anyhow;
17use lazy_static::lazy_static;
18use rusqlite::Connection;
19use rusqlite::functions::FunctionFlags;
20use serde::Serialize;
21use serde_json::Value as JValue;
22use serde_rusqlite::*;
23use tracing::Event;
24use tracing::Subscriber;
25use tracing_subscriber::Layer;
26use tracing_subscriber::Registry;
27use tracing_subscriber::prelude::*;
28use tracing_subscriber::reload;
29
30pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
31
32lazy_static! {
33 static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
36 Mutex::new(None);
37}
38pub trait TableDef {
39 fn name(&self) -> &'static str;
40 fn columns(&self) -> &'static [&'static str];
41 fn create_table_stmt(&self) -> String {
42 let name = self.name();
43 let columns = self
44 .columns()
45 .iter()
46 .map(|col| format!("{col} TEXT "))
47 .collect::<Vec<String>>()
48 .join(",");
49 format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
50 }
51 fn insert_stmt(&self) -> String {
52 let name = self.name();
53 let columns = self.columns().join(", ");
54 let params = self
55 .columns()
56 .iter()
57 .map(|c| format!(":{c}"))
58 .collect::<Vec<String>>()
59 .join(", ");
60 format!("insert into {name} ({columns}) values ({params})")
61 }
62}
63
64impl TableDef for (&'static str, &'static [&'static str]) {
65 fn name(&self) -> &'static str {
66 self.0
67 }
68
69 fn columns(&self) -> &'static [&'static str] {
70 self.1
71 }
72}
73
74#[derive(Clone, Debug)]
75pub struct Table {
76 pub columns: &'static [&'static str],
77 pub create_table_stmt: String,
78 pub insert_stmt: String,
79}
80
81impl From<(&'static str, &'static [&'static str])> for Table {
82 fn from(value: (&'static str, &'static [&'static str])) -> Self {
83 Self {
84 columns: value.columns(),
85 create_table_stmt: value.create_table_stmt(),
86 insert_stmt: value.insert_stmt(),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum TableName {
93 ActorLifecycle,
94 Messages,
95 LogEvents,
96}
97
98impl TableName {
99 pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
100 pub const MESSAGES_STR: &'static str = "messages";
101 pub const LOG_EVENTS_STR: &'static str = "log_events";
102
103 pub fn as_str(&self) -> &'static str {
104 match self {
105 TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
106 TableName::Messages => Self::MESSAGES_STR,
107 TableName::LogEvents => Self::LOG_EVENTS_STR,
108 }
109 }
110
111 pub fn get_table(&self) -> &'static Table {
112 match self {
113 TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
114 TableName::Messages => &MESSAGES,
115 TableName::LogEvents => &LOG_EVENTS,
116 }
117 }
118}
119
120lazy_static! {
121 static ref ACTOR_LIFECYCLE: Table = (
122 TableName::ActorLifecycle.as_str(),
123 [
124 "actor_id",
125 "actor",
126 "name",
127 "supervised_actor",
128 "actor_status",
129 "module_path",
130 "line",
131 "file",
132 ]
133 .as_slice()
134 )
135 .into();
136 static ref MESSAGES: Table = (
137 TableName::Messages.as_str(),
138 [
139 "span_id",
140 "time_us",
141 "src",
142 "dest",
143 "payload",
144 "module_path",
145 "line",
146 "file",
147 ]
148 .as_slice()
149 )
150 .into();
151 static ref LOG_EVENTS: Table = (
152 TableName::LogEvents.as_str(),
153 [
154 "span_id",
155 "time_us",
156 "name",
157 "message",
158 "actor_id",
159 "level",
160 "line",
161 "file",
162 "module_path",
163 ]
164 .as_slice()
165 )
166 .into();
167 static ref ALL_TABLES: Vec<Table> = vec![
168 ACTOR_LIFECYCLE.clone(),
169 MESSAGES.clone(),
170 LOG_EVENTS.clone()
171 ];
172}
173
174pub struct SqliteLayer {
175 conn: Arc<Mutex<Connection>>,
176}
177use tracing::field::Visit;
178
179#[derive(Debug, Clone, Default, Serialize)]
180struct SqlVisitor(HashMap<String, JValue>);
181
182impl Visit for SqlVisitor {
183 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
184 self.0.insert(
185 field.name().to_string(),
186 JValue::String(format!("{:?}", value)),
187 );
188 }
189
190 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
191 self.0
192 .insert(field.name().to_string(), JValue::String(value.to_string()));
193 }
194
195 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
196 self.0
197 .insert(field.name().to_string(), JValue::Number(value.into()));
198 }
199
200 fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
201 let n = serde_json::Number::from_f64(value).unwrap();
202 self.0.insert(field.name().to_string(), JValue::Number(n));
203 }
204
205 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
206 self.0
207 .insert(field.name().to_string(), JValue::Number(value.into()));
208 }
209
210 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
211 self.0.insert(field.name().to_string(), JValue::Bool(value));
212 }
213}
214
215macro_rules! insert_event {
216 ($table:expr, $conn:ident, $event:ident) => {
217 let mut v: SqlVisitor = Default::default();
218 $event.record(&mut v);
219 let meta = $event.metadata();
220 v.0.insert(
221 "module_path".to_string(),
222 meta.module_path().map(String::from).into(),
223 );
224 v.0.insert("line".to_string(), meta.line().into());
225 v.0.insert("file".to_string(), meta.file().map(String::from).into());
226 $conn.prepare_cached(&$table.insert_stmt)?.execute(
227 serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
228 .to_slice()
229 .as_slice(),
230 )?;
231 };
232}
233
234impl SqliteLayer {
235 pub fn new() -> Result<Self> {
236 let conn = Connection::open_in_memory()?;
237 Self::setup_connection(conn)
238 }
239
240 pub fn new_with_file(db_path: &str) -> Result<Self> {
241 let conn = Connection::open(db_path)?;
242 Self::setup_connection(conn)
243 }
244
245 fn setup_connection(conn: Connection) -> Result<Self> {
246 for table in ALL_TABLES.iter() {
247 conn.execute(&table.create_table_stmt, [])?;
248 }
249 conn.create_scalar_function(
250 "assert",
251 2,
252 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
253 move |ctx| {
254 let condition: bool = ctx.get(0)?;
255 let message: String = ctx.get(1)?;
256
257 if !condition {
258 return Err(rusqlite::Error::UserFunctionError(
259 anyhow!("assertion failed:{condition} {message}",).into(),
260 ));
261 }
262
263 Ok(condition)
264 },
265 )?;
266
267 Ok(Self {
268 conn: Arc::new(Mutex::new(conn)),
269 })
270 }
271
272 fn insert_event(&self, event: &Event<'_>) -> Result<()> {
273 let conn = self.conn.lock().unwrap();
274 match (event.metadata().target(), event.metadata().name()) {
275 (TableName::MESSAGES_STR, _) => {
276 insert_event!(TableName::Messages.get_table(), conn, event);
277 }
278 (TableName::ACTOR_LIFECYCLE_STR, _) => {
279 insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
280 }
281 _ => {
282 insert_event!(TableName::LogEvents.get_table(), conn, event);
283 }
284 }
285 Ok(())
286 }
287
288 pub fn connection(&self) -> Arc<Mutex<Connection>> {
289 self.conn.clone()
290 }
291}
292
293impl<S: Subscriber> Layer<S> for SqliteLayer {
294 fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
295 self.insert_event(event).unwrap();
296 }
297}
298
299#[allow(dead_code)]
300fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
301 let table_name_str = table_name.as_str();
302
303 let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
305 let column_info = stmt.query_map([], |row| {
306 row.get::<_, String>(1) })?;
308
309 let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
310
311 println!("=== {} ===", table_name_str.to_uppercase());
313 println!("{}", columns.join(" | "));
314 println!("{}", "-".repeat(columns.len() * 10));
315
316 let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
318 let rows = stmt.query_map([], |row| {
319 let mut values = Vec::new();
320 for (i, column) in columns.iter().enumerate() {
321 let value = if i == 0 && *column == "seq" {
323 match row.get::<_, Option<i64>>(i)? {
325 Some(v) => v.to_string(),
326 None => "NULL".to_string(),
327 }
328 } else {
329 match row.get::<_, Option<String>>(i)? {
331 Some(v) => v,
332 None => "NULL".to_string(),
333 }
334 };
335 values.push(value);
336 }
337 Ok(values.join(" | "))
338 })?;
339
340 for row in rows {
341 println!("{}", row?);
342 }
343 println!();
344 Ok(())
345}
346
347fn init_tracing_subscriber(layer: SqliteLayer) {
348 let handle = RELOAD_HANDLE.lock().unwrap();
349 if let Some(reload_handle) = handle.as_ref() {
350 let _ = reload_handle.reload(layer);
351 } else {
352 tracing_subscriber::registry().with(layer).init();
353 }
354}
355
356pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
360 let (layer, reload_handle) = reload::Layer::new(None);
361 let mut handle = RELOAD_HANDLE.lock().unwrap();
362 *handle = Some(reload_handle);
363 Ok(layer)
364}
365
366pub struct SqliteTracing {
368 db_path: Option<PathBuf>,
369 connection: Arc<Mutex<Connection>>,
370}
371
372impl SqliteTracing {
373 pub fn new() -> Result<Self> {
375 let temp_dir = std::env::temp_dir();
376 let file_name = format!("hyperactor_trace_{}.db", std::process::id());
377 let db_path = temp_dir.join(file_name);
378
379 let db_path_str = db_path.to_string_lossy();
380 let layer = SqliteLayer::new_with_file(&db_path_str)?;
381 let connection = layer.connection();
382
383 init_tracing_subscriber(layer);
384
385 Ok(Self {
386 db_path: Some(db_path),
387 connection,
388 })
389 }
390
391 pub fn new_in_memory() -> Result<Self> {
393 let layer = SqliteLayer::new()?;
394 let connection = layer.connection();
395
396 init_tracing_subscriber(layer);
397
398 Ok(Self {
399 db_path: None,
400 connection,
401 })
402 }
403
404 pub fn db_path(&self) -> Option<&PathBuf> {
406 self.db_path.as_ref()
407 }
408
409 pub fn connection(&self) -> Arc<Mutex<Connection>> {
411 self.connection.clone()
412 }
413}
414
415impl Drop for SqliteTracing {
416 fn drop(&mut self) {
417 let handle = RELOAD_HANDLE.lock().unwrap();
419 if let Some(reload_handle) = handle.as_ref() {
420 let _ = reload_handle.reload(None);
421 }
422
423 if let Some(db_path) = &self.db_path {
425 if db_path.exists() {
426 let _ = fs::remove_file(db_path);
427 }
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use tracing::info;
435
436 use super::*;
437
438 #[test]
439 fn test_sqlite_tracing_with_file() -> Result<()> {
440 let tracing = SqliteTracing::new()?;
441 let conn = tracing.connection();
442
443 info!(target:"messages", test_field = "test_value", "Test msg");
444 info!(target:"log_events", test_field = "test_value", "Test event");
445
446 let count: i64 =
447 conn.lock()
448 .unwrap()
449 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
450 print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
451 assert!(count > 0);
452
453 assert!(tracing.db_path().is_some());
455 let db_path = tracing.db_path().unwrap();
456 assert!(db_path.exists());
457
458 Ok(())
459 }
460
461 #[test]
462 fn test_sqlite_tracing_in_memory() -> Result<()> {
463 let tracing = SqliteTracing::new_in_memory()?;
464 let conn = tracing.connection();
465
466 info!(target:"messages", test_field = "test_value", "Test event in memory");
467
468 let count: i64 =
469 conn.lock()
470 .unwrap()
471 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
472 print_table(&conn.lock().unwrap(), TableName::Messages)?;
473 assert!(count > 0);
474
475 assert!(tracing.db_path().is_none());
477
478 Ok(())
479 }
480
481 #[test]
482 fn test_sqlite_tracing_cleanup() -> Result<()> {
483 let db_path = {
484 let tracing = SqliteTracing::new()?;
485 let conn = tracing.connection();
486
487 info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
488
489 let count: i64 =
490 conn.lock()
491 .unwrap()
492 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
493 assert!(count > 0);
494
495 tracing.db_path().unwrap().clone()
496 }; assert!(!db_path.exists());
500
501 Ok(())
502 }
503
504 #[test]
505 fn test_sqlite_tracing_different_targets() -> Result<()> {
506 let tracing = SqliteTracing::new_in_memory()?;
507 let conn = tracing.connection();
508
509 info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
511 info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
512 info!(target:"log_events", test_field = "general_event", "General event");
513
514 let message_count: i64 =
516 conn.lock()
517 .unwrap()
518 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
519 assert_eq!(message_count, 1);
520
521 let lifecycle_count: i64 =
522 conn.lock()
523 .unwrap()
524 .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
525 assert_eq!(lifecycle_count, 1);
526
527 let events_count: i64 =
528 conn.lock()
529 .unwrap()
530 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
531 assert_eq!(events_count, 1);
532
533 Ok(())
534 }
535}