1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Meta;
39use syn::MetaNameValue;
40use syn::Token;
41use syn::Type;
42use syn::bracketed;
43use syn::parse::Parse;
44use syn::parse::ParseStream;
45use syn::parse_macro_input;
46use syn::punctuated::Punctuated;
47use syn::spanned::Spanned;
48
49const REPLY_VARIANT_ERROR: &str = indoc! {r#"
50`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
51
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
53= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
54"#};
55
56const REPLY_USAGE_ERROR: &str = indoc! {r#"
57`call` message expects at most one `reply` argument
58
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
60= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
61"#};
62
63enum FieldFlag {
64 None,
65 Reply,
66}
67
68#[allow(dead_code)]
70enum Variant {
71 Named {
73 enum_name: Ident,
74 name: Ident,
75 field_names: Vec<Ident>,
76 field_types: Vec<Type>,
77 field_flags: Vec<FieldFlag>,
78 is_struct: bool,
79 generics: syn::Generics,
80 },
81 Anon {
83 enum_name: Ident,
84 name: Ident,
85 field_types: Vec<Type>,
86 field_flags: Vec<FieldFlag>,
87 is_struct: bool,
88 generics: syn::Generics,
89 },
90}
91
92impl Variant {
93 fn len(&self) -> usize {
95 self.field_types().len()
96 }
97
98 fn is_struct(&self) -> bool {
100 match self {
101 Variant::Named { is_struct, .. } => *is_struct,
102 Variant::Anon { is_struct, .. } => *is_struct,
103 }
104 }
105
106 fn enum_name(&self) -> &Ident {
108 match self {
109 Variant::Named { enum_name, .. } => enum_name,
110 Variant::Anon { enum_name, .. } => enum_name,
111 }
112 }
113
114 fn name(&self) -> &Ident {
116 match self {
117 Variant::Named { name, .. } => name,
118 Variant::Anon { name, .. } => name,
119 }
120 }
121
122 #[allow(dead_code)]
124 fn generics(&self) -> &syn::Generics {
125 match self {
126 Variant::Named { generics, .. } => generics,
127 Variant::Anon { generics, .. } => generics,
128 }
129 }
130
131 fn snake_name(&self) -> Ident {
133 Ident::new(
134 &self.name().to_string().to_case(Case::Snake),
135 self.name().span(),
136 )
137 }
138
139 fn qualified_name(&self) -> proc_macro2::TokenStream {
141 let enum_name = self.enum_name();
142 let name = self.name();
143
144 if self.is_struct() {
145 quote! { #enum_name }
146 } else {
147 quote! { #enum_name::#name }
148 }
149 }
150
151 fn field_names(&self) -> Vec<Ident> {
154 match self {
155 Variant::Named { field_names, .. } => field_names.clone(),
156 Variant::Anon { field_types, .. } => (0usize..field_types.len())
157 .map(|idx| format_ident!("arg{}", idx))
158 .collect(),
159 }
160 }
161
162 fn field_types(&self) -> &Vec<Type> {
164 match self {
165 Variant::Named { field_types, .. } => field_types,
166 Variant::Anon { field_types, .. } => field_types,
167 }
168 }
169
170 fn field_flags(&self) -> &Vec<FieldFlag> {
172 match self {
173 Variant::Named { field_flags, .. } => field_flags,
174 Variant::Anon { field_flags, .. } => field_flags,
175 }
176 }
177
178 fn constructor(&self) -> proc_macro2::TokenStream {
180 let qualified_name = self.qualified_name();
181 let field_names = self.field_names();
182 match self {
183 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
184 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
185 }
186 }
187}
188
189struct ReplyPort {
190 is_handle: bool,
191 is_once: bool,
192}
193
194impl ReplyPort {
195 fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
196 ReplyPort {
197 is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
198 is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
199 }
200 }
201
202 fn open_op(&self) -> proc_macro2::TokenStream {
203 if self.is_once {
204 quote! { hyperactor::mailbox::open_once_port }
205 } else {
206 quote! { hyperactor::mailbox::open_port }
207 }
208 }
209
210 fn rx_modifier(&self) -> proc_macro2::TokenStream {
211 if self.is_once {
212 quote! {}
213 } else {
214 quote! { mut }
215 }
216 }
217}
218
219#[allow(clippy::large_enum_variant)]
222enum Message {
223 Call {
226 variant: Variant,
227 reply_port: ReplyPort,
229 return_type: Type,
231 log_level: Option<Ident>,
233 },
234 OneWay {
235 variant: Variant,
236 log_level: Option<Ident>,
238 },
239}
240
241impl Message {
242 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
243 match &variant
244 .field_flags()
245 .iter()
246 .zip(variant.field_types())
247 .filter_map(|(flag, ty)| match flag {
248 FieldFlag::Reply => Some(ty),
249 FieldFlag::None => None,
250 })
251 .collect::<Vec<&Type>>()[..]
252 {
253 [] => Ok(Self::OneWay { variant, log_level }),
254 [reply_port_ty] => {
255 let syn::Type::Path(type_path) = reply_port_ty else {
256 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
257 };
258 let Some(last_segment) = type_path.path.segments.last() else {
259 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
260 };
261 if last_segment.ident != "OncePortRef"
262 && last_segment.ident != "OncePortHandle"
263 && last_segment.ident != "PortRef"
264 && last_segment.ident != "PortHandle"
265 {
266 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
267 }
268 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
269 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
270 };
271 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
272 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
273 };
274 let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
275 let return_type = return_ty.clone();
276 Ok(Self::Call {
277 variant,
278 reply_port,
279 return_type,
280 log_level,
281 })
282 }
283 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
284 }
285 }
286
287 fn args(&self) -> Vec<(Ident, Type)> {
289 match self {
290 Message::Call { variant, .. } => variant
291 .field_names()
292 .into_iter()
293 .zip(variant.field_types().clone())
294 .take(variant.len() - 1)
295 .collect(),
296 Message::OneWay { variant, .. } => variant
297 .field_names()
298 .into_iter()
299 .zip(variant.field_types().clone())
300 .collect(),
301 }
302 }
303
304 fn variant(&self) -> &Variant {
305 match self {
306 Message::Call { variant, .. } => variant,
307 Message::OneWay { variant, .. } => variant,
308 }
309 }
310
311 fn reply_port_position(&self) -> Option<usize> {
312 self.variant()
313 .field_flags()
314 .iter()
315 .position(|flag| matches!(flag, FieldFlag::Reply))
316 }
317
318 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
320 match self {
321 Message::Call { variant, .. } => {
322 let pos = self.reply_port_position()?;
323 Some((
324 variant.field_names()[pos].clone(),
325 variant.field_types()[pos].clone(),
326 ))
327 }
328 Message::OneWay { .. } => None,
329 }
330 }
331}
332
333fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
334 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
335 Some(attr) => {
336 let Ok(meta) = attr.meta.require_list() else {
337 return Err(syn::Error::new(
338 Span::call_site(),
339 indoc! {"
340 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
341
342 = help use `#[log_level(info)]` or `#[log_level(error)]`
343 "},
344 ));
345 };
346 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
347 if parsed.len() != 1 {
348 return Err(syn::Error::new(
349 Span::call_site(),
350 indoc! {"
351 `log_level` attribute must specify exactly one level
352
353 = help use `#[log_level(warn)]` or `#[log_level(info)]`
354 "},
355 ));
356 };
357 Some(parsed.first().unwrap().to_string())
358 }
359 None => None,
360 };
361
362 if level.is_none() {
363 return Ok(None);
364 }
365 let level = level.unwrap();
366
367 match level.as_str() {
368 "error" | "warn" | "info" | "debug" | "trace" => {}
369 _ => {
370 return Err(syn::Error::new(
371 Span::call_site(),
372 indoc! {"
373 `log_level` attribute must be one of 'error, warn, info, debug, trace'
374
375 = help use `#[log_level(warn)]` or `#[log_level(info)]`
376 "},
377 ));
378 }
379 }
380
381 Ok(Some(Ident::new(
382 level.to_ascii_uppercase().as_str(),
383 Span::call_site(),
384 )))
385}
386
387fn parse_field_flag(field: &Field) -> FieldFlag {
388 for attr in field.attrs.iter() {
389 match &attr.meta {
390 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
391 _ => {}
392 }
393 }
394 FieldFlag::None
395}
396
397fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
399 match &input.data {
400 Data::Enum(data_enum) => {
401 let mut messages = Vec::new();
402
403 for variant in &data_enum.variants {
404 let name = variant.ident.clone();
405 let attrs = &variant.attrs;
406
407 let message_variant = match &variant.fields {
408 syn::Fields::Unnamed(fields_) => Variant::Anon {
409 enum_name: input.ident.clone(),
410 name,
411 field_types: fields_
412 .unnamed
413 .iter()
414 .map(|field| field.ty.clone())
415 .collect(),
416 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
417 is_struct: false,
418 generics: input.generics.clone(),
419 },
420 syn::Fields::Named(fields_) => Variant::Named {
421 enum_name: input.ident.clone(),
422 name,
423 field_names: fields_
424 .named
425 .iter()
426 .map(|field| field.ident.clone().unwrap())
427 .collect(),
428 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
429 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
430 is_struct: false,
431 generics: input.generics.clone(),
432 },
433 _ => {
434 return Err(syn::Error::new_spanned(
435 variant,
436 indoc! {r#"
437 `Handler` currently only supports named or tuple struct variants
438
439 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
440 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
441 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
442 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
443 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
444 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
445 "#},
446 ));
447 }
448 };
449 let log_level = parse_log_level(attrs)?;
450
451 messages.push(Message::new(
452 variant.fields.span(),
453 message_variant,
454 log_level,
455 )?);
456 }
457
458 Ok(messages)
459 }
460 Data::Struct(data_struct) => {
461 let struct_name = input.ident.clone();
462 let attrs = &input.attrs;
463
464 let message_variant = match &data_struct.fields {
465 syn::Fields::Unnamed(fields_) => Variant::Anon {
466 enum_name: struct_name.clone(),
467 name: struct_name,
468 field_types: fields_
469 .unnamed
470 .iter()
471 .map(|field| field.ty.clone())
472 .collect(),
473 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
474 is_struct: true,
475 generics: input.generics.clone(),
476 },
477 syn::Fields::Named(fields_) => Variant::Named {
478 enum_name: struct_name.clone(),
479 name: struct_name,
480 field_names: fields_
481 .named
482 .iter()
483 .map(|field| field.ident.clone().unwrap())
484 .collect(),
485 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
486 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
487 is_struct: true,
488 generics: input.generics.clone(),
489 },
490 syn::Fields::Unit => Variant::Anon {
491 enum_name: struct_name.clone(),
492 name: struct_name,
493 field_types: Vec::new(),
494 field_flags: Vec::new(),
495 is_struct: true,
496 generics: input.generics.clone(),
497 },
498 };
499
500 let log_level = parse_log_level(attrs)?;
501 let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
502
503 Ok(vec![message])
504 }
505 _ => Err(syn::Error::new_spanned(
506 input,
507 "handlers can only be derived for enums and structs",
508 )),
509 }
510}
511
512#[proc_macro_derive(Handler, attributes(reply))]
671pub fn derive_handler(input: TokenStream) -> TokenStream {
672 let input = parse_macro_input!(input as DeriveInput);
673 let name: Ident = input.ident.clone();
674 let (_, ty_generics, _) = input.generics.split_for_impl();
675
676 let messages = match parse_messages(input.clone()) {
677 Ok(messages) => messages,
678 Err(err) => return TokenStream::from(err.to_compile_error()),
679 };
680
681 let mut handler_trait_methods = Vec::new();
683
684 let mut match_arms = Vec::new();
686
687 let mut client_trait_methods = Vec::new();
689
690 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
691
692 for message in &messages {
693 match message {
694 Message::Call {
695 variant,
696 reply_port,
697 return_type,
698 log_level,
699 } => {
700 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
701 let variant_name_snake = variant.snake_name();
702 let variant_name_snake_deprecated =
703 format_ident!("{}_deprecated", variant_name_snake);
704 let enum_name = variant.enum_name();
705 let _variant_qualified_name = variant.qualified_name();
706 let log_level = match (&global_log_level, log_level) {
707 (_, Some(local)) => local.clone(),
708 (Some(global), None) => global.clone(),
709 _ => Ident::new("DEBUG", Span::call_site()),
710 };
711 let _log_level = if reply_port.is_handle {
712 quote! {
713 tracing::Level::#log_level
714 }
715 } else {
716 quote! {
717 tracing::Level::TRACE
718 }
719 };
720 let log_message = quote! {
721 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
722 "rpc" => "call",
723 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
724 "message_type" => stringify!(#enum_name),
725 "variant" => stringify!(#variant_name_snake),
726 ));
727 };
728
729 handler_trait_methods.push(quote! {
730 #[doc = "The generated handler method for this enum variant."]
731 async fn #variant_name_snake(
732 &mut self,
733 cx: &hyperactor::Context<Self>,
734 #(#arg_names: #arg_types),*)
735 -> Result<#return_type, hyperactor::anyhow::Error>;
736 });
737
738 client_trait_methods.push(quote! {
739 #[doc = "The generated client method for this enum variant."]
740 async fn #variant_name_snake(
741 &self,
742 cx: &impl hyperactor::context::Actor,
743 #(#arg_names: #arg_types),*)
744 -> Result<#return_type, hyperactor::anyhow::Error>;
745
746 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
747 async fn #variant_name_snake_deprecated(
748 &self,
749 cx: &impl hyperactor::context::Actor,
750 #(#arg_names: #arg_types),*)
751 -> Result<#return_type, hyperactor::anyhow::Error>;
752 });
753
754 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
755 let constructor = variant.constructor();
756 let result_ident = Ident::new("result", Span::mixed_site());
757 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
758 if reply_port.is_handle {
759 match_arms.push(quote! {
760 #constructor => {
761 #log_message
762 #construct_result_future
765 #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
766 }
767 });
768 } else {
769 match_arms.push(quote! {
770 #constructor => {
771 #log_message
772 #construct_result_future
775 #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
776 }
777 });
778 }
779 }
780 Message::OneWay { variant, log_level } => {
781 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
782 let variant_name_snake = variant.snake_name();
783 let variant_name_snake_deprecated =
784 format_ident!("{}_deprecated", variant_name_snake);
785 let enum_name = variant.enum_name();
786 let log_level = match (&global_log_level, log_level) {
787 (_, Some(local)) => local.clone(),
788 (Some(global), None) => global.clone(),
789 _ => Ident::new("TRACE", Span::call_site()),
790 };
791 let _log_level = quote! {
792 tracing::Level::#log_level
793 };
794 let log_message = quote! {
795 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
796 "rpc" => "call",
797 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
798 "message_type" => stringify!(#enum_name),
799 "variant" => stringify!(#variant_name_snake),
800 ));
801 };
802
803 handler_trait_methods.push(quote! {
804 #[doc = "The generated handler method for this enum variant."]
805 async fn #variant_name_snake(
806 &mut self,
807 cx: &hyperactor::Context<Self>,
808 #(#arg_names: #arg_types),*)
809 -> Result<(), hyperactor::anyhow::Error>;
810 });
811
812 client_trait_methods.push(quote! {
813 #[doc = "The generated client method for this enum variant."]
814 async fn #variant_name_snake(
815 &self,
816 cx: &impl hyperactor::context::Actor,
817 #(#arg_names: #arg_types),*)
818 -> Result<(), hyperactor::anyhow::Error>;
819
820 #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
821 async fn #variant_name_snake_deprecated(
822 &self,
823 cx: &impl hyperactor::context::Actor,
824 #(#arg_names: #arg_types),*)
825 -> Result<(), hyperactor::anyhow::Error>;
826 });
827
828 let constructor = variant.constructor();
829
830 match_arms.push(quote! {
831 #constructor => {
832 #log_message
833 self.#variant_name_snake(cx, #(#arg_names),*).await
834 },
835 });
836 }
837 }
838 }
839
840 let handler_trait_name = format_ident!("{}Handler", name);
841 let client_trait_name = format_ident!("{}Client", name);
842
843 let mut handler_generics = input.generics.clone();
847 for param in handler_generics.type_params_mut() {
848 param.bounds.push(syn::parse_quote!(serde::Serialize));
849 param
850 .bounds
851 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
852 param.bounds.push(syn::parse_quote!(Send));
853 param.bounds.push(syn::parse_quote!(Sync));
854 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
855 param.bounds.push(syn::parse_quote!(hyperactor::Named));
856 }
857 let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
858 let (client_impl_generics, _, _) = input.generics.split_for_impl();
859
860 let expanded = quote! {
861 #[doc = "The custom handler trait for this message type."]
862 #[hyperactor::async_trait::async_trait]
863 pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync {
864 #(#handler_trait_methods)*
865
866 #[doc = "Handle the next message."]
867 async fn handle(
868 &mut self,
869 cx: &hyperactor::Context<Self>,
870 message: #name #ty_generics,
871 ) -> hyperactor::anyhow::Result<()> {
872 match message {
874 #(#match_arms)*
875 }
876 }
877 }
878
879 #[doc = "The custom client trait for this message type."]
880 #[hyperactor::async_trait::async_trait]
881 pub trait #client_trait_name #client_impl_generics: Send + Sync {
882 #(#client_trait_methods)*
883 }
884 };
885
886 TokenStream::from(expanded)
887}
888
889#[proc_macro_derive(HandleClient, attributes(log_level))]
892pub fn derive_handle_client(input: TokenStream) -> TokenStream {
893 derive_client(input, true)
894}
895
896#[proc_macro_derive(RefClient, attributes(log_level))]
899pub fn derive_ref_client(input: TokenStream) -> TokenStream {
900 derive_client(input, false)
901}
902
903fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
904 let input = parse_macro_input!(input as DeriveInput);
905 let name = input.ident.clone();
906
907 let messages = match parse_messages(input.clone()) {
908 Ok(messages) => messages,
909 Err(err) => return TokenStream::from(err.to_compile_error()),
910 };
911
912 let mut impl_methods = Vec::new();
914
915 let send_message = if is_handle {
916 quote! { self.send(message)? }
917 } else {
918 quote! { self.send(cx, message)? }
919 };
920 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
921
922 for message in &messages {
923 match message {
924 Message::Call {
925 variant,
926 reply_port,
927 return_type,
928 log_level,
929 } => {
930 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
931 let variant_name_snake = variant.snake_name();
932 let variant_name_snake_deprecated =
933 format_ident!("{}_deprecated", variant_name_snake);
934 let enum_name = variant.enum_name();
935
936 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
937 let constructor = variant.constructor();
938 let log_level = match (&global_log_level, log_level) {
939 (_, Some(local)) => local.clone(),
940 (Some(global), None) => global.clone(),
941 _ => Ident::new("DEBUG", Span::call_site()),
942 };
943 let log_level = if is_handle {
944 quote! {
945 tracing::Level::#log_level
946 }
947 } else {
948 quote! {
949 tracing::Level::TRACE
950 }
951 };
952 let log_message = quote! {
953 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
954 "rpc" => "call",
955 "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
956 "message_type" => stringify!(#enum_name),
957 "variant" => stringify!(#variant_name_snake),
958 ));
959
960 };
961 let open_port = reply_port.open_op();
962 let rx_mod = reply_port.rx_modifier();
963 if reply_port.is_handle {
964 impl_methods.push(quote! {
965 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
966 async fn #variant_name_snake(
967 &self,
968 cx: &impl hyperactor::context::Actor,
969 #(#arg_names: #arg_types),*)
970 -> Result<#return_type, hyperactor::anyhow::Error> {
971 let (#reply_port_arg, #rx_mod reply_receiver) =
972 #open_port::<#return_type>(cx);
973 let message = #constructor;
974 #log_message;
975 #send_message;
976 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
977 }
978
979 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
980 async fn #variant_name_snake_deprecated(
981 &self,
982 cx: &impl hyperactor::context::Actor,
983 #(#arg_names: #arg_types),*)
984 -> Result<#return_type, hyperactor::anyhow::Error> {
985 let (#reply_port_arg, #rx_mod reply_receiver) =
986 #open_port::<#return_type>(cx);
987 let message = #constructor;
988 #log_message;
989 #send_message;
990 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
991 }
992 });
993 } else {
994 impl_methods.push(quote! {
995 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
996 async fn #variant_name_snake(
997 &self,
998 cx: &impl hyperactor::context::Actor,
999 #(#arg_names: #arg_types),*)
1000 -> Result<#return_type, hyperactor::anyhow::Error> {
1001 let (#reply_port_arg, #rx_mod reply_receiver) =
1002 #open_port::<#return_type>(cx);
1003 let #reply_port_arg = #reply_port_arg.bind();
1004 let message = #constructor;
1005 #log_message;
1006 #send_message;
1007 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1008 }
1009
1010 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
1011 async fn #variant_name_snake_deprecated(
1012 &self,
1013 cx: &impl hyperactor::context::Actor,
1014 #(#arg_names: #arg_types),*)
1015 -> Result<#return_type, hyperactor::anyhow::Error> {
1016 let (#reply_port_arg, #rx_mod reply_receiver) =
1017 #open_port::<#return_type>(cx);
1018 let #reply_port_arg = #reply_port_arg.bind();
1019 let message = #constructor;
1020 #log_message;
1021 #send_message;
1022 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1023 }
1024 });
1025 }
1026 }
1027 Message::OneWay { variant, log_level } => {
1028 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1029 let variant_name_snake = variant.snake_name();
1030 let variant_name_snake_deprecated =
1031 format_ident!("{}_deprecated", variant_name_snake);
1032 let enum_name = variant.enum_name();
1033 let constructor = variant.constructor();
1034 let log_level = match (&global_log_level, log_level) {
1035 (_, Some(local)) => local.clone(),
1036 (Some(global), None) => global.clone(),
1037 _ => Ident::new("DEBUG", Span::call_site()),
1038 };
1039 let _log_level = if is_handle {
1040 quote! {
1041 tracing::Level::TRACE
1042 }
1043 } else {
1044 quote! {
1045 tracing::Level::#log_level
1046 }
1047 };
1048 let log_message = quote! {
1049 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1050 "rpc" => "oneway",
1051 "actor_id" => self.actor_id().to_string(),
1052 "message_type" => stringify!(#enum_name),
1053 "variant" => stringify!(#variant_name_snake),
1054 ));
1055 };
1056 impl_methods.push(quote! {
1057 async fn #variant_name_snake(
1058 &self,
1059 cx: &impl hyperactor::context::Actor,
1060 #(#arg_names: #arg_types),*)
1061 -> Result<(), hyperactor::anyhow::Error> {
1062 let message = #constructor;
1063 #log_message;
1064 #send_message;
1065 Ok(())
1066 }
1067
1068 async fn #variant_name_snake_deprecated(
1069 &self,
1070 cx: &impl hyperactor::context::Actor,
1071 #(#arg_names: #arg_types),*)
1072 -> Result<(), hyperactor::anyhow::Error> {
1073 let message = #constructor;
1074 #log_message;
1075 #send_message;
1076 Ok(())
1077 }
1078 });
1079 }
1080 }
1081 }
1082
1083 let trait_name = format_ident!("{}Client", name);
1084
1085 let (_, ty_generics, _) = input.generics.split_for_impl();
1086
1087 let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1089 let mut trait_generics = input.generics.clone();
1090 trait_generics.params.insert(
1091 0,
1092 syn::GenericParam::Type(syn::TypeParam {
1093 ident: actor_ident.clone(),
1094 attrs: vec![],
1095 colon_token: None,
1096 bounds: Punctuated::new(),
1097 eq_token: None,
1098 default: None,
1099 }),
1100 );
1101
1102 for param in trait_generics.type_params_mut() {
1103 if param.ident == actor_ident {
1104 continue;
1105 }
1106 param.bounds.push(syn::parse_quote!(serde::Serialize));
1107 param
1108 .bounds
1109 .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1110 param.bounds.push(syn::parse_quote!(Send));
1111 param.bounds.push(syn::parse_quote!(Sync));
1112 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1113 param.bounds.push(syn::parse_quote!(hyperactor::Named));
1114 }
1115
1116 let (impl_generics, _, _) = trait_generics.split_for_impl();
1117
1118 let expanded = if is_handle {
1119 quote! {
1120 #[hyperactor::async_trait::async_trait]
1121 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1122 where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1123 #(#impl_methods)*
1124 }
1125 }
1126 } else {
1127 quote! {
1128 #[hyperactor::async_trait::async_trait]
1129 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1130 where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1131 #(#impl_methods)*
1132 }
1133 }
1134 };
1135
1136 TokenStream::from(expanded)
1137}
1138
1139const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
1140`forward` expects the message type that is being forwarded
1141
1142= help: use `#[forward(MessageType)]`
1143"#};
1144
1145#[proc_macro_attribute]
1147pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
1148 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1149 if attr_args.len() != 1 {
1150 return TokenStream::from(
1151 syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
1152 );
1153 }
1154
1155 let message_type = attr_args.first().unwrap();
1156 let input = parse_macro_input!(item as ItemImpl);
1157
1158 let self_type = match *input.self_ty {
1159 syn::Type::Path(ref type_path) => {
1160 let segment = type_path.path.segments.last().unwrap();
1161 segment.clone() }
1163 _ => {
1164 return TokenStream::from(
1165 syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
1166 .to_compile_error(),
1167 );
1168 }
1169 };
1170
1171 let trait_name = match input.trait_ {
1172 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1173 None => {
1174 return TokenStream::from(
1175 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1176 .to_compile_error(),
1177 );
1178 }
1179 };
1180
1181 let expanded = quote! {
1182 #input
1183
1184 #[hyperactor::async_trait::async_trait]
1185 impl hyperactor::Handler<#message_type> for #self_type {
1186 async fn handle(
1187 &mut self,
1188 cx: &hyperactor::Context<Self>,
1189 message: #message_type,
1190 ) -> hyperactor::anyhow::Result<()> {
1191 <Self as #trait_name>::handle(self, cx, message).await
1192 }
1193 }
1194 };
1195
1196 TokenStream::from(expanded)
1197}
1198
1199#[proc_macro_attribute]
1212pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1213 let args =
1214 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1215 let input = parse_macro_input!(input as ItemFn);
1216 let output = quote! {
1217 #[hyperactor::tracing::instrument(err, skip_all, #args)]
1218 #input
1219 };
1220
1221 TokenStream::from(output)
1222}
1223
1224#[proc_macro_attribute]
1235pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1236 let args =
1237 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1238 let input = parse_macro_input!(input as ItemFn);
1239
1240 let output = quote! {
1241 #[hyperactor::tracing::instrument(skip_all, #args)]
1242 #input
1243 };
1244
1245 TokenStream::from(output)
1246}
1247
1248#[proc_macro_derive(Named, attributes(named))]
1261pub fn derive_named(input: TokenStream) -> TokenStream {
1262 let input = parse_macro_input!(input as DeriveInput);
1264 let struct_name = &input.ident;
1265
1266 let mut typename = quote! {
1267 concat!(std::module_path!(), "::", stringify!(#struct_name))
1268 };
1269
1270 let type_params: Vec<_> = input.generics.type_params().collect();
1271 let has_generics = !type_params.is_empty();
1272 let mut register = !has_generics;
1274
1275 for attr in &input.attrs {
1276 if attr.path().is_ident("named") {
1277 if let Ok(meta) = attr.parse_args_with(
1278 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
1279 ) {
1280 for item in meta {
1281 if let Meta::NameValue(MetaNameValue {
1282 path,
1283 value: Expr::Lit(expr_lit),
1284 ..
1285 }) = item
1286 {
1287 if path.is_ident("name") {
1288 if let Lit::Str(name) = expr_lit.lit {
1289 typename = quote! { #name };
1290 } else {
1291 return TokenStream::from(
1292 syn::Error::new_spanned(path, "invalid name")
1293 .to_compile_error(),
1294 );
1295 }
1296 } else if path.is_ident("register") {
1297 if let Lit::Bool(flag) = expr_lit.lit {
1298 register = flag.value;
1299 } else {
1300 return TokenStream::from(
1301 syn::Error::new_spanned(path, "invalid registration flag")
1302 .to_compile_error(),
1303 );
1304 }
1305 } else {
1306 return TokenStream::from(
1307 syn::Error::new_spanned(
1308 path,
1309 "unsupported attribute (only `name` or `register` is supported)",
1310 )
1311 .to_compile_error(),
1312 );
1313 }
1314 }
1315 }
1316 }
1317 }
1318 }
1319
1320 let mut generics_with_bounds = input.generics.clone();
1322 if has_generics {
1323 for param in generics_with_bounds.type_params_mut() {
1324 param
1325 .bounds
1326 .push(syn::parse_quote!(hyperactor::data::Named));
1327 }
1328 }
1329 let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1330
1331 let (typename_impl, typehash_impl) = if has_generics {
1333 let placeholders = vec!["{}"; type_params.len()].join(", ");
1335 let placeholders_format_string = format!("<{}>", placeholders);
1336 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#struct_name), #placeholders_format_string) };
1337
1338 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1339 (
1340 quote! {
1341 hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1342 },
1343 quote! {
1344 hyperactor::cityhasher::hash(Self::typename())
1345 },
1346 )
1347 } else {
1348 (
1349 typename,
1350 quote! {
1351 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1352 hyperactor::cityhasher::hash(<#struct_name as hyperactor::data::Named>::typename())
1353 });
1354 *TYPEHASH
1355 },
1356 )
1357 };
1358
1359 let arm_impl = match &input.data {
1361 Data::Enum(DataEnum { variants, .. }) => {
1362 let match_arms = variants.iter().map(|v| {
1363 let variant_name = &v.ident;
1364 let variant_str = variant_name.to_string();
1365 match &v.fields {
1366 Fields::Unit => quote! { Self::#variant_name => Some(#variant_str) },
1367 Fields::Unnamed(_) => quote! { Self::#variant_name(..) => Some(#variant_str) },
1368 Fields::Named(_) => quote! { Self::#variant_name { .. } => Some(#variant_str) },
1369 }
1370 });
1371 quote! {
1372 fn arm(&self) -> Option<&'static str> {
1373 match self {
1374 #(#match_arms,)*
1375 }
1376 }
1377 }
1378 }
1379 _ => quote! {},
1380 };
1381
1382 let registration = if register {
1389 quote! {
1390 hyperactor::register_type!(#struct_name);
1391 }
1392 } else {
1393 quote! {
1394 }
1396 };
1397
1398 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
1399 let expanded = quote! {
1402 impl #impl_generics_with_bounds hyperactor::data::Named for #struct_name #ty_generics #where_clause {
1403 fn typename() -> &'static str { #typename_impl }
1404 fn typehash() -> u64 { #typehash_impl }
1405 #arm_impl
1406 }
1407
1408 #registration
1409 };
1410
1411 TokenStream::from(expanded)
1412}
1413
1414struct HandlerSpec {
1415 ty: Type,
1416 cast: bool,
1417}
1418
1419impl Parse for HandlerSpec {
1420 fn parse(input: ParseStream) -> syn::Result<Self> {
1421 let ty: Type = input.parse()?;
1422
1423 if input.peek(syn::token::Brace) {
1424 let content;
1425 syn::braced!(content in input);
1426 let key: Ident = content.parse()?;
1427 content.parse::<Token![=]>()?;
1428 let expr: Expr = content.parse()?;
1429
1430 let cast = if key == "cast" {
1431 if let Expr::Lit(ExprLit {
1432 lit: Lit::Bool(b), ..
1433 }) = expr
1434 {
1435 b.value
1436 } else {
1437 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1438 }
1439 } else {
1440 return Err(syn::Error::new_spanned(
1441 key,
1442 "unsupported field (expected `cast`)",
1443 ));
1444 };
1445
1446 Ok(HandlerSpec { ty, cast })
1447 } else if input.is_empty() || input.peek(Token![,]) {
1448 Ok(HandlerSpec { ty, cast: false })
1449 } else {
1450 let unexpected: proc_macro2::TokenTree = input.parse()?;
1452 Err(syn::Error::new_spanned(
1453 unexpected,
1454 "unexpected token after type — expected `{ ... }` or nothing",
1455 ))
1456 }
1457 }
1458}
1459
1460impl HandlerSpec {
1461 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1462 let mut tys = Vec::new();
1463 for HandlerSpec { ty, cast } in handlers {
1464 if cast {
1465 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1466 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1467 tys.push(wrapped_ty);
1468 }
1469 tys.push(ty);
1470 }
1471 tys
1472 }
1473}
1474
1475struct ExportAttr {
1477 spawn: bool,
1478 handlers: Vec<HandlerSpec>,
1479}
1480
1481impl Parse for ExportAttr {
1482 fn parse(input: ParseStream) -> syn::Result<Self> {
1483 let mut spawn = false;
1484 let mut handlers: Vec<HandlerSpec> = vec![];
1485
1486 while !input.is_empty() {
1487 let key: Ident = input.parse()?;
1488 input.parse::<Token![=]>()?;
1489
1490 if key == "spawn" {
1491 let expr: Expr = input.parse()?;
1492 if let Expr::Lit(ExprLit {
1493 lit: Lit::Bool(b), ..
1494 }) = expr
1495 {
1496 spawn = b.value;
1497 } else {
1498 return Err(syn::Error::new_spanned(
1499 expr,
1500 "expected boolean for `spawn`",
1501 ));
1502 }
1503 } else if key == "handlers" {
1504 let content;
1505 bracketed!(content in input);
1506 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1507 handlers = raw_handlers.into_iter().collect();
1508 } else {
1509 return Err(syn::Error::new_spanned(
1510 key,
1511 "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1512 ));
1513 }
1514
1515 let _ = input.parse::<Token![,]>();
1517 }
1518
1519 Ok(ExportAttr { spawn, handlers })
1520 }
1521}
1522
1523#[proc_macro_attribute]
1550pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1551 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1552 let data_type_name = &input.ident;
1553 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1554
1555 let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1556 let tys = HandlerSpec::add_indexed(handlers);
1557
1558 let mut handles = Vec::new();
1559 let mut bindings = Vec::new();
1560 let mut type_registrations = Vec::new();
1561
1562 for ty in &tys {
1563 handles.push(quote! {
1564 impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
1565 });
1566 bindings.push(quote! {
1567 ports.bind::<#ty>();
1568 });
1569 type_registrations.push(quote! {
1570 hyperactor::register_type!(#ty);
1571 });
1572 }
1573
1574 let mut expanded = quote! {
1575 #input
1576
1577 impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
1578
1579 #(#handles)*
1580
1581 #(#type_registrations)*
1582
1583 impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1585
1586 impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
1587 fn bind(ports: &hyperactor::proc::Ports<Self>) {
1588 #(#bindings)*
1589 }
1590 }
1591
1592 impl #impl_generics hyperactor::data::Named for #data_type_name #ty_generics #where_clause {
1594 fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
1595 }
1596 };
1597
1598 if spawn {
1599 expanded.extend(quote! {
1600 hyperactor::remote!(#data_type_name);
1601 });
1602 }
1603
1604 TokenStream::from(expanded)
1605}
1606
1607struct BehaviorInput {
1609 behavior: Ident,
1610 generics: syn::Generics,
1611 handlers: Vec<HandlerSpec>,
1612}
1613
1614impl syn::parse::Parse for BehaviorInput {
1615 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1616 let behavior: Ident = input.parse()?;
1617 let generics: syn::Generics = input.parse()?;
1618 let _: Token![,] = input.parse()?;
1619 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1620 let handlers = raw_handlers.into_iter().collect();
1621 Ok(BehaviorInput {
1622 behavior,
1623 generics,
1624 handlers,
1625 })
1626 }
1627}
1628
1629#[proc_macro]
1654pub fn behavior(input: TokenStream) -> TokenStream {
1655 let BehaviorInput {
1656 behavior,
1657 generics,
1658 handlers,
1659 } = parse_macro_input!(input as BehaviorInput);
1660 let tys = HandlerSpec::add_indexed(handlers);
1661
1662 let mut bounded_generics = generics.clone();
1664 for param in bounded_generics.type_params_mut() {
1665 param.bounds.push(syn::parse_quote!(hyperactor::Named));
1666 param.bounds.push(syn::parse_quote!(serde::Serialize));
1667 param.bounds.push(syn::parse_quote!(std::marker::Send));
1668 param.bounds.push(syn::parse_quote!(std::marker::Sync));
1669 param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1670 let lifetime =
1673 syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1674 param
1675 .bounds
1676 .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1677 }
1678
1679 let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1681
1682 let mut binds_generics = bounded_generics.clone();
1684 binds_generics.params.insert(
1685 0,
1686 syn::GenericParam::Type(syn::TypeParam {
1687 attrs: vec![],
1688 ident: Ident::new("A", proc_macro2::Span::call_site()),
1689 colon_token: None,
1690 bounds: Punctuated::new(),
1691 eq_token: None,
1692 default: None,
1693 }),
1694 );
1695 let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1696
1697 let type_params: Vec<_> = bounded_generics.type_params().collect();
1699 let has_generics = !type_params.is_empty();
1700
1701 let (typename_impl, typehash_impl) = if has_generics {
1702 let placeholders = vec!["{}"; type_params.len()].join(", ");
1704 let placeholders_format_string = format!("<{}>", placeholders);
1705 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1706
1707 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1708 (
1709 quote! {
1710 hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1711 },
1712 quote! {
1713 hyperactor::cityhasher::hash(Self::typename())
1714 },
1715 )
1716 } else {
1717 (
1718 quote! {
1719 concat!(std::module_path!(), "::", stringify!(#behavior))
1720 },
1721 quote! {
1722 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1723 hyperactor::cityhasher::hash(<#behavior as hyperactor::data::Named>::typename())
1724 });
1725 *TYPEHASH
1726 },
1727 )
1728 };
1729
1730 let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1731
1732 let expanded = quote! {
1733 #[doc = "The generated behavior struct."]
1734 #[derive(Debug, serde::Serialize, serde::Deserialize)]
1735 pub struct #behavior #impl_generics #where_clause {
1736 _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1737 }
1738
1739 impl #impl_generics hyperactor::Named for #behavior #ty_generics #where_clause {
1740 fn typename() -> &'static str {
1741 #typename_impl
1742 }
1743
1744 fn typehash() -> u64 {
1745 #typehash_impl
1746 }
1747 }
1748
1749 impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1750
1751 impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1752 where
1753 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1754 #where_clause
1755 {
1756 fn bind(ports: &hyperactor::proc::Ports<A>) {
1757 #(
1758 ports.bind::<#tys>();
1759 )*
1760 }
1761 }
1762
1763 #(
1764 impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1765 )*
1766 };
1767
1768 TokenStream::from(expanded)
1769}
1770
1771fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1772 let mut is_included = false;
1773 for attr in &field.attrs {
1774 if attr.path().is_ident("binding") {
1775 attr.parse_nested_meta(|meta| {
1777 if meta.path.is_ident("include") {
1778 is_included = true;
1779 Ok(())
1780 } else {
1781 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1782 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1783 }
1784 })?
1785 }
1786 }
1787 Ok(is_included)
1788}
1789
1790enum FieldAccessor {
1795 Named(Ident),
1796 Unnamed(Index),
1797}
1798
1799struct ParsedField {
1801 accessor: FieldAccessor,
1802 ty: Type,
1803 included: bool,
1804}
1805
1806impl From<&ParsedField> for (Ident, Type) {
1807 fn from(field: &ParsedField) -> Self {
1808 let field_ident = match &field.accessor {
1809 FieldAccessor::Named(ident) => ident.clone(),
1810 FieldAccessor::Unnamed(i) => {
1811 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1812 }
1813 };
1814 (field_ident, field.ty.clone())
1815 }
1816}
1817
1818fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1819 match fields {
1820 Fields::Named(named) => named
1821 .named
1822 .iter()
1823 .map(|f| {
1824 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1825 Ok(ParsedField {
1826 accessor,
1827 ty: f.ty.clone(),
1828 included: include_in_bind_unbind(f)?,
1829 })
1830 })
1831 .collect(),
1832 Fields::Unnamed(unnamed) => unnamed
1833 .unnamed
1834 .iter()
1835 .enumerate()
1836 .map(|(i, f)| {
1837 let accessor = FieldAccessor::Unnamed(Index::from(i));
1838 Ok(ParsedField {
1839 accessor,
1840 ty: f.ty.clone(),
1841 included: include_in_bind_unbind(f)?,
1842 })
1843 })
1844 .collect(),
1845 Fields::Unit => Ok(Vec::new()),
1846 }
1847}
1848
1849fn gen_struct_items<F>(
1850 fields: &Fields,
1851 make_item: F,
1852 is_mutable: bool,
1853) -> syn::Result<Vec<proc_macro2::TokenStream>>
1854where
1855 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1856{
1857 let borrow = if is_mutable {
1858 quote! { &mut }
1859 } else {
1860 quote! { & }
1861 };
1862 let items: Vec<_> = collect_all_fields(fields)?
1863 .into_iter()
1864 .filter(|f| f.included)
1865 .map(
1866 |ParsedField {
1867 accessor,
1868 ty,
1869 included,
1870 }| {
1871 assert!(included);
1872 let field_accessor = match accessor {
1873 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1874 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1875 };
1876 make_item(field_accessor, ty)
1877 },
1878 )
1879 .collect();
1880 Ok(items)
1881}
1882
1883fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1893 all_fields
1894 .iter()
1895 .map(
1896 |ParsedField {
1897 accessor,
1898 ty: _,
1899 included,
1900 }| {
1901 match accessor {
1902 FieldAccessor::Named(ident) => {
1903 if *included {
1904 quote! { #ident }
1905 } else {
1906 quote! { #ident: _ }
1907 }
1908 }
1909 FieldAccessor::Unnamed(i) => {
1910 if *included {
1911 let ident = Ident::new(
1912 &format!("f{}", i.index),
1913 proc_macro2::Span::call_site(),
1914 );
1915 quote! { #ident }
1916 } else {
1917 quote! { _ }
1918 }
1919 }
1920 }
1921 },
1922 )
1923 .collect()
1924}
1925
1926fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1933where
1934 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1935{
1936 data.variants
1937 .iter()
1938 .map(|variant| {
1939 let name = &variant.ident;
1940 let all_fields = collect_all_fields(&variant.fields)?;
1941 let field_accessors = gen_enum_field_accessors(&all_fields);
1942 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1943 let items = included_fields
1944 .iter()
1945 .map(|f| {
1946 let (accessor, ty) = <(Ident, Type)>::from(*f);
1947 make_item(quote! { #accessor }, ty)
1948 })
1949 .collect::<Vec<_>>();
1950
1951 Ok(match &variant.fields {
1952 Fields::Named(_) => {
1953 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1954 }
1955 Fields::Unnamed(_) => {
1956 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1957 }
1958 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1959 })
1960 })
1961 .collect()
1962}
1963
1964#[proc_macro_derive(Bind, attributes(binding))]
2046pub fn derive_bind(input: TokenStream) -> TokenStream {
2047 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2048 quote! {
2049 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
2050 }
2051 }
2052
2053 let input = parse_macro_input!(input as DeriveInput);
2054 let name = &input.ident;
2055 let inner = match &input.data {
2056 Data::Struct(DataStruct { fields, .. }) => {
2057 match gen_struct_items(fields, make_item, true) {
2058 Ok(collects) => {
2059 quote! { #(#collects)* }
2060 }
2061 Err(e) => {
2062 return TokenStream::from(e.to_compile_error());
2063 }
2064 }
2065 }
2066 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2067 Ok(arms) => {
2068 quote! { match self { #(#arms),* } }
2069 }
2070 Err(e) => {
2071 return TokenStream::from(e.to_compile_error());
2072 }
2073 },
2074 _ => panic!("Bind can only be derived for structs and enums"),
2075 };
2076 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2077 let expand = quote! {
2078 #[automatically_derived]
2079 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
2080 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2081 #inner
2082 Ok(())
2083 }
2084 }
2085 };
2086 TokenStream::from(expand)
2087}
2088
2089#[proc_macro_derive(Unbind, attributes(binding))]
2103pub fn derive_unbind(input: TokenStream) -> TokenStream {
2104 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2105 quote! {
2106 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
2107 }
2108 }
2109
2110 let input = parse_macro_input!(input as DeriveInput);
2111 let name = &input.ident;
2112 let inner = match &input.data {
2113 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
2114 {
2115 Ok(collects) => {
2116 quote! { #(#collects)* }
2117 }
2118 Err(e) => {
2119 return TokenStream::from(e.to_compile_error());
2120 }
2121 },
2122 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2123 Ok(arms) => {
2124 quote! { match self { #(#arms),* } }
2125 }
2126 Err(e) => {
2127 return TokenStream::from(e.to_compile_error());
2128 }
2129 },
2130 _ => panic!("Unbind can only be derived for structs and enums"),
2131 };
2132 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2133 let expand = quote! {
2134 #[automatically_derived]
2135 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
2136 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2137 #inner
2138 Ok(())
2139 }
2140 }
2141 };
2142 TokenStream::from(expand)
2143}
2144
2145#[proc_macro_derive(Actor, attributes(actor))]
2192pub fn derive_actor(input: TokenStream) -> TokenStream {
2193 let input = parse_macro_input!(input as DeriveInput);
2194 let name = &input.ident;
2195 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2196
2197 let is_passthrough = input.attrs.iter().any(|attr| {
2198 if attr.path().is_ident("actor") {
2199 if let Ok(meta) = attr.parse_args_with(
2200 syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
2201 ) {
2202 return meta.iter().any(|ident| ident == "passthrough");
2203 }
2204 }
2205 false
2206 });
2207
2208 let expanded = if is_passthrough {
2209 quote! {
2210 #[hyperactor::async_trait::async_trait]
2211 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2212 type Params = Self;
2213
2214 async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
2215 Ok(instance)
2216 }
2217 }
2218 }
2219 } else {
2220 quote! {
2221 #[hyperactor::async_trait::async_trait]
2222 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2223 type Params = ();
2224
2225 async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
2226 Ok(Default::default())
2227 }
2228 }
2229 }
2230 };
2231
2232 TokenStream::from(expanded)
2233}
2234
2235fn parse_observe_function(
2237 attr: TokenStream,
2238 item: TokenStream,
2239) -> syn::Result<(ItemFn, String, String)> {
2240 let input = syn::parse::<ItemFn>(item)?;
2241
2242 if input.sig.asyncness.is_none() {
2243 return Err(syn::Error::new(
2244 input.sig.span(),
2245 "observe macros can only be applied to async functions",
2246 ));
2247 }
2248
2249 let fn_name_str = input.sig.ident.to_string();
2250 let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
2251
2252 Ok((input, fn_name_str, module_name_str))
2253}
2254
2255fn create_telemetry_setup(
2257 module_name_str: &str,
2258 fn_name_str: &str,
2259 include_error: bool,
2260) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2261 let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2262 let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2263
2264 let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2265
2266 let error_ident = if include_error {
2267 Some(Ident::new(
2268 "error",
2269 Span::from(proc_macro::Span::def_site()),
2270 ))
2271 } else {
2272 None
2273 };
2274
2275 let error_declaration = if let Some(ref error_ident) = error_ident {
2276 quote! {
2277 hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2278 }
2279 } else {
2280 quote! {}
2281 };
2282
2283 let setup_code = quote! {
2284 use hyperactor_telemetry;
2285 hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2286 hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2287 #error_declaration
2288 };
2289
2290 (latency_ident, success_ident, error_ident, setup_code)
2291}
2292
2293#[proc_macro_attribute]
2313pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2314 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2315 Ok(parsed) => parsed,
2316 Err(err) => return err.to_compile_error().into(),
2317 };
2318
2319 let fn_name = &input.sig.ident;
2320 let vis = &input.vis;
2321 let args = &input.sig.inputs;
2322 let return_type = &input.sig.output;
2323 let body = &input.block;
2324 let attrs = &input.attrs;
2325 let generics = &input.sig.generics;
2326
2327 let (latency_ident, success_ident, error_ident, telemetry_setup) =
2328 create_telemetry_setup(&module_name_str, &fn_name_str, true);
2329 let error_ident = error_ident.unwrap();
2330
2331 let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2332
2333 let expanded = quote! {
2335 #(#attrs)*
2336 #vis async fn #fn_name #generics(#args) #return_type {
2337 #telemetry_setup
2338
2339 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2340 let _timer = #latency_ident.start(kv_pairs);
2341
2342 let #result_ident = async #body.await;
2343
2344 match &#result_ident {
2345 Ok(_) => {
2346 #success_ident.add(
2347 1,
2348 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2349 );
2350 }
2351 Err(_) => {
2352 #error_ident.add(
2353 1,
2354 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2355 );
2356 }
2357 }
2358
2359 #result_ident
2360 }
2361 };
2362
2363 expanded.into()
2364}
2365
2366#[proc_macro_attribute]
2385pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2386 let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2387 Ok(parsed) => parsed,
2388 Err(err) => return err.to_compile_error().into(),
2389 };
2390
2391 let fn_name = &input.sig.ident;
2392 let vis = &input.vis;
2393 let args = &input.sig.inputs;
2394 let return_type = &input.sig.output;
2395 let body = &input.block;
2396 let attrs = &input.attrs;
2397 let generics = &input.sig.generics;
2398
2399 let (latency_ident, success_ident, _, telemetry_setup) =
2400 create_telemetry_setup(&module_name_str, &fn_name_str, false);
2401
2402 let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2403
2404 let expanded = quote! {
2406 #(#attrs)*
2407 #vis async fn #fn_name #generics(#args) #return_type {
2408 #telemetry_setup
2409
2410 let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2411 let _timer = #latency_ident.start(kv_pairs);
2412
2413 let #return_ident = async #body.await;
2414
2415 #success_ident.add(
2416 1,
2417 hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2418 );
2419 #return_ident
2420 }
2421 };
2422
2423 expanded.into()
2424}
2425
2426#[proc_macro_derive(AttrValue)]
2460pub fn derive_attr_value(input: TokenStream) -> TokenStream {
2461 let input = parse_macro_input!(input as DeriveInput);
2462 let name = &input.ident;
2463
2464 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2465
2466 TokenStream::from(quote! {
2467 impl #impl_generics hyperactor::attrs::AttrValue for #name #ty_generics #where_clause {
2468 fn display(&self) -> String {
2469 self.to_string()
2470 }
2471
2472 fn parse(value: &str) -> Result<Self, anyhow::Error> {
2473 value.parse().map_err(|e| anyhow::anyhow!("failed to parse {}: {}", stringify!(#name), e))
2474 }
2475 }
2476 })
2477}