1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::bracketed;
41use syn::parse::Parse;
42use syn::parse::ParseStream;
43use syn::parse_macro_input;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46
47const REPLY_VARIANT_ERROR: &str = indoc! {r#"
48`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
49
50= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
52"#};
53
54const REPLY_USAGE_ERROR: &str = indoc! {r#"
55`call` message expects at most one `reply` argument
56
57= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
59"#};
60
61enum FieldFlag {
62 None,
63 Reply,
64}
65
66#[allow(dead_code)]
68enum Variant {
69 Named {
71 enum_name: Ident,
72 name: Ident,
73 field_names: Vec<Ident>,
74 field_types: Vec<Type>,
75 field_flags: Vec<FieldFlag>,
76 is_struct: bool,
77 generics: syn::Generics,
78 },
79 Anon {
81 enum_name: Ident,
82 name: Ident,
83 field_types: Vec<Type>,
84 field_flags: Vec<FieldFlag>,
85 is_struct: bool,
86 generics: syn::Generics,
87 },
88}
89
90impl Variant {
91 fn len(&self) -> usize {
93 self.field_types().len()
94 }
95
96 fn is_struct(&self) -> bool {
98 match self {
99 Variant::Named { is_struct, .. } => *is_struct,
100 Variant::Anon { is_struct, .. } => *is_struct,
101 }
102 }
103
104 fn enum_name(&self) -> &Ident {
106 match self {
107 Variant::Named { enum_name, .. } => enum_name,
108 Variant::Anon { enum_name, .. } => enum_name,
109 }
110 }
111
112 fn name(&self) -> &Ident {
114 match self {
115 Variant::Named { name, .. } => name,
116 Variant::Anon { name, .. } => name,
117 }
118 }
119
120 #[allow(dead_code)]
122 fn generics(&self) -> &syn::Generics {
123 match self {
124 Variant::Named { generics, .. } => generics,
125 Variant::Anon { generics, .. } => generics,
126 }
127 }
128
129 fn snake_name(&self) -> Ident {
131 Ident::new(
132 &self.name().to_string().to_case(Case::Snake),
133 self.name().span(),
134 )
135 }
136
137 fn qualified_name(&self) -> proc_macro2::TokenStream {
139 let enum_name = self.enum_name();
140 let name = self.name();
141
142 if self.is_struct() {
143 quote! { #enum_name }
144 } else {
145 quote! { #enum_name::#name }
146 }
147 }
148
149 fn field_names(&self) -> Vec<Ident> {
152 match self {
153 Variant::Named { field_names, .. } => field_names.clone(),
154 Variant::Anon { field_types, .. } => (0usize..field_types.len())
155 .map(|idx| format_ident!("arg{}", idx))
156 .collect(),
157 }
158 }
159
160 fn field_types(&self) -> &Vec<Type> {
162 match self {
163 Variant::Named { field_types, .. } => field_types,
164 Variant::Anon { field_types, .. } => field_types,
165 }
166 }
167
168 fn field_flags(&self) -> &Vec<FieldFlag> {
170 match self {
171 Variant::Named { field_flags, .. } => field_flags,
172 Variant::Anon { field_flags, .. } => field_flags,
173 }
174 }
175
176 fn constructor(&self) -> proc_macro2::TokenStream {
178 let qualified_name = self.qualified_name();
179 let field_names = self.field_names();
180 match self {
181 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
182 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
183 }
184 }
185}
186
187struct ReplyPort {
188 is_handle: bool,
189 is_once: bool,
190}
191
192impl ReplyPort {
193 fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
194 ReplyPort {
195 is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
196 is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
197 }
198 }
199
200 fn open_op(&self) -> proc_macro2::TokenStream {
201 if self.is_once {
202 quote! { hyperactor::mailbox::open_once_port }
203 } else {
204 quote! { hyperactor::mailbox::open_port }
205 }
206 }
207
208 fn rx_modifier(&self) -> proc_macro2::TokenStream {
209 if self.is_once {
210 quote! {}
211 } else {
212 quote! { mut }
213 }
214 }
215}
216
217#[allow(clippy::large_enum_variant)]
220enum Message {
221 Call {
224 variant: Variant,
225 reply_port: ReplyPort,
227 return_type: Type,
229 log_level: Option<Ident>,
231 },
232 OneWay {
233 variant: Variant,
234 log_level: Option<Ident>,
236 },
237}
238
239impl Message {
240 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
241 match &variant
242 .field_flags()
243 .iter()
244 .zip(variant.field_types())
245 .filter_map(|(flag, ty)| match flag {
246 FieldFlag::Reply => Some(ty),
247 FieldFlag::None => None,
248 })
249 .collect::<Vec<&Type>>()[..]
250 {
251 [] => Ok(Self::OneWay { variant, log_level }),
252 [reply_port_ty] => {
253 let syn::Type::Path(type_path) = reply_port_ty else {
254 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
255 };
256 let Some(last_segment) = type_path.path.segments.last() else {
257 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
258 };
259 if last_segment.ident != "OncePortRef"
260 && last_segment.ident != "OncePortHandle"
261 && last_segment.ident != "PortRef"
262 && last_segment.ident != "PortHandle"
263 {
264 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
265 }
266 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
267 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
268 };
269 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
270 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
271 };
272 let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
273 let return_type = return_ty.clone();
274 Ok(Self::Call {
275 variant,
276 reply_port,
277 return_type,
278 log_level,
279 })
280 }
281 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
282 }
283 }
284
285 fn args(&self) -> Vec<(Ident, Type)> {
287 match self {
288 Message::Call { variant, .. } => variant
289 .field_names()
290 .into_iter()
291 .zip(variant.field_types().clone())
292 .take(variant.len() - 1)
293 .collect(),
294 Message::OneWay { variant, .. } => variant
295 .field_names()
296 .into_iter()
297 .zip(variant.field_types().clone())
298 .collect(),
299 }
300 }
301
302 fn variant(&self) -> &Variant {
303 match self {
304 Message::Call { variant, .. } => variant,
305 Message::OneWay { variant, .. } => variant,
306 }
307 }
308
309 fn reply_port_position(&self) -> Option<usize> {
310 self.variant()
311 .field_flags()
312 .iter()
313 .position(|flag| matches!(flag, FieldFlag::Reply))
314 }
315
316 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
318 match self {
319 Message::Call { variant, .. } => {
320 let pos = self.reply_port_position()?;
321 Some((
322 variant.field_names()[pos].clone(),
323 variant.field_types()[pos].clone(),
324 ))
325 }
326 Message::OneWay { .. } => None,
327 }
328 }
329}
330
331fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
332 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
333 Some(attr) => {
334 let Ok(meta) = attr.meta.require_list() else {
335 return Err(syn::Error::new(
336 Span::call_site(),
337 indoc! {"
338 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
339
340 = help use `#[log_level(info)]` or `#[log_level(error)]`
341 "},
342 ));
343 };
344 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
345 if parsed.len() != 1 {
346 return Err(syn::Error::new(
347 Span::call_site(),
348 indoc! {"
349 `log_level` attribute must specify exactly one level
350
351 = help use `#[log_level(warn)]` or `#[log_level(info)]`
352 "},
353 ));
354 };
355 Some(parsed.first().unwrap().to_string())
356 }
357 None => None,
358 };
359
360 if level.is_none() {
361 return Ok(None);
362 }
363 let level = level.unwrap();
364
365 match level.as_str() {
366 "error" | "warn" | "info" | "debug" | "trace" => {}
367 _ => {
368 return Err(syn::Error::new(
369 Span::call_site(),
370 indoc! {"
371 `log_level` attribute must be one of 'error, warn, info, debug, trace'
372
373 = help use `#[log_level(warn)]` or `#[log_level(info)]`
374 "},
375 ));
376 }
377 }
378
379 Ok(Some(Ident::new(
380 level.to_ascii_uppercase().as_str(),
381 Span::call_site(),
382 )))
383}
384
385fn parse_field_flag(field: &Field) -> FieldFlag {
386 for attr in field.attrs.iter() {
387 match &attr.meta {
388 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
389 _ => {}
390 }
391 }
392 FieldFlag::None
393}
394
395fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
397 match &input.data {
398 Data::Enum(data_enum) => {
399 let mut messages = Vec::new();
400
401 for variant in &data_enum.variants {
402 let name = variant.ident.clone();
403 let attrs = &variant.attrs;
404
405 let message_variant = match &variant.fields {
406 syn::Fields::Unnamed(fields_) => Variant::Anon {
407 enum_name: input.ident.clone(),
408 name,
409 field_types: fields_
410 .unnamed
411 .iter()
412 .map(|field| field.ty.clone())
413 .collect(),
414 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
415 is_struct: false,
416 generics: input.generics.clone(),
417 },
418 syn::Fields::Named(fields_) => Variant::Named {
419 enum_name: input.ident.clone(),
420 name,
421 field_names: fields_
422 .named
423 .iter()
424 .map(|field| field.ident.clone().unwrap())
425 .collect(),
426 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
427 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
428 is_struct: false,
429 generics: input.generics.clone(),
430 },
431 _ => {
432 return Err(syn::Error::new_spanned(
433 variant,
434 indoc! {r#"
435 `Handler` currently only supports named or tuple struct variants
436
437 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
438 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
439 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
440 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
441 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
442 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
443 "#},
444 ));
445 }
446 };
447 let log_level = parse_log_level(attrs)?;
448
449 messages.push(Message::new(
450 variant.fields.span(),
451 message_variant,
452 log_level,
453 )?);
454 }
455
456 Ok(messages)
457 }
458 Data::Struct(data_struct) => {
459 let struct_name = input.ident.clone();
460 let attrs = &input.attrs;
461
462 let message_variant = match &data_struct.fields {
463 syn::Fields::Unnamed(fields_) => Variant::Anon {
464 enum_name: struct_name.clone(),
465 name: struct_name,
466 field_types: fields_
467 .unnamed
468 .iter()
469 .map(|field| field.ty.clone())
470 .collect(),
471 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
472 is_struct: true,
473 generics: input.generics.clone(),
474 },
475 syn::Fields::Named(fields_) => Variant::Named {
476 enum_name: struct_name.clone(),
477 name: struct_name,
478 field_names: fields_
479 .named
480 .iter()
481 .map(|field| field.ident.clone().unwrap())
482 .collect(),
483 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
484 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
485 is_struct: true,
486 generics: input.generics.clone(),
487 },
488 syn::Fields::Unit => Variant::Anon {
489 enum_name: struct_name.clone(),
490 name: struct_name,
491 field_types: Vec::new(),
492 field_flags: Vec::new(),
493 is_struct: true,
494 generics: input.generics.clone(),
495 },
496 };
497
498 let log_level = parse_log_level(attrs)?;
499 let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
500
501 Ok(vec![message])
502 }
503 _ => Err(syn::Error::new_spanned(
504 input,
505 "handlers can only be derived for enums and structs",
506 )),
507 }
508}
509
510#[proc_macro_derive(Handler, attributes(reply))]
669pub fn derive_handler(input: TokenStream) -> TokenStream {
670 let input = parse_macro_input!(input as DeriveInput);
671 let name: Ident = input.ident.clone();
672 let (_, ty_generics, _) = input.generics.split_for_impl();
673
674 let messages = match parse_messages(input.clone()) {
675 Ok(messages) => messages,
676 Err(err) => return TokenStream::from(err.to_compile_error()),
677 };
678
679 let mut handler_trait_methods = Vec::new();
681
682 let mut match_arms = Vec::new();
684
685 let mut client_trait_methods = Vec::new();
687
688 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
689
690 for message in &messages {
691 match message {
692 Message::Call {
693 variant,
694 reply_port,
695 return_type,
696 log_level,
697 } => {
698 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
699 let variant_name_snake = variant.snake_name();
700 let variant_name_snake_deprecated =
701 format_ident!("{}_deprecated", variant_name_snake);
702 let enum_name = variant.enum_name();
703 let _variant_qualified_name = variant.qualified_name();
704 let log_level = match (&global_log_level, log_level) {
705 (_, Some(local)) => local.clone(),
706 (Some(global), None) => global.clone(),
707 _ => Ident::new("DEBUG", Span::call_site()),
708 };
709 let _log_level = if reply_port.is_handle {
710 quote! {
711 tracing::Level::#log_level
712 }
713 } else {
714 quote! {
715 tracing::Level::TRACE
716 }
717 };
718 let log_message = quote! {
719 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
720 "rpc" => "call",
721 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
722 "message_type" => stringify!(#enum_name),
723 "variant" => stringify!(#variant_name_snake),
724 ));
725 };
726
727 handler_trait_methods.push(quote! {
728 #[doc = "The generated handler method for this enum variant."]
729 async fn #variant_name_snake(
730 &mut self,
731 cx: &hyperactor::Context<Self>,
732 #(#arg_names: #arg_types),*)
733 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
734 });
735
736 client_trait_methods.push(quote! {
737 #[doc = "The generated client method for this enum variant."]
738 async fn #variant_name_snake(
739 &self,
740 cx: &impl hyperactor::context::Actor,
741 #(#arg_names: #arg_types),*)
742 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
743
744 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
745 async fn #variant_name_snake_deprecated(
746 &self,
747 cx: &impl hyperactor::context::Actor,
748 #(#arg_names: #arg_types),*)
749 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
750 });
751
752 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
753 let constructor = variant.constructor();
754 let result_ident = Ident::new("result", Span::mixed_site());
755 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
756 match_arms.push(quote! {
757 #constructor => {
758 #log_message
759 #construct_result_future
762 #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::internal_macro_support::anyhow::Error::from)
763 }
764 });
765 }
766 Message::OneWay { variant, log_level } => {
767 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
768 let variant_name_snake = variant.snake_name();
769 let variant_name_snake_deprecated =
770 format_ident!("{}_deprecated", variant_name_snake);
771 let enum_name = variant.enum_name();
772 let log_level = match (&global_log_level, log_level) {
773 (_, Some(local)) => local.clone(),
774 (Some(global), None) => global.clone(),
775 _ => Ident::new("TRACE", Span::call_site()),
776 };
777 let _log_level = quote! {
778 tracing::Level::#log_level
779 };
780 let log_message = quote! {
781 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
782 "rpc" => "call",
783 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
784 "message_type" => stringify!(#enum_name),
785 "variant" => stringify!(#variant_name_snake),
786 ));
787 };
788
789 handler_trait_methods.push(quote! {
790 #[doc = "The generated handler method for this enum variant."]
791 async fn #variant_name_snake(
792 &mut self,
793 cx: &hyperactor::Context<Self>,
794 #(#arg_names: #arg_types),*)
795 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
796 });
797
798 client_trait_methods.push(quote! {
799 #[doc = "The generated client method for this enum variant."]
800 async fn #variant_name_snake(
801 &self,
802 cx: &impl hyperactor::context::Actor,
803 #(#arg_names: #arg_types),*)
804 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
805
806 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
807 async fn #variant_name_snake_deprecated(
808 &self,
809 cx: &impl hyperactor::context::Actor,
810 #(#arg_names: #arg_types),*)
811 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
812 });
813
814 let constructor = variant.constructor();
815
816 match_arms.push(quote! {
817 #constructor => {
818 #log_message
819 self.#variant_name_snake(cx, #(#arg_names),*).await
820 },
821 });
822 }
823 }
824 }
825
826 let handler_trait_name = format_ident!("{}Handler", name);
827 let client_trait_name = format_ident!("{}Client", name);
828
829 let mut handler_generics = input.generics.clone();
833 for param in handler_generics.type_params_mut() {
834 param.bounds.push(syn::parse_quote!(serde::Serialize));
835 param
836 .bounds
837 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
838 param.bounds.push(syn::parse_quote!(Send));
839 param.bounds.push(syn::parse_quote!(Sync));
840 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
841 param.bounds.push(syn::parse_quote!(typeuri::Named));
842 }
843 let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
844 let (client_impl_generics, _, _) = input.generics.split_for_impl();
845
846 let expanded = quote! {
847 #[doc = "The custom handler trait for this message type."]
848 #[hyperactor::internal_macro_support::async_trait::async_trait]
849 pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync {
850 #(#handler_trait_methods)*
851
852 #[doc = "Handle the next message."]
853 async fn handle(
854 &mut self,
855 cx: &hyperactor::Context<Self>,
856 message: #name #ty_generics,
857 ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
858 match message {
860 #(#match_arms)*
861 }
862 }
863 }
864
865 #[doc = "The custom client trait for this message type."]
866 #[hyperactor::internal_macro_support::async_trait::async_trait]
867 pub trait #client_trait_name #client_impl_generics: Send + Sync {
868 #(#client_trait_methods)*
869 }
870 };
871
872 TokenStream::from(expanded)
873}
874
875#[proc_macro_derive(HandleClient, attributes(log_level))]
878pub fn derive_handle_client(input: TokenStream) -> TokenStream {
879 derive_client(input, true)
880}
881
882#[proc_macro_derive(RefClient, attributes(log_level))]
885pub fn derive_ref_client(input: TokenStream) -> TokenStream {
886 derive_client(input, false)
887}
888
889fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
890 let input = parse_macro_input!(input as DeriveInput);
891 let name = input.ident.clone();
892
893 let messages = match parse_messages(input.clone()) {
894 Ok(messages) => messages,
895 Err(err) => return TokenStream::from(err.to_compile_error()),
896 };
897
898 let mut impl_methods = Vec::new();
900
901 let send_message = quote! { self.send(cx, message)? };
902 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
903
904 for message in &messages {
905 match message {
906 Message::Call {
907 variant,
908 reply_port,
909 return_type,
910 log_level,
911 } => {
912 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
913 let variant_name_snake = variant.snake_name();
914 let variant_name_snake_deprecated =
915 format_ident!("{}_deprecated", variant_name_snake);
916 let enum_name = variant.enum_name();
917
918 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
919 let constructor = variant.constructor();
920 let log_level = match (&global_log_level, log_level) {
921 (_, Some(local)) => local.clone(),
922 (Some(global), None) => global.clone(),
923 _ => Ident::new("DEBUG", Span::call_site()),
924 };
925 let log_level = if is_handle {
926 quote! {
927 tracing::Level::#log_level
928 }
929 } else {
930 quote! {
931 tracing::Level::TRACE
932 }
933 };
934 let log_message = quote! {
935 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
936 "rpc" => "call",
937 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
938 "message_type" => stringify!(#enum_name),
939 "variant" => stringify!(#variant_name_snake),
940 ));
941
942 };
943 let open_port = reply_port.open_op();
944 let rx_mod = reply_port.rx_modifier();
945 if reply_port.is_handle {
946 impl_methods.push(quote! {
947 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
948 async fn #variant_name_snake(
949 &self,
950 cx: &impl hyperactor::context::Actor,
951 #(#arg_names: #arg_types),*)
952 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
953 let (#reply_port_arg, #rx_mod reply_receiver) =
954 #open_port::<#return_type>(cx);
955 let message = #constructor;
956 #log_message;
957 #send_message;
958 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
959 }
960
961 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
962 async fn #variant_name_snake_deprecated(
963 &self,
964 cx: &impl hyperactor::context::Actor,
965 #(#arg_names: #arg_types),*)
966 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
967 let (#reply_port_arg, #rx_mod reply_receiver) =
968 #open_port::<#return_type>(cx);
969 let message = #constructor;
970 #log_message;
971 #send_message;
972 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
973 }
974 });
975 } else {
976 impl_methods.push(quote! {
977 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
978 async fn #variant_name_snake(
979 &self,
980 cx: &impl hyperactor::context::Actor,
981 #(#arg_names: #arg_types),*)
982 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
983 let (#reply_port_arg, #rx_mod reply_receiver) =
984 #open_port::<#return_type>(cx);
985 let #reply_port_arg = #reply_port_arg.bind();
986 let message = #constructor;
987 #log_message;
988 #send_message;
989 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
990 }
991
992 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
993 async fn #variant_name_snake_deprecated(
994 &self,
995 cx: &impl hyperactor::context::Actor,
996 #(#arg_names: #arg_types),*)
997 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
998 let (#reply_port_arg, #rx_mod reply_receiver) =
999 #open_port::<#return_type>(cx);
1000 let #reply_port_arg = #reply_port_arg.bind();
1001 let message = #constructor;
1002 #log_message;
1003 #send_message;
1004 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
1005 }
1006 });
1007 }
1008 }
1009 Message::OneWay { variant, log_level } => {
1010 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1011 let variant_name_snake = variant.snake_name();
1012 let variant_name_snake_deprecated =
1013 format_ident!("{}_deprecated", variant_name_snake);
1014 let enum_name = variant.enum_name();
1015 let constructor = variant.constructor();
1016 let log_level = match (&global_log_level, log_level) {
1017 (_, Some(local)) => local.clone(),
1018 (Some(global), None) => global.clone(),
1019 _ => Ident::new("DEBUG", Span::call_site()),
1020 };
1021 let _log_level = if is_handle {
1022 quote! {
1023 tracing::Level::TRACE
1024 }
1025 } else {
1026 quote! {
1027 tracing::Level::#log_level
1028 }
1029 };
1030 let log_message = quote! {
1031 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1032 "rpc" => "oneway",
1033 "actor_id" => self.actor_id().to_string(),
1034 "message_type" => stringify!(#enum_name),
1035 "variant" => stringify!(#variant_name_snake),
1036 ));
1037 };
1038 impl_methods.push(quote! {
1039 async fn #variant_name_snake(
1040 &self,
1041 cx: &impl hyperactor::context::Actor,
1042 #(#arg_names: #arg_types),*)
1043 -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1044 let message = #constructor;
1045 #log_message;
1046 #send_message;
1047 Ok(())
1048 }
1049
1050 async fn #variant_name_snake_deprecated(
1051 &self,
1052 cx: &impl hyperactor::context::Actor,
1053 #(#arg_names: #arg_types),*)
1054 -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1055 let message = #constructor;
1056 #log_message;
1057 #send_message;
1058 Ok(())
1059 }
1060 });
1061 }
1062 }
1063 }
1064
1065 let trait_name = format_ident!("{}Client", name);
1066
1067 let (_, ty_generics, _) = input.generics.split_for_impl();
1068
1069 let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1071 let mut trait_generics = input.generics.clone();
1072 trait_generics.params.insert(
1073 0,
1074 syn::GenericParam::Type(syn::TypeParam {
1075 ident: actor_ident.clone(),
1076 attrs: vec![],
1077 colon_token: None,
1078 bounds: Punctuated::new(),
1079 eq_token: None,
1080 default: None,
1081 }),
1082 );
1083
1084 for param in trait_generics.type_params_mut() {
1085 if param.ident == actor_ident {
1086 continue;
1087 }
1088 param.bounds.push(syn::parse_quote!(serde::Serialize));
1089 param
1090 .bounds
1091 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1092 param.bounds.push(syn::parse_quote!(Send));
1093 param.bounds.push(syn::parse_quote!(Sync));
1094 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1095 param.bounds.push(syn::parse_quote!(typeuri::Named));
1096 }
1097
1098 let (impl_generics, _, _) = trait_generics.split_for_impl();
1099
1100 let expanded = if is_handle {
1101 quote! {
1102 #[hyperactor::internal_macro_support::async_trait::async_trait]
1103 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1104 where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1105 #(#impl_methods)*
1106 }
1107 }
1108 } else {
1109 quote! {
1110 #[hyperactor::internal_macro_support::async_trait::async_trait]
1111 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1112 where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1113 #(#impl_methods)*
1114 }
1115 }
1116 };
1117
1118 TokenStream::from(expanded)
1119}
1120
1121const HANDLE_ARGUMENT_ERROR: &str = indoc! {r#"
1122`handle` expects the message type that is being handled
1123
1124= help: use `#[handle(MessageType)]`
1125"#};
1126
1127#[proc_macro_attribute]
1129pub fn handle(attr: TokenStream, item: TokenStream) -> TokenStream {
1130 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1131 if attr_args.len() != 1 {
1132 return TokenStream::from(
1133 syn::Error::new_spanned(attr_args, HANDLE_ARGUMENT_ERROR).to_compile_error(),
1134 );
1135 }
1136
1137 let message_type = attr_args.first().unwrap();
1138 let input = parse_macro_input!(item as ItemImpl);
1139
1140 let self_type = match *input.self_ty {
1141 syn::Type::Path(ref type_path) => {
1142 let segment = type_path.path.segments.last().unwrap();
1143 segment.clone() }
1145 _ => {
1146 return TokenStream::from(
1147 syn::Error::new_spanned(input.self_ty, "`handle` argument must be a type")
1148 .to_compile_error(),
1149 );
1150 }
1151 };
1152
1153 let trait_name = match input.trait_ {
1154 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1155 None => {
1156 return TokenStream::from(
1157 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1158 .to_compile_error(),
1159 );
1160 }
1161 };
1162
1163 let expanded = quote! {
1164 #input
1165
1166 #[hyperactor::internal_macro_support::async_trait::async_trait]
1167 impl hyperactor::Handler<#message_type> for #self_type {
1168 async fn handle(
1169 &mut self,
1170 cx: &hyperactor::Context<Self>,
1171 message: #message_type,
1172 ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
1173 <Self as #trait_name>::handle(self, cx, message).await
1174 }
1175 }
1176 };
1177
1178 TokenStream::from(expanded)
1179}
1180
1181#[proc_macro_attribute]
1194pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1195 let args =
1196 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1197 let input = parse_macro_input!(input as ItemFn);
1198 let output = quote! {
1199 #[hyperactor::internal_macro_support::tracing::instrument(err, skip_all, #args)]
1200 #input
1201 };
1202
1203 TokenStream::from(output)
1204}
1205
1206#[proc_macro_attribute]
1217pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1218 let args =
1219 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1220 let input = parse_macro_input!(input as ItemFn);
1221
1222 let output = quote! {
1223 #[hyperactor::internal_macro_support::tracing::instrument(skip_all, #args)]
1224 #input
1225 };
1226
1227 TokenStream::from(output)
1228}
1229
1230struct HandlerSpec {
1231 ty: Type,
1232 cast: bool,
1233}
1234
1235impl Parse for HandlerSpec {
1236 fn parse(input: ParseStream) -> syn::Result<Self> {
1237 let ty: Type = input.parse()?;
1238
1239 if input.peek(syn::token::Brace) {
1240 let content;
1241 syn::braced!(content in input);
1242 let key: Ident = content.parse()?;
1243 content.parse::<Token![=]>()?;
1244 let expr: Expr = content.parse()?;
1245
1246 let cast = if key == "cast" {
1247 if let Expr::Lit(ExprLit {
1248 lit: Lit::Bool(b), ..
1249 }) = expr
1250 {
1251 b.value
1252 } else {
1253 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1254 }
1255 } else {
1256 return Err(syn::Error::new_spanned(
1257 key,
1258 "unsupported field (expected `cast`)",
1259 ));
1260 };
1261
1262 Ok(HandlerSpec { ty, cast })
1263 } else if input.is_empty() || input.peek(Token![,]) {
1264 Ok(HandlerSpec { ty, cast: false })
1265 } else {
1266 let unexpected: proc_macro2::TokenTree = input.parse()?;
1268 Err(syn::Error::new_spanned(
1269 unexpected,
1270 "unexpected token after type — expected `{ ... }` or nothing",
1271 ))
1272 }
1273 }
1274}
1275
1276impl HandlerSpec {
1277 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1278 let mut tys = Vec::new();
1279 for HandlerSpec { ty, cast } in handlers {
1280 if cast {
1281 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1282 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1283 tys.push(wrapped_ty);
1284 }
1285 tys.push(ty);
1286 }
1287 tys
1288 }
1289}
1290
1291struct ExportAttr {
1293 spawn: bool,
1294 handlers: Vec<HandlerSpec>,
1295}
1296
1297impl Parse for ExportAttr {
1298 fn parse(input: ParseStream) -> syn::Result<Self> {
1299 let mut spawn = false;
1300 let mut handlers: Vec<HandlerSpec> = vec![];
1301
1302 while !input.is_empty() {
1303 let key: Ident = input.parse()?;
1304 input.parse::<Token![=]>()?;
1305
1306 if key == "spawn" {
1307 let expr: Expr = input.parse()?;
1308 if let Expr::Lit(ExprLit {
1309 lit: Lit::Bool(b), ..
1310 }) = expr
1311 {
1312 spawn = b.value;
1313 } else {
1314 return Err(syn::Error::new_spanned(
1315 expr,
1316 "expected boolean for `spawn`",
1317 ));
1318 }
1319 } else if key == "handlers" {
1320 let content;
1321 bracketed!(content in input);
1322 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1323 handlers = raw_handlers.into_iter().collect();
1324 } else {
1325 return Err(syn::Error::new_spanned(
1326 key,
1327 "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1328 ));
1329 }
1330
1331 let _ = input.parse::<Token![,]>();
1333 }
1334
1335 Ok(ExportAttr { spawn, handlers })
1336 }
1337}
1338
1339#[proc_macro_attribute]
1366pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1367 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1368 let data_type_name = &input.ident;
1369 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1370
1371 let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1372 let tys = HandlerSpec::add_indexed(handlers);
1373
1374 let mut handles = Vec::new();
1375 let mut bindings = Vec::new();
1376 let mut type_registrations = Vec::new();
1377
1378 for ty in &tys {
1379 handles.push(quote! {
1380 impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
1381 impl #impl_generics hyperactor::remote::Accepts<#ty> for #data_type_name #ty_generics #where_clause {}
1382 });
1383 bindings.push(quote! {
1384 ports.bind::<#ty>();
1385 });
1386 type_registrations.push(quote! {
1387 wirevalue::register_type!(#ty);
1388 });
1389 }
1390
1391 let mut expanded = quote! {
1392 #input
1393
1394 impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
1395
1396 #(#handles)*
1397
1398 #(#type_registrations)*
1399
1400 impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1402 impl #impl_generics hyperactor::remote::Accepts<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1403
1404 impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::introspect::IntrospectMessage> for #data_type_name #ty_generics #where_clause {}
1406 impl #impl_generics hyperactor::remote::Accepts<hyperactor::introspect::IntrospectMessage> for #data_type_name #ty_generics #where_clause {}
1407
1408 impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
1409 fn bind(ports: &hyperactor::proc::Ports<Self>) {
1410 #(#bindings)*
1411 }
1412 }
1413
1414 impl #impl_generics typeuri::Named for #data_type_name #ty_generics #where_clause {
1416 fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
1417 }
1418 };
1419
1420 if spawn {
1421 expanded.extend(quote! {
1422 hyperactor::remote!(#data_type_name);
1423 });
1424 }
1425
1426 TokenStream::from(expanded)
1427}
1428
1429struct BehaviorInput {
1431 behavior: Ident,
1432 generics: syn::Generics,
1433 handlers: Vec<HandlerSpec>,
1434}
1435
1436impl syn::parse::Parse for BehaviorInput {
1437 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1438 let behavior: Ident = input.parse()?;
1439 let generics: syn::Generics = input.parse()?;
1440 let _: Token![,] = input.parse()?;
1441 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1442 let handlers = raw_handlers.into_iter().collect();
1443 Ok(BehaviorInput {
1444 behavior,
1445 generics,
1446 handlers,
1447 })
1448 }
1449}
1450
1451#[proc_macro]
1476pub fn behavior(input: TokenStream) -> TokenStream {
1477 let BehaviorInput {
1478 behavior,
1479 generics,
1480 handlers,
1481 } = parse_macro_input!(input as BehaviorInput);
1482 let tys = HandlerSpec::add_indexed(handlers);
1483
1484 let mut bounded_generics = generics.clone();
1486 for param in bounded_generics.type_params_mut() {
1487 param.bounds.push(syn::parse_quote!(typeuri::Named));
1488 param.bounds.push(syn::parse_quote!(serde::Serialize));
1489 param.bounds.push(syn::parse_quote!(std::marker::Send));
1490 param.bounds.push(syn::parse_quote!(std::marker::Sync));
1491 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1492 let lifetime =
1495 syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1496 param
1497 .bounds
1498 .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1499 }
1500
1501 let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1503
1504 let mut binds_generics = bounded_generics.clone();
1506 binds_generics.params.insert(
1507 0,
1508 syn::GenericParam::Type(syn::TypeParam {
1509 attrs: vec![],
1510 ident: Ident::new("A", proc_macro2::Span::call_site()),
1511 colon_token: None,
1512 bounds: Punctuated::new(),
1513 eq_token: None,
1514 default: None,
1515 }),
1516 );
1517 let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1518
1519 let type_params: Vec<_> = bounded_generics.type_params().collect();
1521 let has_generics = !type_params.is_empty();
1522
1523 let (typename_impl, typehash_impl) = if has_generics {
1524 let placeholders = vec!["{}"; type_params.len()].join(", ");
1526 let placeholders_format_string = format!("<{}>", placeholders);
1527 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1528
1529 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1530 (
1531 quote! {
1532 typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1533 },
1534 quote! {
1535 typeuri::cityhasher::hash(Self::typename())
1536 },
1537 )
1538 } else {
1539 (
1540 quote! {
1541 concat!(std::module_path!(), "::", stringify!(#behavior))
1542 },
1543 quote! {
1544 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1545 typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1546 });
1547 *TYPEHASH
1548 },
1549 )
1550 };
1551
1552 let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1553
1554 let expanded = quote! {
1555 #[doc = "The generated behavior struct."]
1556 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1557 pub struct #behavior #impl_generics #where_clause {
1558 _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1559 }
1560
1561 impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1562 fn typename() -> &'static str {
1563 #typename_impl
1564 }
1565
1566 fn typehash() -> u64 {
1567 #typehash_impl
1568 }
1569 }
1570
1571 impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1572
1573 impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1574 where
1575 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1576 #where_clause
1577 {
1578 fn bind(ports: &hyperactor::proc::Ports<A>) {
1579 #(
1580 ports.bind::<#tys>();
1581 )*
1582 }
1583 }
1584
1585 #(
1586 impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1587 impl #impl_generics hyperactor::remote::Accepts<#tys> for #behavior #ty_generics #where_clause {}
1588 )*
1589 };
1590
1591 TokenStream::from(expanded)
1592}
1593
1594fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1595 let mut is_included = false;
1596 for attr in &field.attrs {
1597 if attr.path().is_ident("binding") {
1598 attr.parse_nested_meta(|meta| {
1600 if meta.path.is_ident("include") {
1601 is_included = true;
1602 Ok(())
1603 } else {
1604 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1605 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1606 }
1607 })?
1608 }
1609 }
1610 Ok(is_included)
1611}
1612
1613enum FieldAccessor {
1618 Named(Ident),
1619 Unnamed(Index),
1620}
1621
1622struct ParsedField {
1624 accessor: FieldAccessor,
1625 ty: Type,
1626 included: bool,
1627}
1628
1629impl From<&ParsedField> for (Ident, Type) {
1630 fn from(field: &ParsedField) -> Self {
1631 let field_ident = match &field.accessor {
1632 FieldAccessor::Named(ident) => ident.clone(),
1633 FieldAccessor::Unnamed(i) => {
1634 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1635 }
1636 };
1637 (field_ident, field.ty.clone())
1638 }
1639}
1640
1641fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1642 match fields {
1643 Fields::Named(named) => named
1644 .named
1645 .iter()
1646 .map(|f| {
1647 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1648 Ok(ParsedField {
1649 accessor,
1650 ty: f.ty.clone(),
1651 included: include_in_bind_unbind(f)?,
1652 })
1653 })
1654 .collect(),
1655 Fields::Unnamed(unnamed) => unnamed
1656 .unnamed
1657 .iter()
1658 .enumerate()
1659 .map(|(i, f)| {
1660 let accessor = FieldAccessor::Unnamed(Index::from(i));
1661 Ok(ParsedField {
1662 accessor,
1663 ty: f.ty.clone(),
1664 included: include_in_bind_unbind(f)?,
1665 })
1666 })
1667 .collect(),
1668 Fields::Unit => Ok(Vec::new()),
1669 }
1670}
1671
1672fn gen_struct_items<F>(
1673 fields: &Fields,
1674 make_item: F,
1675 is_mutable: bool,
1676) -> syn::Result<Vec<proc_macro2::TokenStream>>
1677where
1678 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1679{
1680 let borrow = if is_mutable {
1681 quote! { &mut }
1682 } else {
1683 quote! { & }
1684 };
1685 let items: Vec<_> = collect_all_fields(fields)?
1686 .into_iter()
1687 .filter(|f| f.included)
1688 .map(
1689 |ParsedField {
1690 accessor,
1691 ty,
1692 included,
1693 }| {
1694 assert!(included);
1695 let field_accessor = match accessor {
1696 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1697 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1698 };
1699 make_item(field_accessor, ty)
1700 },
1701 )
1702 .collect();
1703 Ok(items)
1704}
1705
1706fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1716 all_fields
1717 .iter()
1718 .map(
1719 |ParsedField {
1720 accessor,
1721 ty: _,
1722 included,
1723 }| {
1724 match accessor {
1725 FieldAccessor::Named(ident) => {
1726 if *included {
1727 quote! { #ident }
1728 } else {
1729 quote! { #ident: _ }
1730 }
1731 }
1732 FieldAccessor::Unnamed(i) => {
1733 if *included {
1734 let ident = Ident::new(
1735 &format!("f{}", i.index),
1736 proc_macro2::Span::call_site(),
1737 );
1738 quote! { #ident }
1739 } else {
1740 quote! { _ }
1741 }
1742 }
1743 }
1744 },
1745 )
1746 .collect()
1747}
1748
1749fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1756where
1757 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1758{
1759 data.variants
1760 .iter()
1761 .map(|variant| {
1762 let name = &variant.ident;
1763 let all_fields = collect_all_fields(&variant.fields)?;
1764 let field_accessors = gen_enum_field_accessors(&all_fields);
1765 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1766 let items = included_fields
1767 .iter()
1768 .map(|f| {
1769 let (accessor, ty) = <(Ident, Type)>::from(*f);
1770 make_item(quote! { #accessor }, ty)
1771 })
1772 .collect::<Vec<_>>();
1773
1774 Ok(match &variant.fields {
1775 Fields::Named(_) => {
1776 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1777 }
1778 Fields::Unnamed(_) => {
1779 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1780 }
1781 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1782 })
1783 })
1784 .collect()
1785}
1786
1787#[proc_macro_derive(Bind, attributes(binding))]
1869pub fn derive_bind(input: TokenStream) -> TokenStream {
1870 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1871 quote! {
1872 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1873 }
1874 }
1875
1876 let input = parse_macro_input!(input as DeriveInput);
1877 let name = &input.ident;
1878 let inner = match &input.data {
1879 Data::Struct(DataStruct { fields, .. }) => {
1880 match gen_struct_items(fields, make_item, true) {
1881 Ok(collects) => {
1882 quote! { #(#collects)* }
1883 }
1884 Err(e) => {
1885 return TokenStream::from(e.to_compile_error());
1886 }
1887 }
1888 }
1889 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1890 Ok(arms) => {
1891 quote! { match self { #(#arms),* } }
1892 }
1893 Err(e) => {
1894 return TokenStream::from(e.to_compile_error());
1895 }
1896 },
1897 _ => panic!("Bind can only be derived for structs and enums"),
1898 };
1899 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1900 let expand = quote! {
1901 #[automatically_derived]
1902 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1903 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1904 #inner
1905 Ok(())
1906 }
1907 }
1908 };
1909 TokenStream::from(expand)
1910}
1911
1912#[proc_macro_derive(Unbind, attributes(binding))]
1926pub fn derive_unbind(input: TokenStream) -> TokenStream {
1927 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1928 quote! {
1929 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1930 }
1931 }
1932
1933 let input = parse_macro_input!(input as DeriveInput);
1934 let name = &input.ident;
1935 let inner = match &input.data {
1936 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1937 {
1938 Ok(collects) => {
1939 quote! { #(#collects)* }
1940 }
1941 Err(e) => {
1942 return TokenStream::from(e.to_compile_error());
1943 }
1944 },
1945 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1946 Ok(arms) => {
1947 quote! { match self { #(#arms),* } }
1948 }
1949 Err(e) => {
1950 return TokenStream::from(e.to_compile_error());
1951 }
1952 },
1953 _ => panic!("Unbind can only be derived for structs and enums"),
1954 };
1955 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1956 let expand = quote! {
1957 #[automatically_derived]
1958 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1959 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1960 #inner
1961 Ok(())
1962 }
1963 }
1964 };
1965 TokenStream::from(expand)
1966}
1967
1968fn parse_observe_function(
1970 attr: TokenStream,
1971 item: TokenStream,
1972) -> syn::Result<(ItemFn, String, String)> {
1973 let input = syn::parse::<ItemFn>(item)?;
1974
1975 if input.sig.asyncness.is_none() {
1976 return Err(syn::Error::new(
1977 input.sig.span(),
1978 "observe macros can only be applied to async functions",
1979 ));
1980 }
1981
1982 let fn_name_str = input.sig.ident.to_string();
1983 let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
1984
1985 Ok((input, fn_name_str, module_name_str))
1986}
1987
1988fn create_telemetry_setup(
1990 module_name_str: &str,
1991 fn_name_str: &str,
1992 include_error: bool,
1993) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
1994 let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
1995 let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
1996
1997 let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
1998
1999 let error_ident = if include_error {
2000 Some(Ident::new(
2001 "error",
2002 Span::from(proc_macro::Span::def_site()),
2003 ))
2004 } else {
2005 None
2006 };
2007
2008 let error_declaration = if let Some(ref error_ident) = error_ident {
2009 quote! {
2010 hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2011 }
2012 } else {
2013 quote! {}
2014 };
2015
2016 let setup_code = quote! {
2017 use hyperactor_telemetry;
2018 hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2019 hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2020 #error_declaration
2021 };
2022
2023 (latency_ident, success_ident, error_ident, setup_code)
2024}
2025
2026#[proc_macro_attribute]
2046pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2047 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2048 Ok(parsed) => parsed,
2049 Err(err) => return err.to_compile_error().into(),
2050 };
2051
2052 let fn_name = &input.sig.ident;
2053 let vis = &input.vis;
2054 let args = &input.sig.inputs;
2055 let return_type = &input.sig.output;
2056 let body = &input.block;
2057 let attrs = &input.attrs;
2058 let generics = &input.sig.generics;
2059
2060 let (latency_ident, success_ident, error_ident, telemetry_setup) =
2061 create_telemetry_setup(&module_name_str, &fn_name_str, true);
2062 let error_ident = error_ident.unwrap();
2063
2064 let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2065
2066 let expanded = quote! {
2068 #(#attrs)*
2069 #vis async fn #fn_name #generics(#args) #return_type {
2070 #telemetry_setup
2071
2072 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2073 let _timer = #latency_ident.start(kv_pairs);
2074
2075 let #result_ident = async #body.await;
2076
2077 match &#result_ident {
2078 Ok(_) => {
2079 #success_ident.add(
2080 1,
2081 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2082 );
2083 }
2084 Err(_) => {
2085 #error_ident.add(
2086 1,
2087 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2088 );
2089 }
2090 }
2091
2092 #result_ident
2093 }
2094 };
2095
2096 expanded.into()
2097}
2098
2099#[proc_macro_attribute]
2118pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2119 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2120 Ok(parsed) => parsed,
2121 Err(err) => return err.to_compile_error().into(),
2122 };
2123
2124 let fn_name = &input.sig.ident;
2125 let vis = &input.vis;
2126 let args = &input.sig.inputs;
2127 let return_type = &input.sig.output;
2128 let body = &input.block;
2129 let attrs = &input.attrs;
2130 let generics = &input.sig.generics;
2131
2132 let (latency_ident, success_ident, _, telemetry_setup) =
2133 create_telemetry_setup(&module_name_str, &fn_name_str, false);
2134
2135 let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2136
2137 let expanded = quote! {
2139 #(#attrs)*
2140 #vis async fn #fn_name #generics(#args) #return_type {
2141 #telemetry_setup
2142
2143 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2144 let _timer = #latency_ident.start(kv_pairs);
2145
2146 let #return_ident = async #body.await;
2147
2148 #success_ident.add(
2149 1,
2150 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2151 );
2152 #return_ident
2153 }
2154 };
2155
2156 expanded.into()
2157}
2158
2159fn validate_label(s: &str) -> Result<(), String> {
2160 if s.is_empty() {
2161 return Err("label must not be empty".to_string());
2162 }
2163 if s.len() > 63 {
2164 return Err("label exceeds 63 characters".to_string());
2165 }
2166 let first = s.as_bytes()[0];
2167 if !first.is_ascii_lowercase() {
2168 return Err("label must start with a lowercase letter".to_string());
2169 }
2170 let last = s.as_bytes()[s.len() - 1];
2171 if !last.is_ascii_lowercase() && !last.is_ascii_digit() {
2172 return Err("label must end with a lowercase letter or digit".to_string());
2173 }
2174 for ch in s.chars() {
2175 if !ch.is_ascii_lowercase() && !ch.is_ascii_digit() && ch != '-' {
2176 return Err(format!("label contains invalid character '{ch}'"));
2177 }
2178 }
2179 Ok(())
2180}
2181
2182fn validate_hex_uid(s: &str) -> Result<u64, String> {
2183 if s.is_empty() || s.len() > 16 {
2184 return Err(format!("hex uid must be 1-16 hex characters, got '{s}'"));
2185 }
2186 for ch in s.chars() {
2187 if !ch.is_ascii_hexdigit() {
2188 return Err(format!("hex uid contains invalid character '{ch}'"));
2189 }
2190 }
2191 u64::from_str_radix(s, 16).map_err(|e| format!("invalid hex uid '{s}': {e}"))
2192}
2193
2194#[proc_macro]
2200pub fn uid(input: TokenStream) -> TokenStream {
2201 let input2: proc_macro2::TokenStream = input.into();
2202 let combined: String = input2.into_iter().map(|tt| tt.to_string()).collect();
2203
2204 if combined.is_empty() {
2205 return TokenStream::from(quote! { compile_error!("uid! macro requires an argument") });
2206 }
2207
2208 if let Some(rest) = combined.strip_prefix('_') {
2210 return match validate_label(rest) {
2211 Ok(()) => TokenStream::from(quote! {
2212 hyperactor::id::Uid::Singleton(
2213 hyperactor::id::Label::new(#rest).unwrap()
2214 )
2215 }),
2216 Err(e) => {
2217 let msg = format!("invalid singleton uid: {e}");
2218 TokenStream::from(quote! { compile_error!(#msg) })
2219 }
2220 };
2221 }
2222
2223 match validate_hex_uid(&combined) {
2225 Ok(uid_val) => TokenStream::from(quote! {
2226 hyperactor::id::Uid::Instance(#uid_val)
2227 }),
2228 Err(e) => {
2229 let msg = format!("invalid uid: {e}");
2230 TokenStream::from(quote! { compile_error!(#msg) })
2231 }
2232 }
2233}