monarch_distributed_telemetry/
query_engine.rs1use std::sync::Arc;
12
13use datafusion::arrow::datatypes::SchemaRef;
14use datafusion::arrow::ipc::reader::StreamReader;
15use datafusion::arrow::ipc::writer::StreamWriter;
16use datafusion::arrow::record_batch::RecordBatch;
17use datafusion::catalog::Session;
18use datafusion::datasource::TableProvider;
19use datafusion::error::Result as DFResult;
20use datafusion::logical_expr::Expr;
21use datafusion::logical_expr::TableProviderFilterPushDown;
22use datafusion::logical_expr::TableType;
23use datafusion::physical_expr::EquivalenceProperties;
24use datafusion::physical_expr::Partitioning;
25use datafusion::physical_plan::DisplayAs;
26use datafusion::physical_plan::DisplayFormatType;
27use datafusion::physical_plan::ExecutionPlan;
28use datafusion::physical_plan::PlanProperties;
29use datafusion::physical_plan::SendableRecordBatchStream;
30use datafusion::physical_plan::execution_plan::Boundedness;
31use datafusion::physical_plan::execution_plan::EmissionType;
32use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
33use datafusion::prelude::SessionConfig;
34use datafusion::prelude::SessionContext;
35use datafusion::sql::unparser::expr_to_sql;
36use hyperactor::Instance;
37use hyperactor::context::Mailbox as MailboxTrait;
38use hyperactor::mailbox::PortReceiver;
39use monarch_hyperactor::actor::PythonActor;
40use monarch_hyperactor::context::PyInstance;
41use monarch_hyperactor::mailbox::PyPortId;
42use monarch_hyperactor::pytokio::PyPythonTask;
43use monarch_hyperactor::runtime::get_tokio_runtime;
44use pyo3::exceptions::PyException;
45use pyo3::prelude::*;
46use pyo3::types::PyBytes;
47use pyo3::types::PyModule;
48use tokio::sync::mpsc;
49
50use crate::QueryResponse;
51
52fn deserialize_schema(data: &[u8]) -> anyhow::Result<SchemaRef> {
57 let reader = StreamReader::try_new(std::io::Cursor::new(data), None)?;
58 Ok(reader.schema())
59}
60
61fn deserialize_batch(data: &[u8]) -> anyhow::Result<Option<RecordBatch>> {
62 let mut reader = StreamReader::try_new(std::io::Cursor::new(data), None)?;
63 Ok(reader.next().transpose()?)
64}
65
66fn create_draining_stream<F>(
76 schema: SchemaRef,
77 port_receiver: PortReceiver<QueryResponse>,
78 completion_future: F,
79) -> SendableRecordBatchStream
80where
81 F: std::future::Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
82{
83 let (tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(32);
84
85 get_tokio_runtime().spawn(async move {
87 let mut receiver = port_receiver;
88 let mut batch_count: usize = 0;
89 let mut expected_batches: Option<usize> = None;
90
91 tokio::pin!(completion_future);
92
93 loop {
94 if let Some(expected) = expected_batches {
96 if batch_count >= expected {
97 tracing::info!(
98 "QueryEngine reader: received all {} expected batches",
99 expected
100 );
101 break;
102 }
103 }
104
105 tokio::select! {
106 biased;
107
108 result = &mut completion_future, if expected_batches.is_none() => {
110 match result {
111 Ok(py_result) => {
112 let count = Python::attach(|py| {
115 let bound = py_result.bind(py);
116 let mut total: usize = 0;
117 if let Ok(iter) = bound.try_iter() {
119 for item in iter {
120 if let Ok(tuple) = item {
121 if let Ok(count_val) = tuple.get_item(1) {
123 if let Ok(count) = count_val.extract::<usize>() {
124 total += count;
125 }
126 }
127 }
128 }
129 }
130 total
131 });
132 tracing::info!(
133 "QueryEngine reader: scan completed, expecting {} batches, have {}",
134 count,
135 batch_count
136 );
137 expected_batches = Some(count);
138 }
139 Err(e) => {
140 tracing::error!("QueryEngine reader: scan failed: {:?}", e);
141 let _ = tx.send(Err(datafusion::error::DataFusionError::External(
142 anyhow::anyhow!("Scan failed: {:?}", e).into(),
143 ))).await;
144 break;
145 }
146 }
147 }
148
149 recv_result = receiver.recv() => {
151 match recv_result {
152 Ok(QueryResponse { data }) => {
153 match deserialize_batch(&data.into_bytes()) {
154 Ok(Some(batch)) => {
155 batch_count += 1;
156 if tx.send(Ok(batch)).await.is_err() {
157 tracing::info!(
158 "QueryEngine reader: consumer dropped, continuing to drain"
159 );
160 }
161 }
162 Ok(None) => {}
163 Err(e) => {
164 let _ = tx
165 .send(Err(datafusion::error::DataFusionError::External(e.into())))
166 .await;
167 }
168 }
169 }
170 Err(e) => {
171 tracing::error!("QueryEngine reader: error receiving: {:?}", e);
172 let _ = tx
173 .send(Err(datafusion::error::DataFusionError::External(
174 anyhow::anyhow!("Error receiving: {:?}", e).into(),
175 )))
176 .await;
177 break;
178 }
179 }
180 }
181 }
182 }
183 tracing::info!("QueryEngine reader: complete, {} batches", batch_count);
184 });
185
186 let stream = futures::stream::unfold(rx, |mut rx| async move {
188 rx.recv().await.map(|item| (item, rx))
189 });
190
191 Box::pin(RecordBatchStreamAdapter::new(schema, stream))
192}
193
194struct DistributedTableProvider {
195 table_name: String,
196 schema: SchemaRef,
197 actor: Py<PyAny>,
198 instance: Instance<PythonActor>,
200}
201
202impl std::fmt::Debug for DistributedTableProvider {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("DistributedTableProvider")
205 .field("table_name", &self.table_name)
206 .finish()
207 }
208}
209
210fn expr_to_sql_string(expr: &Expr) -> Option<String> {
211 expr_to_sql(expr).ok().map(|sql| sql.to_string())
212}
213
214#[async_trait::async_trait]
215impl TableProvider for DistributedTableProvider {
216 fn as_any(&self) -> &dyn std::any::Any {
217 self
218 }
219 fn schema(&self) -> SchemaRef {
220 self.schema.clone()
221 }
222 fn table_type(&self) -> TableType {
223 TableType::Base
224 }
225
226 fn supports_filters_pushdown(
227 &self,
228 filters: &[&Expr],
229 ) -> DFResult<Vec<TableProviderFilterPushDown>> {
230 Ok(filters
231 .iter()
232 .map(|e| {
233 if expr_to_sql_string(e).is_some() {
234 TableProviderFilterPushDown::Exact
235 } else {
236 TableProviderFilterPushDown::Unsupported
237 }
238 })
239 .collect())
240 }
241
242 async fn scan(
243 &self,
244 _state: &dyn Session,
245 projection: Option<&Vec<usize>>,
246 filters: &[Expr],
247 limit: Option<usize>,
248 ) -> DFResult<Arc<dyn ExecutionPlan>> {
249 let where_clauses: Vec<String> = filters.iter().filter_map(expr_to_sql_string).collect();
250 let where_clause = if where_clauses.is_empty() {
251 None
252 } else {
253 Some(where_clauses.join(" AND "))
254 };
255
256 let output_schema = match projection {
257 Some(proj) => Arc::new(datafusion::arrow::datatypes::Schema::new(
258 proj.iter()
259 .filter_map(|&i| self.schema.fields().get(i).cloned())
260 .collect::<Vec<_>>(),
261 )),
262 None => self.schema.clone(),
263 };
264
265 let (actor, instance) =
267 Python::attach(|py| (self.actor.clone_ref(py), self.instance.clone_for_py()));
268
269 Ok(Arc::new(DistributedExec {
270 table_name: self.table_name.clone(),
271 schema: output_schema.clone(),
272 projection: projection.cloned(),
273 where_clause,
274 limit,
275 actor,
276 instance,
277 properties: PlanProperties::new(
278 EquivalenceProperties::new(output_schema),
279 Partitioning::UnknownPartitioning(1),
280 EmissionType::Final,
281 Boundedness::Bounded,
282 ),
283 }))
284 }
285}
286
287struct DistributedExec {
288 table_name: String,
289 schema: SchemaRef,
290 projection: Option<Vec<usize>>,
291 where_clause: Option<String>,
292 limit: Option<usize>,
293 actor: Py<PyAny>,
294 instance: Instance<PythonActor>,
295 properties: PlanProperties,
296}
297
298impl std::fmt::Debug for DistributedExec {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("DistributedExec")
301 .field("table_name", &self.table_name)
302 .finish()
303 }
304}
305
306impl DisplayAs for DistributedExec {
307 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
308 write!(f, "DistributedExec: table={}", self.table_name)
309 }
310}
311
312impl ExecutionPlan for DistributedExec {
313 fn name(&self) -> &str {
314 "DistributedExec"
315 }
316 fn as_any(&self) -> &dyn std::any::Any {
317 self
318 }
319 fn properties(&self) -> &PlanProperties {
320 &self.properties
321 }
322 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
323 vec![]
324 }
325 fn with_new_children(
326 self: Arc<Self>,
327 _: Vec<Arc<dyn ExecutionPlan>>,
328 ) -> DFResult<Arc<dyn ExecutionPlan>> {
329 Ok(self)
330 }
331
332 fn execute(
333 &self,
334 _partition: usize,
335 _context: Arc<datafusion::execution::TaskContext>,
336 ) -> DFResult<SendableRecordBatchStream> {
337 let schema = self.schema.clone();
338
339 let (handle, receiver) = self.instance.mailbox().open_port::<QueryResponse>();
341 let dest_port_ref = handle.bind();
342
343 let completion_future = Python::attach(
345 |py| -> anyhow::Result<
346 std::pin::Pin<
347 Box<dyn std::future::Future<Output = PyResult<Py<PyAny>>> + Send + 'static>,
348 >,
349 > {
350 let dest_port_id: PyPortId = dest_port_ref.port_id().clone().into();
351
352 let scan = self.actor.getattr(py, "scan")?;
354 let future_obj = scan.call_method1(
355 py,
356 "call",
357 (
358 dest_port_id,
359 self.table_name.clone(),
360 self.projection.clone(),
361 self.limit,
362 self.where_clause.clone(),
363 ),
364 )?;
365
366 let status = future_obj.getattr(py, "_status")?;
369 let python_task_obj = status.getattr(py, "coro")?;
371 let mut python_task: PyRefMut<'_, PyPythonTask> = python_task_obj.extract(py)?;
372 let completion_future = python_task.take_task()?;
373
374 Ok(completion_future)
375 },
376 )
377 .map_err(|e| datafusion::error::DataFusionError::External(e.into()))?;
378
379 Ok(create_draining_stream(schema, receiver, completion_future))
380 }
381}
382
383#[pyclass(
384 name = "QueryEngine",
385 module = "monarch._rust_bindings.monarch_distributed_telemetry.query_engine"
386)]
387pub struct QueryEngine {
388 session: SessionContext,
389}
390
391#[pymethods]
392impl QueryEngine {
393 #[new]
398 fn new(py: Python<'_>, actor: Py<PyAny>) -> PyResult<Self> {
399 let actor_module = py.import("monarch.actor")?;
401 let ctx = actor_module.call_method0("context")?;
402 let actor_instance_obj = ctx.getattr("actor_instance")?;
403 let py_instance: PyRef<'_, PyInstance> = actor_instance_obj.extract()?;
404 let instance: Instance<PythonActor> = py_instance.clone_for_py();
405
406 let session = Self::setup_tables(py, &actor, instance)?;
407 Ok(Self { session })
408 }
409
410 fn __repr__(&self) -> String {
411 "<QueryEngine>".into()
412 }
413
414 fn query<'py>(&self, py: Python<'py>, sql: String) -> PyResult<Bound<'py, PyBytes>> {
416 let session_ctx = self.session.clone();
417
418 let results: Vec<RecordBatch> = py
420 .detach(|| {
421 get_tokio_runtime().block_on(async {
422 let df = session_ctx.sql(&sql).await?;
423 df.collect().await
424 })
425 })
426 .map_err(|e| PyException::new_err(e.to_string()))?;
427
428 let schema = results
430 .first()
431 .map(|b| b.schema())
432 .unwrap_or_else(|| Arc::new(datafusion::arrow::datatypes::Schema::empty()));
433 let mut buf = Vec::new();
434 let mut writer = StreamWriter::try_new(&mut buf, &schema)
435 .map_err(|e| PyException::new_err(e.to_string()))?;
436 for batch in &results {
437 writer
438 .write(batch)
439 .map_err(|e| PyException::new_err(e.to_string()))?;
440 }
441 writer
442 .finish()
443 .map_err(|e| PyException::new_err(e.to_string()))?;
444
445 Ok(PyBytes::new(py, &buf))
446 }
447}
448
449impl QueryEngine {
450 fn setup_tables(
451 py: Python<'_>,
452 actor: &Py<PyAny>,
453 instance: Instance<PythonActor>,
454 ) -> PyResult<SessionContext> {
455 let tables: Vec<String> = actor
458 .getattr(py, "table_names")?
459 .call_method0(py, "call")?
460 .call_method0(py, "get")?
461 .call_method0(py, "item")?
462 .extract(py)?;
463
464 let config = SessionConfig::new().with_information_schema(true);
465 let ctx = SessionContext::new_with_config(config);
466
467 for table_name in &tables {
468 let schema_bytes: Vec<u8> = actor
470 .getattr(py, "schema_for")?
471 .call_method1(py, "call", (table_name,))?
472 .call_method0(py, "get")?
473 .call_method0(py, "item")?
474 .extract(py)?;
475
476 let schema = deserialize_schema(&schema_bytes).map_err(|e| {
477 PyException::new_err(format!("Failed to deserialize schema: {}", e))
478 })?;
479
480 let provider = DistributedTableProvider {
481 table_name: table_name.clone(),
482 schema,
483 actor: actor.clone_ref(py),
484 instance: instance.clone_for_py(),
485 };
486
487 ctx.register_table(table_name, Arc::new(provider))
488 .map_err(|e| PyException::new_err(e.to_string()))?;
489 }
490
491 Ok(ctx)
492 }
493}
494
495pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
496 module.add_class::<QueryEngine>()?;
497 Ok(())
498}