monarch_distributed_telemetry/
query_engine.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! QueryEngine - DataFusion query execution, creates ports, collects results
10
11use std::sync::Arc;
12
13use datafusion::arrow::datatypes::SchemaRef;
14use datafusion::arrow::ipc::reader::StreamReader;
15use datafusion::arrow::ipc::writer::StreamWriter;
16use datafusion::arrow::record_batch::RecordBatch;
17use datafusion::catalog::Session;
18use datafusion::datasource::TableProvider;
19use datafusion::error::Result as DFResult;
20use datafusion::logical_expr::Expr;
21use datafusion::logical_expr::TableProviderFilterPushDown;
22use datafusion::logical_expr::TableType;
23use datafusion::physical_expr::EquivalenceProperties;
24use datafusion::physical_expr::Partitioning;
25use datafusion::physical_plan::DisplayAs;
26use datafusion::physical_plan::DisplayFormatType;
27use datafusion::physical_plan::ExecutionPlan;
28use datafusion::physical_plan::PlanProperties;
29use datafusion::physical_plan::SendableRecordBatchStream;
30use datafusion::physical_plan::execution_plan::Boundedness;
31use datafusion::physical_plan::execution_plan::EmissionType;
32use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
33use datafusion::prelude::SessionConfig;
34use datafusion::prelude::SessionContext;
35use datafusion::sql::unparser::expr_to_sql;
36use hyperactor::Instance;
37use hyperactor::context::Mailbox as MailboxTrait;
38use hyperactor::mailbox::PortReceiver;
39use monarch_hyperactor::actor::PythonActor;
40use monarch_hyperactor::context::PyInstance;
41use monarch_hyperactor::mailbox::PyPortId;
42use monarch_hyperactor::pytokio::PyPythonTask;
43use monarch_hyperactor::runtime::get_tokio_runtime;
44use pyo3::exceptions::PyException;
45use pyo3::prelude::*;
46use pyo3::types::PyBytes;
47use pyo3::types::PyModule;
48use tokio::sync::mpsc;
49
50use crate::QueryResponse;
51
52// ============================================================================
53// Deserialization helpers
54// ============================================================================
55
56fn deserialize_schema(data: &[u8]) -> anyhow::Result<SchemaRef> {
57    let reader = StreamReader::try_new(std::io::Cursor::new(data), None)?;
58    Ok(reader.schema())
59}
60
61fn deserialize_batch(data: &[u8]) -> anyhow::Result<Option<RecordBatch>> {
62    let mut reader = StreamReader::try_new(std::io::Cursor::new(data), None)?;
63    Ok(reader.next().transpose()?)
64}
65
66// ============================================================================
67// Helper to spawn reader task that drains PortReceiver into a channel
68// ============================================================================
69
70/// Spawns a task that reads QueryResponse messages from port_receiver until
71/// the completion future resolves with the expected batch count, then waits
72/// for exactly that many batches.
73///
74/// Returns a stream that reads from the channel.
75fn create_draining_stream<F>(
76    schema: SchemaRef,
77    port_receiver: PortReceiver<QueryResponse>,
78    completion_future: F,
79) -> SendableRecordBatchStream
80where
81    F: std::future::Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
82{
83    let (tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(32);
84
85    // Spawn a task that reads messages until we have all expected batches
86    get_tokio_runtime().spawn(async move {
87        let mut receiver = port_receiver;
88        let mut batch_count: usize = 0;
89        let mut expected_batches: Option<usize> = None;
90
91        tokio::pin!(completion_future);
92
93        loop {
94            // Check if we've received all expected batches
95            if let Some(expected) = expected_batches {
96                if batch_count >= expected {
97                    tracing::info!(
98                        "QueryEngine reader: received all {} expected batches",
99                        expected
100                    );
101                    break;
102                }
103            }
104
105            tokio::select! {
106                biased;
107
108                // Check if the scan has completed (only if we don't have expected count yet)
109                result = &mut completion_future, if expected_batches.is_none() => {
110                    match result {
111                        Ok(py_result) => {
112                            // Extract the batch count from the Python result
113                            // Result is a ValueMesh which is iterable, yielding (rank_dict, count) tuples
114                            let count = Python::attach(|py| {
115                                let bound = py_result.bind(py);
116                                let mut total: usize = 0;
117                                // Iterate the ValueMesh
118                                if let Ok(iter) = bound.try_iter() {
119                                    for item in iter {
120                                        if let Ok(tuple) = item {
121                                            // Each item is (rank_dict, count) - get second element
122                                            if let Ok(count_val) = tuple.get_item(1) {
123                                                if let Ok(count) = count_val.extract::<usize>() {
124                                                    total += count;
125                                                }
126                                            }
127                                        }
128                                    }
129                                }
130                                total
131                            });
132                            tracing::info!(
133                                "QueryEngine reader: scan completed, expecting {} batches, have {}",
134                                count,
135                                batch_count
136                            );
137                            expected_batches = Some(count);
138                        }
139                        Err(e) => {
140                            tracing::error!("QueryEngine reader: scan failed: {:?}", e);
141                            let _ = tx.send(Err(datafusion::error::DataFusionError::External(
142                                anyhow::anyhow!("Scan failed: {:?}", e).into(),
143                            ))).await;
144                            break;
145                        }
146                    }
147                }
148
149                // Receive data from the port
150                recv_result = receiver.recv() => {
151                    match recv_result {
152                        Ok(QueryResponse { data }) => {
153                            match deserialize_batch(&data.into_bytes()) {
154                                Ok(Some(batch)) => {
155                                    batch_count += 1;
156                                    if tx.send(Ok(batch)).await.is_err() {
157                                        tracing::info!(
158                                            "QueryEngine reader: consumer dropped, continuing to drain"
159                                        );
160                                    }
161                                }
162                                Ok(None) => {}
163                                Err(e) => {
164                                    let _ = tx
165                                        .send(Err(datafusion::error::DataFusionError::External(e.into())))
166                                        .await;
167                                }
168                            }
169                        }
170                        Err(e) => {
171                            tracing::error!("QueryEngine reader: error receiving: {:?}", e);
172                            let _ = tx
173                                .send(Err(datafusion::error::DataFusionError::External(
174                                    anyhow::anyhow!("Error receiving: {:?}", e).into(),
175                                )))
176                                .await;
177                            break;
178                        }
179                    }
180                }
181            }
182        }
183        tracing::info!("QueryEngine reader: complete, {} batches", batch_count);
184    });
185
186    // Convert channel receiver to a stream
187    let stream = futures::stream::unfold(rx, |mut rx| async move {
188        rx.recv().await.map(|item| (item, rx))
189    });
190
191    Box::pin(RecordBatchStreamAdapter::new(schema, stream))
192}
193
194struct DistributedTableProvider {
195    table_name: String,
196    schema: SchemaRef,
197    actor: Py<PyAny>,
198    /// Actor instance for creating ports
199    instance: Instance<PythonActor>,
200}
201
202impl std::fmt::Debug for DistributedTableProvider {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("DistributedTableProvider")
205            .field("table_name", &self.table_name)
206            .finish()
207    }
208}
209
210fn expr_to_sql_string(expr: &Expr) -> Option<String> {
211    expr_to_sql(expr).ok().map(|sql| sql.to_string())
212}
213
214#[async_trait::async_trait]
215impl TableProvider for DistributedTableProvider {
216    fn as_any(&self) -> &dyn std::any::Any {
217        self
218    }
219    fn schema(&self) -> SchemaRef {
220        self.schema.clone()
221    }
222    fn table_type(&self) -> TableType {
223        TableType::Base
224    }
225
226    fn supports_filters_pushdown(
227        &self,
228        filters: &[&Expr],
229    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
230        Ok(filters
231            .iter()
232            .map(|e| {
233                if expr_to_sql_string(e).is_some() {
234                    TableProviderFilterPushDown::Exact
235                } else {
236                    TableProviderFilterPushDown::Unsupported
237                }
238            })
239            .collect())
240    }
241
242    async fn scan(
243        &self,
244        _state: &dyn Session,
245        projection: Option<&Vec<usize>>,
246        filters: &[Expr],
247        limit: Option<usize>,
248    ) -> DFResult<Arc<dyn ExecutionPlan>> {
249        let where_clauses: Vec<String> = filters.iter().filter_map(expr_to_sql_string).collect();
250        let where_clause = if where_clauses.is_empty() {
251            None
252        } else {
253            Some(where_clauses.join(" AND "))
254        };
255
256        let output_schema = match projection {
257            Some(proj) => Arc::new(datafusion::arrow::datatypes::Schema::new(
258                proj.iter()
259                    .filter_map(|&i| self.schema.fields().get(i).cloned())
260                    .collect::<Vec<_>>(),
261            )),
262            None => self.schema.clone(),
263        };
264
265        // Clone actor and instance for the execution plan
266        let (actor, instance) =
267            Python::attach(|py| (self.actor.clone_ref(py), self.instance.clone_for_py()));
268
269        Ok(Arc::new(DistributedExec {
270            table_name: self.table_name.clone(),
271            schema: output_schema.clone(),
272            projection: projection.cloned(),
273            where_clause,
274            limit,
275            actor,
276            instance,
277            properties: PlanProperties::new(
278                EquivalenceProperties::new(output_schema),
279                Partitioning::UnknownPartitioning(1),
280                EmissionType::Final,
281                Boundedness::Bounded,
282            ),
283        }))
284    }
285}
286
287struct DistributedExec {
288    table_name: String,
289    schema: SchemaRef,
290    projection: Option<Vec<usize>>,
291    where_clause: Option<String>,
292    limit: Option<usize>,
293    actor: Py<PyAny>,
294    instance: Instance<PythonActor>,
295    properties: PlanProperties,
296}
297
298impl std::fmt::Debug for DistributedExec {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("DistributedExec")
301            .field("table_name", &self.table_name)
302            .finish()
303    }
304}
305
306impl DisplayAs for DistributedExec {
307    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
308        write!(f, "DistributedExec: table={}", self.table_name)
309    }
310}
311
312impl ExecutionPlan for DistributedExec {
313    fn name(&self) -> &str {
314        "DistributedExec"
315    }
316    fn as_any(&self) -> &dyn std::any::Any {
317        self
318    }
319    fn properties(&self) -> &PlanProperties {
320        &self.properties
321    }
322    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
323        vec![]
324    }
325    fn with_new_children(
326        self: Arc<Self>,
327        _: Vec<Arc<dyn ExecutionPlan>>,
328    ) -> DFResult<Arc<dyn ExecutionPlan>> {
329        Ok(self)
330    }
331
332    fn execute(
333        &self,
334        _partition: usize,
335        _context: Arc<datafusion::execution::TaskContext>,
336    ) -> DFResult<SendableRecordBatchStream> {
337        let schema = self.schema.clone();
338
339        // Open port using the instance captured at scan time
340        let (handle, receiver) = self.instance.mailbox().open_port::<QueryResponse>();
341        let dest_port_ref = handle.bind();
342
343        // Start the distributed scan
344        let completion_future = Python::attach(
345            |py| -> anyhow::Result<
346                std::pin::Pin<
347                    Box<dyn std::future::Future<Output = PyResult<Py<PyAny>>> + Send + 'static>,
348                >,
349            > {
350                let dest_port_id: PyPortId = dest_port_ref.port_id().clone().into();
351
352                // Call actor.scan.call(dest, table, proj, limit, filter) to get a Future
353                let scan = self.actor.getattr(py, "scan")?;
354                let future_obj = scan.call_method1(
355                    py,
356                    "call",
357                    (
358                        dest_port_id,
359                        self.table_name.clone(),
360                        self.projection.clone(),
361                        self.limit,
362                        self.where_clause.clone(),
363                    ),
364                )?;
365
366                // Extract the PythonTask from the Future object
367                // Future._status is an _Unawaited(coro) where coro is a PythonTask
368                let status = future_obj.getattr(py, "_status")?;
369                // _Unawaited is a NamedTuple with .coro attribute
370                let python_task_obj = status.getattr(py, "coro")?;
371                let mut python_task: PyRefMut<'_, PyPythonTask> = python_task_obj.extract(py)?;
372                let completion_future = python_task.take_task()?;
373
374                Ok(completion_future)
375            },
376        )
377        .map_err(|e| datafusion::error::DataFusionError::External(e.into()))?;
378
379        Ok(create_draining_stream(schema, receiver, completion_future))
380    }
381}
382
383#[pyclass(
384    name = "QueryEngine",
385    module = "monarch._rust_bindings.monarch_distributed_telemetry.query_engine"
386)]
387pub struct QueryEngine {
388    session: SessionContext,
389}
390
391#[pymethods]
392impl QueryEngine {
393    /// Create a new QueryEngine.
394    ///
395    /// Args:
396    ///     actor: A singleton DistributedTelemetryActor (ActorMesh) to query
397    #[new]
398    fn new(py: Python<'_>, actor: Py<PyAny>) -> PyResult<Self> {
399        // Get actor instance from current Python context
400        let actor_module = py.import("monarch.actor")?;
401        let ctx = actor_module.call_method0("context")?;
402        let actor_instance_obj = ctx.getattr("actor_instance")?;
403        let py_instance: PyRef<'_, PyInstance> = actor_instance_obj.extract()?;
404        let instance: Instance<PythonActor> = py_instance.clone_for_py();
405
406        let session = Self::setup_tables(py, &actor, instance)?;
407        Ok(Self { session })
408    }
409
410    fn __repr__(&self) -> String {
411        "<QueryEngine>".into()
412    }
413
414    /// Execute a SQL query and return results as Arrow IPC bytes.
415    fn query<'py>(&self, py: Python<'py>, sql: String) -> PyResult<Bound<'py, PyBytes>> {
416        let session_ctx = self.session.clone();
417
418        // Release the GIL and run the async query on the shared monarch runtime.
419        let results: Vec<RecordBatch> = py
420            .detach(|| {
421                get_tokio_runtime().block_on(async {
422                    let df = session_ctx.sql(&sql).await?;
423                    df.collect().await
424                })
425            })
426            .map_err(|e| PyException::new_err(e.to_string()))?;
427
428        // Serialize all results as a single Arrow IPC stream
429        let schema = results
430            .first()
431            .map(|b| b.schema())
432            .unwrap_or_else(|| Arc::new(datafusion::arrow::datatypes::Schema::empty()));
433        let mut buf = Vec::new();
434        let mut writer = StreamWriter::try_new(&mut buf, &schema)
435            .map_err(|e| PyException::new_err(e.to_string()))?;
436        for batch in &results {
437            writer
438                .write(batch)
439                .map_err(|e| PyException::new_err(e.to_string()))?;
440        }
441        writer
442            .finish()
443            .map_err(|e| PyException::new_err(e.to_string()))?;
444
445        Ok(PyBytes::new(py, &buf))
446    }
447}
448
449impl QueryEngine {
450    fn setup_tables(
451        py: Python<'_>,
452        actor: &Py<PyAny>,
453        instance: Instance<PythonActor>,
454    ) -> PyResult<SessionContext> {
455        // Get table names from the actor mesh via endpoint call
456        // table_names is an endpoint, so we get it then call .call().get().item()
457        let tables: Vec<String> = actor
458            .getattr(py, "table_names")?
459            .call_method0(py, "call")?
460            .call_method0(py, "get")?
461            .call_method0(py, "item")?
462            .extract(py)?;
463
464        let config = SessionConfig::new().with_information_schema(true);
465        let ctx = SessionContext::new_with_config(config);
466
467        for table_name in &tables {
468            // Get schema from actor via endpoint call
469            let schema_bytes: Vec<u8> = actor
470                .getattr(py, "schema_for")?
471                .call_method1(py, "call", (table_name,))?
472                .call_method0(py, "get")?
473                .call_method0(py, "item")?
474                .extract(py)?;
475
476            let schema = deserialize_schema(&schema_bytes).map_err(|e| {
477                PyException::new_err(format!("Failed to deserialize schema: {}", e))
478            })?;
479
480            let provider = DistributedTableProvider {
481                table_name: table_name.clone(),
482                schema,
483                actor: actor.clone_ref(py),
484                instance: instance.clone_for_py(),
485            };
486
487            ctx.register_table(table_name, Arc::new(provider))
488                .map_err(|e| PyException::new_err(e.to_string()))?;
489        }
490
491        Ok(ctx)
492    }
493}
494
495pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
496    module.add_class::<QueryEngine>()?;
497    Ok(())
498}