monarch_distributed_telemetry/
database_scanner.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! DatabaseScanner - Local MemTable operations, scans with child stream merging
10
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::Mutex as StdMutex;
14use std::time::SystemTime;
15use std::time::UNIX_EPOCH;
16
17use datafusion::arrow::datatypes::SchemaRef;
18use datafusion::arrow::record_batch::RecordBatch;
19use datafusion::datasource::MemTable;
20use datafusion::datasource::TableProvider;
21use datafusion::prelude::SessionContext;
22use hyperactor::Instance;
23use hyperactor::reference;
24use monarch_hyperactor::actor::PythonActor;
25use monarch_hyperactor::context::PyInstance;
26use monarch_hyperactor::mailbox::PyPortId;
27use monarch_hyperactor::runtime::get_tokio_runtime;
28use monarch_record_batch::RecordBatchBuffer;
29use pyo3::exceptions::PyException;
30use pyo3::prelude::*;
31use pyo3::types::PyBytes;
32use pyo3::types::PyModule;
33use serde_multipart::Part;
34
35use crate::EntityDispatcher;
36use crate::QueryResponse;
37use crate::RecordBatchSink;
38use crate::pyspy_table::PySpyDumpBuffer;
39use crate::pyspy_table::PySpyFrameBuffer;
40use crate::pyspy_table::PySpyLocalVariableBuffer;
41use crate::pyspy_table::PySpyStackTraceBuffer;
42use crate::serialize_batch;
43use crate::serialize_schema;
44use crate::timestamp_to_micros;
45
46/// Wraps a table's data so we can dynamically push new batches.
47/// The MemTable is created on initialization and shared with queries.
48pub struct LiveTableData {
49    /// The MemTable that queries use
50    mem_table: Arc<MemTable>,
51}
52
53impl LiveTableData {
54    fn new(schema: SchemaRef) -> Self {
55        let mem_table = MemTable::try_new(schema, vec![vec![]])
56            .expect("failed to create MemTable with empty partition");
57        Self {
58            mem_table: Arc::new(mem_table),
59        }
60    }
61
62    /// Push a new batch to the table.
63    pub async fn push(&self, batch: RecordBatch) {
64        if batch.num_rows() == 0 {
65            return;
66        }
67
68        let partition = &self.mem_table.batches[0];
69        let mut guard = partition.write().await;
70        guard.push(batch);
71    }
72
73    /// Filter the table's data, keeping only rows that match the WHERE clause.
74    ///
75    /// Holds the write lock for the entire operation to prevent data loss
76    /// from concurrent `push()` calls.
77    pub async fn apply_retention(
78        &self,
79        table_name: &str,
80        where_clause: &str,
81    ) -> anyhow::Result<()> {
82        use futures::TryStreamExt;
83
84        let partition = &self.mem_table.batches[0];
85        let mut guard = partition.write().await;
86
87        // Drain current batches into a temporary MemTable for querying.
88        let current_batches: Vec<RecordBatch> = guard.drain(..).collect();
89        let tmp = MemTable::try_new(self.mem_table.schema(), vec![current_batches])?;
90
91        let ctx = SessionContext::new();
92        ctx.register_table(table_name, Arc::new(tmp))?;
93
94        let query = format!("SELECT * FROM {table_name} WHERE {where_clause}");
95        let df = ctx.sql(&query).await?;
96        let filtered: Vec<RecordBatch> = df.execute_stream().await?.try_collect().await?;
97
98        for batch in filtered {
99            if batch.num_rows() > 0 {
100                guard.push(batch);
101            }
102        }
103        Ok(())
104    }
105
106    /// Get the schema.
107    pub fn schema(&self) -> SchemaRef {
108        self.mem_table.schema()
109    }
110
111    /// Get the MemTable for registering with a SessionContext.
112    pub fn mem_table(&self) -> Arc<MemTable> {
113        self.mem_table.clone()
114    }
115}
116
117/// Opaque handle to the shared table storage.
118///
119/// External crates receive this capability via
120/// [`DatabaseScanner::table_store()`]. The raw storage map is not
121/// part of the public API.
122///
123/// # Table-store invariants (TS-*)
124///
125/// - **TS-1 (opaque capability):** External crates do not receive
126///   the raw `Arc<StdMutex<HashMap<...>>>`.
127/// - **TS-2 (behavior parity):** [`TableStore::ingest_batch`]
128///   preserves existing ingestion semantics (ID-1 through ID-6).
129/// - **TS-3 (read capability minimality):** [`table_names`](Self::table_names)
130///   and [`table_provider`](Self::table_provider) expose only what
131///   downstream query setup needs. Callers receive
132///   `Arc<dyn TableProvider>`, not the backing `MemTable`.
133/// - **TS-4 (ownership preserved):** Storage ownership remains in
134///   `monarch_distributed_telemetry`. `TableStore` is a handle, not
135///   an independent store.
136#[derive(Clone)]
137pub struct TableStore {
138    inner: Arc<StdMutex<HashMap<String, Arc<LiveTableData>>>>,
139}
140
141impl TableStore {
142    /// Create an empty standalone table store.
143    ///
144    /// Useful for testing or standalone ingestion scenarios where
145    /// the full [`DatabaseScanner`] lifecycle is not needed.
146    pub fn new_empty() -> Self {
147        Self {
148            inner: Arc::new(StdMutex::new(HashMap::new())),
149        }
150    }
151
152    /// Ingest a `RecordBatch` into a named table (TS-2).
153    ///
154    /// Async so callers in async contexts can await directly without
155    /// hitting the `block_in_place` bridge in `push_batch_to_tables`.
156    ///
157    /// See the ID-* invariants on
158    /// `DatabaseScanner::push_batch_to_tables` for behavioral
159    /// guarantees (this method preserves the same semantics).
160    pub async fn ingest_batch(&self, table_name: &str, batch: RecordBatch) -> anyhow::Result<()> {
161        let table = {
162            let mut guard = self
163                .inner
164                .lock()
165                .map_err(|_| anyhow::anyhow!("lock poisoned"))?;
166            guard
167                .entry(table_name.to_string())
168                .or_insert_with(|| Arc::new(LiveTableData::new(batch.schema())))
169                .clone()
170        };
171        table.push(batch).await;
172        Ok(())
173    }
174
175    /// Return sorted table names currently in storage (TS-3).
176    pub fn table_names(&self) -> anyhow::Result<Vec<String>> {
177        let guard = self
178            .inner
179            .lock()
180            .map_err(|_| anyhow::anyhow!("lock poisoned"))?;
181        let mut names: Vec<String> = guard.keys().cloned().collect();
182        names.sort();
183        Ok(names)
184    }
185
186    /// Return a [`TableProvider`] for a named table, or `None` if
187    /// the table does not exist (TS-3).
188    ///
189    /// The returned provider can be registered directly with a
190    /// DataFusion `SessionContext`. Callers do not see the backing
191    /// storage type.
192    pub fn table_provider(
193        &self,
194        table_name: &str,
195    ) -> anyhow::Result<Option<Arc<dyn TableProvider>>> {
196        let guard = self
197            .inner
198            .lock()
199            .map_err(|_| anyhow::anyhow!("lock poisoned"))?;
200        Ok(guard
201            .get(table_name)
202            .map(|t| t.mem_table() as Arc<dyn TableProvider>))
203    }
204}
205
206/// Default retention duration: 10 minutes in seconds.
207const DEFAULT_RETENTION_SECS: u64 = 10 * 60;
208
209/// Tables that keep only recent data; all others have unlimited retention.
210const RETENTION_TABLES: &[&str] = &["sent_messages", "messages", "message_status_events"];
211
212#[pyclass(
213    name = "DatabaseScanner",
214    module = "monarch._rust_bindings.monarch_distributed_telemetry.database_scanner"
215)]
216pub struct DatabaseScanner {
217    /// Tables stored by name - each holds the schema and shared PartitionData
218    table_data: Arc<StdMutex<HashMap<String, Arc<LiveTableData>>>>,
219    rank: usize,
220    /// Retention window in microseconds.
221    retention_us: i64,
222    /// Handle to flush the RecordBatchSink for trace events (spans, events)
223    sink: Option<RecordBatchSink>,
224    /// Handle to flush the EntityDispatcher for entity events (actors, meshes)
225    dispatcher: Option<EntityDispatcher>,
226}
227
228#[pymethods]
229impl DatabaseScanner {
230    #[new]
231    #[pyo3(signature = (rank, batch_size=1000, retention_secs=DEFAULT_RETENTION_SECS))]
232    fn new(rank: usize, batch_size: usize, retention_secs: u64) -> PyResult<Self> {
233        let mut scanner = Self {
234            table_data: Arc::new(StdMutex::new(HashMap::new())),
235            rank,
236            retention_us: retention_secs as i64 * 1_000_000,
237            sink: None,
238            dispatcher: None,
239        };
240
241        // Create and register a RecordBatchSink for trace events (spans, events)
242        let sink = scanner.create_record_batch_sink(batch_size);
243        scanner.sink = Some(sink.clone());
244        hyperactor_telemetry::register_sink(Box::new(sink));
245
246        // Create and register an EntityDispatcher for entity events (actors, meshes)
247        let dispatcher = scanner.create_entity_dispatcher(batch_size);
248        scanner.dispatcher = Some(dispatcher.clone());
249        hyperactor_telemetry::set_entity_dispatcher(Box::new(dispatcher));
250
251        // Pre-register py-spy tables so QueryEngine discovers them at setup time
252        for (name, batch) in [
253            (
254                "pyspy_dumps",
255                PySpyDumpBuffer::default().drain_to_record_batch().unwrap(),
256            ),
257            (
258                "pyspy_stack_traces",
259                PySpyStackTraceBuffer::default()
260                    .drain_to_record_batch()
261                    .unwrap(),
262            ),
263            (
264                "pyspy_frames",
265                PySpyFrameBuffer::default().drain_to_record_batch().unwrap(),
266            ),
267            (
268                "pyspy_local_variables",
269                PySpyLocalVariableBuffer::default()
270                    .drain_to_record_batch()
271                    .unwrap(),
272            ),
273        ] {
274            Self::push_batch_to_tables(&scanner.table_data, name, batch).unwrap();
275        }
276
277        Ok(scanner)
278    }
279
280    /// Flush any pending trace events and entity events to the tables,
281    /// then apply time-based retention policies.
282    fn flush(&self) -> PyResult<()> {
283        if let Some(ref sink) = self.sink {
284            sink.flush()
285                .map_err(|e| PyException::new_err(format!("failed to flush sink: {}", e)))?;
286        }
287        if let Some(ref dispatcher) = self.dispatcher {
288            dispatcher
289                .flush()
290                .map_err(|e| PyException::new_err(format!("failed to flush dispatcher: {}", e)))?;
291        }
292        self.apply_retention_policies()?;
293        Ok(())
294    }
295
296    /// Filter a single table, keeping only rows that match the WHERE clause.
297    fn apply_retention(&self, table_name: &str, where_clause: &str) -> PyResult<()> {
298        let table = {
299            let guard = self
300                .table_data
301                .lock()
302                .map_err(|_| PyException::new_err("lock poisoned"))?;
303            match guard.get(table_name) {
304                Some(t) => t.clone(),
305                None => return Ok(()),
306            }
307        };
308
309        let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
310            tokio::task::block_in_place(|| {
311                handle.block_on(table.apply_retention(table_name, where_clause))
312            })
313        } else {
314            get_tokio_runtime().block_on(table.apply_retention(table_name, where_clause))
315        };
316        result.map_err(|e| PyException::new_err(e.to_string()))
317    }
318
319    /// Get list of table names.
320    fn table_names(&self) -> PyResult<Vec<String>> {
321        self.flush()?;
322        let guard = self
323            .table_data
324            .lock()
325            .map_err(|_| PyException::new_err("lock poisoned"))?;
326        Ok(guard.keys().cloned().collect())
327    }
328
329    /// Get schema for a table in Arrow IPC format.
330    fn schema_for<'py>(&self, py: Python<'py>, table: &str) -> PyResult<Bound<'py, PyBytes>> {
331        self.flush()?;
332        let guard = self
333            .table_data
334            .lock()
335            .map_err(|_| PyException::new_err("lock poisoned"))?;
336        let table_data = guard
337            .get(table)
338            .ok_or_else(|| PyException::new_err(format!("table '{}' not found", table)))?;
339        let schema = table_data.schema();
340        let bytes = serialize_schema(&schema).map_err(|e| PyException::new_err(e.to_string()))?;
341        Ok(PyBytes::new(py, &bytes))
342    }
343
344    /// Store a py-spy dump result into the pyspy_stacks table.
345    fn store_pyspy_dump_py(
346        &self,
347        dump_id: &str,
348        proc_ref: &str,
349        pyspy_result_json: &str,
350    ) -> PyResult<()> {
351        self.store_pyspy_dump(dump_id, proc_ref, pyspy_result_json)
352            .map_err(|e| PyException::new_err(e.to_string()))
353    }
354
355    /// Perform a scan, sending results directly to the dest port.
356    ///
357    /// Sends local scan results to `dest` synchronously. The Python caller
358    /// is responsible for calling children and waiting for them to complete.
359    /// When this method and all child scans return, all data has been sent.
360    ///
361    /// Args:
362    ///     dest: The destination PortId to send results to
363    ///     table_name: Name of the table to scan
364    ///     projection: Optional list of column indices to project
365    ///     limit: Optional row limit
366    ///     filter_expr: Optional SQL WHERE clause
367    ///
368    /// Returns:
369    ///     Number of batches sent
370    fn scan(
371        &self,
372        py: Python<'_>,
373        dest: &PyPortId,
374        table_name: String,
375        projection: Option<Vec<usize>>,
376        limit: Option<usize>,
377        filter_expr: Option<String>,
378    ) -> PyResult<usize> {
379        self.flush()?;
380
381        // Get actor instance from context and extract the Rust Instance once
382        let actor_module = py.import("monarch.actor")?;
383        let ctx = actor_module.call_method0("context")?;
384        let actor_instance_obj = ctx.getattr("actor_instance")?;
385        let py_instance: PyRef<'_, PyInstance> = actor_instance_obj.extract()?;
386        let instance: Instance<PythonActor> = py_instance.clone_for_py();
387
388        // Build destination PortRef once
389        let dest_port_id: reference::PortId = dest.clone().into();
390        let dest_ref: reference::PortRef<QueryResponse> = reference::PortRef::attest(dest_port_id);
391
392        // Execute scan, streaming batches directly to destination
393        self.execute_scan_streaming(
394            &table_name,
395            projection,
396            filter_expr,
397            limit,
398            &instance,
399            &dest_ref,
400        )
401    }
402}
403
404impl DatabaseScanner {
405    /// Push a batch into the named table in `table_data`.
406    ///
407    /// # Ingestion invariants (ID-*)
408    ///
409    /// - **ID-1 (create on first batch):** If `table_name` is absent,
410    ///   a new `LiveTableData` is created from `batch.schema()`.
411    /// - **ID-2 (empty batch registers schema):** An empty batch
412    ///   creates the table entry and preserves the schema —
413    ///   `LiveTableData::push` is a no-op for zero rows, but the
414    ///   `entry().or_insert_with()` runs unconditionally.
415    /// - **ID-3 (append on existing table):** A non-empty batch for
416    ///   an existing table appends rows.
417    /// - **ID-4 (error surface):** Lock poisoning propagates as
418    ///   `Err`. `push()` itself is infallible.
419    fn push_batch_to_tables(
420        table_data: &Arc<StdMutex<HashMap<String, Arc<LiveTableData>>>>,
421        table_name: &str,
422        batch: RecordBatch,
423    ) -> anyhow::Result<()> {
424        let table = {
425            let mut guard = table_data
426                .lock()
427                .map_err(|_| anyhow::anyhow!("lock poisoned"))?;
428            guard
429                .entry(table_name.to_string())
430                .or_insert_with(|| Arc::new(LiveTableData::new(batch.schema())))
431                .clone()
432        };
433
434        // Push the batch (push ignores empty batches).
435        // Use block_in_place + Handle::current() when called from within a tokio
436        // runtime (e.g., from notify_sent_message on a worker thread), otherwise
437        // fall back to creating/reusing a runtime via get_tokio_runtime().
438        if let Ok(handle) = tokio::runtime::Handle::try_current() {
439            tokio::task::block_in_place(|| handle.block_on(table.push(batch)));
440        } else {
441            get_tokio_runtime().block_on(table.push(batch));
442        }
443        Ok(())
444    }
445
446    /// Create a RecordBatchSink that pushes batches to this scanner's tables.
447    ///
448    /// The sink can be registered with hyperactor_telemetry::register_sink()
449    /// to receive trace events and store them as queryable tables.
450    pub fn create_record_batch_sink(&self, batch_size: usize) -> RecordBatchSink {
451        let table_data = self.table_data.clone();
452
453        RecordBatchSink::new(
454            batch_size,
455            Box::new(move |table_name, batch| {
456                if let Err(e) = Self::push_batch_to_tables(&table_data, table_name, batch) {
457                    tracing::error!("Failed to push batch to table {}: {}", table_name, e);
458                }
459            }),
460        )
461    }
462
463    /// Create an EntityDispatcher that pushes batches to this scanner's tables.
464    ///
465    /// The dispatcher can be registered with hyperactor_telemetry::set_entity_dispatcher()
466    /// to receive entity events (actors, meshes) and store them as queryable tables.
467    pub fn create_entity_dispatcher(&self, batch_size: usize) -> EntityDispatcher {
468        let table_data = self.table_data.clone();
469
470        EntityDispatcher::new(
471            batch_size,
472            Box::new(move |table_name, batch| {
473                if let Err(e) = Self::push_batch_to_tables(&table_data, table_name, batch) {
474                    tracing::error!("Failed to push batch to table {}: {}", table_name, e);
475                }
476            }),
477        )
478    }
479
480    /// Parse a py-spy result JSON and store data in normalized py-spy tables.
481    ///
482    /// Populates four tables matching the `hyperactor_mesh::pyspy` structs:
483    /// - `pyspy_dumps`: one row per dump
484    /// - `pyspy_stack_traces`: one row per thread (matches `PySpyStackTrace`)
485    /// - `pyspy_frames`: one row per frame (matches `PySpyFrame`)
486    /// - `pyspy_local_variables`: one row per local variable (matches `PySpyLocalVariable`)
487    ///
488    /// Design notes:
489    /// - Non-Ok results (`BinaryNotFound`, `Failed`) are silently dropped.
490    ///   We intentionally do not record them as structured telemetry today;
491    ///   the caller can log or count those cases if needed.
492    /// - `dump_id` is caller-provided; uniqueness is the caller's responsibility.
493    /// - `timestamp_us` records ingestion time, not py-spy capture time (the
494    ///   py-spy JSON carries no capture timestamp).
495    /// - We parse via `serde_json::Value` rather than importing the typed
496    ///   `PySpyResult` to avoid a crate dependency on `hyperactor_mesh`. The
497    ///   tradeoff is that schema drift in the py-spy structs will not be caught
498    ///   at compile time.
499    pub fn store_pyspy_dump(
500        &self,
501        dump_id: &str,
502        proc_ref: &str,
503        pyspy_result_json: &str,
504    ) -> anyhow::Result<()> {
505        use monarch_record_batch::RecordBatchBuffer;
506
507        use crate::pyspy_table::PySpyDump;
508        use crate::pyspy_table::PySpyDumpBuffer;
509        use crate::pyspy_table::PySpyFrame;
510        use crate::pyspy_table::PySpyFrameBuffer;
511        use crate::pyspy_table::PySpyLocalVariable;
512        use crate::pyspy_table::PySpyLocalVariableBuffer;
513        use crate::pyspy_table::PySpyStackTrace;
514        use crate::pyspy_table::PySpyStackTraceBuffer;
515
516        let value: serde_json::Value = serde_json::from_str(pyspy_result_json)?;
517        let ok = match value.get("Ok") {
518            Some(ok) => ok,
519            None => return Ok(()),
520        };
521
522        let pid = ok.get("pid").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
523        let binary = ok
524            .get("binary")
525            .and_then(|v| v.as_str())
526            .unwrap_or("")
527            .to_string();
528        let traces = ok.get("stack_traces").and_then(|v| v.as_array());
529
530        let now_us = timestamp_to_micros(&SystemTime::now());
531
532        // Insert dump row
533        let mut dump_buf = PySpyDumpBuffer::default();
534        dump_buf.insert(PySpyDump {
535            dump_id: dump_id.to_string(),
536            timestamp_us: now_us,
537            pid,
538            binary,
539            proc_ref: proc_ref.to_string(),
540        });
541        Self::push_batch_to_tables(
542            &self.table_data,
543            "pyspy_dumps",
544            dump_buf.drain_to_record_batch()?,
545        )?;
546
547        // Insert stack trace, frame, and local variable rows
548        let mut trace_buf = PySpyStackTraceBuffer::default();
549        let mut frame_buf = PySpyFrameBuffer::default();
550        let mut local_buf = PySpyLocalVariableBuffer::default();
551
552        if let Some(traces) = traces {
553            for trace in traces {
554                let thread_id = trace.get("thread_id").and_then(|v| v.as_u64()).unwrap_or(0);
555
556                trace_buf.insert(PySpyStackTrace {
557                    dump_id: dump_id.to_string(),
558                    pid: trace
559                        .get("pid")
560                        .and_then(|v| v.as_i64())
561                        .unwrap_or(pid as i64) as i32,
562                    thread_id,
563                    thread_name: trace
564                        .get("thread_name")
565                        .and_then(|v| v.as_str())
566                        .map(String::from),
567                    os_thread_id: trace.get("os_thread_id").and_then(|v| v.as_u64()),
568                    active: trace
569                        .get("active")
570                        .and_then(|v| v.as_bool())
571                        .unwrap_or(false),
572                    owns_gil: trace
573                        .get("owns_gil")
574                        .and_then(|v| v.as_bool())
575                        .unwrap_or(false),
576                });
577
578                if let Some(frames) = trace.get("frames").and_then(|v| v.as_array()) {
579                    for (depth, frame) in frames.iter().enumerate() {
580                        frame_buf.insert(PySpyFrame {
581                            dump_id: dump_id.to_string(),
582                            thread_id,
583                            frame_depth: depth as i32,
584                            name: frame
585                                .get("name")
586                                .and_then(|v| v.as_str())
587                                .unwrap_or("")
588                                .to_string(),
589                            filename: frame
590                                .get("filename")
591                                .and_then(|v| v.as_str())
592                                .unwrap_or("")
593                                .to_string(),
594                            module: frame
595                                .get("module")
596                                .and_then(|v| v.as_str())
597                                .map(String::from),
598                            short_filename: frame
599                                .get("short_filename")
600                                .and_then(|v| v.as_str())
601                                .map(String::from),
602                            line: frame.get("line").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
603                            is_entry: frame
604                                .get("is_entry")
605                                .and_then(|v| v.as_bool())
606                                .unwrap_or(false),
607                        });
608
609                        if let Some(locals) = frame.get("locals").and_then(|v| v.as_array()) {
610                            for local in locals {
611                                local_buf.insert(PySpyLocalVariable {
612                                    dump_id: dump_id.to_string(),
613                                    thread_id,
614                                    frame_depth: depth as i32,
615                                    name: local
616                                        .get("name")
617                                        .and_then(|v| v.as_str())
618                                        .unwrap_or("")
619                                        .to_string(),
620                                    addr: local.get("addr").and_then(|v| v.as_u64()).unwrap_or(0),
621                                    arg: local
622                                        .get("arg")
623                                        .and_then(|v| v.as_bool())
624                                        .unwrap_or(false),
625                                    repr: local
626                                        .get("repr")
627                                        .and_then(|v| v.as_str())
628                                        .map(String::from),
629                                });
630                            }
631                        }
632                    }
633                }
634            }
635        }
636
637        Self::push_batch_to_tables(
638            &self.table_data,
639            "pyspy_stack_traces",
640            trace_buf.drain_to_record_batch()?,
641        )?;
642        Self::push_batch_to_tables(
643            &self.table_data,
644            "pyspy_frames",
645            frame_buf.drain_to_record_batch()?,
646        )?;
647        Self::push_batch_to_tables(
648            &self.table_data,
649            "pyspy_local_variables",
650            local_buf.drain_to_record_batch()?,
651        )?;
652        Ok(())
653    }
654
655    /// Apply retention policies for all configured tables.
656    /// Skipped when retention_us is 0 (unlimited).
657    fn apply_retention_policies(&self) -> PyResult<()> {
658        if self.retention_us == 0 {
659            return Ok(());
660        }
661
662        let now_us = SystemTime::now()
663            .duration_since(UNIX_EPOCH)
664            .expect("system clock before unix epoch")
665            .as_micros() as i64;
666        let cutoff = now_us - self.retention_us;
667        let where_clause = format!("timestamp_us > {cutoff}");
668
669        for &table_name in RETENTION_TABLES {
670            self.apply_retention(table_name, &where_clause)?;
671        }
672        Ok(())
673    }
674
675    /// Return an opaque [`TableStore`] handle for external callers.
676    pub fn table_store(&self) -> TableStore {
677        TableStore {
678            inner: self.table_data.clone(),
679        }
680    }
681
682    fn execute_scan_streaming(
683        &self,
684        table_name: &str,
685        projection: Option<Vec<usize>>,
686        where_clause: Option<String>,
687        limit: Option<usize>,
688        instance: &Instance<PythonActor>,
689        dest_ref: &reference::PortRef<QueryResponse>,
690    ) -> PyResult<usize> {
691        let rank = self.rank;
692
693        // Get the LiveTableData's MemTable
694        let (schema, mem_table) = {
695            let guard = self
696                .table_data
697                .lock()
698                .map_err(|_| PyException::new_err("lock poisoned"))?;
699            let table_data = guard
700                .get(table_name)
701                .ok_or_else(|| PyException::new_err(format!("table '{}' not found", table_name)))?;
702            (table_data.schema(), table_data.mem_table())
703        };
704
705        // Handle empty projection (e.g., for COUNT(*) queries)
706        // DataFusion may request 0 columns but we still need row counts
707        let is_empty_projection = matches!(&projection, Some(proj) if proj.is_empty());
708
709        // Build a query using DataFusion
710        let ctx = SessionContext::new();
711        ctx.register_table(table_name, mem_table)
712            .map_err(|e| PyException::new_err(e.to_string()))?;
713
714        // Build SELECT clause - for empty projection, use NULL as fake_column
715        let columns = match &projection {
716            Some(proj) if !proj.is_empty() => {
717                let selected: Vec<_> = proj
718                    .iter()
719                    .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
720                    .collect();
721                if selected.is_empty() {
722                    "*".into()
723                } else {
724                    selected.join(", ")
725                }
726            }
727            Some(_) => "NULL as fake_column".into(),
728            _ => "*".into(),
729        };
730
731        let query = format!(
732            "SELECT {} FROM {}{}{}",
733            columns,
734            table_name,
735            where_clause
736                .map(|c| format!(" WHERE {}", c))
737                .unwrap_or_default(),
738            limit.map(|n| format!(" LIMIT {}", n)).unwrap_or_default()
739        );
740
741        // Execute and stream batches directly to destination
742        let batch_count = get_tokio_runtime()
743            .block_on(async {
744                use futures::StreamExt;
745
746                let df = ctx.sql(&query).await?;
747                let mut stream = df.execute_stream().await?;
748                let mut count: usize = 0;
749
750                while let Some(result) = stream.next().await {
751                    let batch = result?;
752
753                    // For empty projection, project to empty schema
754                    let batch = if is_empty_projection {
755                        batch.project(&[])?
756                    } else {
757                        batch
758                    };
759
760                    if let Ok(data) = serialize_batch(&batch) {
761                        tracing::info!("Scanner {}: sending batch {}", rank, count);
762                        let msg = QueryResponse {
763                            data: Part::from(data),
764                        };
765                        if let Err(e) = dest_ref.send(instance, msg) {
766                            tracing::debug!(
767                                "Scanner {}: send error for batch {}: {:?}",
768                                rank,
769                                count,
770                                e
771                            );
772                        }
773                        count += 1;
774                    }
775                }
776
777                tracing::info!(
778                    "Scanner {}: local scan complete, sent {} batches",
779                    rank,
780                    count
781                );
782                Ok::<usize, datafusion::error::DataFusionError>(count)
783            })
784            .map_err(|e| PyException::new_err(e.to_string()))?;
785
786        Ok(batch_count)
787    }
788}
789
790pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
791    module.add_class::<DatabaseScanner>()?;
792    Ok(())
793}
794
795#[cfg(test)]
796mod tests {
797    use std::sync::Arc;
798
799    use datafusion::arrow::array::Array;
800    use datafusion::arrow::array::BooleanArray;
801    use datafusion::arrow::array::Int32Array;
802    use datafusion::arrow::array::Int64Array;
803    use datafusion::arrow::array::StringArray;
804    use datafusion::arrow::array::UInt64Array;
805    use datafusion::arrow::datatypes::DataType;
806    use datafusion::arrow::datatypes::Field;
807    use datafusion::arrow::datatypes::Schema;
808    use datafusion::arrow::record_batch::RecordBatch;
809
810    use super::*;
811
812    fn make_batch(values: &[i64]) -> RecordBatch {
813        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
814        let col = Int64Array::from(values.to_vec());
815        RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap()
816    }
817
818    async fn row_count(table: &LiveTableData) -> usize {
819        table.mem_table.batches[0]
820            .read()
821            .await
822            .iter()
823            .map(|b| b.num_rows())
824            .sum()
825    }
826
827    #[tokio::test]
828    async fn test_empty_batch_ignored() {
829        let table = LiveTableData::new(make_batch(&[]).schema());
830
831        table.push(make_batch(&[])).await;
832        assert_eq!(row_count(&table).await, 0);
833    }
834
835    #[tokio::test]
836    async fn test_apply_retention_filters_rows() {
837        // Push rows with x values 1..=5, then keep only x >= 3.
838        let table = LiveTableData::new(make_batch(&[]).schema());
839        table.push(make_batch(&[1, 2, 3, 4, 5])).await;
840
841        table.apply_retention("t", "x >= 3").await.unwrap();
842
843        // 3 rows should remain (3, 4, 5).
844        assert_eq!(row_count(&table).await, 3);
845    }
846
847    #[tokio::test]
848    async fn test_apply_retention_keeps_all() {
849        let table = LiveTableData::new(make_batch(&[]).schema());
850        table.push(make_batch(&[1, 2, 3])).await;
851
852        table.apply_retention("t", "1=1").await.unwrap();
853
854        assert_eq!(row_count(&table).await, 3);
855    }
856
857    #[tokio::test]
858    async fn test_concurrent_push_during_retention() {
859        // Verify that a push() concurrent with apply_retention() is not lost.
860        let table = Arc::new(LiveTableData::new(make_batch(&[]).schema()));
861        table.push(make_batch(&[1, 2, 3, 4, 5])).await;
862
863        let table_clone = table.clone();
864        let push_handle = tokio::spawn(async move {
865            // This push races with apply_retention. The write lock ensures
866            // it either completes before or after retention, never lost.
867            table_clone.push(make_batch(&[10, 11])).await;
868        });
869
870        // Retain only x >= 3 from the original batch.
871        table.apply_retention("t", "x >= 3").await.unwrap();
872        push_handle.await.unwrap();
873
874        // The pushed batch (10, 11) must survive regardless of ordering.
875        // If push ran first: 1,2,3,4,5,10,11 -> retain x>=3 -> 3,4,5,10,11 = 5 rows
876        // If push ran after: 1,2,3,4,5 -> retain x>=3 -> 3,4,5 -> push 10,11 = 5 rows
877        assert_eq!(row_count(&table).await, 5);
878    }
879
880    fn table_row_count(scanner: &DatabaseScanner, table_name: &str) -> usize {
881        let guard = scanner.table_data.lock().unwrap();
882        match guard.get(table_name) {
883            Some(table) => get_tokio_runtime().block_on(async {
884                table.mem_table().batches[0]
885                    .read()
886                    .await
887                    .iter()
888                    .map(|b| b.num_rows())
889                    .sum::<usize>()
890            }),
891            None => 0,
892        }
893    }
894
895    fn table_batches(scanner: &DatabaseScanner, table_name: &str) -> Vec<RecordBatch> {
896        let guard = scanner.table_data.lock().unwrap();
897        match guard.get(table_name) {
898            Some(table) => get_tokio_runtime()
899                .block_on(async { table.mem_table().batches[0].read().await.clone() }),
900            None => vec![],
901        }
902    }
903
904    #[test]
905    fn test_store_pyspy_dump_creates_normalized_rows() {
906        let scanner = DatabaseScanner {
907            table_data: Arc::new(StdMutex::new(HashMap::new())),
908            rank: 0,
909            retention_us: 0,
910            sink: None,
911            dispatcher: None,
912        };
913
914        let json = r#"{
915            "Ok": {
916                "pid": 1234, "binary": "python3",
917                "stack_traces": [{
918                    "pid": 1234, "thread_id": 100,
919                    "thread_name": "MainThread", "os_thread_id": 5678,
920                    "active": true, "owns_gil": true,
921                    "frames": [
922                        {"name": "inner", "filename": "a.py", "module": "a",
923                         "short_filename": "a.py", "line": 10, "locals": [
924                            {"name": "x", "addr": 100, "arg": true, "repr": "42"},
925                            {"name": "y", "addr": 200, "arg": false, "repr": null}
926                         ], "is_entry": false},
927                        {"name": "outer", "filename": "a.py", "module": "a",
928                         "short_filename": "a.py", "line": 5, "locals": [
929                            {"name": "z", "addr": 300, "arg": true, "repr": "'hello'"}
930                         ], "is_entry": true}
931                    ]
932                }],
933                "warnings": []
934            }
935        }"#;
936
937        scanner.store_pyspy_dump("dump-1", "proc[0]", json).unwrap();
938
939        assert_eq!(table_row_count(&scanner, "pyspy_dumps"), 1);
940        assert_eq!(table_row_count(&scanner, "pyspy_stack_traces"), 1);
941        assert_eq!(table_row_count(&scanner, "pyspy_frames"), 2);
942        assert_eq!(table_row_count(&scanner, "pyspy_local_variables"), 3);
943
944        // Verify pyspy_dumps content
945        let batches = table_batches(&scanner, "pyspy_dumps");
946        let batch = &batches[0];
947        let dump_ids = batch
948            .column_by_name("dump_id")
949            .unwrap()
950            .as_any()
951            .downcast_ref::<StringArray>()
952            .unwrap();
953        let pids = batch
954            .column_by_name("pid")
955            .unwrap()
956            .as_any()
957            .downcast_ref::<Int32Array>()
958            .unwrap();
959        let binaries = batch
960            .column_by_name("binary")
961            .unwrap()
962            .as_any()
963            .downcast_ref::<StringArray>()
964            .unwrap();
965        let proc_refs = batch
966            .column_by_name("proc_ref")
967            .unwrap()
968            .as_any()
969            .downcast_ref::<StringArray>()
970            .unwrap();
971        assert_eq!(dump_ids.value(0), "dump-1");
972        assert_eq!(pids.value(0), 1234);
973        assert_eq!(binaries.value(0), "python3");
974        assert_eq!(proc_refs.value(0), "proc[0]");
975
976        // Verify pyspy_stack_traces content
977        let batches = table_batches(&scanner, "pyspy_stack_traces");
978        let batch = &batches[0];
979        let dump_ids = batch
980            .column_by_name("dump_id")
981            .unwrap()
982            .as_any()
983            .downcast_ref::<StringArray>()
984            .unwrap();
985        let thread_ids = batch
986            .column_by_name("thread_id")
987            .unwrap()
988            .as_any()
989            .downcast_ref::<UInt64Array>()
990            .unwrap();
991        let thread_names = batch
992            .column_by_name("thread_name")
993            .unwrap()
994            .as_any()
995            .downcast_ref::<StringArray>()
996            .unwrap();
997        let os_thread_ids = batch
998            .column_by_name("os_thread_id")
999            .unwrap()
1000            .as_any()
1001            .downcast_ref::<UInt64Array>()
1002            .unwrap();
1003        let actives = batch
1004            .column_by_name("active")
1005            .unwrap()
1006            .as_any()
1007            .downcast_ref::<BooleanArray>()
1008            .unwrap();
1009        let owns_gils = batch
1010            .column_by_name("owns_gil")
1011            .unwrap()
1012            .as_any()
1013            .downcast_ref::<BooleanArray>()
1014            .unwrap();
1015        assert_eq!(dump_ids.value(0), "dump-1");
1016        assert_eq!(thread_ids.value(0), 100);
1017        assert_eq!(thread_names.value(0), "MainThread");
1018        assert_eq!(os_thread_ids.value(0), 5678);
1019        assert!(actives.value(0), "thread should be active");
1020        assert!(owns_gils.value(0), "thread should own GIL");
1021
1022        // Verify pyspy_frames content (2 rows: inner at depth 0, outer at depth 1)
1023        let batches = table_batches(&scanner, "pyspy_frames");
1024        let batch = &batches[0];
1025        let names = batch
1026            .column_by_name("name")
1027            .unwrap()
1028            .as_any()
1029            .downcast_ref::<StringArray>()
1030            .unwrap();
1031        let filenames = batch
1032            .column_by_name("filename")
1033            .unwrap()
1034            .as_any()
1035            .downcast_ref::<StringArray>()
1036            .unwrap();
1037        let depths = batch
1038            .column_by_name("frame_depth")
1039            .unwrap()
1040            .as_any()
1041            .downcast_ref::<Int32Array>()
1042            .unwrap();
1043        let lines = batch
1044            .column_by_name("line")
1045            .unwrap()
1046            .as_any()
1047            .downcast_ref::<Int32Array>()
1048            .unwrap();
1049        let is_entries = batch
1050            .column_by_name("is_entry")
1051            .unwrap()
1052            .as_any()
1053            .downcast_ref::<BooleanArray>()
1054            .unwrap();
1055        assert_eq!(names.value(0), "inner");
1056        assert_eq!(filenames.value(0), "a.py");
1057        assert_eq!(depths.value(0), 0);
1058        assert_eq!(lines.value(0), 10);
1059        assert!(!is_entries.value(0), "inner frame is not entry");
1060        assert_eq!(names.value(1), "outer");
1061        assert_eq!(filenames.value(1), "a.py");
1062        assert_eq!(depths.value(1), 1);
1063        assert_eq!(lines.value(1), 5);
1064        assert!(is_entries.value(1), "outer frame is entry");
1065
1066        // Verify pyspy_local_variables content (3 rows)
1067        let batches = table_batches(&scanner, "pyspy_local_variables");
1068        let batch = &batches[0];
1069        let dump_ids = batch
1070            .column_by_name("dump_id")
1071            .unwrap()
1072            .as_any()
1073            .downcast_ref::<StringArray>()
1074            .unwrap();
1075        let thread_ids = batch
1076            .column_by_name("thread_id")
1077            .unwrap()
1078            .as_any()
1079            .downcast_ref::<UInt64Array>()
1080            .unwrap();
1081        let depths = batch
1082            .column_by_name("frame_depth")
1083            .unwrap()
1084            .as_any()
1085            .downcast_ref::<Int32Array>()
1086            .unwrap();
1087        let var_names = batch
1088            .column_by_name("name")
1089            .unwrap()
1090            .as_any()
1091            .downcast_ref::<StringArray>()
1092            .unwrap();
1093        let addrs = batch
1094            .column_by_name("addr")
1095            .unwrap()
1096            .as_any()
1097            .downcast_ref::<UInt64Array>()
1098            .unwrap();
1099        let args = batch
1100            .column_by_name("arg")
1101            .unwrap()
1102            .as_any()
1103            .downcast_ref::<BooleanArray>()
1104            .unwrap();
1105        let reprs = batch
1106            .column_by_name("repr")
1107            .unwrap()
1108            .as_any()
1109            .downcast_ref::<StringArray>()
1110            .unwrap();
1111        // Row 0: x, addr=100, arg=true, repr=Some("42")
1112        assert_eq!(dump_ids.value(0), "dump-1");
1113        assert_eq!(thread_ids.value(0), 100);
1114        assert_eq!(depths.value(0), 0);
1115        assert_eq!(var_names.value(0), "x");
1116        assert_eq!(addrs.value(0), 100);
1117        assert!(args.value(0), "x is an argument");
1118        assert_eq!(reprs.value(0), "42");
1119        assert!(!reprs.is_null(0), "x repr should be Some");
1120        // Row 1: y, addr=200, arg=false, repr=None
1121        assert_eq!(dump_ids.value(1), "dump-1");
1122        assert_eq!(thread_ids.value(1), 100);
1123        assert_eq!(depths.value(1), 0);
1124        assert_eq!(var_names.value(1), "y");
1125        assert_eq!(addrs.value(1), 200);
1126        assert!(!args.value(1), "y is not an argument");
1127        assert!(reprs.is_null(1), "y repr should be None");
1128        // Row 2: z, addr=300, arg=true, repr=Some("'hello'")
1129        assert_eq!(dump_ids.value(2), "dump-1");
1130        assert_eq!(thread_ids.value(2), 100);
1131        assert_eq!(depths.value(2), 1);
1132        assert_eq!(var_names.value(2), "z");
1133        assert_eq!(addrs.value(2), 300);
1134        assert!(args.value(2), "z is an argument");
1135        assert_eq!(reprs.value(2), "'hello'");
1136        assert!(!reprs.is_null(2), "z repr should be Some");
1137    }
1138
1139    #[test]
1140    fn test_store_pyspy_dump_failed_result_no_rows() {
1141        let scanner = DatabaseScanner {
1142            table_data: Arc::new(StdMutex::new(HashMap::new())),
1143            rank: 0,
1144            retention_us: 0,
1145            sink: None,
1146            dispatcher: None,
1147        };
1148
1149        let json =
1150            r#"{"Failed": {"pid": 1, "binary": "py-spy", "exit_code": 1, "stderr": "error"}}"#;
1151        scanner.store_pyspy_dump("dump-2", "proc[0]", json).unwrap();
1152
1153        assert_eq!(table_row_count(&scanner, "pyspy_dumps"), 0);
1154        assert_eq!(table_row_count(&scanner, "pyspy_stack_traces"), 0);
1155        assert_eq!(table_row_count(&scanner, "pyspy_frames"), 0);
1156    }
1157
1158    #[test]
1159    fn test_store_pyspy_dump_invalid_json_errors() {
1160        let scanner = DatabaseScanner {
1161            table_data: Arc::new(StdMutex::new(HashMap::new())),
1162            rank: 0,
1163            retention_us: 0,
1164            sink: None,
1165            dispatcher: None,
1166        };
1167        assert!(scanner.store_pyspy_dump("x", "p", "not json").is_err());
1168    }
1169
1170    #[test]
1171    fn test_store_pyspy_dump_multiple_threads() {
1172        let scanner = DatabaseScanner {
1173            table_data: Arc::new(StdMutex::new(HashMap::new())),
1174            rank: 0,
1175            retention_us: 0,
1176            sink: None,
1177            dispatcher: None,
1178        };
1179
1180        let json = r#"{
1181            "Ok": {
1182                "pid": 1, "binary": "python3",
1183                "stack_traces": [
1184                    {"pid": 1, "thread_id": 1, "thread_name": "Main", "os_thread_id": 10,
1185                     "active": true, "owns_gil": true,
1186                     "frames": [{"name": "f1", "filename": "a.py", "line": 1, "is_entry": false}]},
1187                    {"pid": 1, "thread_id": 2, "thread_name": "Worker", "os_thread_id": 11,
1188                     "active": false, "owns_gil": false,
1189                     "frames": [{"name": "f2", "filename": "b.py", "line": 2, "is_entry": false}]}
1190                ],
1191                "warnings": []
1192            }
1193        }"#;
1194
1195        scanner.store_pyspy_dump("dump-3", "proc[0]", json).unwrap();
1196
1197        assert_eq!(table_row_count(&scanner, "pyspy_dumps"), 1);
1198        assert_eq!(table_row_count(&scanner, "pyspy_stack_traces"), 2);
1199        assert_eq!(table_row_count(&scanner, "pyspy_frames"), 2);
1200
1201        // Verify pyspy_stack_traces content: two threads
1202        let batches = table_batches(&scanner, "pyspy_stack_traces");
1203        let batch = &batches[0];
1204        let thread_ids = batch
1205            .column_by_name("thread_id")
1206            .unwrap()
1207            .as_any()
1208            .downcast_ref::<UInt64Array>()
1209            .unwrap();
1210        let thread_names = batch
1211            .column_by_name("thread_name")
1212            .unwrap()
1213            .as_any()
1214            .downcast_ref::<StringArray>()
1215            .unwrap();
1216        let actives = batch
1217            .column_by_name("active")
1218            .unwrap()
1219            .as_any()
1220            .downcast_ref::<BooleanArray>()
1221            .unwrap();
1222        let owns_gils = batch
1223            .column_by_name("owns_gil")
1224            .unwrap()
1225            .as_any()
1226            .downcast_ref::<BooleanArray>()
1227            .unwrap();
1228        // Thread 1: Main, active, owns GIL
1229        assert_eq!(thread_ids.value(0), 1);
1230        assert_eq!(thread_names.value(0), "Main");
1231        assert!(actives.value(0), "Main thread should be active");
1232        assert!(owns_gils.value(0), "Main thread should own GIL");
1233        // Thread 2: Worker, not active, no GIL
1234        assert_eq!(thread_ids.value(1), 2);
1235        assert_eq!(thread_names.value(1), "Worker");
1236        assert!(!actives.value(1), "Worker thread should not be active");
1237        assert!(!owns_gils.value(1), "Worker thread should not own GIL");
1238
1239        // Verify pyspy_frames content: f1 on thread 1, f2 on thread 2
1240        let batches = table_batches(&scanner, "pyspy_frames");
1241        let batch = &batches[0];
1242        let names = batch
1243            .column_by_name("name")
1244            .unwrap()
1245            .as_any()
1246            .downcast_ref::<StringArray>()
1247            .unwrap();
1248        let frame_thread_ids = batch
1249            .column_by_name("thread_id")
1250            .unwrap()
1251            .as_any()
1252            .downcast_ref::<UInt64Array>()
1253            .unwrap();
1254        let filenames = batch
1255            .column_by_name("filename")
1256            .unwrap()
1257            .as_any()
1258            .downcast_ref::<StringArray>()
1259            .unwrap();
1260        assert_eq!(names.value(0), "f1");
1261        assert_eq!(frame_thread_ids.value(0), 1);
1262        assert_eq!(filenames.value(0), "a.py");
1263        assert_eq!(names.value(1), "f2");
1264        assert_eq!(frame_thread_ids.value(1), 2);
1265        assert_eq!(filenames.value(1), "b.py");
1266    }
1267
1268    // --- ingest_batch tests ---
1269    // These reference the ID-* invariants defined on ingest_batch.
1270
1271    // ID-1, ID-2: empty batch creates the table with schema but 0
1272    // rows.
1273    #[tokio::test]
1274    async fn test_ingest_batch_creates_table_for_empty_batch() {
1275        let store = TableStore::new_empty();
1276        let empty = make_batch(&[]);
1277
1278        store.ingest_batch("t", empty.clone()).await.unwrap();
1279
1280        let names = store.table_names().unwrap();
1281        assert!(names.contains(&"t".to_owned()), "ID-1: table should exist");
1282        assert_eq!(
1283            query_row_count("t", store.table_provider("t").unwrap().unwrap()).await,
1284            0,
1285            "ID-2: 0 rows"
1286        );
1287    }
1288
1289    // ID-1, ID-3: non-empty batch creates table and appends rows.
1290    #[tokio::test]
1291    async fn test_ingest_batch_appends_non_empty_batch() {
1292        let store = TableStore::new_empty();
1293
1294        store
1295            .ingest_batch("t", make_batch(&[1, 2, 3]))
1296            .await
1297            .unwrap();
1298
1299        assert_eq!(
1300            query_row_count("t", store.table_provider("t").unwrap().unwrap()).await,
1301            3
1302        );
1303    }
1304
1305    // ID-3: two batches to the same table accumulate rows.
1306    #[tokio::test]
1307    async fn test_ingest_batch_reuses_existing_table() {
1308        let store = TableStore::new_empty();
1309
1310        store.ingest_batch("t", make_batch(&[1, 2])).await.unwrap();
1311        store
1312            .ingest_batch("t", make_batch(&[3, 4, 5]))
1313            .await
1314            .unwrap();
1315
1316        assert_eq!(
1317            store.table_names().unwrap().len(),
1318            1,
1319            "ID-3: still one table"
1320        );
1321        assert_eq!(
1322            query_row_count("t", store.table_provider("t").unwrap().unwrap()).await,
1323            5
1324        );
1325    }
1326
1327    // ID-2, ID-3: empty batch registers schema, then non-empty batch
1328    // appends rows using the same schema.
1329    #[tokio::test]
1330    async fn test_ingest_batch_empty_then_non_empty() {
1331        let store = TableStore::new_empty();
1332
1333        // Register schema with empty batch.
1334        store.ingest_batch("t", make_batch(&[])).await.unwrap();
1335        assert_eq!(
1336            query_row_count("t", store.table_provider("t").unwrap().unwrap()).await,
1337            0
1338        );
1339
1340        // Append rows.
1341        store
1342            .ingest_batch("t", make_batch(&[10, 20]))
1343            .await
1344            .unwrap();
1345        assert_eq!(
1346            query_row_count("t", store.table_provider("t").unwrap().unwrap()).await,
1347            2
1348        );
1349    }
1350
1351    // --- TableStore tests ---
1352    // These reference the TS-* invariants defined on TableStore.
1353
1354    /// Register a provider in a fresh SessionContext and return the
1355    /// row count from `SELECT * FROM {table_name}`.
1356    async fn query_row_count(table_name: &str, provider: Arc<dyn TableProvider>) -> usize {
1357        let ctx = SessionContext::new();
1358        ctx.register_table(table_name, provider).unwrap();
1359        let df = ctx
1360            .sql(&format!("SELECT * FROM {table_name}"))
1361            .await
1362            .unwrap();
1363        df.collect()
1364            .await
1365            .unwrap()
1366            .iter()
1367            .map(|b| b.num_rows())
1368            .sum()
1369    }
1370
1371    // TS-2, TS-3: ingest via TableStore, register the returned
1372    // table_provider in a SessionContext, and query it. Proves the
1373    // opaque handle is sufficient for downstream query setup.
1374    #[tokio::test]
1375    async fn test_table_store_ingest_and_query() {
1376        let store = TableStore::new_empty();
1377
1378        store
1379            .ingest_batch("t", make_batch(&[10, 20, 30]))
1380            .await
1381            .unwrap();
1382
1383        let provider = store
1384            .table_provider("t")
1385            .unwrap()
1386            .expect("TS-3: table_provider should return Some");
1387
1388        assert_eq!(
1389            query_row_count("t", provider).await,
1390            3,
1391            "TS-3: query should return ingested rows"
1392        );
1393    }
1394
1395    // TS-3: table_names returns all ingested table names, sorted.
1396    #[tokio::test]
1397    async fn test_table_store_table_names() {
1398        let store = TableStore::new_empty();
1399
1400        store.ingest_batch("beta", make_batch(&[1])).await.unwrap();
1401        store.ingest_batch("alpha", make_batch(&[2])).await.unwrap();
1402
1403        let names = store.table_names().unwrap();
1404        assert_eq!(names, vec!["alpha", "beta"], "TS-3: names should be sorted");
1405    }
1406
1407    // TS-2 (ID-2 passthrough): empty batch registers schema via
1408    // TableStore. Proves the table is visible through table_names
1409    // and table_provider without re-proving row-count internals.
1410    #[tokio::test]
1411    async fn test_table_store_empty_batch_registers() {
1412        let store = TableStore::new_empty();
1413
1414        store.ingest_batch("t", make_batch(&[])).await.unwrap();
1415
1416        assert!(
1417            store.table_names().unwrap().contains(&"t".to_owned()),
1418            "TS-2: empty batch should register table name"
1419        );
1420        assert!(
1421            store.table_provider("t").unwrap().is_some(),
1422            "TS-2: empty batch should make table_provider available"
1423        );
1424    }
1425
1426    // TS-3: table_provider for unknown table returns None.
1427    #[test]
1428    fn test_table_store_missing_table() {
1429        let store = TableStore::new_empty();
1430
1431        assert!(
1432            store.table_provider("missing").unwrap().is_none(),
1433            "TS-3: unknown table should return None"
1434        );
1435    }
1436}