monarch_record_batch_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Derive macro for generating Arrow RecordBatch buffers from row structs.
10//!
11//! # Example
12//!
13//! ```ignore
14//! #[derive(RecordBatchRow)]
15//! struct Span {
16//!     id: u64,
17//!     name: String,
18//!     timestamp: i64,
19//!     parent_id: Option<u64>,
20//! }
21//! ```
22//!
23//! This generates:
24//! - `SpanBuffer` struct with `Vec<T>` for each field
25//! - `insert(&mut self, row: Span)` method
26//! - `schema() -> SchemaRef` method
27//! - `impl monarch_record_batch::RecordBatchBuffer` with `len()` and
28//!   `drain_to_record_batch()` methods
29//!
30//! The consumer crate must depend on `monarch_record_batch` (for the trait)
31//! and `datafusion` (for Arrow types used in the generated code).
32
33use proc_macro::TokenStream;
34use quote::format_ident;
35use quote::quote;
36use syn::DeriveInput;
37use syn::Field;
38use syn::Fields;
39use syn::Type;
40use syn::parse_macro_input;
41
42/// Derive macro for generating Arrow RecordBatch buffer types.
43#[proc_macro_derive(RecordBatchRow)]
44pub fn derive_record_batch_row(input: TokenStream) -> TokenStream {
45    let input = parse_macro_input!(input as DeriveInput);
46    let name = &input.ident;
47    let buffer_name = format_ident!("{}Buffer", name);
48
49    let fields = match &input.data {
50        syn::Data::Struct(data) => match &data.fields {
51            Fields::Named(fields) => &fields.named,
52            _ => panic!("RecordBatchRow only supports named fields"),
53        },
54        _ => panic!("RecordBatchRow only supports structs"),
55    };
56
57    let field_info: Vec<FieldInfo> = fields.iter().map(FieldInfo::from_field).collect();
58
59    let buffer_fields = field_info.iter().map(|f| {
60        let name = &f.name;
61        let vec_ty = &f.vec_type;
62        quote! { #name: #vec_ty }
63    });
64
65    let insert_pushes = field_info.iter().map(|f| {
66        let name = &f.name;
67        quote! { self.#name.push(row.#name); }
68    });
69
70    let schema_fields = field_info.iter().map(|f| {
71        let field_name_str = f.name.to_string();
72        let nullable = f.nullable;
73        let data_type = &f.arrow_data_type;
74        quote! {
75            datafusion::arrow::datatypes::Field::new(#field_name_str, #data_type, #nullable)
76        }
77    });
78
79    let column_conversions = field_info.iter().map(|f| {
80        let name = &f.name;
81        let array_conversion = &f.array_conversion;
82        quote! {
83            std::sync::Arc::new(#array_conversion(std::mem::take(&mut self.#name)))
84        }
85    });
86
87    let first_field = &field_info[0].name;
88
89    let expanded = quote! {
90        #[derive(Default)]
91        pub struct #buffer_name {
92            #(#buffer_fields,)*
93        }
94
95        impl #buffer_name {
96            pub fn insert(&mut self, row: #name) {
97                #(#insert_pushes)*
98            }
99
100            pub fn schema() -> datafusion::arrow::datatypes::SchemaRef {
101                std::sync::Arc::new(datafusion::arrow::datatypes::Schema::new(vec![
102                    #(#schema_fields,)*
103                ]))
104            }
105        }
106
107        impl monarch_record_batch::RecordBatchBuffer for #buffer_name {
108            fn len(&self) -> usize {
109                self.#first_field.len()
110            }
111
112            fn drain_to_record_batch(&mut self) -> Result<datafusion::arrow::record_batch::RecordBatch, datafusion::arrow::error::ArrowError> {
113                let schema = #buffer_name::schema();
114                let columns: Vec<datafusion::arrow::array::ArrayRef> = vec![
115                    #(#column_conversions,)*
116                ];
117                datafusion::arrow::record_batch::RecordBatch::try_new(schema, columns)
118            }
119        }
120    };
121
122    TokenStream::from(expanded)
123}
124
125struct FieldInfo {
126    name: syn::Ident,
127    vec_type: proc_macro2::TokenStream,
128    nullable: bool,
129    arrow_data_type: proc_macro2::TokenStream,
130    array_conversion: proc_macro2::TokenStream,
131}
132
133impl FieldInfo {
134    fn from_field(field: &Field) -> Self {
135        let name = field.ident.clone().expect("field must have name");
136        let (inner_ty, nullable) = extract_option_inner(&field.ty);
137
138        let (vec_type, arrow_data_type, array_conversion) = if nullable {
139            let vec_ty = quote! { Vec<Option<#inner_ty>> };
140            let (data_type, array_conv) = get_arrow_type_and_conversion(inner_ty);
141            (vec_ty, data_type, array_conv)
142        } else {
143            let vec_ty = quote! { Vec<#inner_ty> };
144            let (data_type, array_conv) = get_arrow_type_and_conversion(inner_ty);
145            (vec_ty, data_type, array_conv)
146        };
147
148        FieldInfo {
149            name,
150            vec_type,
151            nullable,
152            arrow_data_type,
153            array_conversion,
154        }
155    }
156}
157
158fn extract_option_inner(ty: &Type) -> (&Type, bool) {
159    if let Type::Path(type_path) = ty {
160        if let Some(segment) = type_path.path.segments.last() {
161            if segment.ident == "Option" {
162                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
163                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
164                        return (inner, true);
165                    }
166                }
167            }
168        }
169    }
170    (ty, false)
171}
172
173fn get_arrow_type_and_conversion(
174    ty: &Type,
175) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
176    let type_str = quote!(#ty).to_string().replace(' ', "");
177
178    match type_str.as_str() {
179        "u64" => (
180            quote! { datafusion::arrow::datatypes::DataType::UInt64 },
181            quote! { datafusion::arrow::array::UInt64Array::from },
182        ),
183        "u32" => (
184            quote! { datafusion::arrow::datatypes::DataType::UInt32 },
185            quote! { datafusion::arrow::array::UInt32Array::from },
186        ),
187        "i64" => (
188            quote! { datafusion::arrow::datatypes::DataType::Int64 },
189            quote! { datafusion::arrow::array::Int64Array::from },
190        ),
191        "i32" => (
192            quote! { datafusion::arrow::datatypes::DataType::Int32 },
193            quote! { datafusion::arrow::array::Int32Array::from },
194        ),
195        "String" => (
196            quote! { datafusion::arrow::datatypes::DataType::Utf8 },
197            quote! { datafusion::arrow::array::StringArray::from },
198        ),
199        "bool" => (
200            quote! { datafusion::arrow::datatypes::DataType::Boolean },
201            quote! { datafusion::arrow::array::BooleanArray::from },
202        ),
203        "f64" => (
204            quote! { datafusion::arrow::datatypes::DataType::Float64 },
205            quote! { datafusion::arrow::array::Float64Array::from },
206        ),
207        "f32" => (
208            quote! { datafusion::arrow::datatypes::DataType::Float32 },
209            quote! { datafusion::arrow::array::Float32Array::from },
210        ),
211        _ => panic!("unsupported type: {}", type_str),
212    }
213}