Skip to main content

hyperactor_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines macros used by the [`hyperactor`] crate.
10
11#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::WherePredicate;
41use syn::bracketed;
42use syn::parse::Parse;
43use syn::parse::ParseStream;
44use syn::parse_macro_input;
45use syn::punctuated::Punctuated;
46use syn::spanned::Spanned;
47
48const REPLY_VARIANT_ERROR: &str = indoc! {r#"
49`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
50
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
53"#};
54
55const REPLY_USAGE_ERROR: &str = indoc! {r#"
56`call` message expects at most one `reply` argument
57
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
60"#};
61
62enum FieldFlag {
63    None,
64    Reply,
65}
66
67/// Represents a variant of an enum.
68#[allow(dead_code)]
69enum Variant {
70    /// A named variant (i.e., `MyVariant { .. }`).
71    Named {
72        enum_name: Ident,
73        name: Ident,
74        field_names: Vec<Ident>,
75        field_types: Vec<Type>,
76        field_flags: Vec<FieldFlag>,
77        is_struct: bool,
78        generics: syn::Generics,
79    },
80    /// An anonymous variant (i.e., `MyVariant(..)`).
81    Anon {
82        enum_name: Ident,
83        name: Ident,
84        field_types: Vec<Type>,
85        field_flags: Vec<FieldFlag>,
86        is_struct: bool,
87        generics: syn::Generics,
88    },
89}
90
91impl Variant {
92    /// The number of fields in the variant.
93    fn len(&self) -> usize {
94        self.field_types().len()
95    }
96
97    /// Returns whether this variant was defined as a struct.
98    fn is_struct(&self) -> bool {
99        match self {
100            Variant::Named { is_struct, .. } => *is_struct,
101            Variant::Anon { is_struct, .. } => *is_struct,
102        }
103    }
104
105    /// The name of the enum containing the variant.
106    fn enum_name(&self) -> &Ident {
107        match self {
108            Variant::Named { enum_name, .. } => enum_name,
109            Variant::Anon { enum_name, .. } => enum_name,
110        }
111    }
112
113    /// The name of the variant itself.
114    fn name(&self) -> &Ident {
115        match self {
116            Variant::Named { name, .. } => name,
117            Variant::Anon { name, .. } => name,
118        }
119    }
120
121    /// The generics of the variant itself.
122    #[allow(dead_code)]
123    fn generics(&self) -> &syn::Generics {
124        match self {
125            Variant::Named { generics, .. } => generics,
126            Variant::Anon { generics, .. } => generics,
127        }
128    }
129
130    /// The snake_name of the variant itself.
131    fn snake_name(&self) -> Ident {
132        Ident::new(
133            &self.name().to_string().to_case(Case::Snake),
134            self.name().span(),
135        )
136    }
137
138    /// The variant's qualified name.
139    fn qualified_name(&self) -> proc_macro2::TokenStream {
140        let enum_name = self.enum_name();
141        let name = self.name();
142
143        if self.is_struct() {
144            quote! { #enum_name }
145        } else {
146            quote! { #enum_name::#name }
147        }
148    }
149
150    /// Names of the fields in the variant. Anonymous variants are named
151    /// according to their position in the argument list.
152    fn field_names(&self) -> Vec<Ident> {
153        match self {
154            Variant::Named { field_names, .. } => field_names.clone(),
155            Variant::Anon { field_types, .. } => (0usize..field_types.len())
156                .map(|idx| format_ident!("arg{}", idx))
157                .collect(),
158        }
159    }
160
161    /// The types of the fields int the variant.
162    fn field_types(&self) -> &Vec<Type> {
163        match self {
164            Variant::Named { field_types, .. } => field_types,
165            Variant::Anon { field_types, .. } => field_types,
166        }
167    }
168
169    /// Return the field flags for this variant.
170    fn field_flags(&self) -> &Vec<FieldFlag> {
171        match self {
172            Variant::Named { field_flags, .. } => field_flags,
173            Variant::Anon { field_flags, .. } => field_flags,
174        }
175    }
176
177    /// The constructor for the variant, using the field names directly.
178    fn constructor(&self) -> proc_macro2::TokenStream {
179        let qualified_name = self.qualified_name();
180        let field_names = self.field_names();
181        match self {
182            Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
183            Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
184        }
185    }
186}
187
188struct ReplyPort {
189    is_handle: bool,
190    is_once: bool,
191}
192
193impl ReplyPort {
194    fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
195        ReplyPort {
196            is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
197            is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
198        }
199    }
200
201    fn open_op(&self) -> proc_macro2::TokenStream {
202        if self.is_once {
203            quote! { hyperactor::mailbox::open_once_port }
204        } else {
205            quote! { hyperactor::mailbox::open_port }
206        }
207    }
208
209    fn rx_modifier(&self) -> proc_macro2::TokenStream {
210        if self.is_once {
211            quote! {}
212        } else {
213            quote! { mut }
214        }
215    }
216}
217
218/// Represents a message that can be sent to a handler, each message is associated with
219/// a variant.
220#[allow(clippy::large_enum_variant)]
221enum Message {
222    /// A call message is a request-response message, the last argument is
223    /// a [`hyperactor::OncePortRef`] or [`hyperactor::OncePortHandle`].
224    Call {
225        variant: Variant,
226        /// Tells whether the reply argument is a handle.
227        reply_port: ReplyPort,
228        /// The underlying return type (i.e., the type of the reply port).
229        return_type: Type,
230        /// the log level for generated instrumentation for handlers of this message.
231        log_level: Option<Ident>,
232    },
233    OneWay {
234        variant: Variant,
235        /// the log level for generated instrumentation for handlers of this message.
236        log_level: Option<Ident>,
237    },
238}
239
240impl Message {
241    fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
242        match &variant
243            .field_flags()
244            .iter()
245            .zip(variant.field_types())
246            .filter_map(|(flag, ty)| match flag {
247                FieldFlag::Reply => Some(ty),
248                FieldFlag::None => None,
249            })
250            .collect::<Vec<&Type>>()[..]
251        {
252            [] => Ok(Self::OneWay { variant, log_level }),
253            [reply_port_ty] => {
254                let syn::Type::Path(type_path) = reply_port_ty else {
255                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
256                };
257                let Some(last_segment) = type_path.path.segments.last() else {
258                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
259                };
260                if last_segment.ident != "OncePortRef"
261                    && last_segment.ident != "OncePortHandle"
262                    && last_segment.ident != "PortRef"
263                    && last_segment.ident != "PortHandle"
264                {
265                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
266                }
267                let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
268                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
269                };
270                let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
271                    return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
272                };
273                let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
274                let return_type = return_ty.clone();
275                Ok(Self::Call {
276                    variant,
277                    reply_port,
278                    return_type,
279                    log_level,
280                })
281            }
282            _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
283        }
284    }
285
286    /// The arguments of this message.
287    fn args(&self) -> Vec<(Ident, Type)> {
288        match self {
289            Message::Call { variant, .. } => variant
290                .field_names()
291                .into_iter()
292                .zip(variant.field_types().clone())
293                .take(variant.len() - 1)
294                .collect(),
295            Message::OneWay { variant, .. } => variant
296                .field_names()
297                .into_iter()
298                .zip(variant.field_types().clone())
299                .collect(),
300        }
301    }
302
303    fn variant(&self) -> &Variant {
304        match self {
305            Message::Call { variant, .. } => variant,
306            Message::OneWay { variant, .. } => variant,
307        }
308    }
309
310    fn reply_port_position(&self) -> Option<usize> {
311        self.variant()
312            .field_flags()
313            .iter()
314            .position(|flag| matches!(flag, FieldFlag::Reply))
315    }
316
317    /// The reply port argument of this message.
318    fn reply_port_arg(&self) -> Option<(Ident, Type)> {
319        match self {
320            Message::Call { variant, .. } => {
321                let pos = self.reply_port_position()?;
322                Some((
323                    variant.field_names()[pos].clone(),
324                    variant.field_types()[pos].clone(),
325                ))
326            }
327            Message::OneWay { .. } => None,
328        }
329    }
330}
331
332fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
333    let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
334        Some(attr) => {
335            let Ok(meta) = attr.meta.require_list() else {
336                return Err(syn::Error::new(
337                    Span::call_site(),
338                    indoc! {"
339                            `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
340
341                            = help use `#[log_level(info)]` or `#[log_level(error)]`
342                        "},
343                ));
344            };
345            let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
346            if parsed.len() != 1 {
347                return Err(syn::Error::new(
348                    Span::call_site(),
349                    indoc! {"
350                            `log_level` attribute must specify exactly one level
351
352                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
353                        "},
354                ));
355            };
356            Some(parsed.first().unwrap().to_string())
357        }
358        None => None,
359    };
360
361    if level.is_none() {
362        return Ok(None);
363    }
364    let level = level.unwrap();
365
366    match level.as_str() {
367        "error" | "warn" | "info" | "debug" | "trace" => {}
368        _ => {
369            return Err(syn::Error::new(
370                Span::call_site(),
371                indoc! {"
372                            `log_level` attribute must be one of 'error, warn, info, debug, trace'
373
374                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
375                        "},
376            ));
377        }
378    }
379
380    Ok(Some(Ident::new(
381        level.to_ascii_uppercase().as_str(),
382        Span::call_site(),
383    )))
384}
385
386fn parse_field_flag(field: &Field) -> FieldFlag {
387    for attr in field.attrs.iter() {
388        match &attr.meta {
389            syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
390            _ => {}
391        }
392    }
393    FieldFlag::None
394}
395
396/// Parse a message enum or struct into its constituent messages.
397fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
398    match &input.data {
399        Data::Enum(data_enum) => {
400            let mut messages = Vec::new();
401
402            for variant in &data_enum.variants {
403                let name = variant.ident.clone();
404                let attrs = &variant.attrs;
405
406                let message_variant = match &variant.fields {
407                    syn::Fields::Unnamed(fields_) => Variant::Anon {
408                        enum_name: input.ident.clone(),
409                        name,
410                        field_types: fields_
411                            .unnamed
412                            .iter()
413                            .map(|field| field.ty.clone())
414                            .collect(),
415                        field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
416                        is_struct: false,
417                        generics: input.generics.clone(),
418                    },
419                    syn::Fields::Named(fields_) => Variant::Named {
420                        enum_name: input.ident.clone(),
421                        name,
422                        field_names: fields_
423                            .named
424                            .iter()
425                            .map(|field| field.ident.clone().unwrap())
426                            .collect(),
427                        field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
428                        field_flags: fields_.named.iter().map(parse_field_flag).collect(),
429                        is_struct: false,
430                        generics: input.generics.clone(),
431                    },
432                    _ => {
433                        return Err(syn::Error::new_spanned(
434                            variant,
435                            indoc! {r#"
436                                `Handler` currently only supports named or tuple struct variants
437
438                                = help use `MyCall(Arg1Type, Arg2Type, ..)`,
439                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
440                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
441                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
442                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
443                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
444                              "#},
445                        ));
446                    }
447                };
448                let log_level = parse_log_level(attrs)?;
449
450                messages.push(Message::new(
451                    variant.fields.span(),
452                    message_variant,
453                    log_level,
454                )?);
455            }
456
457            Ok(messages)
458        }
459        Data::Struct(data_struct) => {
460            let struct_name = input.ident.clone();
461            let attrs = &input.attrs;
462
463            let message_variant = match &data_struct.fields {
464                syn::Fields::Unnamed(fields_) => Variant::Anon {
465                    enum_name: struct_name.clone(),
466                    name: struct_name,
467                    field_types: fields_
468                        .unnamed
469                        .iter()
470                        .map(|field| field.ty.clone())
471                        .collect(),
472                    field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
473                    is_struct: true,
474                    generics: input.generics.clone(),
475                },
476                syn::Fields::Named(fields_) => Variant::Named {
477                    enum_name: struct_name.clone(),
478                    name: struct_name,
479                    field_names: fields_
480                        .named
481                        .iter()
482                        .map(|field| field.ident.clone().unwrap())
483                        .collect(),
484                    field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
485                    field_flags: fields_.named.iter().map(parse_field_flag).collect(),
486                    is_struct: true,
487                    generics: input.generics.clone(),
488                },
489                syn::Fields::Unit => Variant::Anon {
490                    enum_name: struct_name.clone(),
491                    name: struct_name,
492                    field_types: Vec::new(),
493                    field_flags: Vec::new(),
494                    is_struct: true,
495                    generics: input.generics.clone(),
496                },
497            };
498
499            let log_level = parse_log_level(attrs)?;
500            let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
501
502            Ok(vec![message])
503        }
504        _ => Err(syn::Error::new_spanned(
505            input,
506            "handlers can only be derived for enums and structs",
507        )),
508    }
509}
510
511/// Derive a custom handler trait for given an enum containing tuple
512/// structs.  The handler trait defines a method corresponding
513/// to each of the enum's variants, and a `handle` function
514/// that dispatches messages to the correct method.  The macro
515/// supports two messaging patterns: "call" and "oneway". A call is a
516/// request-response message; a [`hyperactor::mailbox::OncePortRef`] or
517/// [`hyperactor::mailbox::OncePortHandle`] in the last position is used
518/// to send the return value.
519///
520/// The macro also derives a client trait that can be automatically implemented
521/// by specifying [`HandleClient`] for `ActorHandle<Actor>` and [`RefClient`]
522/// for `ActorRef<Actor>` accordingly. We require two implementations because
523/// not `ActorRef`s require that its message type is serializable.
524///
525/// The associated [`hyperactor_macros::handler`] macro can be used to add
526/// a dispatching handler directly to an [`hyperactor::Actor`].
527///
528/// # Example
529///
530/// The following example creates a "shopping list" actor responsible for
531/// maintaining a shopping list.
532///
533/// ```
534/// use std::collections::HashSet;
535/// use std::time::Duration;
536///
537/// use async_trait::async_trait;
538/// use hyperactor::Actor;
539/// use hyperactor::HandleClient;
540/// use hyperactor::Handler;
541/// use hyperactor::Instance;
542/// use hyperactor::RefClient;
543/// use hyperactor::proc::Proc;
544/// use hyperactor::reference;
545/// use serde::Deserialize;
546/// use serde::Serialize;
547/// use typeuri::Named;
548///
549/// #[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
550/// enum ShoppingList {
551///     // Oneway messages dispatch messages asynchronously, with no reply.
552///     Add(String),
553///     Remove(String),
554///
555///     // Call messages dispatch a request, expecting a reply to the
556///     // provided port, which must be in the last position.
557///     Exists(String, #[reply] reference::OncePortRef<bool>),
558///
559///     List(#[reply] reference::OncePortRef<Vec<String>>),
560/// }
561///
562/// // Define an actor.
563/// #[derive(Debug)]
564/// #[hyperactor::export(ShoppingList)]
565/// #[hyperactor::spawnable]
566/// struct ShoppingListActor(HashSet<String>);
567///
568/// #[async_trait]
569/// impl Actor for ShoppingListActor {
570///     type Params = ();
571///
572///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
573///         Ok(Self(HashSet::new()))
574///     }
575/// }
576///
577/// // ShoppingListHandler is the trait generated by derive(Handler) above.
578/// // We implement the trait here for the actor, defining a handler for
579/// // each ShoppingList message.
580/// //
581/// // The `handle` attribute installs a handler that routes messages
582/// // to the `ShoppingListHandler` implementation directly. This can also
583/// // be done manually:
584/// //
585/// // ```ignore
586/// //<ShoppingListActor as ShoppingListHandler>
587/// //     ::handle(self, comm, message).await
588/// // ```
589/// #[async_trait]
590/// #[hyperactor::handle(ShoppingList)]
591/// impl ShoppingListHandler for ShoppingListActor {
592///     async fn add(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
593///         eprintln!("insert {}", item);
594///         self.0.insert(item);
595///         Ok(())
596///     }
597///
598///     async fn remove(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
599///         eprintln!("remove {}", item);
600///         self.0.remove(&item);
601///         Ok(())
602///     }
603///
604///     async fn exists(
605///         &mut self,
606///         _cx: &Context<Self>,
607///         item: String,
608///     ) -> Result<bool, anyhow::Error> {
609///         Ok(self.0.contains(&item))
610///     }
611///
612///     async fn list(&mut self, _cx: &Context<Self>) -> Result<Vec<String>, anyhow::Error> {
613///         Ok(self.0.iter().cloned().collect())
614///     }
615/// }
616///
617/// #[tokio::main]
618/// async fn main() -> Result<(), anyhow::Error> {
619///     let mut proc = Proc::isolated();
620///
621///     // Spawn our actor, and get a handle for rank 0.
622///     let shopping_list_actor: hyperactor::ActorHandle<ShoppingListActor> =
623///         proc.spawn("shopping", ()).await?;
624///
625///     // We join the system, so that we can send messages to actors.
626///     let client = proc.attach("client").unwrap();
627///
628///     // todo: consider making this a macro to remove the magic names
629///
630///     // Derive(Handler) generates client methods, which call the
631///     // remote handler provided an actor instance,
632///     // the destination actor, and the method arguments.
633///
634///     shopping_list_actor.add(&client, "milk".into()).await?;
635///     shopping_list_actor.add(&client, "eggs".into()).await?;
636///
637///     println!(
638///         "got milk? {}",
639///         shopping_list_actor.exists(&client, "milk".into()).await?
640///     );
641///     println!(
642///         "got yoghurt? {}",
643///         shopping_list_actor
644///             .exists(&client, "yoghurt".into())
645///             .await?
646///     );
647///
648///     shopping_list_actor.remove(&client, "milk".into()).await?;
649///     println!(
650///         "got milk now? {}",
651///         shopping_list_actor.exists(&client, "milk".into()).await?
652///     );
653///
654///     println!(
655///         "shopping list: {:?}",
656///         shopping_list_actor.list(&client).await?
657///     );
658///
659///     let _ = proc
660///         .destroy_and_wait(Duration::from_secs(1), "example cleanup")
661///         .await?;
662///     Ok(())
663/// }
664/// ```
665#[proc_macro_derive(Handler, attributes(reply))]
666pub fn derive_handler(input: TokenStream) -> TokenStream {
667    let input = parse_macro_input!(input as DeriveInput);
668    let name: Ident = input.ident.clone();
669    let (_, ty_generics, _) = input.generics.split_for_impl();
670
671    let messages = match parse_messages(input.clone()) {
672        Ok(messages) => messages,
673        Err(err) => return TokenStream::from(err.to_compile_error()),
674    };
675
676    // Trait definition methods for the handler.
677    let mut handler_trait_methods = Vec::new();
678
679    // The arms of the match used in the message dispatcher.
680    let mut match_arms = Vec::new();
681
682    // Trait implemented by clients.
683    let mut client_trait_methods = Vec::new();
684
685    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
686
687    for message in &messages {
688        match message {
689            Message::Call {
690                variant,
691                reply_port,
692                return_type,
693                log_level,
694            } => {
695                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
696                let variant_name_snake = variant.snake_name();
697                let variant_name_snake_deprecated =
698                    format_ident!("{}_deprecated", variant_name_snake);
699                let enum_name = variant.enum_name();
700                let _variant_qualified_name = variant.qualified_name();
701                let log_level = match (&global_log_level, log_level) {
702                    (_, Some(local)) => local.clone(),
703                    (Some(global), None) => global.clone(),
704                    _ => Ident::new("DEBUG", Span::call_site()),
705                };
706                let _log_level = if reply_port.is_handle {
707                    quote! {
708                        tracing::Level::#log_level
709                    }
710                } else {
711                    quote! {
712                        tracing::Level::TRACE
713                    }
714                };
715                let log_message = quote! {
716                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
717                            "rpc" => "call",
718                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
719                            "message_type" => stringify!(#enum_name),
720                            "variant" => stringify!(#variant_name_snake),
721                        ));
722                };
723
724                handler_trait_methods.push(quote! {
725                    #[doc = "The generated handler method for this enum variant."]
726                    async fn #variant_name_snake(
727                        &mut self,
728                        cx: &hyperactor::Context<Self>,
729                        #(#arg_names: #arg_types),*)
730                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
731                });
732
733                client_trait_methods.push(quote! {
734                    #[doc = "The generated client method for this enum variant."]
735                    async fn #variant_name_snake(
736                        &self,
737                        cx: &impl hyperactor::context::Actor,
738                        #(#arg_names: #arg_types),*)
739                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
740
741                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
742                    async fn #variant_name_snake_deprecated(
743                        &self,
744                        cx: &impl hyperactor::context::Actor,
745                        #(#arg_names: #arg_types),*)
746                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
747                });
748
749                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
750                let constructor = variant.constructor();
751                let result_ident = Ident::new("result", Span::mixed_site());
752                let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
753                match_arms.push(quote! {
754                    #constructor => {
755                        #log_message
756                        // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
757                        // This would require Result<Result<..., in order to handle RPC errors.
758                        #construct_result_future
759                        use hyperactor::Endpoint as _;
760                        #reply_port_arg.post(cx, #result_ident);
761                        Ok(())
762                    }
763                });
764            }
765            Message::OneWay { variant, log_level } => {
766                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
767                let variant_name_snake = variant.snake_name();
768                let variant_name_snake_deprecated =
769                    format_ident!("{}_deprecated", variant_name_snake);
770                let enum_name = variant.enum_name();
771                let log_level = match (&global_log_level, log_level) {
772                    (_, Some(local)) => local.clone(),
773                    (Some(global), None) => global.clone(),
774                    _ => Ident::new("TRACE", Span::call_site()),
775                };
776                let _log_level = quote! {
777                    tracing::Level::#log_level
778                };
779                let log_message = quote! {
780                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
781                            "rpc" => "call",
782                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
783                            "message_type" => stringify!(#enum_name),
784                            "variant" => stringify!(#variant_name_snake),
785                        ));
786                };
787
788                handler_trait_methods.push(quote! {
789                    #[doc = "The generated handler method for this enum variant."]
790                    async fn #variant_name_snake(
791                        &mut self,
792                        cx: &hyperactor::Context<Self>,
793                        #(#arg_names: #arg_types),*)
794                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
795                });
796
797                client_trait_methods.push(quote! {
798                    #[doc = "The generated client method for this enum variant."]
799                    async fn #variant_name_snake(
800                        &self,
801                        cx: &impl hyperactor::context::Actor,
802                        #(#arg_names: #arg_types),*)
803                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
804
805                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
806                    async fn #variant_name_snake_deprecated(
807                        &self,
808                        cx: &impl hyperactor::context::Actor,
809                        #(#arg_names: #arg_types),*)
810                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
811                });
812
813                let constructor = variant.constructor();
814
815                match_arms.push(quote! {
816                    #constructor => {
817                        #log_message
818                        self.#variant_name_snake(cx, #(#arg_names),*).await
819                    },
820                });
821            }
822        }
823    }
824
825    let handler_trait_name = format_ident!("{}Handler", name);
826    let client_trait_name = format_ident!("{}Client", name);
827
828    // We impose additional constraints on the generics in the implementation;
829    // but the trait itself should not impose additional constraints:
830
831    let mut handler_generics = input.generics.clone();
832    for param in handler_generics.type_params_mut() {
833        param.bounds.push(syn::parse_quote!(serde::Serialize));
834        param
835            .bounds
836            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
837        param.bounds.push(syn::parse_quote!(Send));
838        param.bounds.push(syn::parse_quote!(Sync));
839        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
840        param.bounds.push(syn::parse_quote!(typeuri::Named));
841    }
842    let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
843    let (client_impl_generics, _, _) = input.generics.split_for_impl();
844
845    let expanded = quote! {
846        #[doc = "The custom handler trait for this message type."]
847        #[hyperactor::internal_macro_support::async_trait::async_trait]
848        pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync  {
849            #(#handler_trait_methods)*
850
851            #[doc = "Handle the next message."]
852            async fn handle(
853                &mut self,
854                cx: &hyperactor::Context<Self>,
855                message: #name #ty_generics,
856            ) -> hyperactor::internal_macro_support::anyhow::Result<()>  {
857                 // Dispatch based on message type.
858                 match message {
859                     #(#match_arms)*
860                }
861            }
862        }
863
864        #[doc = "The custom client trait for this message type."]
865        #[hyperactor::internal_macro_support::async_trait::async_trait]
866        pub trait #client_trait_name #client_impl_generics: Send + Sync  {
867            #(#client_trait_methods)*
868        }
869    };
870
871    TokenStream::from(expanded)
872}
873
874/// Derives a client implementation on `ActorHandle<Actor>`.
875/// See [`Handler`] documentation for details.
876#[proc_macro_derive(HandleClient, attributes(log_level))]
877pub fn derive_handle_client(input: TokenStream) -> TokenStream {
878    derive_client(input, true)
879}
880
881/// Derives a client implementation on `ActorRef<Actor>`.
882/// See [`Handler`] documentation for details.
883#[proc_macro_derive(RefClient, attributes(log_level))]
884pub fn derive_ref_client(input: TokenStream) -> TokenStream {
885    derive_client(input, false)
886}
887
888fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
889    let input = parse_macro_input!(input as DeriveInput);
890    let name = input.ident.clone();
891
892    let messages = match parse_messages(input.clone()) {
893        Ok(messages) => messages,
894        Err(err) => return TokenStream::from(err.to_compile_error()),
895    };
896
897    // The client implementation methods.
898    let mut impl_methods = Vec::new();
899
900    let send_message = quote! { hyperactor::Endpoint::post(self, cx, message); };
901    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
902
903    for message in &messages {
904        match message {
905            Message::Call {
906                variant,
907                reply_port,
908                return_type,
909                log_level,
910            } => {
911                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
912                let variant_name_snake = variant.snake_name();
913                let variant_name_snake_deprecated =
914                    format_ident!("{}_deprecated", variant_name_snake);
915                let enum_name = variant.enum_name();
916
917                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
918                let constructor = variant.constructor();
919                let log_level = match (&global_log_level, log_level) {
920                    (_, Some(local)) => local.clone(),
921                    (Some(global), None) => global.clone(),
922                    _ => Ident::new("DEBUG", Span::call_site()),
923                };
924                let log_level = if is_handle {
925                    quote! {
926                        tracing::Level::#log_level
927                    }
928                } else {
929                    quote! {
930                        tracing::Level::TRACE
931                    }
932                };
933                let log_message = quote! {
934                        hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
935                            "rpc" => "call",
936                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_addr().to_string(),
937                            "message_type" => stringify!(#enum_name),
938                            "variant" => stringify!(#variant_name_snake),
939                        ));
940
941                };
942                let open_port = reply_port.open_op();
943                let rx_mod = reply_port.rx_modifier();
944                if reply_port.is_handle {
945                    impl_methods.push(quote! {
946                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
947                        async fn #variant_name_snake(
948                            &self,
949                            cx: &impl hyperactor::context::Actor,
950                            #(#arg_names: #arg_types),*)
951                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
952                            let (#reply_port_arg, #rx_mod reply_receiver) =
953                                #open_port::<#return_type>(cx);
954                            let message = #constructor;
955                            #log_message;
956                            #send_message;
957                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
958                        }
959
960                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
961                        async fn #variant_name_snake_deprecated(
962                            &self,
963                            cx: &impl hyperactor::context::Actor,
964                            #(#arg_names: #arg_types),*)
965                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
966                            let (#reply_port_arg, #rx_mod reply_receiver) =
967                                #open_port::<#return_type>(cx);
968                            let message = #constructor;
969                            #log_message;
970                            #send_message;
971                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
972                        }
973                    });
974                } else {
975                    impl_methods.push(quote! {
976                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
977                        async fn #variant_name_snake(
978                            &self,
979                            cx: &impl hyperactor::context::Actor,
980                            #(#arg_names: #arg_types),*)
981                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
982                            let (#reply_port_arg, #rx_mod reply_receiver) =
983                                #open_port::<#return_type>(cx);
984                            let #reply_port_arg = #reply_port_arg.bind();
985                            let message = #constructor;
986                            #log_message;
987                            #send_message;
988                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
989                        }
990
991                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
992                        async fn #variant_name_snake_deprecated(
993                            &self,
994                            cx: &impl hyperactor::context::Actor,
995                            #(#arg_names: #arg_types),*)
996                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
997                            let (#reply_port_arg, #rx_mod reply_receiver) =
998                                #open_port::<#return_type>(cx);
999                            let #reply_port_arg = #reply_port_arg.bind();
1000                            let message = #constructor;
1001                            #log_message;
1002                            #send_message;
1003                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
1004                        }
1005                    });
1006                }
1007            }
1008            Message::OneWay { variant, log_level } => {
1009                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1010                let variant_name_snake = variant.snake_name();
1011                let variant_name_snake_deprecated =
1012                    format_ident!("{}_deprecated", variant_name_snake);
1013                let enum_name = variant.enum_name();
1014                let constructor = variant.constructor();
1015                let log_level = match (&global_log_level, log_level) {
1016                    (_, Some(local)) => local.clone(),
1017                    (Some(global), None) => global.clone(),
1018                    _ => Ident::new("DEBUG", Span::call_site()),
1019                };
1020                let _log_level = if is_handle {
1021                    quote! {
1022                        tracing::Level::TRACE
1023                    }
1024                } else {
1025                    quote! {
1026                        tracing::Level::#log_level
1027                    }
1028                };
1029                let log_message = quote! {
1030                    hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1031                        "rpc" => "oneway",
1032                        "actor_id" => self.actor_addr().to_string(),
1033                        "message_type" => stringify!(#enum_name),
1034                        "variant" => stringify!(#variant_name_snake),
1035                    ));
1036                };
1037                impl_methods.push(quote! {
1038                    async fn #variant_name_snake(
1039                        &self,
1040                        cx: &impl hyperactor::context::Actor,
1041                        #(#arg_names: #arg_types),*)
1042                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1043                        let message = #constructor;
1044                        #log_message;
1045                        #send_message;
1046                        Ok(())
1047                    }
1048
1049                    async fn #variant_name_snake_deprecated(
1050                        &self,
1051                        cx: &impl hyperactor::context::Actor,
1052                        #(#arg_names: #arg_types),*)
1053                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1054                        let message = #constructor;
1055                        #log_message;
1056                        #send_message;
1057                        Ok(())
1058                    }
1059                });
1060            }
1061        }
1062    }
1063
1064    let trait_name = format_ident!("{}Client", name);
1065
1066    let (_, ty_generics, _) = input.generics.split_for_impl();
1067
1068    // Add a new generic parameter 'A'
1069    let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1070    let mut trait_generics = input.generics.clone();
1071    trait_generics.params.insert(
1072        0,
1073        syn::GenericParam::Type(syn::TypeParam {
1074            ident: actor_ident.clone(),
1075            attrs: vec![],
1076            colon_token: None,
1077            bounds: Punctuated::new(),
1078            eq_token: None,
1079            default: None,
1080        }),
1081    );
1082
1083    for param in trait_generics.type_params_mut() {
1084        if param.ident == actor_ident {
1085            continue;
1086        }
1087        param.bounds.push(syn::parse_quote!(serde::Serialize));
1088        param
1089            .bounds
1090            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1091        param.bounds.push(syn::parse_quote!(Send));
1092        param.bounds.push(syn::parse_quote!(Sync));
1093        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1094        param.bounds.push(syn::parse_quote!(typeuri::Named));
1095    }
1096
1097    let (impl_generics, _, _) = trait_generics.split_for_impl();
1098
1099    let expanded = if is_handle {
1100        quote! {
1101            #[hyperactor::internal_macro_support::async_trait::async_trait]
1102            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1103              where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1104                #(#impl_methods)*
1105            }
1106        }
1107    } else {
1108        quote! {
1109            #[hyperactor::internal_macro_support::async_trait::async_trait]
1110            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1111              where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1112                #(#impl_methods)*
1113            }
1114        }
1115    };
1116
1117    TokenStream::from(expanded)
1118}
1119
1120const HANDLE_ARGUMENT_ERROR: &str = indoc! {r#"
1121`handle` expects the message type that is being handled
1122
1123= help: use `#[handle(MessageType)]`
1124"#};
1125
1126/// Install a [`Handler`] that routes messages of the provided type to this handler trait implementation.
1127#[proc_macro_attribute]
1128pub fn handle(attr: TokenStream, item: TokenStream) -> TokenStream {
1129    let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1130    if attr_args.len() != 1 {
1131        return TokenStream::from(
1132            syn::Error::new_spanned(attr_args, HANDLE_ARGUMENT_ERROR).to_compile_error(),
1133        );
1134    }
1135
1136    let message_type = attr_args.first().unwrap();
1137    let input = parse_macro_input!(item as ItemImpl);
1138
1139    let self_type = match *input.self_ty {
1140        syn::Type::Path(ref type_path) => {
1141            let segment = type_path.path.segments.last().unwrap();
1142            segment.clone() //ident.clone()
1143        }
1144        _ => {
1145            return TokenStream::from(
1146                syn::Error::new_spanned(input.self_ty, "`handle` argument must be a type")
1147                    .to_compile_error(),
1148            );
1149        }
1150    };
1151
1152    let trait_name = match input.trait_ {
1153        Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1154        None => {
1155            return TokenStream::from(
1156                syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1157                    .to_compile_error(),
1158            );
1159        }
1160    };
1161
1162    let expanded = quote! {
1163        #input
1164
1165        #[hyperactor::internal_macro_support::async_trait::async_trait]
1166        impl hyperactor::Handler<#message_type> for #self_type {
1167            async fn handle(
1168                &mut self,
1169                cx: &hyperactor::Context<Self>,
1170                message: #message_type,
1171            ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
1172                <Self as #trait_name>::handle(self, cx, message).await
1173            }
1174        }
1175    };
1176
1177    TokenStream::from(expanded)
1178}
1179
1180/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1181/// We set a default level of INFO while always setting ERROR if the function returns Result::Err giving us
1182/// consistent and high quality structured logs. Because this wraps around tracing::instrument, all parameters
1183/// mentioned in https://fburl.com/9jlkb5q4 should be valid. For functions that don't return a [`Result`] type, use
1184/// [`instrument_infallible`]
1185///
1186/// ```
1187/// #[telemetry::instrument]
1188/// async fn yolo() -> anyhow::Result<i32> {
1189///     Ok(420)
1190/// }
1191/// ```
1192#[proc_macro_attribute]
1193pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1194    let args =
1195        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1196    let input = parse_macro_input!(input as ItemFn);
1197    let output = quote! {
1198        #[hyperactor::internal_macro_support::tracing::instrument(err, skip_all, #args)]
1199        #input
1200    };
1201
1202    TokenStream::from(output)
1203}
1204
1205/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1206/// Because this wraps around tracing::instrument, all parameters mentioned in
1207/// https://fburl.com/9jlkb5q4 should be valid.
1208///
1209/// ```
1210/// #[telemetry::instrument]
1211/// async fn yolo() -> i32 {
1212///     420
1213/// }
1214/// ```
1215#[proc_macro_attribute]
1216pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1217    let args =
1218        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1219    let input = parse_macro_input!(input as ItemFn);
1220
1221    let output = quote! {
1222        #[hyperactor::internal_macro_support::tracing::instrument(skip_all, #args)]
1223        #input
1224    };
1225
1226    TokenStream::from(output)
1227}
1228
1229struct HandlerSpec {
1230    ty: Type,
1231    cast: bool,
1232}
1233
1234impl Parse for HandlerSpec {
1235    fn parse(input: ParseStream) -> syn::Result<Self> {
1236        let ty: Type = input.parse()?;
1237
1238        if input.peek(syn::token::Brace) {
1239            let content;
1240            syn::braced!(content in input);
1241            let key: Ident = content.parse()?;
1242            content.parse::<Token![=]>()?;
1243            let expr: Expr = content.parse()?;
1244
1245            let cast = if key == "cast" {
1246                if let Expr::Lit(ExprLit {
1247                    lit: Lit::Bool(b), ..
1248                }) = expr
1249                {
1250                    b.value
1251                } else {
1252                    return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1253                }
1254            } else {
1255                return Err(syn::Error::new_spanned(
1256                    key,
1257                    "unsupported field (expected `cast`)",
1258                ));
1259            };
1260
1261            Ok(HandlerSpec { ty, cast })
1262        } else if input.is_empty() || input.peek(Token![,]) {
1263            Ok(HandlerSpec { ty, cast: false })
1264        } else {
1265            // Something unexpected follows the type
1266            let unexpected: proc_macro2::TokenTree = input.parse()?;
1267            Err(syn::Error::new_spanned(
1268                unexpected,
1269                "unexpected token after type — expected `{ ... }` or nothing",
1270            ))
1271        }
1272    }
1273}
1274
1275impl HandlerSpec {
1276    fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1277        let mut tys = Vec::new();
1278        for HandlerSpec { ty, cast } in handlers {
1279            if cast {
1280                let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1281                let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1282                tys.push(wrapped_ty);
1283            }
1284            tys.push(ty);
1285        }
1286        tys
1287    }
1288}
1289
1290fn named_impl(data_type_name: &Ident, generics: &syn::Generics) -> proc_macro2::TokenStream {
1291    let generics_with_bounds = generics_with_named_bounds(generics);
1292    let type_params: Vec<_> = generics.type_params().collect();
1293    let has_generics = !type_params.is_empty();
1294
1295    let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1296    let (_, ty_generics, where_clause) = generics.split_for_impl();
1297
1298    let (typename_impl, typehash_impl) = if has_generics {
1299        let placeholders = vec!["{}"; type_params.len()].join(", ");
1300        let placeholders_format_string = format!("<{}>", placeholders);
1301        let format_string = quote! {
1302            concat!(
1303                std::module_path!(),
1304                "::",
1305                stringify!(#data_type_name),
1306                #placeholders_format_string
1307            )
1308        };
1309        let type_param_idents: Vec<_> = type_params.iter().map(|param| &param.ident).collect();
1310        (
1311            quote! {
1312                typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1313            },
1314            quote! {
1315                typeuri::cityhasher::hash(Self::typename())
1316            },
1317        )
1318    } else {
1319        (
1320            quote! {
1321                concat!(std::module_path!(), "::", stringify!(#data_type_name))
1322            },
1323            quote! {
1324                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1325                    typeuri::cityhasher::hash(<#data_type_name as typeuri::Named>::typename())
1326                });
1327                *TYPEHASH
1328            },
1329        )
1330    };
1331
1332    quote! {
1333        impl #impl_generics_with_bounds typeuri::Named for #data_type_name #ty_generics #where_clause {
1334            fn typename() -> &'static str {
1335                #typename_impl
1336            }
1337
1338            fn typehash() -> u64 {
1339                #typehash_impl
1340            }
1341        }
1342    }
1343}
1344
1345fn generics_with_named_bounds(generics: &syn::Generics) -> syn::Generics {
1346    let mut generics = generics.clone();
1347    for param in generics.type_params_mut() {
1348        param.bounds.push(syn::parse_quote!(typeuri::Named));
1349    }
1350    generics
1351}
1352
1353fn generics_with_predicates(
1354    generics: &syn::Generics,
1355    predicates: impl IntoIterator<Item = WherePredicate>,
1356) -> syn::Generics {
1357    let mut generics = generics.clone();
1358    generics.make_where_clause().predicates.extend(predicates);
1359    generics
1360}
1361
1362/// Attribute Struct for [`fn export`] macro.
1363struct ExportAttr {
1364    handlers: Vec<HandlerSpec>,
1365}
1366
1367impl Parse for ExportAttr {
1368    fn parse(input: ParseStream) -> syn::Result<Self> {
1369        if input.is_empty() {
1370            return Ok(Self {
1371                handlers: Vec::new(),
1372            });
1373        }
1374
1375        let compatibility_form = {
1376            let fork = input.fork();
1377            fork.parse::<Ident>().is_ok() && fork.parse::<Token![=]>().is_ok()
1378        };
1379
1380        if !compatibility_form {
1381            let handlers = input
1382                .parse_terminated(HandlerSpec::parse, Token![,])?
1383                .into_iter()
1384                .collect();
1385            return Ok(Self { handlers });
1386        }
1387
1388        let mut handlers: Vec<HandlerSpec> = vec![];
1389
1390        while !input.is_empty() {
1391            let key: Ident = input.parse()?;
1392            input.parse::<Token![=]>()?;
1393
1394            if key == "spawn" {
1395                let expr: Expr = input.parse()?;
1396                return Err(syn::Error::new_spanned(
1397                    expr,
1398                    "`spawn = true` is no longer supported; use `#[spawnable]` on concrete actor declarations or `hyperactor::register_spawnable!(ConcreteType)` for generic instantiations",
1399                ));
1400            } else if key == "handlers" {
1401                let content;
1402                bracketed!(content in input);
1403                let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1404                handlers = raw_handlers.into_iter().collect();
1405            } else {
1406                return Err(syn::Error::new_spanned(
1407                    key,
1408                    "unexpected key in `#[export(...)]`. Use direct handler lists, or the compatibility key `handlers`",
1409                ));
1410            }
1411
1412            // optional trailing comma
1413            let _ = input.parse::<Token![,]>();
1414        }
1415
1416        Ok(ExportAttr { handlers })
1417    }
1418}
1419
1420/// Exports handlers for this actor. The set of exported handlers
1421/// determine the messages that may be sent to remote references of
1422/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
1423/// [`hyperactor::RemoteMessage`] may be exported.
1424///
1425/// # Example
1426///
1427/// In the following example, `MyActor` exports handlers for two message types,
1428/// `MyMessage` and `MyOtherMessage`. Consequently, `ActorRef`s of the actor's
1429/// type may dispatch messages of these types.
1430///
1431/// ```ignore
1432/// #[export(MyMessage, MyOtherMessage)]
1433/// struct MyActor {}
1434/// ```
1435#[proc_macro_attribute]
1436pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1437    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1438    let data_type_name = &input.ident;
1439    let (_, ty_generics, _) = input.generics.split_for_impl();
1440    let named_generics = generics_with_named_bounds(&input.generics);
1441    let (named_impl_generics, named_ty_generics, named_where_clause) =
1442        named_generics.split_for_impl();
1443
1444    let ExportAttr { handlers } = parse_macro_input!(attr as ExportAttr);
1445
1446    let mut handles = Vec::new();
1447    let mut bindings = Vec::new();
1448    let mut bind_predicates = Vec::new();
1449    let actor_ty: Type = syn::parse_quote!(#data_type_name #ty_generics);
1450
1451    for HandlerSpec { ty, cast } in &handlers {
1452        let message_generics = generics_with_predicates(
1453            &named_generics,
1454            [syn::parse_quote!(#ty: hyperactor::RemoteMessage)],
1455        );
1456        let (message_impl_generics, message_ty_generics, message_where_clause) =
1457            message_generics.split_for_impl();
1458        handles.push(quote! {
1459            impl #message_impl_generics hyperactor::actor::RemoteHandles<#ty>
1460                for #data_type_name #message_ty_generics #message_where_clause {}
1461            impl #message_impl_generics hyperactor::remote::Accepts<#ty>
1462                for #data_type_name #message_ty_generics #message_where_clause {}
1463        });
1464        bindings.push(quote! {
1465            ports.bind::<#ty>();
1466        });
1467        bind_predicates.push(syn::parse_quote!(#ty: hyperactor::RemoteMessage));
1468        bind_predicates.push(syn::parse_quote!(#actor_ty: hyperactor::Handler<#ty>));
1469
1470        if *cast {
1471            let indexed_ty: Type =
1472                syn::parse_quote!(hyperactor::message::IndexedErasedUnbound<#ty>);
1473            let indexed_generics = generics_with_predicates(
1474                &named_generics,
1475                [
1476                    syn::parse_quote!(#ty: hyperactor::message::Castable),
1477                    syn::parse_quote!(#indexed_ty: hyperactor::RemoteMessage),
1478                ],
1479            );
1480            let (indexed_impl_generics, indexed_ty_generics, indexed_where_clause) =
1481                indexed_generics.split_for_impl();
1482            handles.push(quote! {
1483                impl #indexed_impl_generics hyperactor::actor::RemoteHandles<#indexed_ty>
1484                    for #data_type_name #indexed_ty_generics #indexed_where_clause {}
1485                impl #indexed_impl_generics hyperactor::remote::Accepts<#indexed_ty>
1486                    for #data_type_name #indexed_ty_generics #indexed_where_clause {}
1487            });
1488            bindings.push(quote! {
1489                ports.bind::<#indexed_ty>();
1490            });
1491            bind_predicates.push(syn::parse_quote!(#ty: hyperactor::message::Castable));
1492            bind_predicates.push(syn::parse_quote!(#indexed_ty: hyperactor::RemoteMessage));
1493        }
1494    }
1495
1496    let bind_generics = generics_with_predicates(&named_generics, bind_predicates);
1497    let (bind_impl_generics, bind_ty_generics, bind_where_clause) = bind_generics.split_for_impl();
1498    let named_impl = named_impl(data_type_name, &input.generics);
1499
1500    let expanded = quote! {
1501        #input
1502
1503        impl #named_impl_generics hyperactor::actor::Referable for #data_type_name #named_ty_generics #named_where_clause {}
1504
1505        #(#handles)*
1506
1507        // Always export the `IntrospectMessage` type.
1508        impl #named_impl_generics hyperactor::actor::RemoteHandles<hyperactor::introspect::IntrospectMessage> for #data_type_name #named_ty_generics #named_where_clause {}
1509        impl #named_impl_generics hyperactor::remote::Accepts<hyperactor::introspect::IntrospectMessage> for #data_type_name #named_ty_generics #named_where_clause {}
1510
1511        impl #bind_impl_generics hyperactor::actor::Binds<#data_type_name #bind_ty_generics> for #data_type_name #bind_ty_generics #bind_where_clause {
1512            fn bind(ports: &hyperactor::proc::HandlerPorts<Self>) {
1513                #(#bindings)*
1514            }
1515        }
1516
1517        #named_impl
1518    };
1519
1520    TokenStream::from(expanded)
1521}
1522
1523/// Marks a concrete actor declaration as remotely spawnable.
1524///
1525/// When combined with `#[export(...)]`, place `#[spawnable]` below `#[export]`.
1526#[proc_macro_attribute]
1527pub fn spawnable(attr: TokenStream, item: TokenStream) -> TokenStream {
1528    if !attr.is_empty() {
1529        return syn::Error::new(Span::call_site(), "`#[spawnable]` does not take arguments")
1530            .to_compile_error()
1531            .into();
1532    }
1533
1534    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1535    if !matches!(input.data, Data::Struct(_)) {
1536        return syn::Error::new(
1537            input.span(),
1538            "`#[spawnable]` only supports struct actor declarations",
1539        )
1540        .to_compile_error()
1541        .into();
1542    }
1543
1544    if !input.generics.params.is_empty() {
1545        return syn::Error::new(
1546            input.generics.span(),
1547            "generic actor families cannot use `#[spawnable]`; use `hyperactor::register_spawnable!(ConcreteType)` instead",
1548        )
1549        .to_compile_error()
1550        .into();
1551    }
1552
1553    let data_type_name = &input.ident;
1554    quote! {
1555        #input
1556        hyperactor::register_spawnable!(#data_type_name);
1557    }
1558    .into()
1559}
1560
1561/// Represents the full input to [`fn behavior`].
1562struct BehaviorInput {
1563    behavior: Ident,
1564    generics: syn::Generics,
1565    handlers: Vec<HandlerSpec>,
1566}
1567
1568impl syn::parse::Parse for BehaviorInput {
1569    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1570        let behavior: Ident = input.parse()?;
1571        let generics: syn::Generics = input.parse()?;
1572        let _: Token![,] = input.parse()?;
1573        let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1574        let handlers = raw_handlers.into_iter().collect();
1575        Ok(BehaviorInput {
1576            behavior,
1577            generics,
1578            handlers,
1579        })
1580    }
1581}
1582
1583/// Create a [`Referable`] definition, handling a specific set of message types.
1584/// Behaviors are used to create an [`ActorRef`] without having to depend on the
1585/// actor's implementation. If the message type need to be cast, add `castable`
1586/// flag to those types. e.g. the following example creates a behavior with 5
1587/// message types, and 4 of which need to be cast.
1588///
1589/// ```
1590/// hyperactor::behavior!(
1591///     TestActorBehavior,
1592///     TestMessage { castable = true },
1593///     () {castable = true },
1594///     MyGeneric<()> {castable = true },
1595///     u64,
1596/// );
1597/// ```
1598///
1599/// This macro also supports generic behaviors:
1600/// ```
1601/// hyperactor::behavior!(
1602///     TestBehavior<T>,
1603///     Message<T> { castable = true },
1604///     u64,
1605/// );
1606/// ```
1607#[proc_macro]
1608pub fn behavior(input: TokenStream) -> TokenStream {
1609    let BehaviorInput {
1610        behavior,
1611        generics,
1612        handlers,
1613    } = parse_macro_input!(input as BehaviorInput);
1614    let tys = HandlerSpec::add_indexed(handlers);
1615
1616    // Add bounds to generics for Named, Serialize, Deserialize
1617    let mut bounded_generics = generics.clone();
1618    for param in bounded_generics.type_params_mut() {
1619        param.bounds.push(syn::parse_quote!(typeuri::Named));
1620        param.bounds.push(syn::parse_quote!(serde::Serialize));
1621        param.bounds.push(syn::parse_quote!(std::marker::Send));
1622        param.bounds.push(syn::parse_quote!(std::marker::Sync));
1623        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1624        // Note: lifetime parameters are not *actually* hygienic.
1625        // https://github.com/rust-lang/rust/issues/54727
1626        let lifetime =
1627            syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1628        param
1629            .bounds
1630            .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1631    }
1632
1633    // Split the generics for use in different contexts
1634    let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1635
1636    // Create a combined generics for the Binds impl that includes both A and the behavior's generics
1637    let mut binds_generics = bounded_generics.clone();
1638    binds_generics.params.insert(
1639        0,
1640        syn::GenericParam::Type(syn::TypeParam {
1641            attrs: vec![],
1642            ident: Ident::new("A", proc_macro2::Span::call_site()),
1643            colon_token: None,
1644            bounds: Punctuated::new(),
1645            eq_token: None,
1646            default: None,
1647        }),
1648    );
1649    let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1650
1651    // Determine typename and typehash implementation based on whether we have generics
1652    let type_params: Vec<_> = bounded_generics.type_params().collect();
1653    let has_generics = !type_params.is_empty();
1654
1655    let (typename_impl, typehash_impl) = if has_generics {
1656        // Create format string with placeholders for each generic parameter
1657        let placeholders = vec!["{}"; type_params.len()].join(", ");
1658        let placeholders_format_string = format!("<{}>", placeholders);
1659        let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1660
1661        let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1662        (
1663            quote! {
1664                typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1665            },
1666            quote! {
1667                typeuri::cityhasher::hash(Self::typename())
1668            },
1669        )
1670    } else {
1671        (
1672            quote! {
1673                concat!(std::module_path!(), "::", stringify!(#behavior))
1674            },
1675            quote! {
1676                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1677                    typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1678                });
1679                *TYPEHASH
1680            },
1681        )
1682    };
1683
1684    let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1685
1686    let expanded = quote! {
1687        #[doc = "The generated behavior struct."]
1688        #[derive(Debug, serde::Serialize, serde::Deserialize)]
1689        pub struct #behavior #impl_generics #where_clause {
1690            _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1691        }
1692
1693        impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1694            fn typename() -> &'static str {
1695                #typename_impl
1696            }
1697
1698            fn typehash() -> u64 {
1699                #typehash_impl
1700            }
1701        }
1702
1703        impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1704
1705        impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1706        where
1707            A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1708            #where_clause
1709        {
1710            fn bind(ports: &hyperactor::proc::HandlerPorts<A>) {
1711                #(
1712                    ports.bind::<#tys>();
1713                )*
1714            }
1715        }
1716
1717        #(
1718            impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1719            impl #impl_generics hyperactor::remote::Accepts<#tys> for #behavior #ty_generics #where_clause {}
1720        )*
1721    };
1722
1723    TokenStream::from(expanded)
1724}
1725
1726fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1727    let mut is_included = false;
1728    for attr in &field.attrs {
1729        if attr.path().is_ident("binding") {
1730            // parse #[binding(include)] and look for exactly "include"
1731            attr.parse_nested_meta(|meta| {
1732                if meta.path.is_ident("include") {
1733                    is_included = true;
1734                    Ok(())
1735                } else {
1736                    let path = meta.path.to_token_stream().to_string().replace(' ', "");
1737                    Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1738                }
1739            })?
1740        }
1741    }
1742    Ok(is_included)
1743}
1744
1745/// The field accessor in struct or enum variant.
1746/// e.g.:
1747///   struct NamedStruct { foo: u32 } => FieldAccessor::Named(Ident::new("foo", Span::call_site()))
1748///   struct UnnamedStruct(u32) => FieldAccessor::Unnamed(Index::from(0))
1749enum FieldAccessor {
1750    Named(Ident),
1751    Unnamed(Index),
1752}
1753
1754/// Result of parsing a field in a struct, or a enum variant.
1755struct ParsedField {
1756    accessor: FieldAccessor,
1757    ty: Type,
1758    included: bool,
1759}
1760
1761impl From<&ParsedField> for (Ident, Type) {
1762    fn from(field: &ParsedField) -> Self {
1763        let field_ident = match &field.accessor {
1764            FieldAccessor::Named(ident) => ident.clone(),
1765            FieldAccessor::Unnamed(i) => {
1766                Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1767            }
1768        };
1769        (field_ident, field.ty.clone())
1770    }
1771}
1772
1773fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1774    match fields {
1775        Fields::Named(named) => named
1776            .named
1777            .iter()
1778            .map(|f| {
1779                let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1780                Ok(ParsedField {
1781                    accessor,
1782                    ty: f.ty.clone(),
1783                    included: include_in_bind_unbind(f)?,
1784                })
1785            })
1786            .collect(),
1787        Fields::Unnamed(unnamed) => unnamed
1788            .unnamed
1789            .iter()
1790            .enumerate()
1791            .map(|(i, f)| {
1792                let accessor = FieldAccessor::Unnamed(Index::from(i));
1793                Ok(ParsedField {
1794                    accessor,
1795                    ty: f.ty.clone(),
1796                    included: include_in_bind_unbind(f)?,
1797                })
1798            })
1799            .collect(),
1800        Fields::Unit => Ok(Vec::new()),
1801    }
1802}
1803
1804fn gen_struct_items<F>(
1805    fields: &Fields,
1806    make_item: F,
1807    is_mutable: bool,
1808) -> syn::Result<Vec<proc_macro2::TokenStream>>
1809where
1810    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1811{
1812    let borrow = if is_mutable {
1813        quote! { &mut }
1814    } else {
1815        quote! { & }
1816    };
1817    let items: Vec<_> = collect_all_fields(fields)?
1818        .into_iter()
1819        .filter(|f| f.included)
1820        .map(
1821            |ParsedField {
1822                 accessor,
1823                 ty,
1824                 included,
1825             }| {
1826                assert!(included);
1827                let field_accessor = match accessor {
1828                    FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1829                    FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1830                };
1831                make_item(field_accessor, ty)
1832            },
1833        )
1834        .collect();
1835    Ok(items)
1836}
1837
1838/// Generate the field accessor for a enum variant in pattern matching. e.g.
1839/// the <GENERATED> parts in the following example:
1840///
1841///   match my_enum {
1842///     // e.g. MyEnum::Tuple(_, f1, f2, _)
1843///     MyEnum::Tuple(<GENERATED>) => { ... }
1844///     // e.g. MyEnum::Struct { field0: _, field1 }
1845///     MyEnum::Struct(<GENERATED>) => { ... }
1846///   }
1847fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1848    all_fields
1849        .iter()
1850        .map(
1851            |ParsedField {
1852                 accessor,
1853                 ty: _,
1854                 included,
1855             }| {
1856                match accessor {
1857                    FieldAccessor::Named(ident) => {
1858                        if *included {
1859                            quote! { #ident }
1860                        } else {
1861                            quote! { #ident: _ }
1862                        }
1863                    }
1864                    FieldAccessor::Unnamed(i) => {
1865                        if *included {
1866                            let ident = Ident::new(
1867                                &format!("f{}", i.index),
1868                                proc_macro2::Span::call_site(),
1869                            );
1870                            quote! { #ident }
1871                        } else {
1872                            quote! { _ }
1873                        }
1874                    }
1875                }
1876            },
1877        )
1878        .collect()
1879}
1880
1881/// Generate all the parts for enum variants. e.g. the <GENERATED> part in the
1882/// following example:
1883///
1884///   match my_enum {
1885///      <GENERATED>
1886///   }
1887fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1888where
1889    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1890{
1891    data.variants
1892        .iter()
1893        .map(|variant| {
1894            let name = &variant.ident;
1895            let all_fields = collect_all_fields(&variant.fields)?;
1896            let field_accessors = gen_enum_field_accessors(&all_fields);
1897            let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1898            let items = included_fields
1899                .iter()
1900                .map(|f| {
1901                    let (accessor, ty) = <(Ident, Type)>::from(*f);
1902                    make_item(quote! { #accessor }, ty)
1903                })
1904                .collect::<Vec<_>>();
1905
1906            Ok(match &variant.fields {
1907                Fields::Named(_) => {
1908                    quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1909                }
1910                Fields::Unnamed(_) => {
1911                    quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1912                }
1913                Fields::Unit => quote! { Self::#name => { #(#items)* } },
1914            })
1915        })
1916        .collect()
1917}
1918
1919/// Derive a custom implementation of [`hyperactor::message::Bind`] trait for
1920/// a struct or enum. This macro is normally used in tandem with [`fn derive_unbind`]
1921/// to make the applied struct or enum castable.
1922///
1923/// Specifically, the derived implementation iterates through fields annotated
1924/// with `#[binding(include)]` based on their order of declaration in the struct
1925/// or enum. These fields' types must implement `Bind` trait as well. During the
1926/// iteration, parameters from `bindings` are bound to these fields.
1927///
1928/// # Example
1929///
1930/// This macro supports named and unamed structs and enums. Below are examples
1931/// of the supported types:
1932///
1933/// ```
1934/// #[derive(Bind, Unbind)]
1935/// struct MyNamedStruct {
1936///     field0: u64,
1937///     field1: MyReply,
1938///     #[binding(include)]
1939/// nnnn     field2: PortRef<MyReply>,
1940///     field3: bool,
1941///     #[binding(include)]
1942///     field4: hyperactor::PortRef<u64>,
1943/// }
1944///
1945/// #[derive(Bind, Unbind)]
1946/// struct MyUnamedStruct(
1947///     u64,
1948///     MyReply,
1949///     #[binding(include)] hyperactor::PortRef<MyReply>,
1950///     bool,
1951///     #[binding(include)] PortRef<u64>,
1952/// );
1953///
1954/// #[derive(Bind, Unbind)]
1955/// enum MyEnum {
1956///     Unit,
1957///     NoopTuple(u64, bool),
1958///     NoopStruct {
1959///         field0: u64,
1960///         field1: bool,
1961///     },
1962///     Tuple(
1963///         u64,
1964///         MyReply,
1965///         #[binding(include)] PortRef<MyReply>,
1966///         bool,
1967///         #[binding(include)] hyperactor::PortRef<u64>,
1968///     ),
1969///     Struct {
1970///         field0: u64,
1971///         field1: MyReply,
1972///         #[binding(include)]
1973///         field2: PortRef<MyReply>,
1974///         field3: bool,
1975///         #[binding(include)]
1976///         field4: hyperactor::PortRef<u64>,
1977///     },
1978/// }
1979/// ```
1980///
1981/// The following shows what derived `Bind`` and `Unbind`` implementations for
1982/// `MyNamedStruct` will look like. The implementations of other types are
1983/// similar, and thus are not shown here.
1984/// ```ignore
1985/// impl Bind for MyNamedStruct {
1986/// fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
1987///     Bind::bind(&mut self.field2, bindings)?;
1988///     Bind::bind(&mut self.field4, bindings)?;
1989///     Ok(())
1990/// }
1991///
1992/// impl Unbind for MyNamedStruct {
1993///     fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
1994///         Unbind::unbind(&self.field2, bindings)?;
1995///         Unbind::unbind(&self.field4, bindings)?;
1996///         Ok(())
1997///     }
1998/// }
1999/// ```
2000#[proc_macro_derive(Bind, attributes(binding))]
2001pub fn derive_bind(input: TokenStream) -> TokenStream {
2002    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2003        quote! {
2004            hyperactor::message::Bind::bind(#field_accessor, bindings)?;
2005        }
2006    }
2007
2008    let input = parse_macro_input!(input as DeriveInput);
2009    let name = &input.ident;
2010    let inner = match &input.data {
2011        Data::Struct(DataStruct { fields, .. }) => {
2012            match gen_struct_items(fields, make_item, true) {
2013                Ok(collects) => {
2014                    quote! { #(#collects)* }
2015                }
2016                Err(e) => {
2017                    return TokenStream::from(e.to_compile_error());
2018                }
2019            }
2020        }
2021        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2022            Ok(arms) => {
2023                quote! { match self { #(#arms),* } }
2024            }
2025            Err(e) => {
2026                return TokenStream::from(e.to_compile_error());
2027            }
2028        },
2029        _ => panic!("Bind can only be derived for structs and enums"),
2030    };
2031    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2032    let expand = quote! {
2033        #[automatically_derived]
2034        impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
2035            fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2036                #inner
2037                Ok(())
2038            }
2039        }
2040    };
2041    TokenStream::from(expand)
2042}
2043
2044/// Derive a custom implementation of [`hyperactor::message::Unbind`] trait for
2045/// a struct or enum. This macro is normally used in tandem with [`fn derive_bind`]
2046/// to make the applied struct or enum castable.
2047///
2048/// Specifically, the derived implementation iterates through fields annoated
2049/// with `#[binding(include)]` based on their order of declaration in the struct
2050/// or enum. These fields' types must implement `Unbind` trait as well. During
2051/// the iteration, parameters from these fields are extracted and stored in
2052/// `bindings`.
2053///
2054/// # Example
2055///
2056/// See [`fn derive_bind`]'s documentation for examples.
2057#[proc_macro_derive(Unbind, attributes(binding))]
2058pub fn derive_unbind(input: TokenStream) -> TokenStream {
2059    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2060        quote! {
2061            hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
2062        }
2063    }
2064
2065    let input = parse_macro_input!(input as DeriveInput);
2066    let name = &input.ident;
2067    let inner = match &input.data {
2068        Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
2069        {
2070            Ok(collects) => {
2071                quote! { #(#collects)* }
2072            }
2073            Err(e) => {
2074                return TokenStream::from(e.to_compile_error());
2075            }
2076        },
2077        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2078            Ok(arms) => {
2079                quote! { match self { #(#arms),* } }
2080            }
2081            Err(e) => {
2082                return TokenStream::from(e.to_compile_error());
2083            }
2084        },
2085        _ => panic!("Unbind can only be derived for structs and enums"),
2086    };
2087    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2088    let expand = quote! {
2089        #[automatically_derived]
2090        impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
2091            fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2092                #inner
2093                Ok(())
2094            }
2095        }
2096    };
2097    TokenStream::from(expand)
2098}
2099
2100// Helper function for common parsing and validation
2101fn parse_observe_function(
2102    attr: TokenStream,
2103    item: TokenStream,
2104) -> syn::Result<(ItemFn, String, String)> {
2105    let input = syn::parse::<ItemFn>(item)?;
2106
2107    if input.sig.asyncness.is_none() {
2108        return Err(syn::Error::new(
2109            input.sig.span(),
2110            "observe macros can only be applied to async functions",
2111        ));
2112    }
2113
2114    let fn_name_str = input.sig.ident.to_string();
2115    let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
2116
2117    Ok((input, fn_name_str, module_name_str))
2118}
2119
2120// Helper function for creating telemetry identifiers and setup code
2121fn create_telemetry_setup(
2122    module_name_str: &str,
2123    fn_name_str: &str,
2124    include_error: bool,
2125) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2126    let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2127    let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2128
2129    let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2130
2131    let error_ident = if include_error {
2132        Some(Ident::new(
2133            "error",
2134            Span::from(proc_macro::Span::def_site()),
2135        ))
2136    } else {
2137        None
2138    };
2139
2140    let error_declaration = if let Some(ref error_ident) = error_ident {
2141        quote! {
2142            hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2143        }
2144    } else {
2145        quote! {}
2146    };
2147
2148    let setup_code = quote! {
2149        use hyperactor_telemetry;
2150        hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2151        hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2152        #error_declaration
2153    };
2154
2155    (latency_ident, success_ident, error_ident, setup_code)
2156}
2157
2158/// A procedural macro that automatically injects telemetry code into async functions
2159/// that return a Result type.
2160///
2161/// This macro wraps async functions and adds instrumentation to measure:
2162/// 1. Latency - how long the function takes to execute
2163/// 2. Error counter - function error count
2164/// 3. Success counter - function completion count
2165///
2166/// # Example
2167///
2168/// ```rust
2169/// use hyperactor_actor::observe_result;
2170///
2171/// #[observe_result("my_module")]
2172/// async fn process_request(user_id: &str) -> Result<String, Error> {
2173///     // Function implementation
2174///     // Telemetry will be automatically collected
2175/// }
2176/// ```
2177#[proc_macro_attribute]
2178pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2179    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2180        Ok(parsed) => parsed,
2181        Err(err) => return err.to_compile_error().into(),
2182    };
2183
2184    let fn_name = &input.sig.ident;
2185    let vis = &input.vis;
2186    let args = &input.sig.inputs;
2187    let return_type = &input.sig.output;
2188    let body = &input.block;
2189    let attrs = &input.attrs;
2190    let generics = &input.sig.generics;
2191
2192    let (latency_ident, success_ident, error_ident, telemetry_setup) =
2193        create_telemetry_setup(&module_name_str, &fn_name_str, true);
2194    let error_ident = error_ident.unwrap();
2195
2196    let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2197
2198    // Generate the instrumented function
2199    let expanded = quote! {
2200        #(#attrs)*
2201        #vis async fn #fn_name #generics(#args) #return_type {
2202            #telemetry_setup
2203
2204            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2205            let _timer = #latency_ident.start(kv_pairs);
2206
2207            let #result_ident = async #body.await;
2208
2209            match &#result_ident {
2210                Ok(_) => {
2211                    #success_ident.add(
2212                        1,
2213                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2214                    );
2215                }
2216                Err(_) => {
2217                    #error_ident.add(
2218                        1,
2219                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2220                    );
2221                }
2222            }
2223
2224            #result_ident
2225        }
2226    };
2227
2228    expanded.into()
2229}
2230
2231/// A procedural macro that automatically injects telemetry code into async functions
2232/// that do not return a Result type.
2233///
2234/// This macro wraps async functions and adds instrumentation to measure:
2235/// 1. Latency - how long the function takes to execute
2236/// 2. Success counter - function completion count
2237///
2238/// # Example
2239///
2240/// ```rust
2241/// use hyperactor_actor::observe_async;
2242///
2243/// #[observe_async("my_module")]
2244/// async fn process_data(data: &str) -> String {
2245///     // Function implementation
2246///     // Telemetry will be automatically collected
2247/// }
2248/// ```
2249#[proc_macro_attribute]
2250pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2251    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2252        Ok(parsed) => parsed,
2253        Err(err) => return err.to_compile_error().into(),
2254    };
2255
2256    let fn_name = &input.sig.ident;
2257    let vis = &input.vis;
2258    let args = &input.sig.inputs;
2259    let return_type = &input.sig.output;
2260    let body = &input.block;
2261    let attrs = &input.attrs;
2262    let generics = &input.sig.generics;
2263
2264    let (latency_ident, success_ident, _, telemetry_setup) =
2265        create_telemetry_setup(&module_name_str, &fn_name_str, false);
2266
2267    let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2268
2269    // Generate the instrumented function
2270    let expanded = quote! {
2271        #(#attrs)*
2272        #vis async fn #fn_name #generics(#args) #return_type {
2273            #telemetry_setup
2274
2275            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2276            let _timer = #latency_ident.start(kv_pairs);
2277
2278            let #return_ident = async #body.await;
2279
2280            #success_ident.add(
2281                1,
2282                hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2283            );
2284            #return_ident
2285        }
2286    };
2287
2288    expanded.into()
2289}
2290
2291fn validate_label(s: &str) -> Result<(), String> {
2292    if s.is_empty() {
2293        return Err("label must not be empty".to_string());
2294    }
2295    if s.len() > 63 {
2296        return Err("label exceeds 63 characters".to_string());
2297    }
2298    let first = s.as_bytes()[0];
2299    if !first.is_ascii_lowercase() {
2300        return Err("label must start with a lowercase letter".to_string());
2301    }
2302    let last = s.as_bytes()[s.len() - 1];
2303    if !last.is_ascii_lowercase() && !last.is_ascii_digit() {
2304        return Err("label must end with a lowercase letter or digit".to_string());
2305    }
2306    for ch in s.chars() {
2307        if !ch.is_ascii_lowercase() && !ch.is_ascii_digit() && ch != '-' {
2308            return Err(format!("label contains invalid character '{ch}'"));
2309        }
2310    }
2311    Ok(())
2312}
2313
2314fn validate_hex_uid(s: &str) -> Result<u64, String> {
2315    if s.is_empty() || s.len() > 16 {
2316        return Err(format!("hex uid must be 1-16 hex characters, got '{s}'"));
2317    }
2318    for ch in s.chars() {
2319        if !ch.is_ascii_hexdigit() {
2320            return Err(format!("hex uid contains invalid character '{ch}'"));
2321        }
2322    }
2323    u64::from_str_radix(s, 16).map_err(|e| format!("invalid hex uid '{s}': {e}"))
2324}
2325
2326/// Compile-time validated [`hyperactor::id::Uid`] construction.
2327///
2328/// Accepts two forms:
2329/// - `uid!(_my-singleton)` — a singleton Uid
2330/// - `uid!(d5d54d7201103869)` — an instance Uid
2331#[proc_macro]
2332pub fn uid(input: TokenStream) -> TokenStream {
2333    let input2: proc_macro2::TokenStream = input.into();
2334    let combined: String = input2.into_iter().map(|tt| tt.to_string()).collect();
2335
2336    if combined.is_empty() {
2337        return TokenStream::from(quote! { compile_error!("uid! macro requires an argument") });
2338    }
2339
2340    // Singleton: starts with '_'
2341    if let Some(rest) = combined.strip_prefix('_') {
2342        return match validate_label(rest) {
2343            Ok(()) => TokenStream::from(quote! {
2344                hyperactor::id::Uid::Singleton(
2345                    hyperactor::id::Label::new(#rest).unwrap()
2346                )
2347            }),
2348            Err(e) => {
2349                let msg = format!("invalid singleton uid: {e}");
2350                TokenStream::from(quote! { compile_error!(#msg) })
2351            }
2352        };
2353    }
2354
2355    // Instance: bare hex
2356    match validate_hex_uid(&combined) {
2357        Ok(uid_val) => TokenStream::from(quote! {
2358            hyperactor::id::Uid::Instance(#uid_val, None)
2359        }),
2360        Err(e) => {
2361            let msg = format!("invalid uid: {e}");
2362            TokenStream::from(quote! { compile_error!(#msg) })
2363        }
2364    }
2365}