1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Meta;
39use syn::MetaNameValue;
40use syn::Token;
41use syn::Type;
42use syn::bracketed;
43use syn::parse::Parse;
44use syn::parse::ParseStream;
45use syn::parse_macro_input;
46use syn::punctuated::Punctuated;
47use syn::spanned::Spanned;
48
49const REPLY_VARIANT_ERROR: &str = indoc! {r#"
50`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
51
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
53= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
54"#};
55
56const REPLY_USAGE_ERROR: &str = indoc! {r#"
57`call` message expects at most one `reply` argument
58
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
60= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
61"#};
62
63enum FieldFlag {
64 None,
65 Reply,
66}
67
68#[allow(dead_code)]
70enum Variant {
71 Named {
73 enum_name: Ident,
74 name: Ident,
75 field_names: Vec<Ident>,
76 field_types: Vec<Type>,
77 field_flags: Vec<FieldFlag>,
78 is_struct: bool,
79 generics: syn::Generics,
80 },
81 Anon {
83 enum_name: Ident,
84 name: Ident,
85 field_types: Vec<Type>,
86 field_flags: Vec<FieldFlag>,
87 is_struct: bool,
88 generics: syn::Generics,
89 },
90}
91
92impl Variant {
93 fn len(&self) -> usize {
95 self.field_types().len()
96 }
97
98 fn is_struct(&self) -> bool {
100 match self {
101 Variant::Named { is_struct, .. } => *is_struct,
102 Variant::Anon { is_struct, .. } => *is_struct,
103 }
104 }
105
106 fn enum_name(&self) -> &Ident {
108 match self {
109 Variant::Named { enum_name, .. } => enum_name,
110 Variant::Anon { enum_name, .. } => enum_name,
111 }
112 }
113
114 fn name(&self) -> &Ident {
116 match self {
117 Variant::Named { name, .. } => name,
118 Variant::Anon { name, .. } => name,
119 }
120 }
121
122 #[allow(dead_code)]
124 fn generics(&self) -> &syn::Generics {
125 match self {
126 Variant::Named { generics, .. } => generics,
127 Variant::Anon { generics, .. } => generics,
128 }
129 }
130
131 fn snake_name(&self) -> Ident {
133 Ident::new(
134 &self.name().to_string().to_case(Case::Snake),
135 self.name().span(),
136 )
137 }
138
139 fn qualified_name(&self) -> proc_macro2::TokenStream {
141 let enum_name = self.enum_name();
142 let name = self.name();
143
144 if self.is_struct() {
145 quote! { #enum_name }
146 } else {
147 quote! { #enum_name::#name }
148 }
149 }
150
151 fn field_names(&self) -> Vec<Ident> {
154 match self {
155 Variant::Named { field_names, .. } => field_names.clone(),
156 Variant::Anon { field_types, .. } => (0usize..field_types.len())
157 .map(|idx| format_ident!("arg{}", idx))
158 .collect(),
159 }
160 }
161
162 fn field_types(&self) -> &Vec<Type> {
164 match self {
165 Variant::Named { field_types, .. } => field_types,
166 Variant::Anon { field_types, .. } => field_types,
167 }
168 }
169
170 fn field_flags(&self) -> &Vec<FieldFlag> {
172 match self {
173 Variant::Named { field_flags, .. } => field_flags,
174 Variant::Anon { field_flags, .. } => field_flags,
175 }
176 }
177
178 fn constructor(&self) -> proc_macro2::TokenStream {
180 let qualified_name = self.qualified_name();
181 let field_names = self.field_names();
182 match self {
183 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
184 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
185 }
186 }
187}
188
189struct ReplyPort {
190 is_handle: bool,
191 is_once: bool,
192}
193
194impl ReplyPort {
195 fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
196 ReplyPort {
197 is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
198 is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
199 }
200 }
201
202 fn open_op(&self) -> proc_macro2::TokenStream {
203 if self.is_once {
204 quote! { hyperactor::mailbox::open_once_port }
205 } else {
206 quote! { hyperactor::mailbox::open_port }
207 }
208 }
209
210 fn rx_modifier(&self) -> proc_macro2::TokenStream {
211 if self.is_once {
212 quote! {}
213 } else {
214 quote! { mut }
215 }
216 }
217}
218
219#[allow(clippy::large_enum_variant)]
222enum Message {
223 Call {
226 variant: Variant,
227 reply_port: ReplyPort,
229 return_type: Type,
231 log_level: Option<Ident>,
233 },
234 OneWay {
235 variant: Variant,
236 log_level: Option<Ident>,
238 },
239}
240
241impl Message {
242 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
243 match &variant
244 .field_flags()
245 .iter()
246 .zip(variant.field_types())
247 .filter_map(|(flag, ty)| match flag {
248 FieldFlag::Reply => Some(ty),
249 FieldFlag::None => None,
250 })
251 .collect::<Vec<&Type>>()[..]
252 {
253 [] => Ok(Self::OneWay { variant, log_level }),
254 [reply_port_ty] => {
255 let syn::Type::Path(type_path) = reply_port_ty else {
256 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
257 };
258 let Some(last_segment) = type_path.path.segments.last() else {
259 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
260 };
261 if last_segment.ident != "OncePortRef"
262 && last_segment.ident != "OncePortHandle"
263 && last_segment.ident != "PortRef"
264 && last_segment.ident != "PortHandle"
265 {
266 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
267 }
268 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
269 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
270 };
271 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
272 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
273 };
274 let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
275 let return_type = return_ty.clone();
276 Ok(Self::Call {
277 variant,
278 reply_port,
279 return_type,
280 log_level,
281 })
282 }
283 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
284 }
285 }
286
287 fn args(&self) -> Vec<(Ident, Type)> {
289 match self {
290 Message::Call { variant, .. } => variant
291 .field_names()
292 .into_iter()
293 .zip(variant.field_types().clone())
294 .take(variant.len() - 1)
295 .collect(),
296 Message::OneWay { variant, .. } => variant
297 .field_names()
298 .into_iter()
299 .zip(variant.field_types().clone())
300 .collect(),
301 }
302 }
303
304 fn variant(&self) -> &Variant {
305 match self {
306 Message::Call { variant, .. } => variant,
307 Message::OneWay { variant, .. } => variant,
308 }
309 }
310
311 fn reply_port_position(&self) -> Option<usize> {
312 self.variant()
313 .field_flags()
314 .iter()
315 .position(|flag| matches!(flag, FieldFlag::Reply))
316 }
317
318 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
320 match self {
321 Message::Call { variant, .. } => {
322 let pos = self.reply_port_position()?;
323 Some((
324 variant.field_names()[pos].clone(),
325 variant.field_types()[pos].clone(),
326 ))
327 }
328 Message::OneWay { .. } => None,
329 }
330 }
331}
332
333fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
334 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
335 Some(attr) => {
336 let Ok(meta) = attr.meta.require_list() else {
337 return Err(syn::Error::new(
338 Span::call_site(),
339 indoc! {"
340 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
341
342 = help use `#[log_level(info)]` or `#[log_level(error)]`
343 "},
344 ));
345 };
346 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
347 if parsed.len() != 1 {
348 return Err(syn::Error::new(
349 Span::call_site(),
350 indoc! {"
351 `log_level` attribute must specify exactly one level
352
353 = help use `#[log_level(warn)]` or `#[log_level(info)]`
354 "},
355 ));
356 };
357 Some(parsed.first().unwrap().to_string())
358 }
359 None => None,
360 };
361
362 if level.is_none() {
363 return Ok(None);
364 }
365 let level = level.unwrap();
366
367 match level.as_str() {
368 "error" | "warn" | "info" | "debug" | "trace" => {}
369 _ => {
370 return Err(syn::Error::new(
371 Span::call_site(),
372 indoc! {"
373 `log_level` attribute must be one of 'error, warn, info, debug, trace'
374
375 = help use `#[log_level(warn)]` or `#[log_level(info)]`
376 "},
377 ));
378 }
379 }
380
381 Ok(Some(Ident::new(
382 level.to_ascii_uppercase().as_str(),
383 Span::call_site(),
384 )))
385}
386
387fn parse_field_flag(field: &Field) -> FieldFlag {
388 for attr in field.attrs.iter() {
389 match &attr.meta {
390 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
391 _ => {}
392 }
393 }
394 FieldFlag::None
395}
396
397fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
399 match &input.data {
400 Data::Enum(data_enum) => {
401 let mut messages = Vec::new();
402
403 for variant in &data_enum.variants {
404 let name = variant.ident.clone();
405 let attrs = &variant.attrs;
406
407 let message_variant = match &variant.fields {
408 syn::Fields::Unnamed(fields_) => Variant::Anon {
409 enum_name: input.ident.clone(),
410 name,
411 field_types: fields_
412 .unnamed
413 .iter()
414 .map(|field| field.ty.clone())
415 .collect(),
416 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
417 is_struct: false,
418 generics: input.generics.clone(),
419 },
420 syn::Fields::Named(fields_) => Variant::Named {
421 enum_name: input.ident.clone(),
422 name,
423 field_names: fields_
424 .named
425 .iter()
426 .map(|field| field.ident.clone().unwrap())
427 .collect(),
428 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
429 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
430 is_struct: false,
431 generics: input.generics.clone(),
432 },
433 _ => {
434 return Err(syn::Error::new_spanned(
435 variant,
436 indoc! {r#"
437 `Handler` currently only supports named or tuple struct variants
438
439 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
440 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
441 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
442 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
443 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
444 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
445 "#},
446 ));
447 }
448 };
449 let log_level = parse_log_level(attrs)?;
450
451 messages.push(Message::new(
452 variant.fields.span(),
453 message_variant,
454 log_level,
455 )?);
456 }
457
458 Ok(messages)
459 }
460 Data::Struct(data_struct) => {
461 let struct_name = input.ident.clone();
462 let attrs = &input.attrs;
463
464 let message_variant = match &data_struct.fields {
465 syn::Fields::Unnamed(fields_) => Variant::Anon {
466 enum_name: struct_name.clone(),
467 name: struct_name,
468 field_types: fields_
469 .unnamed
470 .iter()
471 .map(|field| field.ty.clone())
472 .collect(),
473 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
474 is_struct: true,
475 generics: input.generics.clone(),
476 },
477 syn::Fields::Named(fields_) => Variant::Named {
478 enum_name: struct_name.clone(),
479 name: struct_name,
480 field_names: fields_
481 .named
482 .iter()
483 .map(|field| field.ident.clone().unwrap())
484 .collect(),
485 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
486 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
487 is_struct: true,
488 generics: input.generics.clone(),
489 },
490 syn::Fields::Unit => Variant::Anon {
491 enum_name: struct_name.clone(),
492 name: struct_name,
493 field_types: Vec::new(),
494 field_flags: Vec::new(),
495 is_struct: true,
496 generics: input.generics.clone(),
497 },
498 };
499
500 let log_level = parse_log_level(attrs)?;
501 let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
502
503 Ok(vec![message])
504 }
505 _ => Err(syn::Error::new_spanned(
506 input,
507 "handlers can only be derived for enums and structs",
508 )),
509 }
510}
511
512#[proc_macro_derive(Handler, attributes(reply))]
671pub fn derive_handler(input: TokenStream) -> TokenStream {
672 let input = parse_macro_input!(input as DeriveInput);
673 let name: Ident = input.ident.clone();
674 let (_, ty_generics, _) = input.generics.split_for_impl();
675
676 let messages = match parse_messages(input.clone()) {
677 Ok(messages) => messages,
678 Err(err) => return TokenStream::from(err.to_compile_error()),
679 };
680
681 let mut handler_trait_methods = Vec::new();
683
684 let mut match_arms = Vec::new();
686
687 let mut client_trait_methods = Vec::new();
689
690 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
691
692 for message in &messages {
693 match message {
694 Message::Call {
695 variant,
696 reply_port,
697 return_type,
698 log_level,
699 } => {
700 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
701 let variant_name_snake = variant.snake_name();
702 let variant_name_snake_deprecated =
703 format_ident!("{}_deprecated", variant_name_snake);
704 let enum_name = variant.enum_name();
705 let _variant_qualified_name = variant.qualified_name();
706 let log_level = match (&global_log_level, log_level) {
707 (_, Some(local)) => local.clone(),
708 (Some(global), None) => global.clone(),
709 _ => Ident::new("DEBUG", Span::call_site()),
710 };
711 let _log_level = if reply_port.is_handle {
712 quote! {
713 tracing::Level::#log_level
714 }
715 } else {
716 quote! {
717 tracing::Level::TRACE
718 }
719 };
720 let log_message = quote! {
721 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
722 "rpc" => "call",
723 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
724 "message_type" => stringify!(#enum_name),
725 "variant" => stringify!(#variant_name_snake),
726 ));
727 };
728
729 handler_trait_methods.push(quote! {
730 #[doc = "The generated handler method for this enum variant."]
731 async fn #variant_name_snake(
732 &mut self,
733 cx: &hyperactor::Context<Self>,
734 #(#arg_names: #arg_types),*)
735 -> Result<#return_type, hyperactor::anyhow::Error>;
736 });
737
738 client_trait_methods.push(quote! {
739 #[doc = "The generated client method for this enum variant."]
740 async fn #variant_name_snake(
741 &self,
742 cx: &impl hyperactor::context::Actor,
743 #(#arg_names: #arg_types),*)
744 -> Result<#return_type, hyperactor::anyhow::Error>;
745
746 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
747 async fn #variant_name_snake_deprecated(
748 &self,
749 cx: &impl hyperactor::context::Actor,
750 #(#arg_names: #arg_types),*)
751 -> Result<#return_type, hyperactor::anyhow::Error>;
752 });
753
754 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
755 let constructor = variant.constructor();
756 let result_ident = Ident::new("result", Span::mixed_site());
757 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
758 if reply_port.is_handle {
759 match_arms.push(quote! {
760 #constructor => {
761 #log_message
762 #construct_result_future
765 #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
766 }
767 });
768 } else {
769 match_arms.push(quote! {
770 #constructor => {
771 #log_message
772 #construct_result_future
775 #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
776 }
777 });
778 }
779 }
780 Message::OneWay { variant, log_level } => {
781 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
782 let variant_name_snake = variant.snake_name();
783 let variant_name_snake_deprecated =
784 format_ident!("{}_deprecated", variant_name_snake);
785 let enum_name = variant.enum_name();
786 let log_level = match (&global_log_level, log_level) {
787 (_, Some(local)) => local.clone(),
788 (Some(global), None) => global.clone(),
789 _ => Ident::new("TRACE", Span::call_site()),
790 };
791 let _log_level = quote! {
792 tracing::Level::#log_level
793 };
794 let log_message = quote! {
795 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
796 "rpc" => "call",
797 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
798 "message_type" => stringify!(#enum_name),
799 "variant" => stringify!(#variant_name_snake),
800 ));
801 };
802
803 handler_trait_methods.push(quote! {
804 #[doc = "The generated handler method for this enum variant."]
805 async fn #variant_name_snake(
806 &mut self,
807 cx: &hyperactor::Context<Self>,
808 #(#arg_names: #arg_types),*)
809 -> Result<(), hyperactor::anyhow::Error>;
810 });
811
812 client_trait_methods.push(quote! {
813 #[doc = "The generated client method for this enum variant."]
814 async fn #variant_name_snake(
815 &self,
816 cx: &impl hyperactor::context::Actor,
817 #(#arg_names: #arg_types),*)
818 -> Result<(), hyperactor::anyhow::Error>;
819
820 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
821 async fn #variant_name_snake_deprecated(
822 &self,
823 cx: &impl hyperactor::context::Actor,
824 #(#arg_names: #arg_types),*)
825 -> Result<(), hyperactor::anyhow::Error>;
826 });
827
828 let constructor = variant.constructor();
829
830 match_arms.push(quote! {
831 #constructor => {
832 #log_message
833 self.#variant_name_snake(cx, #(#arg_names),*).await
834 },
835 });
836 }
837 }
838 }
839
840 let handler_trait_name = format_ident!("{}Handler", name);
841 let client_trait_name = format_ident!("{}Client", name);
842
843 let mut handler_generics = input.generics.clone();
847 for param in handler_generics.type_params_mut() {
848 param.bounds.push(syn::parse_quote!(serde::Serialize));
849 param
850 .bounds
851 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
852 param.bounds.push(syn::parse_quote!(Send));
853 param.bounds.push(syn::parse_quote!(Sync));
854 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
855 param.bounds.push(syn::parse_quote!(hyperactor::Named));
856 }
857 let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
858 let (client_impl_generics, _, _) = input.generics.split_for_impl();
859
860 let expanded = quote! {
861 #[doc = "The custom handler trait for this message type."]
862 #[hyperactor::async_trait::async_trait]
863 pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync {
864 #(#handler_trait_methods)*
865
866 #[doc = "Handle the next message."]
867 async fn handle(
868 &mut self,
869 cx: &hyperactor::Context<Self>,
870 message: #name #ty_generics,
871 ) -> hyperactor::anyhow::Result<()> {
872 match message {
874 #(#match_arms)*
875 }
876 }
877 }
878
879 #[doc = "The custom client trait for this message type."]
880 #[hyperactor::async_trait::async_trait]
881 pub trait #client_trait_name #client_impl_generics: Send + Sync {
882 #(#client_trait_methods)*
883 }
884 };
885
886 TokenStream::from(expanded)
887}
888
889#[proc_macro_derive(HandleClient, attributes(log_level))]
892pub fn derive_handle_client(input: TokenStream) -> TokenStream {
893 derive_client(input, true)
894}
895
896#[proc_macro_derive(RefClient, attributes(log_level))]
899pub fn derive_ref_client(input: TokenStream) -> TokenStream {
900 derive_client(input, false)
901}
902
903fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
904 let input = parse_macro_input!(input as DeriveInput);
905 let name = input.ident.clone();
906
907 let messages = match parse_messages(input.clone()) {
908 Ok(messages) => messages,
909 Err(err) => return TokenStream::from(err.to_compile_error()),
910 };
911
912 let mut impl_methods = Vec::new();
914
915 let send_message = if is_handle {
916 quote! { self.send(message)? }
917 } else {
918 quote! { self.send(cx, message)? }
919 };
920 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
921
922 for message in &messages {
923 match message {
924 Message::Call {
925 variant,
926 reply_port,
927 return_type,
928 log_level,
929 } => {
930 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
931 let variant_name_snake = variant.snake_name();
932 let variant_name_snake_deprecated =
933 format_ident!("{}_deprecated", variant_name_snake);
934 let enum_name = variant.enum_name();
935
936 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
937 let constructor = variant.constructor();
938 let log_level = match (&global_log_level, log_level) {
939 (_, Some(local)) => local.clone(),
940 (Some(global), None) => global.clone(),
941 _ => Ident::new("DEBUG", Span::call_site()),
942 };
943 let log_level = if is_handle {
944 quote! {
945 tracing::Level::#log_level
946 }
947 } else {
948 quote! {
949 tracing::Level::TRACE
950 }
951 };
952 let log_message = quote! {
953 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
954 "rpc" => "call",
955 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
956 "message_type" => stringify!(#enum_name),
957 "variant" => stringify!(#variant_name_snake),
958 ));
959
960 };
961 let open_port = reply_port.open_op();
962 let rx_mod = reply_port.rx_modifier();
963 if reply_port.is_handle {
964 impl_methods.push(quote! {
965 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
966 async fn #variant_name_snake(
967 &self,
968 cx: &impl hyperactor::context::Actor,
969 #(#arg_names: #arg_types),*)
970 -> Result<#return_type, hyperactor::anyhow::Error> {
971 let (#reply_port_arg, #rx_mod reply_receiver) =
972 #open_port::<#return_type>(cx);
973 let message = #constructor;
974 #log_message;
975 #send_message;
976 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
977 }
978
979 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
980 async fn #variant_name_snake_deprecated(
981 &self,
982 cx: &impl hyperactor::context::Actor,
983 #(#arg_names: #arg_types),*)
984 -> Result<#return_type, hyperactor::anyhow::Error> {
985 let (#reply_port_arg, #rx_mod reply_receiver) =
986 #open_port::<#return_type>(cx);
987 let message = #constructor;
988 #log_message;
989 #send_message;
990 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
991 }
992 });
993 } else {
994 impl_methods.push(quote! {
995 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
996 async fn #variant_name_snake(
997 &self,
998 cx: &impl hyperactor::context::Actor,
999 #(#arg_names: #arg_types),*)
1000 -> Result<#return_type, hyperactor::anyhow::Error> {
1001 let (#reply_port_arg, #rx_mod reply_receiver) =
1002 #open_port::<#return_type>(cx);
1003 let #reply_port_arg = #reply_port_arg.bind();
1004 let message = #constructor;
1005 #log_message;
1006 #send_message;
1007 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1008 }
1009
1010 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
1011 async fn #variant_name_snake_deprecated(
1012 &self,
1013 cx: &impl hyperactor::context::Actor,
1014 #(#arg_names: #arg_types),*)
1015 -> Result<#return_type, hyperactor::anyhow::Error> {
1016 let (#reply_port_arg, #rx_mod reply_receiver) =
1017 #open_port::<#return_type>(cx);
1018 let #reply_port_arg = #reply_port_arg.bind();
1019 let message = #constructor;
1020 #log_message;
1021 #send_message;
1022 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1023 }
1024 });
1025 }
1026 }
1027 Message::OneWay { variant, log_level } => {
1028 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1029 let variant_name_snake = variant.snake_name();
1030 let variant_name_snake_deprecated =
1031 format_ident!("{}_deprecated", variant_name_snake);
1032 let enum_name = variant.enum_name();
1033 let constructor = variant.constructor();
1034 let log_level = match (&global_log_level, log_level) {
1035 (_, Some(local)) => local.clone(),
1036 (Some(global), None) => global.clone(),
1037 _ => Ident::new("DEBUG", Span::call_site()),
1038 };
1039 let _log_level = if is_handle {
1040 quote! {
1041 tracing::Level::TRACE
1042 }
1043 } else {
1044 quote! {
1045 tracing::Level::#log_level
1046 }
1047 };
1048 let log_message = quote! {
1049 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1050 "rpc" => "oneway",
1051 "actor_id" => self.actor_id().to_string(),
1052 "message_type" => stringify!(#enum_name),
1053 "variant" => stringify!(#variant_name_snake),
1054 ));
1055 };
1056 impl_methods.push(quote! {
1057 async fn #variant_name_snake(
1058 &self,
1059 cx: &impl hyperactor::context::Actor,
1060 #(#arg_names: #arg_types),*)
1061 -> Result<(), hyperactor::anyhow::Error> {
1062 let message = #constructor;
1063 #log_message;
1064 #send_message;
1065 Ok(())
1066 }
1067
1068 async fn #variant_name_snake_deprecated(
1069 &self,
1070 cx: &impl hyperactor::context::Actor,
1071 #(#arg_names: #arg_types),*)
1072 -> Result<(), hyperactor::anyhow::Error> {
1073 let message = #constructor;
1074 #log_message;
1075 #send_message;
1076 Ok(())
1077 }
1078 });
1079 }
1080 }
1081 }
1082
1083 let trait_name = format_ident!("{}Client", name);
1084
1085 let (_, ty_generics, _) = input.generics.split_for_impl();
1086
1087 let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1089 let mut trait_generics = input.generics.clone();
1090 trait_generics.params.insert(
1091 0,
1092 syn::GenericParam::Type(syn::TypeParam {
1093 ident: actor_ident.clone(),
1094 attrs: vec![],
1095 colon_token: None,
1096 bounds: Punctuated::new(),
1097 eq_token: None,
1098 default: None,
1099 }),
1100 );
1101
1102 for param in trait_generics.type_params_mut() {
1103 if param.ident == actor_ident {
1104 continue;
1105 }
1106 param.bounds.push(syn::parse_quote!(serde::Serialize));
1107 param
1108 .bounds
1109 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1110 param.bounds.push(syn::parse_quote!(Send));
1111 param.bounds.push(syn::parse_quote!(Sync));
1112 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1113 param.bounds.push(syn::parse_quote!(hyperactor::Named));
1114 }
1115
1116 let (impl_generics, _, _) = trait_generics.split_for_impl();
1117
1118 let expanded = if is_handle {
1119 quote! {
1120 #[hyperactor::async_trait::async_trait]
1121 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1122 where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1123 #(#impl_methods)*
1124 }
1125 }
1126 } else {
1127 quote! {
1128 #[hyperactor::async_trait::async_trait]
1129 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1130 where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1131 #(#impl_methods)*
1132 }
1133 }
1134 };
1135
1136 TokenStream::from(expanded)
1137}
1138
1139const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
1140`forward` expects the message type that is being forwarded
1141
1142= help: use `#[forward(MessageType)]`
1143"#};
1144
1145#[proc_macro_attribute]
1147pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
1148 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1149 if attr_args.len() != 1 {
1150 return TokenStream::from(
1151 syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
1152 );
1153 }
1154
1155 let message_type = attr_args.first().unwrap();
1156 let input = parse_macro_input!(item as ItemImpl);
1157
1158 let self_type = match *input.self_ty {
1159 syn::Type::Path(ref type_path) => {
1160 let segment = type_path.path.segments.last().unwrap();
1161 segment.clone() }
1163 _ => {
1164 return TokenStream::from(
1165 syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
1166 .to_compile_error(),
1167 );
1168 }
1169 };
1170
1171 let trait_name = match input.trait_ {
1172 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1173 None => {
1174 return TokenStream::from(
1175 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1176 .to_compile_error(),
1177 );
1178 }
1179 };
1180
1181 let expanded = quote! {
1182 #input
1183
1184 #[hyperactor::async_trait::async_trait]
1185 impl hyperactor::Handler<#message_type> for #self_type {
1186 async fn handle(
1187 &mut self,
1188 cx: &hyperactor::Context<Self>,
1189 message: #message_type,
1190 ) -> hyperactor::anyhow::Result<()> {
1191 <Self as #trait_name>::handle(self, cx, message).await
1192 }
1193 }
1194 };
1195
1196 TokenStream::from(expanded)
1197}
1198
1199#[proc_macro_attribute]
1212pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1213 let args =
1214 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1215 let input = parse_macro_input!(input as ItemFn);
1216 let output = quote! {
1217 #[hyperactor::tracing::instrument(err, skip_all, #args)]
1218 #input
1219 };
1220
1221 TokenStream::from(output)
1222}
1223
1224#[proc_macro_attribute]
1235pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1236 let args =
1237 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1238 let input = parse_macro_input!(input as ItemFn);
1239
1240 let output = quote! {
1241 #[hyperactor::tracing::instrument(skip_all, #args)]
1242 #input
1243 };
1244
1245 TokenStream::from(output)
1246}
1247
1248#[proc_macro_derive(Named, attributes(named))]
1261pub fn derive_named(input: TokenStream) -> TokenStream {
1262 let input = parse_macro_input!(input as DeriveInput);
1264 let struct_name = &input.ident;
1265
1266 let mut typename = quote! {
1267 concat!(std::module_path!(), "::", stringify!(#struct_name))
1268 };
1269
1270 let type_params: Vec<_> = input.generics.type_params().collect();
1271 let has_generics = !type_params.is_empty();
1272 let mut register = !has_generics;
1274
1275 for attr in &input.attrs {
1276 if attr.path().is_ident("named") {
1277 if let Ok(meta) = attr.parse_args_with(
1278 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
1279 ) {
1280 for item in meta {
1281 if let Meta::NameValue(MetaNameValue {
1282 path,
1283 value: Expr::Lit(expr_lit),
1284 ..
1285 }) = item
1286 {
1287 if path.is_ident("name") {
1288 if let Lit::Str(name) = expr_lit.lit {
1289 typename = quote! { #name };
1290 } else {
1291 return TokenStream::from(
1292 syn::Error::new_spanned(path, "invalid name")
1293 .to_compile_error(),
1294 );
1295 }
1296 } else if path.is_ident("register") {
1297 if let Lit::Bool(flag) = expr_lit.lit {
1298 register = flag.value;
1299 } else {
1300 return TokenStream::from(
1301 syn::Error::new_spanned(path, "invalid registration flag")
1302 .to_compile_error(),
1303 );
1304 }
1305 } else {
1306 return TokenStream::from(
1307 syn::Error::new_spanned(
1308 path,
1309 "unsupported attribute (only `name` or `register` is supported)",
1310 )
1311 .to_compile_error(),
1312 );
1313 }
1314 }
1315 }
1316 }
1317 }
1318 }
1319
1320 let mut generics_with_bounds = input.generics.clone();
1322 if has_generics {
1323 for param in generics_with_bounds.type_params_mut() {
1324 param
1325 .bounds
1326 .push(syn::parse_quote!(hyperactor::data::Named));
1327 }
1328 }
1329 let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1330
1331 let (typename_impl, typehash_impl) = if has_generics {
1333 let placeholders = vec!["{}"; type_params.len()].join(", ");
1335 let placeholders_format_string = format!("<{}>", placeholders);
1336 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#struct_name), #placeholders_format_string) };
1337
1338 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1339 (
1340 quote! {
1341 hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1342 },
1343 quote! {
1344 hyperactor::cityhasher::hash(Self::typename())
1345 },
1346 )
1347 } else {
1348 (
1349 typename,
1350 quote! {
1351 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1352 hyperactor::cityhasher::hash(<#struct_name as hyperactor::data::Named>::typename())
1353 });
1354 *TYPEHASH
1355 },
1356 )
1357 };
1358
1359 let arm_impl = match &input.data {
1361 Data::Enum(DataEnum { variants, .. }) => {
1362 let match_arms = variants.iter().map(|v| {
1363 let variant_name = &v.ident;
1364 let variant_str = variant_name.to_string();
1365 match &v.fields {
1366 Fields::Unit => quote! { Self::#variant_name => Some(#variant_str) },
1367 Fields::Unnamed(_) => quote! { Self::#variant_name(..) => Some(#variant_str) },
1368 Fields::Named(_) => quote! { Self::#variant_name { .. } => Some(#variant_str) },
1369 }
1370 });
1371 quote! {
1372 fn arm(&self) -> Option<&'static str> {
1373 match self {
1374 #(#match_arms,)*
1375 }
1376 }
1377 }
1378 }
1379 _ => quote! {},
1380 };
1381
1382 let registration = if register {
1389 quote! {
1390 hyperactor::register_type!(#struct_name);
1391 }
1392 } else {
1393 quote! {
1394 }
1396 };
1397
1398 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
1399 let expanded = quote! {
1402 impl #impl_generics_with_bounds hyperactor::data::Named for #struct_name #ty_generics #where_clause {
1403 fn typename() -> &'static str { #typename_impl }
1404 fn typehash() -> u64 { #typehash_impl }
1405 #arm_impl
1406 }
1407
1408 #registration
1409 };
1410
1411 TokenStream::from(expanded)
1412}
1413
1414struct HandlerSpec {
1415 ty: Type,
1416 cast: bool,
1417}
1418
1419impl Parse for HandlerSpec {
1420 fn parse(input: ParseStream) -> syn::Result<Self> {
1421 let ty: Type = input.parse()?;
1422
1423 if input.peek(syn::token::Brace) {
1424 let content;
1425 syn::braced!(content in input);
1426 let key: Ident = content.parse()?;
1427 content.parse::<Token![=]>()?;
1428 let expr: Expr = content.parse()?;
1429
1430 let cast = if key == "cast" {
1431 if let Expr::Lit(ExprLit {
1432 lit: Lit::Bool(b), ..
1433 }) = expr
1434 {
1435 b.value
1436 } else {
1437 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1438 }
1439 } else {
1440 return Err(syn::Error::new_spanned(
1441 key,
1442 "unsupported field (expected `cast`)",
1443 ));
1444 };
1445
1446 Ok(HandlerSpec { ty, cast })
1447 } else if input.is_empty() || input.peek(Token![,]) {
1448 Ok(HandlerSpec { ty, cast: false })
1449 } else {
1450 let unexpected: proc_macro2::TokenTree = input.parse()?;
1452 Err(syn::Error::new_spanned(
1453 unexpected,
1454 "unexpected token after type — expected `{ ... }` or nothing",
1455 ))
1456 }
1457 }
1458}
1459
1460impl HandlerSpec {
1461 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1462 let mut tys = Vec::new();
1463 for HandlerSpec { ty, cast } in handlers {
1464 if cast {
1465 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1466 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1467 tys.push(wrapped_ty);
1468 }
1469 tys.push(ty);
1470 }
1471 tys
1472 }
1473}
1474
1475struct ExportAttr {
1477 spawn: bool,
1478 handlers: Vec<HandlerSpec>,
1479}
1480
1481impl Parse for ExportAttr {
1482 fn parse(input: ParseStream) -> syn::Result<Self> {
1483 let mut spawn = false;
1484 let mut handlers: Vec<HandlerSpec> = vec![];
1485
1486 while !input.is_empty() {
1487 let key: Ident = input.parse()?;
1488 input.parse::<Token![=]>()?;
1489
1490 if key == "spawn" {
1491 let expr: Expr = input.parse()?;
1492 if let Expr::Lit(ExprLit {
1493 lit: Lit::Bool(b), ..
1494 }) = expr
1495 {
1496 spawn = b.value;
1497 } else {
1498 return Err(syn::Error::new_spanned(
1499 expr,
1500 "expected boolean for `spawn`",
1501 ));
1502 }
1503 } else if key == "handlers" {
1504 let content;
1505 bracketed!(content in input);
1506 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1507 handlers = raw_handlers.into_iter().collect();
1508 } else {
1509 return Err(syn::Error::new_spanned(
1510 key,
1511 "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1512 ));
1513 }
1514
1515 let _ = input.parse::<Token![,]>();
1517 }
1518
1519 Ok(ExportAttr { spawn, handlers })
1520 }
1521}
1522
1523#[proc_macro_attribute]
1550pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1551 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1552 let data_type_name = &input.ident;
1553
1554 let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1555 let tys = HandlerSpec::add_indexed(handlers);
1556
1557 let mut handles = Vec::new();
1558 let mut bindings = Vec::new();
1559 let mut type_registrations = Vec::new();
1560
1561 for ty in &tys {
1562 handles.push(quote! {
1563 impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
1564 });
1565 bindings.push(quote! {
1566 ports.bind::<#ty>();
1567 });
1568 type_registrations.push(quote! {
1569 hyperactor::register_type!(#ty);
1570 });
1571 }
1572
1573 let mut expanded = quote! {
1574 #input
1575
1576 impl hyperactor::actor::Referable for #data_type_name {}
1577
1578 #(#handles)*
1579
1580 #(#type_registrations)*
1581
1582 impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
1584
1585 impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
1586 fn bind(ports: &hyperactor::proc::Ports<Self>) {
1587 #(#bindings)*
1588 }
1589 }
1590
1591 impl hyperactor::data::Named for #data_type_name {
1593 fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
1594 }
1595 };
1596
1597 if spawn {
1598 expanded.extend(quote! {
1599 hyperactor::remote!(#data_type_name);
1600 });
1601 }
1602
1603 TokenStream::from(expanded)
1604}
1605
1606struct AliasInput {
1608 alias: Ident,
1609 handlers: Vec<HandlerSpec>,
1610}
1611
1612impl syn::parse::Parse for AliasInput {
1613 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1614 let alias: Ident = input.parse()?;
1615 let _: Token![,] = input.parse()?;
1616 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1617 let handlers = raw_handlers.into_iter().collect();
1618 Ok(AliasInput { alias, handlers })
1619 }
1620}
1621
1622#[proc_macro]
1638pub fn alias(input: TokenStream) -> TokenStream {
1639 let AliasInput { alias, handlers } = parse_macro_input!(input as AliasInput);
1640 let tys = HandlerSpec::add_indexed(handlers);
1641
1642 let expanded = quote! {
1643 #[doc = "The generated alias struct."]
1644 #[derive(Debug, hyperactor::Named, serde::Serialize, serde::Deserialize)]
1645 pub struct #alias;
1646 impl hyperactor::actor::Referable for #alias {}
1647
1648 impl<A> hyperactor::actor::Binds<A> for #alias
1649 where
1650 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)* {
1651 fn bind(ports: &hyperactor::proc::Ports<A>) {
1652 #(
1653 ports.bind::<#tys>();
1654 )*
1655 }
1656 }
1657
1658 #(
1659 impl hyperactor::actor::RemoteHandles<#tys> for #alias {}
1660 )*
1661 };
1662
1663 TokenStream::from(expanded)
1664}
1665
1666fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1667 let mut is_included = false;
1668 for attr in &field.attrs {
1669 if attr.path().is_ident("binding") {
1670 attr.parse_nested_meta(|meta| {
1672 if meta.path.is_ident("include") {
1673 is_included = true;
1674 Ok(())
1675 } else {
1676 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1677 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1678 }
1679 })?
1680 }
1681 }
1682 Ok(is_included)
1683}
1684
1685enum FieldAccessor {
1690 Named(Ident),
1691 Unnamed(Index),
1692}
1693
1694struct ParsedField {
1696 accessor: FieldAccessor,
1697 ty: Type,
1698 included: bool,
1699}
1700
1701impl From<&ParsedField> for (Ident, Type) {
1702 fn from(field: &ParsedField) -> Self {
1703 let field_ident = match &field.accessor {
1704 FieldAccessor::Named(ident) => ident.clone(),
1705 FieldAccessor::Unnamed(i) => {
1706 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1707 }
1708 };
1709 (field_ident, field.ty.clone())
1710 }
1711}
1712
1713fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1714 match fields {
1715 Fields::Named(named) => named
1716 .named
1717 .iter()
1718 .map(|f| {
1719 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1720 Ok(ParsedField {
1721 accessor,
1722 ty: f.ty.clone(),
1723 included: include_in_bind_unbind(f)?,
1724 })
1725 })
1726 .collect(),
1727 Fields::Unnamed(unnamed) => unnamed
1728 .unnamed
1729 .iter()
1730 .enumerate()
1731 .map(|(i, f)| {
1732 let accessor = FieldAccessor::Unnamed(Index::from(i));
1733 Ok(ParsedField {
1734 accessor,
1735 ty: f.ty.clone(),
1736 included: include_in_bind_unbind(f)?,
1737 })
1738 })
1739 .collect(),
1740 Fields::Unit => Ok(Vec::new()),
1741 }
1742}
1743
1744fn gen_struct_items<F>(
1745 fields: &Fields,
1746 make_item: F,
1747 is_mutable: bool,
1748) -> syn::Result<Vec<proc_macro2::TokenStream>>
1749where
1750 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1751{
1752 let borrow = if is_mutable {
1753 quote! { &mut }
1754 } else {
1755 quote! { & }
1756 };
1757 let items: Vec<_> = collect_all_fields(fields)?
1758 .into_iter()
1759 .filter(|f| f.included)
1760 .map(
1761 |ParsedField {
1762 accessor,
1763 ty,
1764 included,
1765 }| {
1766 assert!(included);
1767 let field_accessor = match accessor {
1768 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1769 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1770 };
1771 make_item(field_accessor, ty)
1772 },
1773 )
1774 .collect();
1775 Ok(items)
1776}
1777
1778fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1788 all_fields
1789 .iter()
1790 .map(
1791 |ParsedField {
1792 accessor,
1793 ty: _,
1794 included,
1795 }| {
1796 match accessor {
1797 FieldAccessor::Named(ident) => {
1798 if *included {
1799 quote! { #ident }
1800 } else {
1801 quote! { #ident: _ }
1802 }
1803 }
1804 FieldAccessor::Unnamed(i) => {
1805 if *included {
1806 let ident = Ident::new(
1807 &format!("f{}", i.index),
1808 proc_macro2::Span::call_site(),
1809 );
1810 quote! { #ident }
1811 } else {
1812 quote! { _ }
1813 }
1814 }
1815 }
1816 },
1817 )
1818 .collect()
1819}
1820
1821fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1828where
1829 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1830{
1831 data.variants
1832 .iter()
1833 .map(|variant| {
1834 let name = &variant.ident;
1835 let all_fields = collect_all_fields(&variant.fields)?;
1836 let field_accessors = gen_enum_field_accessors(&all_fields);
1837 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1838 let items = included_fields
1839 .iter()
1840 .map(|f| {
1841 let (accessor, ty) = <(Ident, Type)>::from(*f);
1842 make_item(quote! { #accessor }, ty)
1843 })
1844 .collect::<Vec<_>>();
1845
1846 Ok(match &variant.fields {
1847 Fields::Named(_) => {
1848 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1849 }
1850 Fields::Unnamed(_) => {
1851 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1852 }
1853 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1854 })
1855 })
1856 .collect()
1857}
1858
1859#[proc_macro_derive(Bind, attributes(binding))]
1941pub fn derive_bind(input: TokenStream) -> TokenStream {
1942 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1943 quote! {
1944 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1945 }
1946 }
1947
1948 let input = parse_macro_input!(input as DeriveInput);
1949 let name = &input.ident;
1950 let inner = match &input.data {
1951 Data::Struct(DataStruct { fields, .. }) => {
1952 match gen_struct_items(fields, make_item, true) {
1953 Ok(collects) => {
1954 quote! { #(#collects)* }
1955 }
1956 Err(e) => {
1957 return TokenStream::from(e.to_compile_error());
1958 }
1959 }
1960 }
1961 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1962 Ok(arms) => {
1963 quote! { match self { #(#arms),* } }
1964 }
1965 Err(e) => {
1966 return TokenStream::from(e.to_compile_error());
1967 }
1968 },
1969 _ => panic!("Bind can only be derived for structs and enums"),
1970 };
1971 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1972 let expand = quote! {
1973 #[automatically_derived]
1974 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1975 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1976 #inner
1977 Ok(())
1978 }
1979 }
1980 };
1981 TokenStream::from(expand)
1982}
1983
1984#[proc_macro_derive(Unbind, attributes(binding))]
1998pub fn derive_unbind(input: TokenStream) -> TokenStream {
1999 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2000 quote! {
2001 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
2002 }
2003 }
2004
2005 let input = parse_macro_input!(input as DeriveInput);
2006 let name = &input.ident;
2007 let inner = match &input.data {
2008 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
2009 {
2010 Ok(collects) => {
2011 quote! { #(#collects)* }
2012 }
2013 Err(e) => {
2014 return TokenStream::from(e.to_compile_error());
2015 }
2016 },
2017 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2018 Ok(arms) => {
2019 quote! { match self { #(#arms),* } }
2020 }
2021 Err(e) => {
2022 return TokenStream::from(e.to_compile_error());
2023 }
2024 },
2025 _ => panic!("Unbind can only be derived for structs and enums"),
2026 };
2027 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2028 let expand = quote! {
2029 #[automatically_derived]
2030 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
2031 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2032 #inner
2033 Ok(())
2034 }
2035 }
2036 };
2037 TokenStream::from(expand)
2038}
2039
2040#[proc_macro_derive(Actor, attributes(actor))]
2087pub fn derive_actor(input: TokenStream) -> TokenStream {
2088 let input = parse_macro_input!(input as DeriveInput);
2089 let name = &input.ident;
2090 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2091
2092 let is_passthrough = input.attrs.iter().any(|attr| {
2093 if attr.path().is_ident("actor") {
2094 if let Ok(meta) = attr.parse_args_with(
2095 syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
2096 ) {
2097 return meta.iter().any(|ident| ident == "passthrough");
2098 }
2099 }
2100 false
2101 });
2102
2103 let expanded = if is_passthrough {
2104 quote! {
2105 #[hyperactor::async_trait::async_trait]
2106 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2107 type Params = Self;
2108
2109 async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
2110 Ok(instance)
2111 }
2112 }
2113 }
2114 } else {
2115 quote! {
2116 #[hyperactor::async_trait::async_trait]
2117 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2118 type Params = ();
2119
2120 async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
2121 Ok(Default::default())
2122 }
2123 }
2124 }
2125 };
2126
2127 TokenStream::from(expanded)
2128}
2129
2130fn parse_observe_function(
2132 attr: TokenStream,
2133 item: TokenStream,
2134) -> syn::Result<(ItemFn, String, String)> {
2135 let input = syn::parse::<ItemFn>(item)?;
2136
2137 if input.sig.asyncness.is_none() {
2138 return Err(syn::Error::new(
2139 input.sig.span(),
2140 "observe macros can only be applied to async functions",
2141 ));
2142 }
2143
2144 let fn_name_str = input.sig.ident.to_string();
2145 let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
2146
2147 Ok((input, fn_name_str, module_name_str))
2148}
2149
2150fn create_telemetry_setup(
2152 module_name_str: &str,
2153 fn_name_str: &str,
2154 include_error: bool,
2155) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2156 let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2157 let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2158
2159 let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2160
2161 let error_ident = if include_error {
2162 Some(Ident::new(
2163 "error",
2164 Span::from(proc_macro::Span::def_site()),
2165 ))
2166 } else {
2167 None
2168 };
2169
2170 let error_declaration = if let Some(ref error_ident) = error_ident {
2171 quote! {
2172 hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2173 }
2174 } else {
2175 quote! {}
2176 };
2177
2178 let setup_code = quote! {
2179 use hyperactor_telemetry;
2180 hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2181 hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2182 #error_declaration
2183 };
2184
2185 (latency_ident, success_ident, error_ident, setup_code)
2186}
2187
2188#[proc_macro_attribute]
2208pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2209 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2210 Ok(parsed) => parsed,
2211 Err(err) => return err.to_compile_error().into(),
2212 };
2213
2214 let fn_name = &input.sig.ident;
2215 let vis = &input.vis;
2216 let args = &input.sig.inputs;
2217 let return_type = &input.sig.output;
2218 let body = &input.block;
2219 let attrs = &input.attrs;
2220 let generics = &input.sig.generics;
2221
2222 let (latency_ident, success_ident, error_ident, telemetry_setup) =
2223 create_telemetry_setup(&module_name_str, &fn_name_str, true);
2224 let error_ident = error_ident.unwrap();
2225
2226 let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2227
2228 let expanded = quote! {
2230 #(#attrs)*
2231 #vis async fn #fn_name #generics(#args) #return_type {
2232 #telemetry_setup
2233
2234 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2235 let _timer = #latency_ident.start(kv_pairs);
2236
2237 let #result_ident = async #body.await;
2238
2239 match &#result_ident {
2240 Ok(_) => {
2241 #success_ident.add(
2242 1,
2243 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2244 );
2245 }
2246 Err(_) => {
2247 #error_ident.add(
2248 1,
2249 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2250 );
2251 }
2252 }
2253
2254 #result_ident
2255 }
2256 };
2257
2258 expanded.into()
2259}
2260
2261#[proc_macro_attribute]
2280pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2281 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2282 Ok(parsed) => parsed,
2283 Err(err) => return err.to_compile_error().into(),
2284 };
2285
2286 let fn_name = &input.sig.ident;
2287 let vis = &input.vis;
2288 let args = &input.sig.inputs;
2289 let return_type = &input.sig.output;
2290 let body = &input.block;
2291 let attrs = &input.attrs;
2292 let generics = &input.sig.generics;
2293
2294 let (latency_ident, success_ident, _, telemetry_setup) =
2295 create_telemetry_setup(&module_name_str, &fn_name_str, false);
2296
2297 let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2298
2299 let expanded = quote! {
2301 #(#attrs)*
2302 #vis async fn #fn_name #generics(#args) #return_type {
2303 #telemetry_setup
2304
2305 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2306 let _timer = #latency_ident.start(kv_pairs);
2307
2308 let #return_ident = async #body.await;
2309
2310 #success_ident.add(
2311 1,
2312 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2313 );
2314 #return_ident
2315 }
2316 };
2317
2318 expanded.into()
2319}
2320
2321#[proc_macro_derive(AttrValue)]
2355pub fn derive_attr_value(input: TokenStream) -> TokenStream {
2356 let input = parse_macro_input!(input as DeriveInput);
2357 let name = &input.ident;
2358
2359 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2360
2361 TokenStream::from(quote! {
2362 impl #impl_generics hyperactor::attrs::AttrValue for #name #ty_generics #where_clause {
2363 fn display(&self) -> String {
2364 self.to_string()
2365 }
2366
2367 fn parse(value: &str) -> Result<Self, anyhow::Error> {
2368 value.parse().map_err(|e| anyhow::anyhow!("failed to parse {}: {}", stringify!(#name), e))
2369 }
2370 }
2371 })
2372}