1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::bracketed;
41use syn::parse::Parse;
42use syn::parse::ParseStream;
43use syn::parse_macro_input;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46
47const REPLY_VARIANT_ERROR: &str = indoc! {r#"
48`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
49
50= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
52"#};
53
54const REPLY_USAGE_ERROR: &str = indoc! {r#"
55`call` message expects at most one `reply` argument
56
57= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
59"#};
60
61enum FieldFlag {
62 None,
63 Reply,
64}
65
66#[allow(dead_code)]
68enum Variant {
69 Named {
71 enum_name: Ident,
72 name: Ident,
73 field_names: Vec<Ident>,
74 field_types: Vec<Type>,
75 field_flags: Vec<FieldFlag>,
76 is_struct: bool,
77 generics: syn::Generics,
78 },
79 Anon {
81 enum_name: Ident,
82 name: Ident,
83 field_types: Vec<Type>,
84 field_flags: Vec<FieldFlag>,
85 is_struct: bool,
86 generics: syn::Generics,
87 },
88}
89
90impl Variant {
91 fn len(&self) -> usize {
93 self.field_types().len()
94 }
95
96 fn is_struct(&self) -> bool {
98 match self {
99 Variant::Named { is_struct, .. } => *is_struct,
100 Variant::Anon { is_struct, .. } => *is_struct,
101 }
102 }
103
104 fn enum_name(&self) -> &Ident {
106 match self {
107 Variant::Named { enum_name, .. } => enum_name,
108 Variant::Anon { enum_name, .. } => enum_name,
109 }
110 }
111
112 fn name(&self) -> &Ident {
114 match self {
115 Variant::Named { name, .. } => name,
116 Variant::Anon { name, .. } => name,
117 }
118 }
119
120 #[allow(dead_code)]
122 fn generics(&self) -> &syn::Generics {
123 match self {
124 Variant::Named { generics, .. } => generics,
125 Variant::Anon { generics, .. } => generics,
126 }
127 }
128
129 fn snake_name(&self) -> Ident {
131 Ident::new(
132 &self.name().to_string().to_case(Case::Snake),
133 self.name().span(),
134 )
135 }
136
137 fn qualified_name(&self) -> proc_macro2::TokenStream {
139 let enum_name = self.enum_name();
140 let name = self.name();
141
142 if self.is_struct() {
143 quote! { #enum_name }
144 } else {
145 quote! { #enum_name::#name }
146 }
147 }
148
149 fn field_names(&self) -> Vec<Ident> {
152 match self {
153 Variant::Named { field_names, .. } => field_names.clone(),
154 Variant::Anon { field_types, .. } => (0usize..field_types.len())
155 .map(|idx| format_ident!("arg{}", idx))
156 .collect(),
157 }
158 }
159
160 fn field_types(&self) -> &Vec<Type> {
162 match self {
163 Variant::Named { field_types, .. } => field_types,
164 Variant::Anon { field_types, .. } => field_types,
165 }
166 }
167
168 fn field_flags(&self) -> &Vec<FieldFlag> {
170 match self {
171 Variant::Named { field_flags, .. } => field_flags,
172 Variant::Anon { field_flags, .. } => field_flags,
173 }
174 }
175
176 fn constructor(&self) -> proc_macro2::TokenStream {
178 let qualified_name = self.qualified_name();
179 let field_names = self.field_names();
180 match self {
181 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
182 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
183 }
184 }
185}
186
187struct ReplyPort {
188 is_handle: bool,
189 is_once: bool,
190}
191
192impl ReplyPort {
193 fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
194 ReplyPort {
195 is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
196 is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
197 }
198 }
199
200 fn open_op(&self) -> proc_macro2::TokenStream {
201 if self.is_once {
202 quote! { hyperactor::mailbox::open_once_port }
203 } else {
204 quote! { hyperactor::mailbox::open_port }
205 }
206 }
207
208 fn rx_modifier(&self) -> proc_macro2::TokenStream {
209 if self.is_once {
210 quote! {}
211 } else {
212 quote! { mut }
213 }
214 }
215}
216
217#[allow(clippy::large_enum_variant)]
220enum Message {
221 Call {
224 variant: Variant,
225 reply_port: ReplyPort,
227 return_type: Type,
229 log_level: Option<Ident>,
231 },
232 OneWay {
233 variant: Variant,
234 log_level: Option<Ident>,
236 },
237}
238
239impl Message {
240 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
241 match &variant
242 .field_flags()
243 .iter()
244 .zip(variant.field_types())
245 .filter_map(|(flag, ty)| match flag {
246 FieldFlag::Reply => Some(ty),
247 FieldFlag::None => None,
248 })
249 .collect::<Vec<&Type>>()[..]
250 {
251 [] => Ok(Self::OneWay { variant, log_level }),
252 [reply_port_ty] => {
253 let syn::Type::Path(type_path) = reply_port_ty else {
254 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
255 };
256 let Some(last_segment) = type_path.path.segments.last() else {
257 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
258 };
259 if last_segment.ident != "OncePortRef"
260 && last_segment.ident != "OncePortHandle"
261 && last_segment.ident != "PortRef"
262 && last_segment.ident != "PortHandle"
263 {
264 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
265 }
266 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
267 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
268 };
269 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
270 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
271 };
272 let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
273 let return_type = return_ty.clone();
274 Ok(Self::Call {
275 variant,
276 reply_port,
277 return_type,
278 log_level,
279 })
280 }
281 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
282 }
283 }
284
285 fn args(&self) -> Vec<(Ident, Type)> {
287 match self {
288 Message::Call { variant, .. } => variant
289 .field_names()
290 .into_iter()
291 .zip(variant.field_types().clone())
292 .take(variant.len() - 1)
293 .collect(),
294 Message::OneWay { variant, .. } => variant
295 .field_names()
296 .into_iter()
297 .zip(variant.field_types().clone())
298 .collect(),
299 }
300 }
301
302 fn variant(&self) -> &Variant {
303 match self {
304 Message::Call { variant, .. } => variant,
305 Message::OneWay { variant, .. } => variant,
306 }
307 }
308
309 fn reply_port_position(&self) -> Option<usize> {
310 self.variant()
311 .field_flags()
312 .iter()
313 .position(|flag| matches!(flag, FieldFlag::Reply))
314 }
315
316 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
318 match self {
319 Message::Call { variant, .. } => {
320 let pos = self.reply_port_position()?;
321 Some((
322 variant.field_names()[pos].clone(),
323 variant.field_types()[pos].clone(),
324 ))
325 }
326 Message::OneWay { .. } => None,
327 }
328 }
329}
330
331fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
332 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
333 Some(attr) => {
334 let Ok(meta) = attr.meta.require_list() else {
335 return Err(syn::Error::new(
336 Span::call_site(),
337 indoc! {"
338 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
339
340 = help use `#[log_level(info)]` or `#[log_level(error)]`
341 "},
342 ));
343 };
344 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
345 if parsed.len() != 1 {
346 return Err(syn::Error::new(
347 Span::call_site(),
348 indoc! {"
349 `log_level` attribute must specify exactly one level
350
351 = help use `#[log_level(warn)]` or `#[log_level(info)]`
352 "},
353 ));
354 };
355 Some(parsed.first().unwrap().to_string())
356 }
357 None => None,
358 };
359
360 if level.is_none() {
361 return Ok(None);
362 }
363 let level = level.unwrap();
364
365 match level.as_str() {
366 "error" | "warn" | "info" | "debug" | "trace" => {}
367 _ => {
368 return Err(syn::Error::new(
369 Span::call_site(),
370 indoc! {"
371 `log_level` attribute must be one of 'error, warn, info, debug, trace'
372
373 = help use `#[log_level(warn)]` or `#[log_level(info)]`
374 "},
375 ));
376 }
377 }
378
379 Ok(Some(Ident::new(
380 level.to_ascii_uppercase().as_str(),
381 Span::call_site(),
382 )))
383}
384
385fn parse_field_flag(field: &Field) -> FieldFlag {
386 for attr in field.attrs.iter() {
387 match &attr.meta {
388 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
389 _ => {}
390 }
391 }
392 FieldFlag::None
393}
394
395fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
397 match &input.data {
398 Data::Enum(data_enum) => {
399 let mut messages = Vec::new();
400
401 for variant in &data_enum.variants {
402 let name = variant.ident.clone();
403 let attrs = &variant.attrs;
404
405 let message_variant = match &variant.fields {
406 syn::Fields::Unnamed(fields_) => Variant::Anon {
407 enum_name: input.ident.clone(),
408 name,
409 field_types: fields_
410 .unnamed
411 .iter()
412 .map(|field| field.ty.clone())
413 .collect(),
414 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
415 is_struct: false,
416 generics: input.generics.clone(),
417 },
418 syn::Fields::Named(fields_) => Variant::Named {
419 enum_name: input.ident.clone(),
420 name,
421 field_names: fields_
422 .named
423 .iter()
424 .map(|field| field.ident.clone().unwrap())
425 .collect(),
426 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
427 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
428 is_struct: false,
429 generics: input.generics.clone(),
430 },
431 _ => {
432 return Err(syn::Error::new_spanned(
433 variant,
434 indoc! {r#"
435 `Handler` currently only supports named or tuple struct variants
436
437 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
438 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
439 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
440 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
441 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
442 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
443 "#},
444 ));
445 }
446 };
447 let log_level = parse_log_level(attrs)?;
448
449 messages.push(Message::new(
450 variant.fields.span(),
451 message_variant,
452 log_level,
453 )?);
454 }
455
456 Ok(messages)
457 }
458 Data::Struct(data_struct) => {
459 let struct_name = input.ident.clone();
460 let attrs = &input.attrs;
461
462 let message_variant = match &data_struct.fields {
463 syn::Fields::Unnamed(fields_) => Variant::Anon {
464 enum_name: struct_name.clone(),
465 name: struct_name,
466 field_types: fields_
467 .unnamed
468 .iter()
469 .map(|field| field.ty.clone())
470 .collect(),
471 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
472 is_struct: true,
473 generics: input.generics.clone(),
474 },
475 syn::Fields::Named(fields_) => Variant::Named {
476 enum_name: struct_name.clone(),
477 name: struct_name,
478 field_names: fields_
479 .named
480 .iter()
481 .map(|field| field.ident.clone().unwrap())
482 .collect(),
483 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
484 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
485 is_struct: true,
486 generics: input.generics.clone(),
487 },
488 syn::Fields::Unit => Variant::Anon {
489 enum_name: struct_name.clone(),
490 name: struct_name,
491 field_types: Vec::new(),
492 field_flags: Vec::new(),
493 is_struct: true,
494 generics: input.generics.clone(),
495 },
496 };
497
498 let log_level = parse_log_level(attrs)?;
499 let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
500
501 Ok(vec![message])
502 }
503 _ => Err(syn::Error::new_spanned(
504 input,
505 "handlers can only be derived for enums and structs",
506 )),
507 }
508}
509
510#[proc_macro_derive(Handler, attributes(reply))]
669pub fn derive_handler(input: TokenStream) -> TokenStream {
670 let input = parse_macro_input!(input as DeriveInput);
671 let name: Ident = input.ident.clone();
672 let (_, ty_generics, _) = input.generics.split_for_impl();
673
674 let messages = match parse_messages(input.clone()) {
675 Ok(messages) => messages,
676 Err(err) => return TokenStream::from(err.to_compile_error()),
677 };
678
679 let mut handler_trait_methods = Vec::new();
681
682 let mut match_arms = Vec::new();
684
685 let mut client_trait_methods = Vec::new();
687
688 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
689
690 for message in &messages {
691 match message {
692 Message::Call {
693 variant,
694 reply_port,
695 return_type,
696 log_level,
697 } => {
698 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
699 let variant_name_snake = variant.snake_name();
700 let variant_name_snake_deprecated =
701 format_ident!("{}_deprecated", variant_name_snake);
702 let enum_name = variant.enum_name();
703 let _variant_qualified_name = variant.qualified_name();
704 let log_level = match (&global_log_level, log_level) {
705 (_, Some(local)) => local.clone(),
706 (Some(global), None) => global.clone(),
707 _ => Ident::new("DEBUG", Span::call_site()),
708 };
709 let _log_level = if reply_port.is_handle {
710 quote! {
711 tracing::Level::#log_level
712 }
713 } else {
714 quote! {
715 tracing::Level::TRACE
716 }
717 };
718 let log_message = quote! {
719 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
720 "rpc" => "call",
721 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
722 "message_type" => stringify!(#enum_name),
723 "variant" => stringify!(#variant_name_snake),
724 ));
725 };
726
727 handler_trait_methods.push(quote! {
728 #[doc = "The generated handler method for this enum variant."]
729 async fn #variant_name_snake(
730 &mut self,
731 cx: &hyperactor::Context<Self>,
732 #(#arg_names: #arg_types),*)
733 -> Result<#return_type, hyperactor::anyhow::Error>;
734 });
735
736 client_trait_methods.push(quote! {
737 #[doc = "The generated client method for this enum variant."]
738 async fn #variant_name_snake(
739 &self,
740 cx: &impl hyperactor::context::Actor,
741 #(#arg_names: #arg_types),*)
742 -> Result<#return_type, hyperactor::anyhow::Error>;
743
744 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
745 async fn #variant_name_snake_deprecated(
746 &self,
747 cx: &impl hyperactor::context::Actor,
748 #(#arg_names: #arg_types),*)
749 -> Result<#return_type, hyperactor::anyhow::Error>;
750 });
751
752 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
753 let constructor = variant.constructor();
754 let result_ident = Ident::new("result", Span::mixed_site());
755 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
756 if reply_port.is_handle {
757 match_arms.push(quote! {
758 #constructor => {
759 #log_message
760 #construct_result_future
763 #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
764 }
765 });
766 } else {
767 match_arms.push(quote! {
768 #constructor => {
769 #log_message
770 #construct_result_future
773 #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
774 }
775 });
776 }
777 }
778 Message::OneWay { variant, log_level } => {
779 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
780 let variant_name_snake = variant.snake_name();
781 let variant_name_snake_deprecated =
782 format_ident!("{}_deprecated", variant_name_snake);
783 let enum_name = variant.enum_name();
784 let log_level = match (&global_log_level, log_level) {
785 (_, Some(local)) => local.clone(),
786 (Some(global), None) => global.clone(),
787 _ => Ident::new("TRACE", Span::call_site()),
788 };
789 let _log_level = quote! {
790 tracing::Level::#log_level
791 };
792 let log_message = quote! {
793 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
794 "rpc" => "call",
795 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
796 "message_type" => stringify!(#enum_name),
797 "variant" => stringify!(#variant_name_snake),
798 ));
799 };
800
801 handler_trait_methods.push(quote! {
802 #[doc = "The generated handler method for this enum variant."]
803 async fn #variant_name_snake(
804 &mut self,
805 cx: &hyperactor::Context<Self>,
806 #(#arg_names: #arg_types),*)
807 -> Result<(), hyperactor::anyhow::Error>;
808 });
809
810 client_trait_methods.push(quote! {
811 #[doc = "The generated client method for this enum variant."]
812 async fn #variant_name_snake(
813 &self,
814 cx: &impl hyperactor::context::Actor,
815 #(#arg_names: #arg_types),*)
816 -> Result<(), hyperactor::anyhow::Error>;
817
818 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
819 async fn #variant_name_snake_deprecated(
820 &self,
821 cx: &impl hyperactor::context::Actor,
822 #(#arg_names: #arg_types),*)
823 -> Result<(), hyperactor::anyhow::Error>;
824 });
825
826 let constructor = variant.constructor();
827
828 match_arms.push(quote! {
829 #constructor => {
830 #log_message
831 self.#variant_name_snake(cx, #(#arg_names),*).await
832 },
833 });
834 }
835 }
836 }
837
838 let handler_trait_name = format_ident!("{}Handler", name);
839 let client_trait_name = format_ident!("{}Client", name);
840
841 let mut handler_generics = input.generics.clone();
845 for param in handler_generics.type_params_mut() {
846 param.bounds.push(syn::parse_quote!(serde::Serialize));
847 param
848 .bounds
849 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
850 param.bounds.push(syn::parse_quote!(Send));
851 param.bounds.push(syn::parse_quote!(Sync));
852 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
853 param.bounds.push(syn::parse_quote!(typeuri::Named));
854 }
855 let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
856 let (client_impl_generics, _, _) = input.generics.split_for_impl();
857
858 let expanded = quote! {
859 #[doc = "The custom handler trait for this message type."]
860 #[hyperactor::async_trait::async_trait]
861 pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync {
862 #(#handler_trait_methods)*
863
864 #[doc = "Handle the next message."]
865 async fn handle(
866 &mut self,
867 cx: &hyperactor::Context<Self>,
868 message: #name #ty_generics,
869 ) -> hyperactor::anyhow::Result<()> {
870 match message {
872 #(#match_arms)*
873 }
874 }
875 }
876
877 #[doc = "The custom client trait for this message type."]
878 #[hyperactor::async_trait::async_trait]
879 pub trait #client_trait_name #client_impl_generics: Send + Sync {
880 #(#client_trait_methods)*
881 }
882 };
883
884 TokenStream::from(expanded)
885}
886
887#[proc_macro_derive(HandleClient, attributes(log_level))]
890pub fn derive_handle_client(input: TokenStream) -> TokenStream {
891 derive_client(input, true)
892}
893
894#[proc_macro_derive(RefClient, attributes(log_level))]
897pub fn derive_ref_client(input: TokenStream) -> TokenStream {
898 derive_client(input, false)
899}
900
901fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
902 let input = parse_macro_input!(input as DeriveInput);
903 let name = input.ident.clone();
904
905 let messages = match parse_messages(input.clone()) {
906 Ok(messages) => messages,
907 Err(err) => return TokenStream::from(err.to_compile_error()),
908 };
909
910 let mut impl_methods = Vec::new();
912
913 let send_message = if is_handle {
914 quote! { self.send(message)? }
915 } else {
916 quote! { self.send(cx, message)? }
917 };
918 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
919
920 for message in &messages {
921 match message {
922 Message::Call {
923 variant,
924 reply_port,
925 return_type,
926 log_level,
927 } => {
928 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
929 let variant_name_snake = variant.snake_name();
930 let variant_name_snake_deprecated =
931 format_ident!("{}_deprecated", variant_name_snake);
932 let enum_name = variant.enum_name();
933
934 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
935 let constructor = variant.constructor();
936 let log_level = match (&global_log_level, log_level) {
937 (_, Some(local)) => local.clone(),
938 (Some(global), None) => global.clone(),
939 _ => Ident::new("DEBUG", Span::call_site()),
940 };
941 let log_level = if is_handle {
942 quote! {
943 tracing::Level::#log_level
944 }
945 } else {
946 quote! {
947 tracing::Level::TRACE
948 }
949 };
950 let log_message = quote! {
951 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
952 "rpc" => "call",
953 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
954 "message_type" => stringify!(#enum_name),
955 "variant" => stringify!(#variant_name_snake),
956 ));
957
958 };
959 let open_port = reply_port.open_op();
960 let rx_mod = reply_port.rx_modifier();
961 if reply_port.is_handle {
962 impl_methods.push(quote! {
963 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
964 async fn #variant_name_snake(
965 &self,
966 cx: &impl hyperactor::context::Actor,
967 #(#arg_names: #arg_types),*)
968 -> Result<#return_type, hyperactor::anyhow::Error> {
969 let (#reply_port_arg, #rx_mod reply_receiver) =
970 #open_port::<#return_type>(cx);
971 let message = #constructor;
972 #log_message;
973 #send_message;
974 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
975 }
976
977 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
978 async fn #variant_name_snake_deprecated(
979 &self,
980 cx: &impl hyperactor::context::Actor,
981 #(#arg_names: #arg_types),*)
982 -> Result<#return_type, hyperactor::anyhow::Error> {
983 let (#reply_port_arg, #rx_mod reply_receiver) =
984 #open_port::<#return_type>(cx);
985 let message = #constructor;
986 #log_message;
987 #send_message;
988 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
989 }
990 });
991 } else {
992 impl_methods.push(quote! {
993 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
994 async fn #variant_name_snake(
995 &self,
996 cx: &impl hyperactor::context::Actor,
997 #(#arg_names: #arg_types),*)
998 -> Result<#return_type, hyperactor::anyhow::Error> {
999 let (#reply_port_arg, #rx_mod reply_receiver) =
1000 #open_port::<#return_type>(cx);
1001 let #reply_port_arg = #reply_port_arg.bind();
1002 let message = #constructor;
1003 #log_message;
1004 #send_message;
1005 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1006 }
1007
1008 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
1009 async fn #variant_name_snake_deprecated(
1010 &self,
1011 cx: &impl hyperactor::context::Actor,
1012 #(#arg_names: #arg_types),*)
1013 -> Result<#return_type, hyperactor::anyhow::Error> {
1014 let (#reply_port_arg, #rx_mod reply_receiver) =
1015 #open_port::<#return_type>(cx);
1016 let #reply_port_arg = #reply_port_arg.bind();
1017 let message = #constructor;
1018 #log_message;
1019 #send_message;
1020 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1021 }
1022 });
1023 }
1024 }
1025 Message::OneWay { variant, log_level } => {
1026 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1027 let variant_name_snake = variant.snake_name();
1028 let variant_name_snake_deprecated =
1029 format_ident!("{}_deprecated", variant_name_snake);
1030 let enum_name = variant.enum_name();
1031 let constructor = variant.constructor();
1032 let log_level = match (&global_log_level, log_level) {
1033 (_, Some(local)) => local.clone(),
1034 (Some(global), None) => global.clone(),
1035 _ => Ident::new("DEBUG", Span::call_site()),
1036 };
1037 let _log_level = if is_handle {
1038 quote! {
1039 tracing::Level::TRACE
1040 }
1041 } else {
1042 quote! {
1043 tracing::Level::#log_level
1044 }
1045 };
1046 let log_message = quote! {
1047 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1048 "rpc" => "oneway",
1049 "actor_id" => self.actor_id().to_string(),
1050 "message_type" => stringify!(#enum_name),
1051 "variant" => stringify!(#variant_name_snake),
1052 ));
1053 };
1054 impl_methods.push(quote! {
1055 async fn #variant_name_snake(
1056 &self,
1057 cx: &impl hyperactor::context::Actor,
1058 #(#arg_names: #arg_types),*)
1059 -> Result<(), hyperactor::anyhow::Error> {
1060 let message = #constructor;
1061 #log_message;
1062 #send_message;
1063 Ok(())
1064 }
1065
1066 async fn #variant_name_snake_deprecated(
1067 &self,
1068 cx: &impl hyperactor::context::Actor,
1069 #(#arg_names: #arg_types),*)
1070 -> Result<(), hyperactor::anyhow::Error> {
1071 let message = #constructor;
1072 #log_message;
1073 #send_message;
1074 Ok(())
1075 }
1076 });
1077 }
1078 }
1079 }
1080
1081 let trait_name = format_ident!("{}Client", name);
1082
1083 let (_, ty_generics, _) = input.generics.split_for_impl();
1084
1085 let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1087 let mut trait_generics = input.generics.clone();
1088 trait_generics.params.insert(
1089 0,
1090 syn::GenericParam::Type(syn::TypeParam {
1091 ident: actor_ident.clone(),
1092 attrs: vec![],
1093 colon_token: None,
1094 bounds: Punctuated::new(),
1095 eq_token: None,
1096 default: None,
1097 }),
1098 );
1099
1100 for param in trait_generics.type_params_mut() {
1101 if param.ident == actor_ident {
1102 continue;
1103 }
1104 param.bounds.push(syn::parse_quote!(serde::Serialize));
1105 param
1106 .bounds
1107 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1108 param.bounds.push(syn::parse_quote!(Send));
1109 param.bounds.push(syn::parse_quote!(Sync));
1110 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1111 param.bounds.push(syn::parse_quote!(typeuri::Named));
1112 }
1113
1114 let (impl_generics, _, _) = trait_generics.split_for_impl();
1115
1116 let expanded = if is_handle {
1117 quote! {
1118 #[hyperactor::async_trait::async_trait]
1119 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1120 where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1121 #(#impl_methods)*
1122 }
1123 }
1124 } else {
1125 quote! {
1126 #[hyperactor::async_trait::async_trait]
1127 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1128 where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1129 #(#impl_methods)*
1130 }
1131 }
1132 };
1133
1134 TokenStream::from(expanded)
1135}
1136
1137const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
1138`forward` expects the message type that is being forwarded
1139
1140= help: use `#[forward(MessageType)]`
1141"#};
1142
1143#[proc_macro_attribute]
1145pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
1146 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1147 if attr_args.len() != 1 {
1148 return TokenStream::from(
1149 syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
1150 );
1151 }
1152
1153 let message_type = attr_args.first().unwrap();
1154 let input = parse_macro_input!(item as ItemImpl);
1155
1156 let self_type = match *input.self_ty {
1157 syn::Type::Path(ref type_path) => {
1158 let segment = type_path.path.segments.last().unwrap();
1159 segment.clone() }
1161 _ => {
1162 return TokenStream::from(
1163 syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
1164 .to_compile_error(),
1165 );
1166 }
1167 };
1168
1169 let trait_name = match input.trait_ {
1170 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1171 None => {
1172 return TokenStream::from(
1173 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1174 .to_compile_error(),
1175 );
1176 }
1177 };
1178
1179 let expanded = quote! {
1180 #input
1181
1182 #[hyperactor::async_trait::async_trait]
1183 impl hyperactor::Handler<#message_type> for #self_type {
1184 async fn handle(
1185 &mut self,
1186 cx: &hyperactor::Context<Self>,
1187 message: #message_type,
1188 ) -> hyperactor::anyhow::Result<()> {
1189 <Self as #trait_name>::handle(self, cx, message).await
1190 }
1191 }
1192 };
1193
1194 TokenStream::from(expanded)
1195}
1196
1197#[proc_macro_attribute]
1210pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1211 let args =
1212 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1213 let input = parse_macro_input!(input as ItemFn);
1214 let output = quote! {
1215 #[hyperactor::tracing::instrument(err, skip_all, #args)]
1216 #input
1217 };
1218
1219 TokenStream::from(output)
1220}
1221
1222#[proc_macro_attribute]
1233pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1234 let args =
1235 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1236 let input = parse_macro_input!(input as ItemFn);
1237
1238 let output = quote! {
1239 #[hyperactor::tracing::instrument(skip_all, #args)]
1240 #input
1241 };
1242
1243 TokenStream::from(output)
1244}
1245
1246struct HandlerSpec {
1247 ty: Type,
1248 cast: bool,
1249}
1250
1251impl Parse for HandlerSpec {
1252 fn parse(input: ParseStream) -> syn::Result<Self> {
1253 let ty: Type = input.parse()?;
1254
1255 if input.peek(syn::token::Brace) {
1256 let content;
1257 syn::braced!(content in input);
1258 let key: Ident = content.parse()?;
1259 content.parse::<Token![=]>()?;
1260 let expr: Expr = content.parse()?;
1261
1262 let cast = if key == "cast" {
1263 if let Expr::Lit(ExprLit {
1264 lit: Lit::Bool(b), ..
1265 }) = expr
1266 {
1267 b.value
1268 } else {
1269 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1270 }
1271 } else {
1272 return Err(syn::Error::new_spanned(
1273 key,
1274 "unsupported field (expected `cast`)",
1275 ));
1276 };
1277
1278 Ok(HandlerSpec { ty, cast })
1279 } else if input.is_empty() || input.peek(Token![,]) {
1280 Ok(HandlerSpec { ty, cast: false })
1281 } else {
1282 let unexpected: proc_macro2::TokenTree = input.parse()?;
1284 Err(syn::Error::new_spanned(
1285 unexpected,
1286 "unexpected token after type — expected `{ ... }` or nothing",
1287 ))
1288 }
1289 }
1290}
1291
1292impl HandlerSpec {
1293 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1294 let mut tys = Vec::new();
1295 for HandlerSpec { ty, cast } in handlers {
1296 if cast {
1297 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1298 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1299 tys.push(wrapped_ty);
1300 }
1301 tys.push(ty);
1302 }
1303 tys
1304 }
1305}
1306
1307struct ExportAttr {
1309 spawn: bool,
1310 handlers: Vec<HandlerSpec>,
1311}
1312
1313impl Parse for ExportAttr {
1314 fn parse(input: ParseStream) -> syn::Result<Self> {
1315 let mut spawn = false;
1316 let mut handlers: Vec<HandlerSpec> = vec![];
1317
1318 while !input.is_empty() {
1319 let key: Ident = input.parse()?;
1320 input.parse::<Token![=]>()?;
1321
1322 if key == "spawn" {
1323 let expr: Expr = input.parse()?;
1324 if let Expr::Lit(ExprLit {
1325 lit: Lit::Bool(b), ..
1326 }) = expr
1327 {
1328 spawn = b.value;
1329 } else {
1330 return Err(syn::Error::new_spanned(
1331 expr,
1332 "expected boolean for `spawn`",
1333 ));
1334 }
1335 } else if key == "handlers" {
1336 let content;
1337 bracketed!(content in input);
1338 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1339 handlers = raw_handlers.into_iter().collect();
1340 } else {
1341 return Err(syn::Error::new_spanned(
1342 key,
1343 "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1344 ));
1345 }
1346
1347 let _ = input.parse::<Token![,]>();
1349 }
1350
1351 Ok(ExportAttr { spawn, handlers })
1352 }
1353}
1354
1355#[proc_macro_attribute]
1382pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1383 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1384 let data_type_name = &input.ident;
1385 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1386
1387 let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1388 let tys = HandlerSpec::add_indexed(handlers);
1389
1390 let mut handles = Vec::new();
1391 let mut bindings = Vec::new();
1392 let mut type_registrations = Vec::new();
1393
1394 for ty in &tys {
1395 handles.push(quote! {
1396 impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
1397 });
1398 bindings.push(quote! {
1399 ports.bind::<#ty>();
1400 });
1401 type_registrations.push(quote! {
1402 wirevalue::register_type!(#ty);
1403 });
1404 }
1405
1406 let mut expanded = quote! {
1407 #input
1408
1409 impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
1410
1411 #(#handles)*
1412
1413 #(#type_registrations)*
1414
1415 impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1417
1418 impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
1419 fn bind(ports: &hyperactor::proc::Ports<Self>) {
1420 #(#bindings)*
1421 }
1422 }
1423
1424 impl #impl_generics typeuri::Named for #data_type_name #ty_generics #where_clause {
1426 fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
1427 }
1428 };
1429
1430 if spawn {
1431 expanded.extend(quote! {
1432 hyperactor::remote!(#data_type_name);
1433 });
1434 }
1435
1436 TokenStream::from(expanded)
1437}
1438
1439struct BehaviorInput {
1441 behavior: Ident,
1442 generics: syn::Generics,
1443 handlers: Vec<HandlerSpec>,
1444}
1445
1446impl syn::parse::Parse for BehaviorInput {
1447 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1448 let behavior: Ident = input.parse()?;
1449 let generics: syn::Generics = input.parse()?;
1450 let _: Token![,] = input.parse()?;
1451 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1452 let handlers = raw_handlers.into_iter().collect();
1453 Ok(BehaviorInput {
1454 behavior,
1455 generics,
1456 handlers,
1457 })
1458 }
1459}
1460
1461#[proc_macro]
1486pub fn behavior(input: TokenStream) -> TokenStream {
1487 let BehaviorInput {
1488 behavior,
1489 generics,
1490 handlers,
1491 } = parse_macro_input!(input as BehaviorInput);
1492 let tys = HandlerSpec::add_indexed(handlers);
1493
1494 let mut bounded_generics = generics.clone();
1496 for param in bounded_generics.type_params_mut() {
1497 param.bounds.push(syn::parse_quote!(typeuri::Named));
1498 param.bounds.push(syn::parse_quote!(serde::Serialize));
1499 param.bounds.push(syn::parse_quote!(std::marker::Send));
1500 param.bounds.push(syn::parse_quote!(std::marker::Sync));
1501 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1502 let lifetime =
1505 syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1506 param
1507 .bounds
1508 .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1509 }
1510
1511 let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1513
1514 let mut binds_generics = bounded_generics.clone();
1516 binds_generics.params.insert(
1517 0,
1518 syn::GenericParam::Type(syn::TypeParam {
1519 attrs: vec![],
1520 ident: Ident::new("A", proc_macro2::Span::call_site()),
1521 colon_token: None,
1522 bounds: Punctuated::new(),
1523 eq_token: None,
1524 default: None,
1525 }),
1526 );
1527 let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1528
1529 let type_params: Vec<_> = bounded_generics.type_params().collect();
1531 let has_generics = !type_params.is_empty();
1532
1533 let (typename_impl, typehash_impl) = if has_generics {
1534 let placeholders = vec!["{}"; type_params.len()].join(", ");
1536 let placeholders_format_string = format!("<{}>", placeholders);
1537 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1538
1539 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1540 (
1541 quote! {
1542 typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1543 },
1544 quote! {
1545 typeuri::cityhasher::hash(Self::typename())
1546 },
1547 )
1548 } else {
1549 (
1550 quote! {
1551 concat!(std::module_path!(), "::", stringify!(#behavior))
1552 },
1553 quote! {
1554 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1555 typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1556 });
1557 *TYPEHASH
1558 },
1559 )
1560 };
1561
1562 let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1563
1564 let expanded = quote! {
1565 #[doc = "The generated behavior struct."]
1566 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1567 pub struct #behavior #impl_generics #where_clause {
1568 _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1569 }
1570
1571 impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1572 fn typename() -> &'static str {
1573 #typename_impl
1574 }
1575
1576 fn typehash() -> u64 {
1577 #typehash_impl
1578 }
1579 }
1580
1581 impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1582
1583 impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1584 where
1585 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1586 #where_clause
1587 {
1588 fn bind(ports: &hyperactor::proc::Ports<A>) {
1589 #(
1590 ports.bind::<#tys>();
1591 )*
1592 }
1593 }
1594
1595 #(
1596 impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1597 )*
1598 };
1599
1600 TokenStream::from(expanded)
1601}
1602
1603fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1604 let mut is_included = false;
1605 for attr in &field.attrs {
1606 if attr.path().is_ident("binding") {
1607 attr.parse_nested_meta(|meta| {
1609 if meta.path.is_ident("include") {
1610 is_included = true;
1611 Ok(())
1612 } else {
1613 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1614 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1615 }
1616 })?
1617 }
1618 }
1619 Ok(is_included)
1620}
1621
1622enum FieldAccessor {
1627 Named(Ident),
1628 Unnamed(Index),
1629}
1630
1631struct ParsedField {
1633 accessor: FieldAccessor,
1634 ty: Type,
1635 included: bool,
1636}
1637
1638impl From<&ParsedField> for (Ident, Type) {
1639 fn from(field: &ParsedField) -> Self {
1640 let field_ident = match &field.accessor {
1641 FieldAccessor::Named(ident) => ident.clone(),
1642 FieldAccessor::Unnamed(i) => {
1643 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1644 }
1645 };
1646 (field_ident, field.ty.clone())
1647 }
1648}
1649
1650fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1651 match fields {
1652 Fields::Named(named) => named
1653 .named
1654 .iter()
1655 .map(|f| {
1656 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1657 Ok(ParsedField {
1658 accessor,
1659 ty: f.ty.clone(),
1660 included: include_in_bind_unbind(f)?,
1661 })
1662 })
1663 .collect(),
1664 Fields::Unnamed(unnamed) => unnamed
1665 .unnamed
1666 .iter()
1667 .enumerate()
1668 .map(|(i, f)| {
1669 let accessor = FieldAccessor::Unnamed(Index::from(i));
1670 Ok(ParsedField {
1671 accessor,
1672 ty: f.ty.clone(),
1673 included: include_in_bind_unbind(f)?,
1674 })
1675 })
1676 .collect(),
1677 Fields::Unit => Ok(Vec::new()),
1678 }
1679}
1680
1681fn gen_struct_items<F>(
1682 fields: &Fields,
1683 make_item: F,
1684 is_mutable: bool,
1685) -> syn::Result<Vec<proc_macro2::TokenStream>>
1686where
1687 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1688{
1689 let borrow = if is_mutable {
1690 quote! { &mut }
1691 } else {
1692 quote! { & }
1693 };
1694 let items: Vec<_> = collect_all_fields(fields)?
1695 .into_iter()
1696 .filter(|f| f.included)
1697 .map(
1698 |ParsedField {
1699 accessor,
1700 ty,
1701 included,
1702 }| {
1703 assert!(included);
1704 let field_accessor = match accessor {
1705 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1706 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1707 };
1708 make_item(field_accessor, ty)
1709 },
1710 )
1711 .collect();
1712 Ok(items)
1713}
1714
1715fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1725 all_fields
1726 .iter()
1727 .map(
1728 |ParsedField {
1729 accessor,
1730 ty: _,
1731 included,
1732 }| {
1733 match accessor {
1734 FieldAccessor::Named(ident) => {
1735 if *included {
1736 quote! { #ident }
1737 } else {
1738 quote! { #ident: _ }
1739 }
1740 }
1741 FieldAccessor::Unnamed(i) => {
1742 if *included {
1743 let ident = Ident::new(
1744 &format!("f{}", i.index),
1745 proc_macro2::Span::call_site(),
1746 );
1747 quote! { #ident }
1748 } else {
1749 quote! { _ }
1750 }
1751 }
1752 }
1753 },
1754 )
1755 .collect()
1756}
1757
1758fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1765where
1766 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1767{
1768 data.variants
1769 .iter()
1770 .map(|variant| {
1771 let name = &variant.ident;
1772 let all_fields = collect_all_fields(&variant.fields)?;
1773 let field_accessors = gen_enum_field_accessors(&all_fields);
1774 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1775 let items = included_fields
1776 .iter()
1777 .map(|f| {
1778 let (accessor, ty) = <(Ident, Type)>::from(*f);
1779 make_item(quote! { #accessor }, ty)
1780 })
1781 .collect::<Vec<_>>();
1782
1783 Ok(match &variant.fields {
1784 Fields::Named(_) => {
1785 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1786 }
1787 Fields::Unnamed(_) => {
1788 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1789 }
1790 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1791 })
1792 })
1793 .collect()
1794}
1795
1796#[proc_macro_derive(Bind, attributes(binding))]
1878pub fn derive_bind(input: TokenStream) -> TokenStream {
1879 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1880 quote! {
1881 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1882 }
1883 }
1884
1885 let input = parse_macro_input!(input as DeriveInput);
1886 let name = &input.ident;
1887 let inner = match &input.data {
1888 Data::Struct(DataStruct { fields, .. }) => {
1889 match gen_struct_items(fields, make_item, true) {
1890 Ok(collects) => {
1891 quote! { #(#collects)* }
1892 }
1893 Err(e) => {
1894 return TokenStream::from(e.to_compile_error());
1895 }
1896 }
1897 }
1898 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1899 Ok(arms) => {
1900 quote! { match self { #(#arms),* } }
1901 }
1902 Err(e) => {
1903 return TokenStream::from(e.to_compile_error());
1904 }
1905 },
1906 _ => panic!("Bind can only be derived for structs and enums"),
1907 };
1908 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1909 let expand = quote! {
1910 #[automatically_derived]
1911 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1912 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1913 #inner
1914 Ok(())
1915 }
1916 }
1917 };
1918 TokenStream::from(expand)
1919}
1920
1921#[proc_macro_derive(Unbind, attributes(binding))]
1935pub fn derive_unbind(input: TokenStream) -> TokenStream {
1936 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1937 quote! {
1938 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1939 }
1940 }
1941
1942 let input = parse_macro_input!(input as DeriveInput);
1943 let name = &input.ident;
1944 let inner = match &input.data {
1945 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1946 {
1947 Ok(collects) => {
1948 quote! { #(#collects)* }
1949 }
1950 Err(e) => {
1951 return TokenStream::from(e.to_compile_error());
1952 }
1953 },
1954 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1955 Ok(arms) => {
1956 quote! { match self { #(#arms),* } }
1957 }
1958 Err(e) => {
1959 return TokenStream::from(e.to_compile_error());
1960 }
1961 },
1962 _ => panic!("Unbind can only be derived for structs and enums"),
1963 };
1964 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1965 let expand = quote! {
1966 #[automatically_derived]
1967 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1968 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1969 #inner
1970 Ok(())
1971 }
1972 }
1973 };
1974 TokenStream::from(expand)
1975}
1976
1977fn parse_observe_function(
1979 attr: TokenStream,
1980 item: TokenStream,
1981) -> syn::Result<(ItemFn, String, String)> {
1982 let input = syn::parse::<ItemFn>(item)?;
1983
1984 if input.sig.asyncness.is_none() {
1985 return Err(syn::Error::new(
1986 input.sig.span(),
1987 "observe macros can only be applied to async functions",
1988 ));
1989 }
1990
1991 let fn_name_str = input.sig.ident.to_string();
1992 let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
1993
1994 Ok((input, fn_name_str, module_name_str))
1995}
1996
1997fn create_telemetry_setup(
1999 module_name_str: &str,
2000 fn_name_str: &str,
2001 include_error: bool,
2002) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2003 let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2004 let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2005
2006 let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2007
2008 let error_ident = if include_error {
2009 Some(Ident::new(
2010 "error",
2011 Span::from(proc_macro::Span::def_site()),
2012 ))
2013 } else {
2014 None
2015 };
2016
2017 let error_declaration = if let Some(ref error_ident) = error_ident {
2018 quote! {
2019 hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2020 }
2021 } else {
2022 quote! {}
2023 };
2024
2025 let setup_code = quote! {
2026 use hyperactor_telemetry;
2027 hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2028 hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2029 #error_declaration
2030 };
2031
2032 (latency_ident, success_ident, error_ident, setup_code)
2033}
2034
2035#[proc_macro_attribute]
2055pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2056 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2057 Ok(parsed) => parsed,
2058 Err(err) => return err.to_compile_error().into(),
2059 };
2060
2061 let fn_name = &input.sig.ident;
2062 let vis = &input.vis;
2063 let args = &input.sig.inputs;
2064 let return_type = &input.sig.output;
2065 let body = &input.block;
2066 let attrs = &input.attrs;
2067 let generics = &input.sig.generics;
2068
2069 let (latency_ident, success_ident, error_ident, telemetry_setup) =
2070 create_telemetry_setup(&module_name_str, &fn_name_str, true);
2071 let error_ident = error_ident.unwrap();
2072
2073 let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2074
2075 let expanded = quote! {
2077 #(#attrs)*
2078 #vis async fn #fn_name #generics(#args) #return_type {
2079 #telemetry_setup
2080
2081 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2082 let _timer = #latency_ident.start(kv_pairs);
2083
2084 let #result_ident = async #body.await;
2085
2086 match &#result_ident {
2087 Ok(_) => {
2088 #success_ident.add(
2089 1,
2090 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2091 );
2092 }
2093 Err(_) => {
2094 #error_ident.add(
2095 1,
2096 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2097 );
2098 }
2099 }
2100
2101 #result_ident
2102 }
2103 };
2104
2105 expanded.into()
2106}
2107
2108#[proc_macro_attribute]
2127pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2128 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2129 Ok(parsed) => parsed,
2130 Err(err) => return err.to_compile_error().into(),
2131 };
2132
2133 let fn_name = &input.sig.ident;
2134 let vis = &input.vis;
2135 let args = &input.sig.inputs;
2136 let return_type = &input.sig.output;
2137 let body = &input.block;
2138 let attrs = &input.attrs;
2139 let generics = &input.sig.generics;
2140
2141 let (latency_ident, success_ident, _, telemetry_setup) =
2142 create_telemetry_setup(&module_name_str, &fn_name_str, false);
2143
2144 let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2145
2146 let expanded = quote! {
2148 #(#attrs)*
2149 #vis async fn #fn_name #generics(#args) #return_type {
2150 #telemetry_setup
2151
2152 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2153 let _timer = #latency_ident.start(kv_pairs);
2154
2155 let #return_ident = async #body.await;
2156
2157 #success_ident.add(
2158 1,
2159 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2160 );
2161 #return_ident
2162 }
2163 };
2164
2165 expanded.into()
2166}