monarch_record_batch_macros/
lib.rs1use proc_macro::TokenStream;
34use quote::format_ident;
35use quote::quote;
36use syn::DeriveInput;
37use syn::Field;
38use syn::Fields;
39use syn::Type;
40use syn::parse_macro_input;
41
42#[proc_macro_derive(RecordBatchRow)]
44pub fn derive_record_batch_row(input: TokenStream) -> TokenStream {
45 let input = parse_macro_input!(input as DeriveInput);
46 let name = &input.ident;
47 let buffer_name = format_ident!("{}Buffer", name);
48
49 let fields = match &input.data {
50 syn::Data::Struct(data) => match &data.fields {
51 Fields::Named(fields) => &fields.named,
52 _ => panic!("RecordBatchRow only supports named fields"),
53 },
54 _ => panic!("RecordBatchRow only supports structs"),
55 };
56
57 let field_info: Vec<FieldInfo> = fields.iter().map(FieldInfo::from_field).collect();
58
59 let buffer_fields = field_info.iter().map(|f| {
60 let name = &f.name;
61 let vec_ty = &f.vec_type;
62 quote! { #name: #vec_ty }
63 });
64
65 let insert_pushes = field_info.iter().map(|f| {
66 let name = &f.name;
67 quote! { self.#name.push(row.#name); }
68 });
69
70 let schema_fields = field_info.iter().map(|f| {
71 let field_name_str = f.name.to_string();
72 let nullable = f.nullable;
73 let data_type = &f.arrow_data_type;
74 quote! {
75 datafusion::arrow::datatypes::Field::new(#field_name_str, #data_type, #nullable)
76 }
77 });
78
79 let column_conversions = field_info.iter().map(|f| {
80 let name = &f.name;
81 let array_conversion = &f.array_conversion;
82 quote! {
83 std::sync::Arc::new(#array_conversion(std::mem::take(&mut self.#name)))
84 }
85 });
86
87 let first_field = &field_info[0].name;
88
89 let expanded = quote! {
90 #[derive(Default)]
91 pub struct #buffer_name {
92 #(#buffer_fields,)*
93 }
94
95 impl #buffer_name {
96 pub fn insert(&mut self, row: #name) {
97 #(#insert_pushes)*
98 }
99
100 pub fn schema() -> datafusion::arrow::datatypes::SchemaRef {
101 std::sync::Arc::new(datafusion::arrow::datatypes::Schema::new(vec![
102 #(#schema_fields,)*
103 ]))
104 }
105 }
106
107 impl monarch_record_batch::RecordBatchBuffer for #buffer_name {
108 fn len(&self) -> usize {
109 self.#first_field.len()
110 }
111
112 fn drain_to_record_batch(&mut self) -> Result<datafusion::arrow::record_batch::RecordBatch, datafusion::arrow::error::ArrowError> {
113 let schema = #buffer_name::schema();
114 let columns: Vec<datafusion::arrow::array::ArrayRef> = vec![
115 #(#column_conversions,)*
116 ];
117 datafusion::arrow::record_batch::RecordBatch::try_new(schema, columns)
118 }
119 }
120 };
121
122 TokenStream::from(expanded)
123}
124
125struct FieldInfo {
126 name: syn::Ident,
127 vec_type: proc_macro2::TokenStream,
128 nullable: bool,
129 arrow_data_type: proc_macro2::TokenStream,
130 array_conversion: proc_macro2::TokenStream,
131}
132
133impl FieldInfo {
134 fn from_field(field: &Field) -> Self {
135 let name = field.ident.clone().expect("field must have name");
136 let (inner_ty, nullable) = extract_option_inner(&field.ty);
137
138 let (vec_type, arrow_data_type, array_conversion) = if nullable {
139 let vec_ty = quote! { Vec<Option<#inner_ty>> };
140 let (data_type, array_conv) = get_arrow_type_and_conversion(inner_ty);
141 (vec_ty, data_type, array_conv)
142 } else {
143 let vec_ty = quote! { Vec<#inner_ty> };
144 let (data_type, array_conv) = get_arrow_type_and_conversion(inner_ty);
145 (vec_ty, data_type, array_conv)
146 };
147
148 FieldInfo {
149 name,
150 vec_type,
151 nullable,
152 arrow_data_type,
153 array_conversion,
154 }
155 }
156}
157
158fn extract_option_inner(ty: &Type) -> (&Type, bool) {
159 if let Type::Path(type_path) = ty {
160 if let Some(segment) = type_path.path.segments.last() {
161 if segment.ident == "Option" {
162 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
163 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
164 return (inner, true);
165 }
166 }
167 }
168 }
169 }
170 (ty, false)
171}
172
173fn get_arrow_type_and_conversion(
174 ty: &Type,
175) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
176 let type_str = quote!(#ty).to_string().replace(' ', "");
177
178 match type_str.as_str() {
179 "u64" => (
180 quote! { datafusion::arrow::datatypes::DataType::UInt64 },
181 quote! { datafusion::arrow::array::UInt64Array::from },
182 ),
183 "u32" => (
184 quote! { datafusion::arrow::datatypes::DataType::UInt32 },
185 quote! { datafusion::arrow::array::UInt32Array::from },
186 ),
187 "i64" => (
188 quote! { datafusion::arrow::datatypes::DataType::Int64 },
189 quote! { datafusion::arrow::array::Int64Array::from },
190 ),
191 "i32" => (
192 quote! { datafusion::arrow::datatypes::DataType::Int32 },
193 quote! { datafusion::arrow::array::Int32Array::from },
194 ),
195 "String" => (
196 quote! { datafusion::arrow::datatypes::DataType::Utf8 },
197 quote! { datafusion::arrow::array::StringArray::from },
198 ),
199 "bool" => (
200 quote! { datafusion::arrow::datatypes::DataType::Boolean },
201 quote! { datafusion::arrow::array::BooleanArray::from },
202 ),
203 "f64" => (
204 quote! { datafusion::arrow::datatypes::DataType::Float64 },
205 quote! { datafusion::arrow::array::Float64Array::from },
206 ),
207 "f32" => (
208 quote! { datafusion::arrow::datatypes::DataType::Float32 },
209 quote! { datafusion::arrow::array::Float32Array::from },
210 ),
211 _ => panic!("unsupported type: {}", type_str),
212 }
213}