1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::WherePredicate;
41use syn::bracketed;
42use syn::parse::Parse;
43use syn::parse::ParseStream;
44use syn::parse_macro_input;
45use syn::punctuated::Punctuated;
46use syn::spanned::Spanned;
47
48const REPLY_VARIANT_ERROR: &str = indoc! {r#"
49`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
50
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
53"#};
54
55const REPLY_USAGE_ERROR: &str = indoc! {r#"
56`call` message expects at most one `reply` argument
57
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
60"#};
61
62enum FieldFlag {
63 None,
64 Reply,
65}
66
67#[allow(dead_code)]
69enum Variant {
70 Named {
72 enum_name: Ident,
73 name: Ident,
74 field_names: Vec<Ident>,
75 field_types: Vec<Type>,
76 field_flags: Vec<FieldFlag>,
77 is_struct: bool,
78 generics: syn::Generics,
79 },
80 Anon {
82 enum_name: Ident,
83 name: Ident,
84 field_types: Vec<Type>,
85 field_flags: Vec<FieldFlag>,
86 is_struct: bool,
87 generics: syn::Generics,
88 },
89}
90
91impl Variant {
92 fn len(&self) -> usize {
94 self.field_types().len()
95 }
96
97 fn is_struct(&self) -> bool {
99 match self {
100 Variant::Named { is_struct, .. } => *is_struct,
101 Variant::Anon { is_struct, .. } => *is_struct,
102 }
103 }
104
105 fn enum_name(&self) -> &Ident {
107 match self {
108 Variant::Named { enum_name, .. } => enum_name,
109 Variant::Anon { enum_name, .. } => enum_name,
110 }
111 }
112
113 fn name(&self) -> &Ident {
115 match self {
116 Variant::Named { name, .. } => name,
117 Variant::Anon { name, .. } => name,
118 }
119 }
120
121 #[allow(dead_code)]
123 fn generics(&self) -> &syn::Generics {
124 match self {
125 Variant::Named { generics, .. } => generics,
126 Variant::Anon { generics, .. } => generics,
127 }
128 }
129
130 fn snake_name(&self) -> Ident {
132 Ident::new(
133 &self.name().to_string().to_case(Case::Snake),
134 self.name().span(),
135 )
136 }
137
138 fn qualified_name(&self) -> proc_macro2::TokenStream {
140 let enum_name = self.enum_name();
141 let name = self.name();
142
143 if self.is_struct() {
144 quote! { #enum_name }
145 } else {
146 quote! { #enum_name::#name }
147 }
148 }
149
150 fn field_names(&self) -> Vec<Ident> {
153 match self {
154 Variant::Named { field_names, .. } => field_names.clone(),
155 Variant::Anon { field_types, .. } => (0usize..field_types.len())
156 .map(|idx| format_ident!("arg{}", idx))
157 .collect(),
158 }
159 }
160
161 fn field_types(&self) -> &Vec<Type> {
163 match self {
164 Variant::Named { field_types, .. } => field_types,
165 Variant::Anon { field_types, .. } => field_types,
166 }
167 }
168
169 fn field_flags(&self) -> &Vec<FieldFlag> {
171 match self {
172 Variant::Named { field_flags, .. } => field_flags,
173 Variant::Anon { field_flags, .. } => field_flags,
174 }
175 }
176
177 fn constructor(&self) -> proc_macro2::TokenStream {
179 let qualified_name = self.qualified_name();
180 let field_names = self.field_names();
181 match self {
182 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
183 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
184 }
185 }
186}
187
188struct ReplyPort {
189 is_handle: bool,
190 is_once: bool,
191}
192
193impl ReplyPort {
194 fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
195 ReplyPort {
196 is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
197 is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
198 }
199 }
200
201 fn open_op(&self) -> proc_macro2::TokenStream {
202 if self.is_once {
203 quote! { hyperactor::mailbox::open_once_port }
204 } else {
205 quote! { hyperactor::mailbox::open_port }
206 }
207 }
208
209 fn rx_modifier(&self) -> proc_macro2::TokenStream {
210 if self.is_once {
211 quote! {}
212 } else {
213 quote! { mut }
214 }
215 }
216}
217
218#[allow(clippy::large_enum_variant)]
221enum Message {
222 Call {
225 variant: Variant,
226 reply_port: ReplyPort,
228 return_type: Type,
230 log_level: Option<Ident>,
232 },
233 OneWay {
234 variant: Variant,
235 log_level: Option<Ident>,
237 },
238}
239
240impl Message {
241 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
242 match &variant
243 .field_flags()
244 .iter()
245 .zip(variant.field_types())
246 .filter_map(|(flag, ty)| match flag {
247 FieldFlag::Reply => Some(ty),
248 FieldFlag::None => None,
249 })
250 .collect::<Vec<&Type>>()[..]
251 {
252 [] => Ok(Self::OneWay { variant, log_level }),
253 [reply_port_ty] => {
254 let syn::Type::Path(type_path) = reply_port_ty else {
255 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
256 };
257 let Some(last_segment) = type_path.path.segments.last() else {
258 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
259 };
260 if last_segment.ident != "OncePortRef"
261 && last_segment.ident != "OncePortHandle"
262 && last_segment.ident != "PortRef"
263 && last_segment.ident != "PortHandle"
264 {
265 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
266 }
267 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
268 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
269 };
270 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
271 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
272 };
273 let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
274 let return_type = return_ty.clone();
275 Ok(Self::Call {
276 variant,
277 reply_port,
278 return_type,
279 log_level,
280 })
281 }
282 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
283 }
284 }
285
286 fn args(&self) -> Vec<(Ident, Type)> {
288 match self {
289 Message::Call { variant, .. } => variant
290 .field_names()
291 .into_iter()
292 .zip(variant.field_types().clone())
293 .take(variant.len() - 1)
294 .collect(),
295 Message::OneWay { variant, .. } => variant
296 .field_names()
297 .into_iter()
298 .zip(variant.field_types().clone())
299 .collect(),
300 }
301 }
302
303 fn variant(&self) -> &Variant {
304 match self {
305 Message::Call { variant, .. } => variant,
306 Message::OneWay { variant, .. } => variant,
307 }
308 }
309
310 fn reply_port_position(&self) -> Option<usize> {
311 self.variant()
312 .field_flags()
313 .iter()
314 .position(|flag| matches!(flag, FieldFlag::Reply))
315 }
316
317 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
319 match self {
320 Message::Call { variant, .. } => {
321 let pos = self.reply_port_position()?;
322 Some((
323 variant.field_names()[pos].clone(),
324 variant.field_types()[pos].clone(),
325 ))
326 }
327 Message::OneWay { .. } => None,
328 }
329 }
330}
331
332fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
333 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
334 Some(attr) => {
335 let Ok(meta) = attr.meta.require_list() else {
336 return Err(syn::Error::new(
337 Span::call_site(),
338 indoc! {"
339 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
340
341 = help use `#[log_level(info)]` or `#[log_level(error)]`
342 "},
343 ));
344 };
345 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
346 if parsed.len() != 1 {
347 return Err(syn::Error::new(
348 Span::call_site(),
349 indoc! {"
350 `log_level` attribute must specify exactly one level
351
352 = help use `#[log_level(warn)]` or `#[log_level(info)]`
353 "},
354 ));
355 };
356 Some(parsed.first().unwrap().to_string())
357 }
358 None => None,
359 };
360
361 if level.is_none() {
362 return Ok(None);
363 }
364 let level = level.unwrap();
365
366 match level.as_str() {
367 "error" | "warn" | "info" | "debug" | "trace" => {}
368 _ => {
369 return Err(syn::Error::new(
370 Span::call_site(),
371 indoc! {"
372 `log_level` attribute must be one of 'error, warn, info, debug, trace'
373
374 = help use `#[log_level(warn)]` or `#[log_level(info)]`
375 "},
376 ));
377 }
378 }
379
380 Ok(Some(Ident::new(
381 level.to_ascii_uppercase().as_str(),
382 Span::call_site(),
383 )))
384}
385
386fn parse_field_flag(field: &Field) -> FieldFlag {
387 for attr in field.attrs.iter() {
388 match &attr.meta {
389 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
390 _ => {}
391 }
392 }
393 FieldFlag::None
394}
395
396fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
398 match &input.data {
399 Data::Enum(data_enum) => {
400 let mut messages = Vec::new();
401
402 for variant in &data_enum.variants {
403 let name = variant.ident.clone();
404 let attrs = &variant.attrs;
405
406 let message_variant = match &variant.fields {
407 syn::Fields::Unnamed(fields_) => Variant::Anon {
408 enum_name: input.ident.clone(),
409 name,
410 field_types: fields_
411 .unnamed
412 .iter()
413 .map(|field| field.ty.clone())
414 .collect(),
415 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
416 is_struct: false,
417 generics: input.generics.clone(),
418 },
419 syn::Fields::Named(fields_) => Variant::Named {
420 enum_name: input.ident.clone(),
421 name,
422 field_names: fields_
423 .named
424 .iter()
425 .map(|field| field.ident.clone().unwrap())
426 .collect(),
427 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
428 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
429 is_struct: false,
430 generics: input.generics.clone(),
431 },
432 _ => {
433 return Err(syn::Error::new_spanned(
434 variant,
435 indoc! {r#"
436 `Handler` currently only supports named or tuple struct variants
437
438 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
439 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
440 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
441 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
442 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
443 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
444 "#},
445 ));
446 }
447 };
448 let log_level = parse_log_level(attrs)?;
449
450 messages.push(Message::new(
451 variant.fields.span(),
452 message_variant,
453 log_level,
454 )?);
455 }
456
457 Ok(messages)
458 }
459 Data::Struct(data_struct) => {
460 let struct_name = input.ident.clone();
461 let attrs = &input.attrs;
462
463 let message_variant = match &data_struct.fields {
464 syn::Fields::Unnamed(fields_) => Variant::Anon {
465 enum_name: struct_name.clone(),
466 name: struct_name,
467 field_types: fields_
468 .unnamed
469 .iter()
470 .map(|field| field.ty.clone())
471 .collect(),
472 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
473 is_struct: true,
474 generics: input.generics.clone(),
475 },
476 syn::Fields::Named(fields_) => Variant::Named {
477 enum_name: struct_name.clone(),
478 name: struct_name,
479 field_names: fields_
480 .named
481 .iter()
482 .map(|field| field.ident.clone().unwrap())
483 .collect(),
484 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
485 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
486 is_struct: true,
487 generics: input.generics.clone(),
488 },
489 syn::Fields::Unit => Variant::Anon {
490 enum_name: struct_name.clone(),
491 name: struct_name,
492 field_types: Vec::new(),
493 field_flags: Vec::new(),
494 is_struct: true,
495 generics: input.generics.clone(),
496 },
497 };
498
499 let log_level = parse_log_level(attrs)?;
500 let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
501
502 Ok(vec![message])
503 }
504 _ => Err(syn::Error::new_spanned(
505 input,
506 "handlers can only be derived for enums and structs",
507 )),
508 }
509}
510
511#[proc_macro_derive(Handler, attributes(reply))]
666pub fn derive_handler(input: TokenStream) -> TokenStream {
667 let input = parse_macro_input!(input as DeriveInput);
668 let name: Ident = input.ident.clone();
669 let (_, ty_generics, _) = input.generics.split_for_impl();
670
671 let messages = match parse_messages(input.clone()) {
672 Ok(messages) => messages,
673 Err(err) => return TokenStream::from(err.to_compile_error()),
674 };
675
676 let mut handler_trait_methods = Vec::new();
678
679 let mut match_arms = Vec::new();
681
682 let mut client_trait_methods = Vec::new();
684
685 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
686
687 for message in &messages {
688 match message {
689 Message::Call {
690 variant,
691 reply_port,
692 return_type,
693 log_level,
694 } => {
695 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
696 let variant_name_snake = variant.snake_name();
697 let variant_name_snake_deprecated =
698 format_ident!("{}_deprecated", variant_name_snake);
699 let enum_name = variant.enum_name();
700 let _variant_qualified_name = variant.qualified_name();
701 let log_level = match (&global_log_level, log_level) {
702 (_, Some(local)) => local.clone(),
703 (Some(global), None) => global.clone(),
704 _ => Ident::new("DEBUG", Span::call_site()),
705 };
706 let _log_level = if reply_port.is_handle {
707 quote! {
708 tracing::Level::#log_level
709 }
710 } else {
711 quote! {
712 tracing::Level::TRACE
713 }
714 };
715 let log_message = quote! {
716 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
717 "rpc" => "call",
718 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
719 "message_type" => stringify!(#enum_name),
720 "variant" => stringify!(#variant_name_snake),
721 ));
722 };
723
724 handler_trait_methods.push(quote! {
725 #[doc = "The generated handler method for this enum variant."]
726 async fn #variant_name_snake(
727 &mut self,
728 cx: &hyperactor::Context<Self>,
729 #(#arg_names: #arg_types),*)
730 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
731 });
732
733 client_trait_methods.push(quote! {
734 #[doc = "The generated client method for this enum variant."]
735 async fn #variant_name_snake(
736 &self,
737 cx: &impl hyperactor::context::Actor,
738 #(#arg_names: #arg_types),*)
739 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
740
741 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
742 async fn #variant_name_snake_deprecated(
743 &self,
744 cx: &impl hyperactor::context::Actor,
745 #(#arg_names: #arg_types),*)
746 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
747 });
748
749 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
750 let constructor = variant.constructor();
751 let result_ident = Ident::new("result", Span::mixed_site());
752 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
753 match_arms.push(quote! {
754 #constructor => {
755 #log_message
756 #construct_result_future
759 use hyperactor::Endpoint as _;
760 #reply_port_arg.post(cx, #result_ident);
761 Ok(())
762 }
763 });
764 }
765 Message::OneWay { variant, log_level } => {
766 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
767 let variant_name_snake = variant.snake_name();
768 let variant_name_snake_deprecated =
769 format_ident!("{}_deprecated", variant_name_snake);
770 let enum_name = variant.enum_name();
771 let log_level = match (&global_log_level, log_level) {
772 (_, Some(local)) => local.clone(),
773 (Some(global), None) => global.clone(),
774 _ => Ident::new("TRACE", Span::call_site()),
775 };
776 let _log_level = quote! {
777 tracing::Level::#log_level
778 };
779 let log_message = quote! {
780 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
781 "rpc" => "call",
782 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
783 "message_type" => stringify!(#enum_name),
784 "variant" => stringify!(#variant_name_snake),
785 ));
786 };
787
788 handler_trait_methods.push(quote! {
789 #[doc = "The generated handler method for this enum variant."]
790 async fn #variant_name_snake(
791 &mut self,
792 cx: &hyperactor::Context<Self>,
793 #(#arg_names: #arg_types),*)
794 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
795 });
796
797 client_trait_methods.push(quote! {
798 #[doc = "The generated client method for this enum variant."]
799 async fn #variant_name_snake(
800 &self,
801 cx: &impl hyperactor::context::Actor,
802 #(#arg_names: #arg_types),*)
803 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
804
805 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
806 async fn #variant_name_snake_deprecated(
807 &self,
808 cx: &impl hyperactor::context::Actor,
809 #(#arg_names: #arg_types),*)
810 -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
811 });
812
813 let constructor = variant.constructor();
814
815 match_arms.push(quote! {
816 #constructor => {
817 #log_message
818 self.#variant_name_snake(cx, #(#arg_names),*).await
819 },
820 });
821 }
822 }
823 }
824
825 let handler_trait_name = format_ident!("{}Handler", name);
826 let client_trait_name = format_ident!("{}Client", name);
827
828 let mut handler_generics = input.generics.clone();
832 for param in handler_generics.type_params_mut() {
833 param.bounds.push(syn::parse_quote!(serde::Serialize));
834 param
835 .bounds
836 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
837 param.bounds.push(syn::parse_quote!(Send));
838 param.bounds.push(syn::parse_quote!(Sync));
839 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
840 param.bounds.push(syn::parse_quote!(typeuri::Named));
841 }
842 let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
843 let (client_impl_generics, _, _) = input.generics.split_for_impl();
844
845 let expanded = quote! {
846 #[doc = "The custom handler trait for this message type."]
847 #[hyperactor::internal_macro_support::async_trait::async_trait]
848 pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync {
849 #(#handler_trait_methods)*
850
851 #[doc = "Handle the next message."]
852 async fn handle(
853 &mut self,
854 cx: &hyperactor::Context<Self>,
855 message: #name #ty_generics,
856 ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
857 match message {
859 #(#match_arms)*
860 }
861 }
862 }
863
864 #[doc = "The custom client trait for this message type."]
865 #[hyperactor::internal_macro_support::async_trait::async_trait]
866 pub trait #client_trait_name #client_impl_generics: Send + Sync {
867 #(#client_trait_methods)*
868 }
869 };
870
871 TokenStream::from(expanded)
872}
873
874#[proc_macro_derive(HandleClient, attributes(log_level))]
877pub fn derive_handle_client(input: TokenStream) -> TokenStream {
878 derive_client(input, true)
879}
880
881#[proc_macro_derive(RefClient, attributes(log_level))]
884pub fn derive_ref_client(input: TokenStream) -> TokenStream {
885 derive_client(input, false)
886}
887
888fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
889 let input = parse_macro_input!(input as DeriveInput);
890 let name = input.ident.clone();
891
892 let messages = match parse_messages(input.clone()) {
893 Ok(messages) => messages,
894 Err(err) => return TokenStream::from(err.to_compile_error()),
895 };
896
897 let mut impl_methods = Vec::new();
899
900 let send_message = quote! { hyperactor::Endpoint::post(self, cx, message); };
901 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
902
903 for message in &messages {
904 match message {
905 Message::Call {
906 variant,
907 reply_port,
908 return_type,
909 log_level,
910 } => {
911 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
912 let variant_name_snake = variant.snake_name();
913 let variant_name_snake_deprecated =
914 format_ident!("{}_deprecated", variant_name_snake);
915 let enum_name = variant.enum_name();
916
917 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
918 let constructor = variant.constructor();
919 let log_level = match (&global_log_level, log_level) {
920 (_, Some(local)) => local.clone(),
921 (Some(global), None) => global.clone(),
922 _ => Ident::new("DEBUG", Span::call_site()),
923 };
924 let log_level = if is_handle {
925 quote! {
926 tracing::Level::#log_level
927 }
928 } else {
929 quote! {
930 tracing::Level::TRACE
931 }
932 };
933 let log_message = quote! {
934 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
935 "rpc" => "call",
936 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
937 "message_type" => stringify!(#enum_name),
938 "variant" => stringify!(#variant_name_snake),
939 ));
940
941 };
942 let open_port = reply_port.open_op();
943 let rx_mod = reply_port.rx_modifier();
944 if reply_port.is_handle {
945 impl_methods.push(quote! {
946 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
947 async fn #variant_name_snake(
948 &self,
949 cx: &impl hyperactor::context::Actor,
950 #(#arg_names: #arg_types),*)
951 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
952 let (#reply_port_arg, #rx_mod reply_receiver) =
953 #open_port::<#return_type>(cx);
954 let message = #constructor;
955 #log_message;
956 #send_message;
957 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
958 }
959
960 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
961 async fn #variant_name_snake_deprecated(
962 &self,
963 cx: &impl hyperactor::context::Actor,
964 #(#arg_names: #arg_types),*)
965 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
966 let (#reply_port_arg, #rx_mod reply_receiver) =
967 #open_port::<#return_type>(cx);
968 let message = #constructor;
969 #log_message;
970 #send_message;
971 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
972 }
973 });
974 } else {
975 impl_methods.push(quote! {
976 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
977 async fn #variant_name_snake(
978 &self,
979 cx: &impl hyperactor::context::Actor,
980 #(#arg_names: #arg_types),*)
981 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
982 let (#reply_port_arg, #rx_mod reply_receiver) =
983 #open_port::<#return_type>(cx);
984 let #reply_port_arg = #reply_port_arg.bind();
985 let message = #constructor;
986 #log_message;
987 #send_message;
988 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
989 }
990
991 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
992 async fn #variant_name_snake_deprecated(
993 &self,
994 cx: &impl hyperactor::context::Actor,
995 #(#arg_names: #arg_types),*)
996 -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
997 let (#reply_port_arg, #rx_mod reply_receiver) =
998 #open_port::<#return_type>(cx);
999 let #reply_port_arg = #reply_port_arg.bind();
1000 let message = #constructor;
1001 #log_message;
1002 #send_message;
1003 reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
1004 }
1005 });
1006 }
1007 }
1008 Message::OneWay { variant, log_level } => {
1009 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1010 let variant_name_snake = variant.snake_name();
1011 let variant_name_snake_deprecated =
1012 format_ident!("{}_deprecated", variant_name_snake);
1013 let enum_name = variant.enum_name();
1014 let constructor = variant.constructor();
1015 let log_level = match (&global_log_level, log_level) {
1016 (_, Some(local)) => local.clone(),
1017 (Some(global), None) => global.clone(),
1018 _ => Ident::new("DEBUG", Span::call_site()),
1019 };
1020 let _log_level = if is_handle {
1021 quote! {
1022 tracing::Level::TRACE
1023 }
1024 } else {
1025 quote! {
1026 tracing::Level::#log_level
1027 }
1028 };
1029 let log_message = quote! {
1030 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1031 "rpc" => "oneway",
1032 "actor_id" => self.actor_addr().to_string(),
1033 "message_type" => stringify!(#enum_name),
1034 "variant" => stringify!(#variant_name_snake),
1035 ));
1036 };
1037 impl_methods.push(quote! {
1038 async fn #variant_name_snake(
1039 &self,
1040 cx: &impl hyperactor::context::Actor,
1041 #(#arg_names: #arg_types),*)
1042 -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1043 let message = #constructor;
1044 #log_message;
1045 #send_message;
1046 Ok(())
1047 }
1048
1049 async fn #variant_name_snake_deprecated(
1050 &self,
1051 cx: &impl hyperactor::context::Actor,
1052 #(#arg_names: #arg_types),*)
1053 -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1054 let message = #constructor;
1055 #log_message;
1056 #send_message;
1057 Ok(())
1058 }
1059 });
1060 }
1061 }
1062 }
1063
1064 let trait_name = format_ident!("{}Client", name);
1065
1066 let (_, ty_generics, _) = input.generics.split_for_impl();
1067
1068 let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1070 let mut trait_generics = input.generics.clone();
1071 trait_generics.params.insert(
1072 0,
1073 syn::GenericParam::Type(syn::TypeParam {
1074 ident: actor_ident.clone(),
1075 attrs: vec![],
1076 colon_token: None,
1077 bounds: Punctuated::new(),
1078 eq_token: None,
1079 default: None,
1080 }),
1081 );
1082
1083 for param in trait_generics.type_params_mut() {
1084 if param.ident == actor_ident {
1085 continue;
1086 }
1087 param.bounds.push(syn::parse_quote!(serde::Serialize));
1088 param
1089 .bounds
1090 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1091 param.bounds.push(syn::parse_quote!(Send));
1092 param.bounds.push(syn::parse_quote!(Sync));
1093 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1094 param.bounds.push(syn::parse_quote!(typeuri::Named));
1095 }
1096
1097 let (impl_generics, _, _) = trait_generics.split_for_impl();
1098
1099 let expanded = if is_handle {
1100 quote! {
1101 #[hyperactor::internal_macro_support::async_trait::async_trait]
1102 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1103 where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1104 #(#impl_methods)*
1105 }
1106 }
1107 } else {
1108 quote! {
1109 #[hyperactor::internal_macro_support::async_trait::async_trait]
1110 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1111 where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1112 #(#impl_methods)*
1113 }
1114 }
1115 };
1116
1117 TokenStream::from(expanded)
1118}
1119
1120const HANDLE_ARGUMENT_ERROR: &str = indoc! {r#"
1121`handle` expects the message type that is being handled
1122
1123= help: use `#[handle(MessageType)]`
1124"#};
1125
1126#[proc_macro_attribute]
1128pub fn handle(attr: TokenStream, item: TokenStream) -> TokenStream {
1129 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1130 if attr_args.len() != 1 {
1131 return TokenStream::from(
1132 syn::Error::new_spanned(attr_args, HANDLE_ARGUMENT_ERROR).to_compile_error(),
1133 );
1134 }
1135
1136 let message_type = attr_args.first().unwrap();
1137 let input = parse_macro_input!(item as ItemImpl);
1138
1139 let self_type = match *input.self_ty {
1140 syn::Type::Path(ref type_path) => {
1141 let segment = type_path.path.segments.last().unwrap();
1142 segment.clone() }
1144 _ => {
1145 return TokenStream::from(
1146 syn::Error::new_spanned(input.self_ty, "`handle` argument must be a type")
1147 .to_compile_error(),
1148 );
1149 }
1150 };
1151
1152 let trait_name = match input.trait_ {
1153 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1154 None => {
1155 return TokenStream::from(
1156 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1157 .to_compile_error(),
1158 );
1159 }
1160 };
1161
1162 let expanded = quote! {
1163 #input
1164
1165 #[hyperactor::internal_macro_support::async_trait::async_trait]
1166 impl hyperactor::Handler<#message_type> for #self_type {
1167 async fn handle(
1168 &mut self,
1169 cx: &hyperactor::Context<Self>,
1170 message: #message_type,
1171 ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
1172 <Self as #trait_name>::handle(self, cx, message).await
1173 }
1174 }
1175 };
1176
1177 TokenStream::from(expanded)
1178}
1179
1180#[proc_macro_attribute]
1193pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1194 let args =
1195 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1196 let input = parse_macro_input!(input as ItemFn);
1197 let output = quote! {
1198 #[hyperactor::internal_macro_support::tracing::instrument(err, skip_all, #args)]
1199 #input
1200 };
1201
1202 TokenStream::from(output)
1203}
1204
1205#[proc_macro_attribute]
1216pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1217 let args =
1218 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1219 let input = parse_macro_input!(input as ItemFn);
1220
1221 let output = quote! {
1222 #[hyperactor::internal_macro_support::tracing::instrument(skip_all, #args)]
1223 #input
1224 };
1225
1226 TokenStream::from(output)
1227}
1228
1229struct HandlerSpec {
1230 ty: Type,
1231 cast: bool,
1232}
1233
1234impl Parse for HandlerSpec {
1235 fn parse(input: ParseStream) -> syn::Result<Self> {
1236 let ty: Type = input.parse()?;
1237
1238 if input.peek(syn::token::Brace) {
1239 let content;
1240 syn::braced!(content in input);
1241 let key: Ident = content.parse()?;
1242 content.parse::<Token![=]>()?;
1243 let expr: Expr = content.parse()?;
1244
1245 let cast = if key == "cast" {
1246 if let Expr::Lit(ExprLit {
1247 lit: Lit::Bool(b), ..
1248 }) = expr
1249 {
1250 b.value
1251 } else {
1252 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1253 }
1254 } else {
1255 return Err(syn::Error::new_spanned(
1256 key,
1257 "unsupported field (expected `cast`)",
1258 ));
1259 };
1260
1261 Ok(HandlerSpec { ty, cast })
1262 } else if input.is_empty() || input.peek(Token![,]) {
1263 Ok(HandlerSpec { ty, cast: false })
1264 } else {
1265 let unexpected: proc_macro2::TokenTree = input.parse()?;
1267 Err(syn::Error::new_spanned(
1268 unexpected,
1269 "unexpected token after type — expected `{ ... }` or nothing",
1270 ))
1271 }
1272 }
1273}
1274
1275impl HandlerSpec {
1276 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1277 let mut tys = Vec::new();
1278 for HandlerSpec { ty, cast } in handlers {
1279 if cast {
1280 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1281 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1282 tys.push(wrapped_ty);
1283 }
1284 tys.push(ty);
1285 }
1286 tys
1287 }
1288}
1289
1290fn named_impl(data_type_name: &Ident, generics: &syn::Generics) -> proc_macro2::TokenStream {
1291 let generics_with_bounds = generics_with_named_bounds(generics);
1292 let type_params: Vec<_> = generics.type_params().collect();
1293 let has_generics = !type_params.is_empty();
1294
1295 let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1296 let (_, ty_generics, where_clause) = generics.split_for_impl();
1297
1298 let (typename_impl, typehash_impl) = if has_generics {
1299 let placeholders = vec!["{}"; type_params.len()].join(", ");
1300 let placeholders_format_string = format!("<{}>", placeholders);
1301 let format_string = quote! {
1302 concat!(
1303 std::module_path!(),
1304 "::",
1305 stringify!(#data_type_name),
1306 #placeholders_format_string
1307 )
1308 };
1309 let type_param_idents: Vec<_> = type_params.iter().map(|param| ¶m.ident).collect();
1310 (
1311 quote! {
1312 typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1313 },
1314 quote! {
1315 typeuri::cityhasher::hash(Self::typename())
1316 },
1317 )
1318 } else {
1319 (
1320 quote! {
1321 concat!(std::module_path!(), "::", stringify!(#data_type_name))
1322 },
1323 quote! {
1324 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1325 typeuri::cityhasher::hash(<#data_type_name as typeuri::Named>::typename())
1326 });
1327 *TYPEHASH
1328 },
1329 )
1330 };
1331
1332 quote! {
1333 impl #impl_generics_with_bounds typeuri::Named for #data_type_name #ty_generics #where_clause {
1334 fn typename() -> &'static str {
1335 #typename_impl
1336 }
1337
1338 fn typehash() -> u64 {
1339 #typehash_impl
1340 }
1341 }
1342 }
1343}
1344
1345fn generics_with_named_bounds(generics: &syn::Generics) -> syn::Generics {
1346 let mut generics = generics.clone();
1347 for param in generics.type_params_mut() {
1348 param.bounds.push(syn::parse_quote!(typeuri::Named));
1349 }
1350 generics
1351}
1352
1353fn generics_with_predicates(
1354 generics: &syn::Generics,
1355 predicates: impl IntoIterator<Item = WherePredicate>,
1356) -> syn::Generics {
1357 let mut generics = generics.clone();
1358 generics.make_where_clause().predicates.extend(predicates);
1359 generics
1360}
1361
1362struct ExportAttr {
1364 handlers: Vec<HandlerSpec>,
1365}
1366
1367impl Parse for ExportAttr {
1368 fn parse(input: ParseStream) -> syn::Result<Self> {
1369 if input.is_empty() {
1370 return Ok(Self {
1371 handlers: Vec::new(),
1372 });
1373 }
1374
1375 let compatibility_form = {
1376 let fork = input.fork();
1377 fork.parse::<Ident>().is_ok() && fork.parse::<Token![=]>().is_ok()
1378 };
1379
1380 if !compatibility_form {
1381 let handlers = input
1382 .parse_terminated(HandlerSpec::parse, Token![,])?
1383 .into_iter()
1384 .collect();
1385 return Ok(Self { handlers });
1386 }
1387
1388 let mut handlers: Vec<HandlerSpec> = vec![];
1389
1390 while !input.is_empty() {
1391 let key: Ident = input.parse()?;
1392 input.parse::<Token![=]>()?;
1393
1394 if key == "spawn" {
1395 let expr: Expr = input.parse()?;
1396 return Err(syn::Error::new_spanned(
1397 expr,
1398 "`spawn = true` is no longer supported; use `#[spawnable]` on concrete actor declarations or `hyperactor::register_spawnable!(ConcreteType)` for generic instantiations",
1399 ));
1400 } else if key == "handlers" {
1401 let content;
1402 bracketed!(content in input);
1403 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1404 handlers = raw_handlers.into_iter().collect();
1405 } else {
1406 return Err(syn::Error::new_spanned(
1407 key,
1408 "unexpected key in `#[export(...)]`. Use direct handler lists, or the compatibility key `handlers`",
1409 ));
1410 }
1411
1412 let _ = input.parse::<Token![,]>();
1414 }
1415
1416 Ok(ExportAttr { handlers })
1417 }
1418}
1419
1420#[proc_macro_attribute]
1436pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1437 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1438 let data_type_name = &input.ident;
1439 let (_, ty_generics, _) = input.generics.split_for_impl();
1440 let named_generics = generics_with_named_bounds(&input.generics);
1441 let (named_impl_generics, named_ty_generics, named_where_clause) =
1442 named_generics.split_for_impl();
1443
1444 let ExportAttr { handlers } = parse_macro_input!(attr as ExportAttr);
1445
1446 let mut handles = Vec::new();
1447 let mut bindings = Vec::new();
1448 let mut bind_predicates = Vec::new();
1449 let actor_ty: Type = syn::parse_quote!(#data_type_name #ty_generics);
1450
1451 for HandlerSpec { ty, cast } in &handlers {
1452 let message_generics = generics_with_predicates(
1453 &named_generics,
1454 [syn::parse_quote!(#ty: hyperactor::RemoteMessage)],
1455 );
1456 let (message_impl_generics, message_ty_generics, message_where_clause) =
1457 message_generics.split_for_impl();
1458 handles.push(quote! {
1459 impl #message_impl_generics hyperactor::actor::RemoteHandles<#ty>
1460 for #data_type_name #message_ty_generics #message_where_clause {}
1461 impl #message_impl_generics hyperactor::remote::Accepts<#ty>
1462 for #data_type_name #message_ty_generics #message_where_clause {}
1463 });
1464 bindings.push(quote! {
1465 ports.bind::<#ty>();
1466 });
1467 bind_predicates.push(syn::parse_quote!(#ty: hyperactor::RemoteMessage));
1468 bind_predicates.push(syn::parse_quote!(#actor_ty: hyperactor::Handler<#ty>));
1469
1470 if *cast {
1471 let indexed_ty: Type =
1472 syn::parse_quote!(hyperactor::message::IndexedErasedUnbound<#ty>);
1473 let indexed_generics = generics_with_predicates(
1474 &named_generics,
1475 [
1476 syn::parse_quote!(#ty: hyperactor::message::Castable),
1477 syn::parse_quote!(#indexed_ty: hyperactor::RemoteMessage),
1478 ],
1479 );
1480 let (indexed_impl_generics, indexed_ty_generics, indexed_where_clause) =
1481 indexed_generics.split_for_impl();
1482 handles.push(quote! {
1483 impl #indexed_impl_generics hyperactor::actor::RemoteHandles<#indexed_ty>
1484 for #data_type_name #indexed_ty_generics #indexed_where_clause {}
1485 impl #indexed_impl_generics hyperactor::remote::Accepts<#indexed_ty>
1486 for #data_type_name #indexed_ty_generics #indexed_where_clause {}
1487 });
1488 bindings.push(quote! {
1489 ports.bind::<#indexed_ty>();
1490 });
1491 bind_predicates.push(syn::parse_quote!(#ty: hyperactor::message::Castable));
1492 bind_predicates.push(syn::parse_quote!(#indexed_ty: hyperactor::RemoteMessage));
1493 }
1494 }
1495
1496 let bind_generics = generics_with_predicates(&named_generics, bind_predicates);
1497 let (bind_impl_generics, bind_ty_generics, bind_where_clause) = bind_generics.split_for_impl();
1498 let named_impl = named_impl(data_type_name, &input.generics);
1499
1500 let expanded = quote! {
1501 #input
1502
1503 impl #named_impl_generics hyperactor::actor::Referable for #data_type_name #named_ty_generics #named_where_clause {}
1504
1505 #(#handles)*
1506
1507 impl #named_impl_generics hyperactor::actor::RemoteHandles<hyperactor::introspect::IntrospectMessage> for #data_type_name #named_ty_generics #named_where_clause {}
1509 impl #named_impl_generics hyperactor::remote::Accepts<hyperactor::introspect::IntrospectMessage> for #data_type_name #named_ty_generics #named_where_clause {}
1510
1511 impl #bind_impl_generics hyperactor::actor::Binds<#data_type_name #bind_ty_generics> for #data_type_name #bind_ty_generics #bind_where_clause {
1512 fn bind(ports: &hyperactor::proc::HandlerPorts<Self>) {
1513 #(#bindings)*
1514 }
1515 }
1516
1517 #named_impl
1518 };
1519
1520 TokenStream::from(expanded)
1521}
1522
1523#[proc_macro_attribute]
1527pub fn spawnable(attr: TokenStream, item: TokenStream) -> TokenStream {
1528 if !attr.is_empty() {
1529 return syn::Error::new(Span::call_site(), "`#[spawnable]` does not take arguments")
1530 .to_compile_error()
1531 .into();
1532 }
1533
1534 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1535 if !matches!(input.data, Data::Struct(_)) {
1536 return syn::Error::new(
1537 input.span(),
1538 "`#[spawnable]` only supports struct actor declarations",
1539 )
1540 .to_compile_error()
1541 .into();
1542 }
1543
1544 if !input.generics.params.is_empty() {
1545 return syn::Error::new(
1546 input.generics.span(),
1547 "generic actor families cannot use `#[spawnable]`; use `hyperactor::register_spawnable!(ConcreteType)` instead",
1548 )
1549 .to_compile_error()
1550 .into();
1551 }
1552
1553 let data_type_name = &input.ident;
1554 quote! {
1555 #input
1556 hyperactor::register_spawnable!(#data_type_name);
1557 }
1558 .into()
1559}
1560
1561struct BehaviorInput {
1563 behavior: Ident,
1564 generics: syn::Generics,
1565 handlers: Vec<HandlerSpec>,
1566}
1567
1568impl syn::parse::Parse for BehaviorInput {
1569 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1570 let behavior: Ident = input.parse()?;
1571 let generics: syn::Generics = input.parse()?;
1572 let _: Token![,] = input.parse()?;
1573 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1574 let handlers = raw_handlers.into_iter().collect();
1575 Ok(BehaviorInput {
1576 behavior,
1577 generics,
1578 handlers,
1579 })
1580 }
1581}
1582
1583#[proc_macro]
1608pub fn behavior(input: TokenStream) -> TokenStream {
1609 let BehaviorInput {
1610 behavior,
1611 generics,
1612 handlers,
1613 } = parse_macro_input!(input as BehaviorInput);
1614 let tys = HandlerSpec::add_indexed(handlers);
1615
1616 let mut bounded_generics = generics.clone();
1618 for param in bounded_generics.type_params_mut() {
1619 param.bounds.push(syn::parse_quote!(typeuri::Named));
1620 param.bounds.push(syn::parse_quote!(serde::Serialize));
1621 param.bounds.push(syn::parse_quote!(std::marker::Send));
1622 param.bounds.push(syn::parse_quote!(std::marker::Sync));
1623 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1624 let lifetime =
1627 syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1628 param
1629 .bounds
1630 .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1631 }
1632
1633 let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1635
1636 let mut binds_generics = bounded_generics.clone();
1638 binds_generics.params.insert(
1639 0,
1640 syn::GenericParam::Type(syn::TypeParam {
1641 attrs: vec![],
1642 ident: Ident::new("A", proc_macro2::Span::call_site()),
1643 colon_token: None,
1644 bounds: Punctuated::new(),
1645 eq_token: None,
1646 default: None,
1647 }),
1648 );
1649 let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1650
1651 let type_params: Vec<_> = bounded_generics.type_params().collect();
1653 let has_generics = !type_params.is_empty();
1654
1655 let (typename_impl, typehash_impl) = if has_generics {
1656 let placeholders = vec!["{}"; type_params.len()].join(", ");
1658 let placeholders_format_string = format!("<{}>", placeholders);
1659 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1660
1661 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1662 (
1663 quote! {
1664 typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1665 },
1666 quote! {
1667 typeuri::cityhasher::hash(Self::typename())
1668 },
1669 )
1670 } else {
1671 (
1672 quote! {
1673 concat!(std::module_path!(), "::", stringify!(#behavior))
1674 },
1675 quote! {
1676 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1677 typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1678 });
1679 *TYPEHASH
1680 },
1681 )
1682 };
1683
1684 let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1685
1686 let expanded = quote! {
1687 #[doc = "The generated behavior struct."]
1688 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1689 pub struct #behavior #impl_generics #where_clause {
1690 _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1691 }
1692
1693 impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1694 fn typename() -> &'static str {
1695 #typename_impl
1696 }
1697
1698 fn typehash() -> u64 {
1699 #typehash_impl
1700 }
1701 }
1702
1703 impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1704
1705 impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1706 where
1707 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1708 #where_clause
1709 {
1710 fn bind(ports: &hyperactor::proc::HandlerPorts<A>) {
1711 #(
1712 ports.bind::<#tys>();
1713 )*
1714 }
1715 }
1716
1717 #(
1718 impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1719 impl #impl_generics hyperactor::remote::Accepts<#tys> for #behavior #ty_generics #where_clause {}
1720 )*
1721 };
1722
1723 TokenStream::from(expanded)
1724}
1725
1726fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1727 let mut is_included = false;
1728 for attr in &field.attrs {
1729 if attr.path().is_ident("binding") {
1730 attr.parse_nested_meta(|meta| {
1732 if meta.path.is_ident("include") {
1733 is_included = true;
1734 Ok(())
1735 } else {
1736 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1737 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1738 }
1739 })?
1740 }
1741 }
1742 Ok(is_included)
1743}
1744
1745enum FieldAccessor {
1750 Named(Ident),
1751 Unnamed(Index),
1752}
1753
1754struct ParsedField {
1756 accessor: FieldAccessor,
1757 ty: Type,
1758 included: bool,
1759}
1760
1761impl From<&ParsedField> for (Ident, Type) {
1762 fn from(field: &ParsedField) -> Self {
1763 let field_ident = match &field.accessor {
1764 FieldAccessor::Named(ident) => ident.clone(),
1765 FieldAccessor::Unnamed(i) => {
1766 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1767 }
1768 };
1769 (field_ident, field.ty.clone())
1770 }
1771}
1772
1773fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1774 match fields {
1775 Fields::Named(named) => named
1776 .named
1777 .iter()
1778 .map(|f| {
1779 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1780 Ok(ParsedField {
1781 accessor,
1782 ty: f.ty.clone(),
1783 included: include_in_bind_unbind(f)?,
1784 })
1785 })
1786 .collect(),
1787 Fields::Unnamed(unnamed) => unnamed
1788 .unnamed
1789 .iter()
1790 .enumerate()
1791 .map(|(i, f)| {
1792 let accessor = FieldAccessor::Unnamed(Index::from(i));
1793 Ok(ParsedField {
1794 accessor,
1795 ty: f.ty.clone(),
1796 included: include_in_bind_unbind(f)?,
1797 })
1798 })
1799 .collect(),
1800 Fields::Unit => Ok(Vec::new()),
1801 }
1802}
1803
1804fn gen_struct_items<F>(
1805 fields: &Fields,
1806 make_item: F,
1807 is_mutable: bool,
1808) -> syn::Result<Vec<proc_macro2::TokenStream>>
1809where
1810 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1811{
1812 let borrow = if is_mutable {
1813 quote! { &mut }
1814 } else {
1815 quote! { & }
1816 };
1817 let items: Vec<_> = collect_all_fields(fields)?
1818 .into_iter()
1819 .filter(|f| f.included)
1820 .map(
1821 |ParsedField {
1822 accessor,
1823 ty,
1824 included,
1825 }| {
1826 assert!(included);
1827 let field_accessor = match accessor {
1828 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1829 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1830 };
1831 make_item(field_accessor, ty)
1832 },
1833 )
1834 .collect();
1835 Ok(items)
1836}
1837
1838fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1848 all_fields
1849 .iter()
1850 .map(
1851 |ParsedField {
1852 accessor,
1853 ty: _,
1854 included,
1855 }| {
1856 match accessor {
1857 FieldAccessor::Named(ident) => {
1858 if *included {
1859 quote! { #ident }
1860 } else {
1861 quote! { #ident: _ }
1862 }
1863 }
1864 FieldAccessor::Unnamed(i) => {
1865 if *included {
1866 let ident = Ident::new(
1867 &format!("f{}", i.index),
1868 proc_macro2::Span::call_site(),
1869 );
1870 quote! { #ident }
1871 } else {
1872 quote! { _ }
1873 }
1874 }
1875 }
1876 },
1877 )
1878 .collect()
1879}
1880
1881fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1888where
1889 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1890{
1891 data.variants
1892 .iter()
1893 .map(|variant| {
1894 let name = &variant.ident;
1895 let all_fields = collect_all_fields(&variant.fields)?;
1896 let field_accessors = gen_enum_field_accessors(&all_fields);
1897 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1898 let items = included_fields
1899 .iter()
1900 .map(|f| {
1901 let (accessor, ty) = <(Ident, Type)>::from(*f);
1902 make_item(quote! { #accessor }, ty)
1903 })
1904 .collect::<Vec<_>>();
1905
1906 Ok(match &variant.fields {
1907 Fields::Named(_) => {
1908 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1909 }
1910 Fields::Unnamed(_) => {
1911 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1912 }
1913 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1914 })
1915 })
1916 .collect()
1917}
1918
1919#[proc_macro_derive(Bind, attributes(binding))]
2001pub fn derive_bind(input: TokenStream) -> TokenStream {
2002 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2003 quote! {
2004 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
2005 }
2006 }
2007
2008 let input = parse_macro_input!(input as DeriveInput);
2009 let name = &input.ident;
2010 let inner = match &input.data {
2011 Data::Struct(DataStruct { fields, .. }) => {
2012 match gen_struct_items(fields, make_item, true) {
2013 Ok(collects) => {
2014 quote! { #(#collects)* }
2015 }
2016 Err(e) => {
2017 return TokenStream::from(e.to_compile_error());
2018 }
2019 }
2020 }
2021 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2022 Ok(arms) => {
2023 quote! { match self { #(#arms),* } }
2024 }
2025 Err(e) => {
2026 return TokenStream::from(e.to_compile_error());
2027 }
2028 },
2029 _ => panic!("Bind can only be derived for structs and enums"),
2030 };
2031 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2032 let expand = quote! {
2033 #[automatically_derived]
2034 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
2035 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2036 #inner
2037 Ok(())
2038 }
2039 }
2040 };
2041 TokenStream::from(expand)
2042}
2043
2044#[proc_macro_derive(Unbind, attributes(binding))]
2058pub fn derive_unbind(input: TokenStream) -> TokenStream {
2059 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2060 quote! {
2061 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
2062 }
2063 }
2064
2065 let input = parse_macro_input!(input as DeriveInput);
2066 let name = &input.ident;
2067 let inner = match &input.data {
2068 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
2069 {
2070 Ok(collects) => {
2071 quote! { #(#collects)* }
2072 }
2073 Err(e) => {
2074 return TokenStream::from(e.to_compile_error());
2075 }
2076 },
2077 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2078 Ok(arms) => {
2079 quote! { match self { #(#arms),* } }
2080 }
2081 Err(e) => {
2082 return TokenStream::from(e.to_compile_error());
2083 }
2084 },
2085 _ => panic!("Unbind can only be derived for structs and enums"),
2086 };
2087 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2088 let expand = quote! {
2089 #[automatically_derived]
2090 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
2091 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2092 #inner
2093 Ok(())
2094 }
2095 }
2096 };
2097 TokenStream::from(expand)
2098}
2099
2100fn parse_observe_function(
2102 attr: TokenStream,
2103 item: TokenStream,
2104) -> syn::Result<(ItemFn, String, String)> {
2105 let input = syn::parse::<ItemFn>(item)?;
2106
2107 if input.sig.asyncness.is_none() {
2108 return Err(syn::Error::new(
2109 input.sig.span(),
2110 "observe macros can only be applied to async functions",
2111 ));
2112 }
2113
2114 let fn_name_str = input.sig.ident.to_string();
2115 let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
2116
2117 Ok((input, fn_name_str, module_name_str))
2118}
2119
2120fn create_telemetry_setup(
2122 module_name_str: &str,
2123 fn_name_str: &str,
2124 include_error: bool,
2125) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2126 let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2127 let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2128
2129 let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2130
2131 let error_ident = if include_error {
2132 Some(Ident::new(
2133 "error",
2134 Span::from(proc_macro::Span::def_site()),
2135 ))
2136 } else {
2137 None
2138 };
2139
2140 let error_declaration = if let Some(ref error_ident) = error_ident {
2141 quote! {
2142 hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2143 }
2144 } else {
2145 quote! {}
2146 };
2147
2148 let setup_code = quote! {
2149 use hyperactor_telemetry;
2150 hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2151 hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2152 #error_declaration
2153 };
2154
2155 (latency_ident, success_ident, error_ident, setup_code)
2156}
2157
2158#[proc_macro_attribute]
2178pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2179 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2180 Ok(parsed) => parsed,
2181 Err(err) => return err.to_compile_error().into(),
2182 };
2183
2184 let fn_name = &input.sig.ident;
2185 let vis = &input.vis;
2186 let args = &input.sig.inputs;
2187 let return_type = &input.sig.output;
2188 let body = &input.block;
2189 let attrs = &input.attrs;
2190 let generics = &input.sig.generics;
2191
2192 let (latency_ident, success_ident, error_ident, telemetry_setup) =
2193 create_telemetry_setup(&module_name_str, &fn_name_str, true);
2194 let error_ident = error_ident.unwrap();
2195
2196 let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2197
2198 let expanded = quote! {
2200 #(#attrs)*
2201 #vis async fn #fn_name #generics(#args) #return_type {
2202 #telemetry_setup
2203
2204 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2205 let _timer = #latency_ident.start(kv_pairs);
2206
2207 let #result_ident = async #body.await;
2208
2209 match &#result_ident {
2210 Ok(_) => {
2211 #success_ident.add(
2212 1,
2213 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2214 );
2215 }
2216 Err(_) => {
2217 #error_ident.add(
2218 1,
2219 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2220 );
2221 }
2222 }
2223
2224 #result_ident
2225 }
2226 };
2227
2228 expanded.into()
2229}
2230
2231#[proc_macro_attribute]
2250pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2251 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2252 Ok(parsed) => parsed,
2253 Err(err) => return err.to_compile_error().into(),
2254 };
2255
2256 let fn_name = &input.sig.ident;
2257 let vis = &input.vis;
2258 let args = &input.sig.inputs;
2259 let return_type = &input.sig.output;
2260 let body = &input.block;
2261 let attrs = &input.attrs;
2262 let generics = &input.sig.generics;
2263
2264 let (latency_ident, success_ident, _, telemetry_setup) =
2265 create_telemetry_setup(&module_name_str, &fn_name_str, false);
2266
2267 let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2268
2269 let expanded = quote! {
2271 #(#attrs)*
2272 #vis async fn #fn_name #generics(#args) #return_type {
2273 #telemetry_setup
2274
2275 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2276 let _timer = #latency_ident.start(kv_pairs);
2277
2278 let #return_ident = async #body.await;
2279
2280 #success_ident.add(
2281 1,
2282 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2283 );
2284 #return_ident
2285 }
2286 };
2287
2288 expanded.into()
2289}
2290
2291fn validate_label(s: &str) -> Result<(), String> {
2292 if s.is_empty() {
2293 return Err("label must not be empty".to_string());
2294 }
2295 if s.len() > 63 {
2296 return Err("label exceeds 63 characters".to_string());
2297 }
2298 let first = s.as_bytes()[0];
2299 if !first.is_ascii_lowercase() {
2300 return Err("label must start with a lowercase letter".to_string());
2301 }
2302 let last = s.as_bytes()[s.len() - 1];
2303 if !last.is_ascii_lowercase() && !last.is_ascii_digit() {
2304 return Err("label must end with a lowercase letter or digit".to_string());
2305 }
2306 for ch in s.chars() {
2307 if !ch.is_ascii_lowercase() && !ch.is_ascii_digit() && ch != '-' {
2308 return Err(format!("label contains invalid character '{ch}'"));
2309 }
2310 }
2311 Ok(())
2312}
2313
2314fn validate_hex_uid(s: &str) -> Result<u64, String> {
2315 if s.is_empty() || s.len() > 16 {
2316 return Err(format!("hex uid must be 1-16 hex characters, got '{s}'"));
2317 }
2318 for ch in s.chars() {
2319 if !ch.is_ascii_hexdigit() {
2320 return Err(format!("hex uid contains invalid character '{ch}'"));
2321 }
2322 }
2323 u64::from_str_radix(s, 16).map_err(|e| format!("invalid hex uid '{s}': {e}"))
2324}
2325
2326#[proc_macro]
2332pub fn uid(input: TokenStream) -> TokenStream {
2333 let input2: proc_macro2::TokenStream = input.into();
2334 let combined: String = input2.into_iter().map(|tt| tt.to_string()).collect();
2335
2336 if combined.is_empty() {
2337 return TokenStream::from(quote! { compile_error!("uid! macro requires an argument") });
2338 }
2339
2340 if let Some(rest) = combined.strip_prefix('_') {
2342 return match validate_label(rest) {
2343 Ok(()) => TokenStream::from(quote! {
2344 hyperactor::id::Uid::Singleton(
2345 hyperactor::id::Label::new(#rest).unwrap()
2346 )
2347 }),
2348 Err(e) => {
2349 let msg = format!("invalid singleton uid: {e}");
2350 TokenStream::from(quote! { compile_error!(#msg) })
2351 }
2352 };
2353 }
2354
2355 match validate_hex_uid(&combined) {
2357 Ok(uid_val) => TokenStream::from(quote! {
2358 hyperactor::id::Uid::Instance(#uid_val, None)
2359 }),
2360 Err(e) => {
2361 let msg = format!("invalid uid: {e}");
2362 TokenStream::from(quote! { compile_error!(#msg) })
2363 }
2364 }
2365}