1use std::collections::HashMap;
10use std::fs;
11use std::os::unix::fs::PermissionsExt;
12use std::path::PathBuf;
13use std::sync::Arc;
14use std::sync::Mutex;
15
16use anyhow::Result;
17use anyhow::anyhow;
18use lazy_static::lazy_static;
19use rusqlite::Connection;
20use rusqlite::functions::FunctionFlags;
21use serde::Serialize;
22use serde_json::Value as JValue;
23use serde_rusqlite::*;
24use tracing::Event;
25use tracing::Subscriber;
26use tracing_subscriber::Layer;
27use tracing_subscriber::Registry;
28use tracing_subscriber::prelude::*;
29use tracing_subscriber::reload;
30
31pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
32
33lazy_static! {
34 static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
37 Mutex::new(None);
38}
39pub trait TableDef {
40 fn name(&self) -> &'static str;
41 fn columns(&self) -> &'static [&'static str];
42 fn create_table_stmt(&self) -> String {
43 let name = self.name();
44 let columns = self
45 .columns()
46 .iter()
47 .map(|col| format!("{col} TEXT "))
48 .collect::<Vec<String>>()
49 .join(",");
50 format!("create table if not exists {name} (seq INTEGER primary key, {columns})")
51 }
52 fn insert_stmt(&self) -> String {
53 let name = self.name();
54 let columns = self.columns().join(", ");
55 let params = self
56 .columns()
57 .iter()
58 .map(|c| format!(":{c}"))
59 .collect::<Vec<String>>()
60 .join(", ");
61 format!("insert into {name} ({columns}) values ({params})")
62 }
63}
64
65impl TableDef for (&'static str, &'static [&'static str]) {
66 fn name(&self) -> &'static str {
67 self.0
68 }
69
70 fn columns(&self) -> &'static [&'static str] {
71 self.1
72 }
73}
74
75#[derive(Clone, Debug)]
76pub struct Table {
77 pub columns: &'static [&'static str],
78 pub create_table_stmt: String,
79 pub insert_stmt: String,
80}
81
82impl From<(&'static str, &'static [&'static str])> for Table {
83 fn from(value: (&'static str, &'static [&'static str])) -> Self {
84 Self {
85 columns: value.columns(),
86 create_table_stmt: value.create_table_stmt(),
87 insert_stmt: value.insert_stmt(),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum TableName {
94 ActorLifecycle,
95 Messages,
96 LogEvents,
97}
98
99impl TableName {
100 pub const ACTOR_LIFECYCLE_STR: &'static str = "actor_lifecycle";
101 pub const MESSAGES_STR: &'static str = "messages";
102 pub const LOG_EVENTS_STR: &'static str = "log_events";
103
104 pub fn as_str(&self) -> &'static str {
105 match self {
106 TableName::ActorLifecycle => Self::ACTOR_LIFECYCLE_STR,
107 TableName::Messages => Self::MESSAGES_STR,
108 TableName::LogEvents => Self::LOG_EVENTS_STR,
109 }
110 }
111
112 pub fn get_table(&self) -> &'static Table {
113 match self {
114 TableName::ActorLifecycle => &ACTOR_LIFECYCLE,
115 TableName::Messages => &MESSAGES,
116 TableName::LogEvents => &LOG_EVENTS,
117 }
118 }
119}
120
121lazy_static! {
122 static ref ACTOR_LIFECYCLE: Table = (
123 TableName::ActorLifecycle.as_str(),
124 [
125 "actor_id",
126 "actor",
127 "name",
128 "supervised_actor",
129 "actor_status",
130 "module_path",
131 "line",
132 "file",
133 ]
134 .as_slice()
135 )
136 .into();
137 static ref MESSAGES: Table = (
138 TableName::Messages.as_str(),
139 [
140 "span_id",
141 "time_us",
142 "src",
143 "dest",
144 "payload",
145 "module_path",
146 "line",
147 "file",
148 ]
149 .as_slice()
150 )
151 .into();
152 static ref LOG_EVENTS: Table = (
153 TableName::LogEvents.as_str(),
154 [
155 "span_id",
156 "time_us",
157 "name",
158 "message",
159 "actor_id",
160 "level",
161 "line",
162 "file",
163 "module_path",
164 ]
165 .as_slice()
166 )
167 .into();
168 static ref ALL_TABLES: Vec<Table> = vec![
169 ACTOR_LIFECYCLE.clone(),
170 MESSAGES.clone(),
171 LOG_EVENTS.clone()
172 ];
173}
174
175pub struct SqliteLayer {
176 conn: Arc<Mutex<Connection>>,
177}
178use tracing::field::Visit;
179
180#[derive(Debug, Clone, Default, Serialize)]
181struct SqlVisitor(HashMap<String, JValue>);
182
183impl Visit for SqlVisitor {
184 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
185 self.0.insert(
186 field.name().to_string(),
187 JValue::String(format!("{:?}", value)),
188 );
189 }
190
191 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
192 self.0
193 .insert(field.name().to_string(), JValue::String(value.to_string()));
194 }
195
196 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
197 self.0
198 .insert(field.name().to_string(), JValue::Number(value.into()));
199 }
200
201 fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
202 let n = serde_json::Number::from_f64(value).unwrap();
203 self.0.insert(field.name().to_string(), JValue::Number(n));
204 }
205
206 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
207 self.0
208 .insert(field.name().to_string(), JValue::Number(value.into()));
209 }
210
211 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
212 self.0.insert(field.name().to_string(), JValue::Bool(value));
213 }
214}
215
216macro_rules! insert_event {
217 ($table:expr, $conn:ident, $event:ident) => {
218 let mut v: SqlVisitor = Default::default();
219 $event.record(&mut v);
220 let meta = $event.metadata();
221 v.0.insert(
222 "module_path".to_string(),
223 meta.module_path().map(String::from).into(),
224 );
225 v.0.insert("line".to_string(), meta.line().into());
226 v.0.insert("file".to_string(), meta.file().map(String::from).into());
227 $conn.prepare_cached(&$table.insert_stmt)?.execute(
228 serde_rusqlite::to_params_named_with_fields(v, $table.columns)?
229 .to_slice()
230 .as_slice(),
231 )?;
232 };
233}
234
235impl SqliteLayer {
236 pub fn new() -> Result<Self> {
237 let conn = Connection::open_in_memory()?;
238 Self::setup_connection(conn)
239 }
240
241 pub fn new_with_file(db_path: &str) -> Result<Self> {
242 let conn = Connection::open(db_path)?;
243 Self::setup_connection(conn)
244 }
245
246 fn setup_connection(conn: Connection) -> Result<Self> {
247 for table in ALL_TABLES.iter() {
248 conn.execute(&table.create_table_stmt, [])?;
249 }
250 conn.create_scalar_function(
251 "assert",
252 2,
253 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
254 move |ctx| {
255 let condition: bool = ctx.get(0)?;
256 let message: String = ctx.get(1)?;
257
258 if !condition {
259 return Err(rusqlite::Error::UserFunctionError(
260 anyhow!("assertion failed:{condition} {message}",).into(),
261 ));
262 }
263
264 Ok(condition)
265 },
266 )?;
267
268 Ok(Self {
269 conn: Arc::new(Mutex::new(conn)),
270 })
271 }
272
273 fn insert_event(&self, event: &Event<'_>) -> Result<()> {
274 let conn = self.conn.lock().unwrap();
275 match (event.metadata().target(), event.metadata().name()) {
276 (TableName::MESSAGES_STR, _) => {
277 insert_event!(TableName::Messages.get_table(), conn, event);
278 }
279 (TableName::ACTOR_LIFECYCLE_STR, _) => {
280 insert_event!(TableName::ActorLifecycle.get_table(), conn, event);
281 }
282 _ => {
283 insert_event!(TableName::LogEvents.get_table(), conn, event);
284 }
285 }
286 Ok(())
287 }
288
289 pub fn connection(&self) -> Arc<Mutex<Connection>> {
290 self.conn.clone()
291 }
292}
293
294impl<S: Subscriber> Layer<S> for SqliteLayer {
295 fn on_event(&self, event: &Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
296 self.insert_event(event).unwrap();
297 }
298}
299
300#[allow(dead_code)]
301fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
302 let table_name_str = table_name.as_str();
303
304 let mut stmt = conn.prepare(&format!("PRAGMA table_info({})", table_name_str))?;
306 let column_info = stmt.query_map([], |row| {
307 row.get::<_, String>(1) })?;
309
310 let columns: Vec<String> = column_info.collect::<Result<Vec<_>, _>>()?;
311
312 println!("=== {} ===", table_name_str.to_uppercase());
314 println!("{}", columns.join(" | "));
315 println!("{}", "-".repeat(columns.len() * 10));
316
317 let mut stmt = conn.prepare(&format!("SELECT * FROM {}", table_name_str))?;
319 let rows = stmt.query_map([], |row| {
320 let mut values = Vec::new();
321 for (i, column) in columns.iter().enumerate() {
322 let value = if i == 0 && *column == "seq" {
324 match row.get::<_, Option<i64>>(i)? {
326 Some(v) => v.to_string(),
327 None => "NULL".to_string(),
328 }
329 } else {
330 match row.get::<_, Option<String>>(i)? {
332 Some(v) => v,
333 None => "NULL".to_string(),
334 }
335 };
336 values.push(value);
337 }
338 Ok(values.join(" | "))
339 })?;
340
341 for row in rows {
342 println!("{}", row?);
343 }
344 println!();
345 Ok(())
346}
347
348fn init_tracing_subscriber(layer: SqliteLayer) {
349 let handle = RELOAD_HANDLE.lock().unwrap();
350 if let Some(reload_handle) = handle.as_ref() {
351 let _ = reload_handle.reload(layer);
352 } else {
353 tracing_subscriber::registry().with(layer).init();
354 }
355}
356
357pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
361 let (layer, reload_handle) = reload::Layer::new(None);
362 let mut handle = RELOAD_HANDLE.lock().unwrap();
363 *handle = Some(reload_handle);
364 Ok(layer)
365}
366
367pub struct SqliteTracing {
369 db_path: Option<PathBuf>,
370 connection: Arc<Mutex<Connection>>,
371}
372
373impl SqliteTracing {
374 pub fn new() -> Result<Self> {
376 let temp_dir = std::env::temp_dir();
377 let file_name = format!("hyperactor_trace_{}.db", std::process::id());
378 let db_path = temp_dir.join(file_name);
379
380 let db_path_str = db_path.to_string_lossy();
381 let layer = SqliteLayer::new_with_file(&db_path_str)?;
382 let connection = layer.connection();
383
384 if let Ok(metadata) = fs::metadata(&db_path) {
387 let mut permissions = metadata.permissions();
388 permissions.set_mode(0o664); let _ = fs::set_permissions(&db_path, permissions);
390 }
391
392 init_tracing_subscriber(layer);
393
394 Ok(Self {
395 db_path: Some(db_path),
396 connection,
397 })
398 }
399
400 pub fn new_in_memory() -> Result<Self> {
402 let layer = SqliteLayer::new()?;
403 let connection = layer.connection();
404
405 init_tracing_subscriber(layer);
406
407 Ok(Self {
408 db_path: None,
409 connection,
410 })
411 }
412
413 pub fn db_path(&self) -> Option<&PathBuf> {
415 self.db_path.as_ref()
416 }
417
418 pub fn connection(&self) -> Arc<Mutex<Connection>> {
420 self.connection.clone()
421 }
422}
423
424impl Drop for SqliteTracing {
425 fn drop(&mut self) {
426 let handle = RELOAD_HANDLE.lock().unwrap();
428 if let Some(reload_handle) = handle.as_ref() {
429 let _ = reload_handle.reload(None);
430 }
431
432 if let Some(db_path) = &self.db_path {
434 if db_path.exists() {
435 let _ = fs::remove_file(db_path);
436 }
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use tracing::info;
444
445 use super::*;
446
447 #[test]
448 fn test_sqlite_tracing_with_file() -> Result<()> {
449 let tracing = SqliteTracing::new()?;
450 let conn = tracing.connection();
451
452 info!(target:"messages", test_field = "test_value", "Test msg");
453 info!(target:"log_events", test_field = "test_value", "Test event");
454
455 let count: i64 =
456 conn.lock()
457 .unwrap()
458 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
459 print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
460 assert!(count > 0);
461
462 assert!(tracing.db_path().is_some());
464 let db_path = tracing.db_path().unwrap();
465 assert!(db_path.exists());
466
467 Ok(())
468 }
469
470 #[test]
471 fn test_sqlite_tracing_in_memory() -> Result<()> {
472 let tracing = SqliteTracing::new_in_memory()?;
473 let conn = tracing.connection();
474
475 info!(target:"messages", test_field = "test_value", "Test event in memory");
476
477 let count: i64 =
478 conn.lock()
479 .unwrap()
480 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
481 print_table(&conn.lock().unwrap(), TableName::Messages)?;
482 assert!(count > 0);
483
484 assert!(tracing.db_path().is_none());
486
487 Ok(())
488 }
489
490 #[test]
491 fn test_sqlite_tracing_cleanup() -> Result<()> {
492 let db_path = {
493 let tracing = SqliteTracing::new()?;
494 let conn = tracing.connection();
495
496 info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
497
498 let count: i64 =
499 conn.lock()
500 .unwrap()
501 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
502 assert!(count > 0);
503
504 tracing.db_path().unwrap().clone()
505 }; assert!(!db_path.exists());
509
510 Ok(())
511 }
512
513 #[test]
514 fn test_sqlite_tracing_different_targets() -> Result<()> {
515 let tracing = SqliteTracing::new_in_memory()?;
516 let conn = tracing.connection();
517
518 info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
520 info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
521 info!(target:"log_events", test_field = "general_event", "General event");
522
523 let message_count: i64 =
525 conn.lock()
526 .unwrap()
527 .query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
528 assert_eq!(message_count, 1);
529
530 let lifecycle_count: i64 =
531 conn.lock()
532 .unwrap()
533 .query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
534 assert_eq!(lifecycle_count, 1);
535
536 let events_count: i64 =
537 conn.lock()
538 .unwrap()
539 .query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
540 assert_eq!(events_count, 1);
541
542 Ok(())
543 }
544}