1#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Meta;
39use syn::MetaNameValue;
40use syn::Token;
41use syn::Type;
42use syn::bracketed;
43use syn::parse::Parse;
44use syn::parse::ParseStream;
45use syn::parse_macro_input;
46use syn::punctuated::Punctuated;
47use syn::spanned::Spanned;
48
49const REPLY_VARIANT_ERROR: &str = indoc! {r#"
50`call` message expects a typed `OncePortRef` or `OncePortHandle` argument in the last position
51
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
53= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
54"#};
55
56const REPLY_USAGE_ERROR: &str = indoc! {r#"
57`call` message expects at most one `reply` argument
58
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
60= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
61"#};
62
63enum FieldFlag {
64 None,
65 Reply,
66}
67
68enum Variant {
70 Named {
72 enum_name: Ident,
73 name: Ident,
74 field_names: Vec<Ident>,
75 field_types: Vec<Type>,
76 field_flags: Vec<FieldFlag>,
77 },
78 Anon {
80 enum_name: Ident,
81 name: Ident,
82 field_types: Vec<Type>,
83 field_flags: Vec<FieldFlag>,
84 },
85}
86
87impl Variant {
88 fn len(&self) -> usize {
90 self.field_types().len()
91 }
92
93 fn enum_name(&self) -> &Ident {
95 match self {
96 Variant::Named { enum_name, .. } => enum_name,
97 Variant::Anon { enum_name, .. } => enum_name,
98 }
99 }
100
101 fn name(&self) -> &Ident {
103 match self {
104 Variant::Named { name, .. } => name,
105 Variant::Anon { name, .. } => name,
106 }
107 }
108
109 fn snake_name(&self) -> Ident {
111 Ident::new(
112 &self.name().to_string().to_case(Case::Snake),
113 self.name().span(),
114 )
115 }
116
117 fn qualified_name(&self) -> proc_macro2::TokenStream {
119 let enum_name = self.enum_name();
120 let name = self.name();
121 quote! { #enum_name::#name }
122 }
123
124 fn field_names(&self) -> Vec<Ident> {
127 match self {
128 Variant::Named { field_names, .. } => field_names.clone(),
129 Variant::Anon { field_types, .. } => (0usize..field_types.len())
130 .map(|idx| format_ident!("arg{}", idx))
131 .collect(),
132 }
133 }
134
135 fn field_types(&self) -> &Vec<Type> {
137 match self {
138 Variant::Named { field_types, .. } => field_types,
139 Variant::Anon { field_types, .. } => field_types,
140 }
141 }
142
143 fn field_flags(&self) -> &Vec<FieldFlag> {
145 match self {
146 Variant::Named { field_flags, .. } => field_flags,
147 Variant::Anon { field_flags, .. } => field_flags,
148 }
149 }
150
151 fn constructor(&self) -> proc_macro2::TokenStream {
153 let qualified_name = self.qualified_name();
154 let field_names = self.field_names();
155 match self {
156 Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
157 Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
158 }
159 }
160}
161
162#[allow(clippy::large_enum_variant)]
165enum Message {
166 Call {
169 variant: Variant,
170 reply_port_is_handle: bool,
172 return_type: Type,
174 log_level: Option<Ident>,
176 },
177 OneWay {
178 variant: Variant,
179 log_level: Option<Ident>,
181 },
182}
183
184impl Message {
185 fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
186 match &variant
187 .field_flags()
188 .iter()
189 .zip(variant.field_types())
190 .filter_map(|(flag, ty)| match flag {
191 FieldFlag::Reply => Some(ty),
192 FieldFlag::None => None,
193 })
194 .collect::<Vec<&Type>>()[..]
195 {
196 [] => Ok(Self::OneWay { variant, log_level }),
197 [reply_port_ty] => {
198 let syn::Type::Path(type_path) = reply_port_ty else {
199 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
200 };
201 let Some(last_segment) = type_path.path.segments.last() else {
202 return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
203 };
204 if last_segment.ident != "OncePortRef" && last_segment.ident != "OncePortHandle" {
205 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
206 }
207 let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
208 return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
209 };
210 let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
211 return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
212 };
213 let reply_port_is_handle = last_segment.ident == "OncePortHandle";
214 let return_type = return_ty.clone();
215 Ok(Self::Call {
216 variant,
217 reply_port_is_handle,
218 return_type,
219 log_level,
220 })
221 }
222 _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
223 }
224 }
225
226 fn args(&self) -> Vec<(Ident, Type)> {
228 match self {
229 Message::Call { variant, .. } => variant
230 .field_names()
231 .into_iter()
232 .zip(variant.field_types().clone())
233 .take(variant.len() - 1)
234 .collect(),
235 Message::OneWay { variant, .. } => variant
236 .field_names()
237 .into_iter()
238 .zip(variant.field_types().clone())
239 .collect(),
240 }
241 }
242
243 fn variant(&self) -> &Variant {
244 match self {
245 Message::Call { variant, .. } => variant,
246 Message::OneWay { variant, .. } => variant,
247 }
248 }
249
250 fn reply_port_position(&self) -> Option<usize> {
251 self.variant()
252 .field_flags()
253 .iter()
254 .position(|flag| matches!(flag, FieldFlag::Reply))
255 }
256
257 fn reply_port_arg(&self) -> Option<(Ident, Type)> {
259 match self {
260 Message::Call { variant, .. } => {
261 let pos = self.reply_port_position()?;
262 Some((
263 variant.field_names()[pos].clone(),
264 variant.field_types()[pos].clone(),
265 ))
266 }
267 Message::OneWay { .. } => None,
268 }
269 }
270}
271
272fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
273 let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
274 Some(attr) => {
275 let Ok(meta) = attr.meta.require_list() else {
276 return Err(syn::Error::new(
277 Span::call_site(),
278 indoc! {"
279 `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
280
281 = help use `#[log_level(info)]` or `#[log_level(error)]`
282 "},
283 ));
284 };
285 let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
286 if parsed.len() != 1 {
287 return Err(syn::Error::new(
288 Span::call_site(),
289 indoc! {"
290 `log_level` attribute must specify exactly one level
291
292 = help use `#[log_level(warn)]` or `#[log_level(info)]`
293 "},
294 ));
295 };
296 Some(parsed.first().unwrap().to_string())
297 }
298 None => None,
299 };
300
301 if level.is_none() {
302 return Ok(None);
303 }
304 let level = level.unwrap();
305
306 match level.as_str() {
307 "error" | "warn" | "info" | "debug" | "trace" => {}
308 _ => {
309 return Err(syn::Error::new(
310 Span::call_site(),
311 indoc! {"
312 `log_level` attribute must be one of 'error, warn, info, debug, trace'
313
314 = help use `#[log_level(warn)]` or `#[log_level(info)]`
315 "},
316 ));
317 }
318 }
319
320 Ok(Some(Ident::new(
321 level.to_ascii_uppercase().as_str(),
322 Span::call_site(),
323 )))
324}
325
326fn parse_field_flag(field: &Field) -> FieldFlag {
327 for attr in field.attrs.iter() {
328 match &attr.meta {
329 syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
330 _ => {}
331 }
332 }
333 FieldFlag::None
334}
335
336fn parse_message_enum(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
338 let variants = if let Data::Enum(data_enum) = &input.data {
339 &data_enum.variants
340 } else {
341 return Err(syn::Error::new_spanned(
342 input,
343 "handlers can only be derived for enums",
344 ));
345 };
346
347 let mut messages = Vec::new();
348
349 for variant in variants {
350 let name = variant.ident.clone();
351 let attrs = &variant.attrs;
352
353 let message_variant = match &variant.fields {
354 syn::Fields::Unnamed(fields_) => Variant::Anon {
355 enum_name: input.ident.clone(),
356 name,
357 field_types: fields_
358 .unnamed
359 .iter()
360 .map(|field| field.ty.clone())
361 .collect(),
362 field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
363 },
364 syn::Fields::Named(fields_) => Variant::Named {
365 enum_name: input.ident.clone(),
366 name,
367 field_names: fields_
368 .named
369 .iter()
370 .map(|field| field.ident.clone().unwrap())
371 .collect(),
372 field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
373 field_flags: fields_.named.iter().map(parse_field_flag).collect(),
374 },
375 _ => {
376 return Err(syn::Error::new_spanned(
377 variant,
378 indoc! {r#"
379 `Handler` currently only supports named or tuple struct variants
380
381 = help use `MyCall(Arg1Type, Arg2Type, ..)`,
382 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
383 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
384 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>)`
385 = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
386 = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>)`
387 "#},
388 ));
389 }
390 };
391 let log_level = parse_log_level(attrs)?;
392
393 messages.push(Message::new(
394 variant.fields.span(),
395 message_variant,
396 log_level,
397 )?);
398 }
399
400 Ok(messages)
401}
402
403#[proc_macro_derive(Handler, attributes(reply))]
560pub fn derive_handler(input: TokenStream) -> TokenStream {
561 let input = parse_macro_input!(input as DeriveInput);
562 let name: Ident = input.ident.clone();
563 let (impl_generics, ty_generics, _) = input.generics.split_for_impl();
564
565 let messages = match parse_message_enum(input.clone()) {
566 Ok(messages) => messages,
567 Err(err) => return TokenStream::from(err.to_compile_error()),
568 };
569
570 let mut handler_trait_methods = Vec::new();
572
573 let mut match_arms = Vec::new();
575
576 let mut client_trait_methods = Vec::new();
578
579 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
580
581 for message in messages {
582 match message {
583 Message::Call {
584 ref variant,
585 ref reply_port_is_handle,
586 ref return_type,
587 ref log_level,
588 } => {
589 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
590 let variant_name_snake = variant.snake_name();
591 let enum_name = variant.enum_name();
592 let _variant_qualified_name = variant.qualified_name();
593 let log_level = match (&global_log_level, log_level) {
594 (_, Some(local)) => local.clone(),
595 (Some(global), None) => global.clone(),
596 _ => Ident::new("DEBUG", Span::call_site()),
597 };
598 let _log_level = if *reply_port_is_handle {
599 quote! {
600 tracing::Level::#log_level
601 }
602 } else {
603 quote! {
604 tracing::Level::TRACE
605 }
606 };
607 let log_message = quote! {
608 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
609 "rpc" => "call",
610 "actor_id" => cx.self_id().to_string(),
611 "message_type" => stringify!(#enum_name),
612 "variant" => stringify!(#variant_name_snake),
613 ));
614 };
615
616 handler_trait_methods.push(quote! {
617 #[doc = "The generated handler method for this enum variant."]
618 async fn #variant_name_snake(
619 &mut self,
620 cx: &hyperactor::Context<Self>,
621 #(#arg_names: #arg_types),*)
622 -> Result<#return_type, hyperactor::anyhow::Error>;
623 });
624
625 client_trait_methods.push(quote! {
626 #[doc = "The generated client method for this enum variant."]
627 async fn #variant_name_snake(
628 &self,
629 caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
630 #(#arg_names: #arg_types),*)
631 -> Result<#return_type, hyperactor::anyhow::Error>;
632 });
633
634 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
635 let constructor = variant.constructor();
636 let result_ident = Ident::new("result", Span::mixed_site());
637 let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
638 if *reply_port_is_handle {
639 match_arms.push(quote! {
640 #constructor => {
641 #log_message
642 #construct_result_future
645 #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
646 }
647 });
648 } else {
649 match_arms.push(quote! {
650 #constructor => {
651 #log_message
652 #construct_result_future
655 #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
656 }
657 });
658 }
659 }
660 Message::OneWay {
661 ref variant,
662 ref log_level,
663 } => {
664 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
665 let variant_name_snake = variant.snake_name();
666 let enum_name = variant.enum_name();
667 let log_level = match (&global_log_level, log_level) {
668 (_, Some(local)) => local.clone(),
669 (Some(global), None) => global.clone(),
670 _ => Ident::new("TRACE", Span::call_site()),
671 };
672 let _log_level = quote! {
673 tracing::Level::#log_level
674 };
675 let log_message = quote! {
676 hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
677 "rpc" => "call",
678 "actor_id" => cx.self_id().to_string(),
679 "message_type" => stringify!(#enum_name),
680 "variant" => stringify!(#variant_name_snake),
681 ));
682 };
683
684 handler_trait_methods.push(quote! {
685 #[doc = "The generated handler method for this enum variant."]
686 async fn #variant_name_snake(
687 &mut self,
688 cx: &hyperactor::Context<Self>,
689 #(#arg_names: #arg_types),*)
690 -> Result<(), hyperactor::anyhow::Error>;
691 });
692
693 client_trait_methods.push(quote! {
694 #[doc = "The generated client method for this enum variant."]
695 async fn #variant_name_snake(
696 &self,
697 caps: &impl hyperactor::cap::CanSend,
698 #(#arg_names: #arg_types),*)
699 -> Result<(), hyperactor::anyhow::Error>;
700 });
701
702 let constructor = variant.constructor();
703
704 match_arms.push(quote! {
705 #constructor => {
706 #log_message
707 self.#variant_name_snake(cx, #(#arg_names),*).await
708 },
709 });
710 }
711 }
712 }
713
714 let handler_trait_name = format_ident!("{}Handler", name);
715 let client_trait_name = format_ident!("{}Client", name);
716
717 let expanded = quote! {
718 #[doc = "The custom handler trait for this message type."]
719 #[hyperactor::async_trait::async_trait]
720 pub trait #handler_trait_name #impl_generics: hyperactor::Actor + Send + Sync {
721 #(#handler_trait_methods)*
722
723 #[doc = "Handle the next message."]
724 async fn handle(
725 &mut self,
726 cx: &hyperactor::Context<Self>,
727 message: #name #ty_generics,
728 ) -> hyperactor::anyhow::Result<()> {
729 match message {
731 #(#match_arms)*
732 }
733 }
734 }
735
736 #[doc = "The custom client trait for this message type."]
737 #[hyperactor::async_trait::async_trait]
738 pub trait #client_trait_name #impl_generics: Send + Sync {
739 #(#client_trait_methods)*
740 }
741 };
742
743 TokenStream::from(expanded)
744}
745
746#[proc_macro_derive(HandleClient, attributes(log_level))]
749pub fn derive_handle_client(input: TokenStream) -> TokenStream {
750 derive_client(input, true)
751}
752
753#[proc_macro_derive(RefClient, attributes(log_level))]
756pub fn derive_ref_client(input: TokenStream) -> TokenStream {
757 derive_client(input, false)
758}
759
760fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
761 let input = parse_macro_input!(input as DeriveInput);
762 let name = input.ident.clone();
763
764 let messages = match parse_message_enum(input.clone()) {
765 Ok(messages) => messages,
766 Err(err) => return TokenStream::from(err.to_compile_error()),
767 };
768
769 let mut impl_methods = Vec::new();
771
772 let send_message = if is_handle {
773 quote! { self.send(message)? }
774 } else {
775 quote! { self.send(caps, message)? }
776 };
777 let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
778
779 for message in messages {
780 match message {
781 Message::Call {
782 ref variant,
783 ref reply_port_is_handle,
784 ref return_type,
785 ref log_level,
786 } => {
787 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
788 let variant_name_snake = variant.snake_name();
789 let enum_name = variant.enum_name();
790
791 let (reply_port_arg, _) = message.reply_port_arg().unwrap();
792 let constructor = variant.constructor();
793 let log_level = match (&global_log_level, log_level) {
794 (_, Some(local)) => local.clone(),
795 (Some(global), None) => global.clone(),
796 _ => Ident::new("DEBUG", Span::call_site()),
797 };
798 let log_level = if is_handle {
799 quote! {
800 tracing::Level::#log_level
801 }
802 } else {
803 quote! {
804 tracing::Level::TRACE
805 }
806 };
807 let log_message = quote! {
808 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
809 "rpc" => "call",
810 "actor_id" => self.actor_id().to_string(),
811 "message_type" => stringify!(#enum_name),
812 "variant" => stringify!(#variant_name_snake),
813 ));
814
815 };
816 if *reply_port_is_handle {
817 impl_methods.push(quote! {
818 #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
819 async fn #variant_name_snake(
820 &self,
821 caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
822 #(#arg_names: #arg_types),*)
823 -> Result<#return_type, hyperactor::anyhow::Error> {
824 let (#reply_port_arg, reply_receiver) =
825 hyperactor::mailbox::open_once_port::<#return_type>(caps);
826 let message = #constructor;
827 #log_message;
828 #send_message;
829 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
830 }
831 });
832 } else {
833 impl_methods.push(quote! {
834 #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
835 async fn #variant_name_snake(
836 &self,
837 caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
838 #(#arg_names: #arg_types),*)
839 -> Result<#return_type, hyperactor::anyhow::Error> {
840 let (#reply_port_arg, reply_receiver) =
841 hyperactor::mailbox::open_once_port::<#return_type>(caps);
842 let #reply_port_arg = #reply_port_arg.bind();
843 let message = #constructor;
844 #log_message;
845 #send_message;
846 reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
847 }
848 });
849 }
850 }
851 Message::OneWay {
852 ref variant,
853 ref log_level,
854 } => {
855 let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
856 let variant_name_snake = variant.snake_name();
857 let enum_name = variant.enum_name();
858 let constructor = variant.constructor();
859 let log_level = match (&global_log_level, log_level) {
860 (_, Some(local)) => local.clone(),
861 (Some(global), None) => global.clone(),
862 _ => Ident::new("DEBUG", Span::call_site()),
863 };
864 let _log_level = if is_handle {
865 quote! {
866 tracing::Level::TRACE
867 }
868 } else {
869 quote! {
870 tracing::Level::#log_level
871 }
872 };
873 let log_message = quote! {
874 hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
875 "rpc" => "oneway",
876 "actor_id" => self.actor_id().to_string(),
877 "message_type" => stringify!(#enum_name),
878 "variant" => stringify!(#variant_name_snake),
879 ));
880 };
881 impl_methods.push(quote! {
882 async fn #variant_name_snake(
883 &self,
884 caps: &impl hyperactor::cap::CanSend,
885 #(#arg_names: #arg_types),*)
886 -> Result<(), hyperactor::anyhow::Error> {
887 let message = #constructor;
888 #log_message;
889 #send_message;
890 Ok(())
891 }
892 });
893 }
894 }
895 }
896
897 let trait_name = format_ident!("{}Client", name);
898
899 let (_, ty_generics, _) = input.generics.split_for_impl();
900
901 let a_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
903 let mut trait_generics = input.generics.clone();
904 trait_generics.params.insert(
905 0,
906 syn::GenericParam::Type(syn::TypeParam {
907 ident: a_ident.clone(),
908 attrs: vec![],
909 colon_token: None,
910 bounds: Punctuated::new(),
911 eq_token: None,
912 default: None,
913 }),
914 );
915 let (impl_generics, _, _) = trait_generics.split_for_impl();
916
917 let expanded = if is_handle {
918 quote! {
919 #[hyperactor::async_trait::async_trait]
920 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#a_ident>
921 where #a_ident: hyperactor::Handler<#name #ty_generics> {
922 #(#impl_methods)*
923 }
924 }
925 } else {
926 quote! {
927 #[hyperactor::async_trait::async_trait]
928 impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#a_ident>
929 where #a_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
930 #(#impl_methods)*
931 }
932 }
933 };
934
935 TokenStream::from(expanded)
936}
937
938const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
939`forward` expects the message type that is being forwarded
940
941= help: use `#[forward(MessageType)]`
942"#};
943
944#[proc_macro_attribute]
946pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
947 let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
948 if attr_args.len() != 1 {
949 return TokenStream::from(
950 syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
951 );
952 }
953
954 let message_type = attr_args.first().unwrap();
955 let input = parse_macro_input!(item as ItemImpl);
956
957 let self_type = match *input.self_ty {
958 syn::Type::Path(ref type_path) => {
959 let segment = type_path.path.segments.last().unwrap();
960 segment.clone() }
962 _ => {
963 return TokenStream::from(
964 syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
965 .to_compile_error(),
966 );
967 }
968 };
969
970 let trait_name = match input.trait_ {
971 Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
972 None => {
973 return TokenStream::from(
974 syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
975 .to_compile_error(),
976 );
977 }
978 };
979
980 let expanded = quote! {
981 #input
982
983 #[hyperactor::async_trait::async_trait]
984 impl hyperactor::Handler<#message_type> for #self_type {
985 async fn handle(
986 &mut self,
987 cx: &hyperactor::Context<Self>,
988 message: #message_type,
989 ) -> hyperactor::anyhow::Result<()> {
990 <Self as #trait_name>::handle(self, cx, message).await
991 }
992 }
993 };
994
995 TokenStream::from(expanded)
996}
997
998#[proc_macro_attribute]
1011pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1012 let args =
1013 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1014 let input = parse_macro_input!(input as ItemFn);
1015 let output = quote! {
1016 #[hyperactor::tracing::instrument(err, skip_all, #args)]
1017 #input
1018 };
1019
1020 TokenStream::from(output)
1021}
1022
1023#[proc_macro_attribute]
1034pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1035 let args =
1036 parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1037 let input = parse_macro_input!(input as ItemFn);
1038
1039 let output = quote! {
1040 #[hyperactor::tracing::instrument(skip_all, #args)]
1041 #input
1042 };
1043
1044 TokenStream::from(output)
1045}
1046
1047#[proc_macro_derive(Named, attributes(named))]
1052pub fn named_derive(input: TokenStream) -> TokenStream {
1053 let input = parse_macro_input!(input as DeriveInput);
1055 let struct_name = &input.ident;
1056
1057 let mut typename = quote! {
1058 concat!(std::module_path!(), "::", stringify!(#struct_name))
1059 };
1060
1061 for attr in &input.attrs {
1062 if attr.path().is_ident("named") {
1063 if let Ok(meta) = attr.parse_args_with(
1064 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
1065 ) {
1066 for item in meta {
1067 if let Meta::NameValue(MetaNameValue {
1068 path,
1069 value: Expr::Lit(expr_lit),
1070 ..
1071 }) = item
1072 {
1073 if path.is_ident("name") {
1074 if let Lit::Str(name) = expr_lit.lit {
1075 typename = quote! { #name };
1076 }
1077 } else {
1078 return TokenStream::from(
1079 syn::Error::new_spanned(
1080 path,
1081 "unsupported attribute (only `name` is supported)",
1082 )
1083 .to_compile_error(),
1084 );
1085 }
1086 }
1087 }
1088 }
1089 }
1090 }
1091
1092 let type_params: Vec<_> = input.generics.type_params().collect();
1094 let has_generics = !type_params.is_empty();
1095
1096 let mut generics_with_bounds = input.generics.clone();
1098 if has_generics {
1099 for param in generics_with_bounds.type_params_mut() {
1100 param
1101 .bounds
1102 .push(syn::parse_quote!(hyperactor::data::Named));
1103 }
1104 }
1105 let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1106
1107 let (typename_impl, typehash_impl) = if has_generics {
1109 let placeholders = vec!["{}"; type_params.len()].join(", ");
1111 let placeholders_format_string = format!("<{}>", placeholders);
1112 let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#struct_name), #placeholders_format_string) };
1113
1114 let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1115 (
1116 quote! {
1117 hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1118 },
1119 quote! {
1120 hyperactor::cityhasher::hash(Self::typename())
1121 },
1122 )
1123 } else {
1124 (
1125 typename,
1126 quote! {
1127 static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1128 hyperactor::cityhasher::hash(<#struct_name as hyperactor::data::Named>::typename())
1129 });
1130 *TYPEHASH
1131 },
1132 )
1133 };
1134
1135 let arm_impl = match &input.data {
1137 Data::Enum(DataEnum { variants, .. }) => {
1138 let match_arms = variants.iter().map(|v| {
1139 let variant_name = &v.ident;
1140 let variant_str = variant_name.to_string();
1141 match &v.fields {
1142 Fields::Unit => quote! { Self::#variant_name => Some(#variant_str) },
1143 Fields::Unnamed(_) => quote! { Self::#variant_name(..) => Some(#variant_str) },
1144 Fields::Named(_) => quote! { Self::#variant_name { .. } => Some(#variant_str) },
1145 }
1146 });
1147 quote! {
1148 fn arm(&self) -> Option<&'static str> {
1149 match self {
1150 #(#match_arms,)*
1151 }
1152 }
1153 }
1154 }
1155 _ => quote! {},
1156 };
1157
1158 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
1159 let expanded = quote! {
1162 impl #impl_generics_with_bounds hyperactor::data::Named for #struct_name #ty_generics #where_clause {
1163 fn typename() -> &'static str { #typename_impl }
1164 fn typehash() -> u64 { #typehash_impl }
1165 #arm_impl
1166 }
1167 };
1168
1169 TokenStream::from(expanded)
1170}
1171
1172struct HandlerSpec {
1173 ty: Type,
1174 cast: bool,
1175}
1176
1177impl Parse for HandlerSpec {
1178 fn parse(input: ParseStream) -> syn::Result<Self> {
1179 let ty: Type = input.parse()?;
1180
1181 if input.peek(syn::token::Brace) {
1182 let content;
1183 syn::braced!(content in input);
1184 let key: Ident = content.parse()?;
1185 content.parse::<Token![=]>()?;
1186 let expr: Expr = content.parse()?;
1187
1188 let cast = if key == "cast" {
1189 if let Expr::Lit(ExprLit {
1190 lit: Lit::Bool(b), ..
1191 }) = expr
1192 {
1193 b.value
1194 } else {
1195 return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1196 }
1197 } else {
1198 return Err(syn::Error::new_spanned(
1199 key,
1200 "unsupported field (expected `cast`)",
1201 ));
1202 };
1203
1204 Ok(HandlerSpec { ty, cast })
1205 } else if input.is_empty() || input.peek(Token![,]) {
1206 Ok(HandlerSpec { ty, cast: false })
1207 } else {
1208 let unexpected: proc_macro2::TokenTree = input.parse()?;
1210 Err(syn::Error::new_spanned(
1211 unexpected,
1212 "unexpected token after type — expected `{ ... }` or nothing",
1213 ))
1214 }
1215 }
1216}
1217
1218impl HandlerSpec {
1219 fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1220 let mut tys = Vec::new();
1221 for HandlerSpec { ty, cast } in handlers {
1222 if cast {
1223 let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1224 let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1225 tys.push(wrapped_ty);
1226 }
1227 tys.push(ty);
1228 }
1229 tys
1230 }
1231}
1232
1233struct ExportAttr {
1235 spawn: bool,
1236 handlers: Vec<HandlerSpec>,
1237}
1238
1239impl Parse for ExportAttr {
1240 fn parse(input: ParseStream) -> syn::Result<Self> {
1241 let mut spawn = false;
1242 let mut handlers: Vec<HandlerSpec> = vec![];
1243
1244 while !input.is_empty() {
1245 let key: Ident = input.parse()?;
1246 input.parse::<Token![=]>()?;
1247
1248 if key == "spawn" {
1249 let expr: Expr = input.parse()?;
1250 if let Expr::Lit(ExprLit {
1251 lit: Lit::Bool(b), ..
1252 }) = expr
1253 {
1254 spawn = b.value;
1255 } else {
1256 return Err(syn::Error::new_spanned(
1257 expr,
1258 "expected boolean for `spawn`",
1259 ));
1260 }
1261 } else if key == "handlers" {
1262 let content;
1263 bracketed!(content in input);
1264 let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1265 handlers = raw_handlers.into_iter().collect();
1266 } else {
1267 return Err(syn::Error::new_spanned(
1268 key,
1269 "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1270 ));
1271 }
1272
1273 let _ = input.parse::<Token![,]>();
1275 }
1276
1277 Ok(ExportAttr { spawn, handlers })
1278 }
1279}
1280
1281#[proc_macro_attribute]
1308pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1309 let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1310 let data_type_name = &input.ident;
1311
1312 let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1313 let tys = HandlerSpec::add_indexed(handlers);
1314
1315 let mut handles = Vec::new();
1316 let mut bindings = Vec::new();
1317
1318 for ty in &tys {
1319 handles.push(quote! {
1320 impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
1321 });
1322 bindings.push(quote! {
1323 ports.bind::<#ty>();
1324 });
1325 }
1326
1327 let mut expanded = quote! {
1328 #input
1329
1330 impl hyperactor::actor::RemoteActor for #data_type_name {}
1331
1332 #(#handles)*
1333
1334 impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
1336
1337 impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
1338 fn bind(ports: &hyperactor::proc::Ports<Self>) {
1339 #(#bindings)*
1340 }
1341 }
1342
1343 impl hyperactor::data::Named for #data_type_name {
1345 fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
1346 }
1347 };
1348
1349 if spawn {
1350 expanded.extend(quote! {
1351
1352 hyperactor::remote!(#data_type_name);
1353 });
1354 }
1355
1356 TokenStream::from(expanded)
1357}
1358
1359struct AliasInput {
1361 alias: Ident,
1362 handlers: Vec<HandlerSpec>,
1363}
1364
1365impl syn::parse::Parse for AliasInput {
1366 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1367 let alias: Ident = input.parse()?;
1368 let _: Token![,] = input.parse()?;
1369 let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1370 let handlers = raw_handlers.into_iter().collect();
1371 Ok(AliasInput { alias, handlers })
1372 }
1373}
1374
1375#[proc_macro]
1391pub fn alias(input: TokenStream) -> TokenStream {
1392 let AliasInput { alias, handlers } = parse_macro_input!(input as AliasInput);
1393 let tys = HandlerSpec::add_indexed(handlers);
1394
1395 let expanded = quote! {
1396 #[doc = "The generated alias struct."]
1397 #[derive(Debug, Named)]
1398 pub struct #alias;
1399 impl hyperactor::actor::RemoteActor for #alias {}
1400
1401 impl<A> hyperactor::actor::Binds<A> for #alias
1402 where
1403 A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)* {
1404 fn bind(ports: &hyperactor::proc::Ports<A>) {
1405 #(
1406 ports.bind::<#tys>();
1407 )*
1408 }
1409 }
1410
1411 #(
1412 impl hyperactor::actor::RemoteHandles<#tys> for #alias {}
1413 )*
1414 };
1415
1416 TokenStream::from(expanded)
1417}
1418
1419fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1420 let mut is_included = false;
1421 for attr in &field.attrs {
1422 if attr.path().is_ident("binding") {
1423 attr.parse_nested_meta(|meta| {
1425 if meta.path.is_ident("include") {
1426 is_included = true;
1427 Ok(())
1428 } else {
1429 let path = meta.path.to_token_stream().to_string().replace(' ', "");
1430 Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1431 }
1432 })?
1433 }
1434 }
1435 Ok(is_included)
1436}
1437
1438enum FieldAccessor {
1443 Named(Ident),
1444 Unnamed(Index),
1445}
1446
1447struct ParsedField {
1449 accessor: FieldAccessor,
1450 ty: Type,
1451 included: bool,
1452}
1453
1454impl From<&ParsedField> for (Ident, Type) {
1455 fn from(field: &ParsedField) -> Self {
1456 let field_ident = match &field.accessor {
1457 FieldAccessor::Named(ident) => ident.clone(),
1458 FieldAccessor::Unnamed(i) => {
1459 Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1460 }
1461 };
1462 (field_ident, field.ty.clone())
1463 }
1464}
1465
1466fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1467 match fields {
1468 Fields::Named(named) => named
1469 .named
1470 .iter()
1471 .map(|f| {
1472 let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1473 Ok(ParsedField {
1474 accessor,
1475 ty: f.ty.clone(),
1476 included: include_in_bind_unbind(f)?,
1477 })
1478 })
1479 .collect(),
1480 Fields::Unnamed(unnamed) => unnamed
1481 .unnamed
1482 .iter()
1483 .enumerate()
1484 .map(|(i, f)| {
1485 let accessor = FieldAccessor::Unnamed(Index::from(i));
1486 Ok(ParsedField {
1487 accessor,
1488 ty: f.ty.clone(),
1489 included: include_in_bind_unbind(f)?,
1490 })
1491 })
1492 .collect(),
1493 Fields::Unit => Ok(Vec::new()),
1494 }
1495}
1496
1497fn gen_struct_items<F>(
1498 fields: &Fields,
1499 make_item: F,
1500 is_mutable: bool,
1501) -> syn::Result<Vec<proc_macro2::TokenStream>>
1502where
1503 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1504{
1505 let borrow = if is_mutable {
1506 quote! { &mut }
1507 } else {
1508 quote! { & }
1509 };
1510 let items: Vec<_> = collect_all_fields(fields)?
1511 .into_iter()
1512 .filter(|f| f.included)
1513 .map(
1514 |ParsedField {
1515 accessor,
1516 ty,
1517 included,
1518 }| {
1519 assert!(included);
1520 let field_accessor = match accessor {
1521 FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1522 FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1523 };
1524 make_item(field_accessor, ty)
1525 },
1526 )
1527 .collect();
1528 Ok(items)
1529}
1530
1531fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1541 all_fields
1542 .iter()
1543 .map(
1544 |ParsedField {
1545 accessor,
1546 ty: _,
1547 included,
1548 }| {
1549 match accessor {
1550 FieldAccessor::Named(ident) => {
1551 if *included {
1552 quote! { #ident }
1553 } else {
1554 quote! { #ident: _ }
1555 }
1556 }
1557 FieldAccessor::Unnamed(i) => {
1558 if *included {
1559 let ident = Ident::new(
1560 &format!("f{}", i.index),
1561 proc_macro2::Span::call_site(),
1562 );
1563 quote! { #ident }
1564 } else {
1565 quote! { _ }
1566 }
1567 }
1568 }
1569 },
1570 )
1571 .collect()
1572}
1573
1574fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1581where
1582 F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1583{
1584 data.variants
1585 .iter()
1586 .map(|variant| {
1587 let name = &variant.ident;
1588 let all_fields = collect_all_fields(&variant.fields)?;
1589 let field_accessors = gen_enum_field_accessors(&all_fields);
1590 let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1591 let items = included_fields
1592 .iter()
1593 .map(|f| {
1594 let (accessor, ty) = <(Ident, Type)>::from(*f);
1595 make_item(quote! { #accessor }, ty)
1596 })
1597 .collect::<Vec<_>>();
1598
1599 Ok(match &variant.fields {
1600 Fields::Named(_) => {
1601 quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1602 }
1603 Fields::Unnamed(_) => {
1604 quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1605 }
1606 Fields::Unit => quote! { Self::#name => { #(#items)* } },
1607 })
1608 })
1609 .collect()
1610}
1611
1612#[proc_macro_derive(Bind, attributes(binding))]
1694pub fn derive_bind(input: TokenStream) -> TokenStream {
1695 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1696 quote! {
1697 hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1698 }
1699 }
1700
1701 let input = parse_macro_input!(input as DeriveInput);
1702 let name = &input.ident;
1703 let inner = match &input.data {
1704 Data::Struct(DataStruct { fields, .. }) => {
1705 match gen_struct_items(fields, make_item, true) {
1706 Ok(collects) => {
1707 quote! { #(#collects)* }
1708 }
1709 Err(e) => {
1710 return TokenStream::from(e.to_compile_error());
1711 }
1712 }
1713 }
1714 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1715 Ok(arms) => {
1716 quote! { match self { #(#arms),* } }
1717 }
1718 Err(e) => {
1719 return TokenStream::from(e.to_compile_error());
1720 }
1721 },
1722 _ => panic!("Bind can only be derived for structs and enums"),
1723 };
1724 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1725 let expand = quote! {
1726 #[automatically_derived]
1727 impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1728 fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1729 #inner
1730 Ok(())
1731 }
1732 }
1733 };
1734 TokenStream::from(expand)
1735}
1736
1737#[proc_macro_derive(Unbind, attributes(binding))]
1751pub fn derive_unbind(input: TokenStream) -> TokenStream {
1752 fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1753 quote! {
1754 hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1755 }
1756 }
1757
1758 let input = parse_macro_input!(input as DeriveInput);
1759 let name = &input.ident;
1760 let inner = match &input.data {
1761 Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1762 {
1763 Ok(collects) => {
1764 quote! { #(#collects)* }
1765 }
1766 Err(e) => {
1767 return TokenStream::from(e.to_compile_error());
1768 }
1769 },
1770 Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1771 Ok(arms) => {
1772 quote! { match self { #(#arms),* } }
1773 }
1774 Err(e) => {
1775 return TokenStream::from(e.to_compile_error());
1776 }
1777 },
1778 _ => panic!("Unbind can only be derived for structs and enums"),
1779 };
1780 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1781 let expand = quote! {
1782 #[automatically_derived]
1783 impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1784 fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1785 #inner
1786 Ok(())
1787 }
1788 }
1789 };
1790 TokenStream::from(expand)
1791}
1792
1793#[proc_macro_derive(Actor, attributes(actor))]
1840pub fn derive_actor(input: TokenStream) -> TokenStream {
1841 let input = parse_macro_input!(input as DeriveInput);
1842 let name = &input.ident;
1843 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1844
1845 let is_passthrough = input.attrs.iter().any(|attr| {
1846 if attr.path().is_ident("actor") {
1847 if let Ok(meta) = attr.parse_args_with(
1848 syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
1849 ) {
1850 return meta.iter().any(|ident| ident == "passthrough");
1851 }
1852 }
1853 false
1854 });
1855
1856 let expanded = if is_passthrough {
1857 quote! {
1858 #[hyperactor::async_trait::async_trait]
1859 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1860 type Params = Self;
1861
1862 async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
1863 Ok(instance)
1864 }
1865 }
1866 }
1867 } else {
1868 quote! {
1869 #[hyperactor::async_trait::async_trait]
1870 impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1871 type Params = ();
1872
1873 async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
1874 Ok(Default::default())
1875 }
1876 }
1877 }
1878 };
1879
1880 TokenStream::from(expanded)
1881}