hyperactor_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines macros used by the [`hyperactor`] crate.
10
11#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::bracketed;
41use syn::parse::Parse;
42use syn::parse::ParseStream;
43use syn::parse_macro_input;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46
47const REPLY_VARIANT_ERROR: &str = indoc! {r#"
48`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
49
50= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
52"#};
53
54const REPLY_USAGE_ERROR: &str = indoc! {r#"
55`call` message expects at most one `reply` argument
56
57= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
59"#};
60
61enum FieldFlag {
62    None,
63    Reply,
64}
65
66/// Represents a variant of an enum.
67#[allow(dead_code)]
68enum Variant {
69    /// A named variant (i.e., `MyVariant { .. }`).
70    Named {
71        enum_name: Ident,
72        name: Ident,
73        field_names: Vec<Ident>,
74        field_types: Vec<Type>,
75        field_flags: Vec<FieldFlag>,
76        is_struct: bool,
77        generics: syn::Generics,
78    },
79    /// An anonymous variant (i.e., `MyVariant(..)`).
80    Anon {
81        enum_name: Ident,
82        name: Ident,
83        field_types: Vec<Type>,
84        field_flags: Vec<FieldFlag>,
85        is_struct: bool,
86        generics: syn::Generics,
87    },
88}
89
90impl Variant {
91    /// The number of fields in the variant.
92    fn len(&self) -> usize {
93        self.field_types().len()
94    }
95
96    /// Returns whether this variant was defined as a struct.
97    fn is_struct(&self) -> bool {
98        match self {
99            Variant::Named { is_struct, .. } => *is_struct,
100            Variant::Anon { is_struct, .. } => *is_struct,
101        }
102    }
103
104    /// The name of the enum containing the variant.
105    fn enum_name(&self) -> &Ident {
106        match self {
107            Variant::Named { enum_name, .. } => enum_name,
108            Variant::Anon { enum_name, .. } => enum_name,
109        }
110    }
111
112    /// The name of the variant itself.
113    fn name(&self) -> &Ident {
114        match self {
115            Variant::Named { name, .. } => name,
116            Variant::Anon { name, .. } => name,
117        }
118    }
119
120    /// The generics of the variant itself.
121    #[allow(dead_code)]
122    fn generics(&self) -> &syn::Generics {
123        match self {
124            Variant::Named { generics, .. } => generics,
125            Variant::Anon { generics, .. } => generics,
126        }
127    }
128
129    /// The snake_name of the variant itself.
130    fn snake_name(&self) -> Ident {
131        Ident::new(
132            &self.name().to_string().to_case(Case::Snake),
133            self.name().span(),
134        )
135    }
136
137    /// The variant's qualified name.
138    fn qualified_name(&self) -> proc_macro2::TokenStream {
139        let enum_name = self.enum_name();
140        let name = self.name();
141
142        if self.is_struct() {
143            quote! { #enum_name }
144        } else {
145            quote! { #enum_name::#name }
146        }
147    }
148
149    /// Names of the fields in the variant. Anonymous variants are named
150    /// according to their position in the argument list.
151    fn field_names(&self) -> Vec<Ident> {
152        match self {
153            Variant::Named { field_names, .. } => field_names.clone(),
154            Variant::Anon { field_types, .. } => (0usize..field_types.len())
155                .map(|idx| format_ident!("arg{}", idx))
156                .collect(),
157        }
158    }
159
160    /// The types of the fields int the variant.
161    fn field_types(&self) -> &Vec<Type> {
162        match self {
163            Variant::Named { field_types, .. } => field_types,
164            Variant::Anon { field_types, .. } => field_types,
165        }
166    }
167
168    /// Return the field flags for this variant.
169    fn field_flags(&self) -> &Vec<FieldFlag> {
170        match self {
171            Variant::Named { field_flags, .. } => field_flags,
172            Variant::Anon { field_flags, .. } => field_flags,
173        }
174    }
175
176    /// The constructor for the variant, using the field names directly.
177    fn constructor(&self) -> proc_macro2::TokenStream {
178        let qualified_name = self.qualified_name();
179        let field_names = self.field_names();
180        match self {
181            Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
182            Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
183        }
184    }
185}
186
187struct ReplyPort {
188    is_handle: bool,
189    is_once: bool,
190}
191
192impl ReplyPort {
193    fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
194        ReplyPort {
195            is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
196            is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
197        }
198    }
199
200    fn open_op(&self) -> proc_macro2::TokenStream {
201        if self.is_once {
202            quote! { hyperactor::mailbox::open_once_port }
203        } else {
204            quote! { hyperactor::mailbox::open_port }
205        }
206    }
207
208    fn rx_modifier(&self) -> proc_macro2::TokenStream {
209        if self.is_once {
210            quote! {}
211        } else {
212            quote! { mut }
213        }
214    }
215}
216
217/// Represents a message that can be sent to a handler, each message is associated with
218/// a variant.
219#[allow(clippy::large_enum_variant)]
220enum Message {
221    /// A call message is a request-response message, the last argument is
222    /// a [`hyperactor::OncePortRef`] or [`hyperactor::OncePortHandle`].
223    Call {
224        variant: Variant,
225        /// Tells whether the reply argument is a handle.
226        reply_port: ReplyPort,
227        /// The underlying return type (i.e., the type of the reply port).
228        return_type: Type,
229        /// the log level for generated instrumentation for handlers of this message.
230        log_level: Option<Ident>,
231    },
232    OneWay {
233        variant: Variant,
234        /// the log level for generated instrumentation for handlers of this message.
235        log_level: Option<Ident>,
236    },
237}
238
239impl Message {
240    fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
241        match &variant
242            .field_flags()
243            .iter()
244            .zip(variant.field_types())
245            .filter_map(|(flag, ty)| match flag {
246                FieldFlag::Reply => Some(ty),
247                FieldFlag::None => None,
248            })
249            .collect::<Vec<&Type>>()[..]
250        {
251            [] => Ok(Self::OneWay { variant, log_level }),
252            [reply_port_ty] => {
253                let syn::Type::Path(type_path) = reply_port_ty else {
254                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
255                };
256                let Some(last_segment) = type_path.path.segments.last() else {
257                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
258                };
259                if last_segment.ident != "OncePortRef"
260                    && last_segment.ident != "OncePortHandle"
261                    && last_segment.ident != "PortRef"
262                    && last_segment.ident != "PortHandle"
263                {
264                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
265                }
266                let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
267                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
268                };
269                let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
270                    return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
271                };
272                let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
273                let return_type = return_ty.clone();
274                Ok(Self::Call {
275                    variant,
276                    reply_port,
277                    return_type,
278                    log_level,
279                })
280            }
281            _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
282        }
283    }
284
285    /// The arguments of this message.
286    fn args(&self) -> Vec<(Ident, Type)> {
287        match self {
288            Message::Call { variant, .. } => variant
289                .field_names()
290                .into_iter()
291                .zip(variant.field_types().clone())
292                .take(variant.len() - 1)
293                .collect(),
294            Message::OneWay { variant, .. } => variant
295                .field_names()
296                .into_iter()
297                .zip(variant.field_types().clone())
298                .collect(),
299        }
300    }
301
302    fn variant(&self) -> &Variant {
303        match self {
304            Message::Call { variant, .. } => variant,
305            Message::OneWay { variant, .. } => variant,
306        }
307    }
308
309    fn reply_port_position(&self) -> Option<usize> {
310        self.variant()
311            .field_flags()
312            .iter()
313            .position(|flag| matches!(flag, FieldFlag::Reply))
314    }
315
316    /// The reply port argument of this message.
317    fn reply_port_arg(&self) -> Option<(Ident, Type)> {
318        match self {
319            Message::Call { variant, .. } => {
320                let pos = self.reply_port_position()?;
321                Some((
322                    variant.field_names()[pos].clone(),
323                    variant.field_types()[pos].clone(),
324                ))
325            }
326            Message::OneWay { .. } => None,
327        }
328    }
329}
330
331fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
332    let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
333        Some(attr) => {
334            let Ok(meta) = attr.meta.require_list() else {
335                return Err(syn::Error::new(
336                    Span::call_site(),
337                    indoc! {"
338                            `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
339
340                            = help use `#[log_level(info)]` or `#[log_level(error)]`
341                        "},
342                ));
343            };
344            let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
345            if parsed.len() != 1 {
346                return Err(syn::Error::new(
347                    Span::call_site(),
348                    indoc! {"
349                            `log_level` attribute must specify exactly one level
350
351                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
352                        "},
353                ));
354            };
355            Some(parsed.first().unwrap().to_string())
356        }
357        None => None,
358    };
359
360    if level.is_none() {
361        return Ok(None);
362    }
363    let level = level.unwrap();
364
365    match level.as_str() {
366        "error" | "warn" | "info" | "debug" | "trace" => {}
367        _ => {
368            return Err(syn::Error::new(
369                Span::call_site(),
370                indoc! {"
371                            `log_level` attribute must be one of 'error, warn, info, debug, trace'
372
373                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
374                        "},
375            ));
376        }
377    }
378
379    Ok(Some(Ident::new(
380        level.to_ascii_uppercase().as_str(),
381        Span::call_site(),
382    )))
383}
384
385fn parse_field_flag(field: &Field) -> FieldFlag {
386    for attr in field.attrs.iter() {
387        match &attr.meta {
388            syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
389            _ => {}
390        }
391    }
392    FieldFlag::None
393}
394
395/// Parse a message enum or struct into its constituent messages.
396fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
397    match &input.data {
398        Data::Enum(data_enum) => {
399            let mut messages = Vec::new();
400
401            for variant in &data_enum.variants {
402                let name = variant.ident.clone();
403                let attrs = &variant.attrs;
404
405                let message_variant = match &variant.fields {
406                    syn::Fields::Unnamed(fields_) => Variant::Anon {
407                        enum_name: input.ident.clone(),
408                        name,
409                        field_types: fields_
410                            .unnamed
411                            .iter()
412                            .map(|field| field.ty.clone())
413                            .collect(),
414                        field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
415                        is_struct: false,
416                        generics: input.generics.clone(),
417                    },
418                    syn::Fields::Named(fields_) => Variant::Named {
419                        enum_name: input.ident.clone(),
420                        name,
421                        field_names: fields_
422                            .named
423                            .iter()
424                            .map(|field| field.ident.clone().unwrap())
425                            .collect(),
426                        field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
427                        field_flags: fields_.named.iter().map(parse_field_flag).collect(),
428                        is_struct: false,
429                        generics: input.generics.clone(),
430                    },
431                    _ => {
432                        return Err(syn::Error::new_spanned(
433                            variant,
434                            indoc! {r#"
435                                `Handler` currently only supports named or tuple struct variants
436
437                                = help use `MyCall(Arg1Type, Arg2Type, ..)`,
438                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
439                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
440                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
441                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
442                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
443                              "#},
444                        ));
445                    }
446                };
447                let log_level = parse_log_level(attrs)?;
448
449                messages.push(Message::new(
450                    variant.fields.span(),
451                    message_variant,
452                    log_level,
453                )?);
454            }
455
456            Ok(messages)
457        }
458        Data::Struct(data_struct) => {
459            let struct_name = input.ident.clone();
460            let attrs = &input.attrs;
461
462            let message_variant = match &data_struct.fields {
463                syn::Fields::Unnamed(fields_) => Variant::Anon {
464                    enum_name: struct_name.clone(),
465                    name: struct_name,
466                    field_types: fields_
467                        .unnamed
468                        .iter()
469                        .map(|field| field.ty.clone())
470                        .collect(),
471                    field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
472                    is_struct: true,
473                    generics: input.generics.clone(),
474                },
475                syn::Fields::Named(fields_) => Variant::Named {
476                    enum_name: struct_name.clone(),
477                    name: struct_name,
478                    field_names: fields_
479                        .named
480                        .iter()
481                        .map(|field| field.ident.clone().unwrap())
482                        .collect(),
483                    field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
484                    field_flags: fields_.named.iter().map(parse_field_flag).collect(),
485                    is_struct: true,
486                    generics: input.generics.clone(),
487                },
488                syn::Fields::Unit => Variant::Anon {
489                    enum_name: struct_name.clone(),
490                    name: struct_name,
491                    field_types: Vec::new(),
492                    field_flags: Vec::new(),
493                    is_struct: true,
494                    generics: input.generics.clone(),
495                },
496            };
497
498            let log_level = parse_log_level(attrs)?;
499            let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
500
501            Ok(vec![message])
502        }
503        _ => Err(syn::Error::new_spanned(
504            input,
505            "handlers can only be derived for enums and structs",
506        )),
507    }
508}
509
510/// Derive a custom handler trait for given an enum containing tuple
511/// structs.  The handler trait defines a method corresponding
512/// to each of the enum's variants, and a `handle` function
513/// that dispatches messages to the correct method.  The macro
514/// supports two messaging patterns: "call" and "oneway". A call is a
515/// request-response message; a [`hyperactor::mailbox::OncePortRef`] or
516/// [`hyperactor::mailbox::OncePortHandle`] in the last position is used
517/// to send the return value.
518///
519/// The macro also derives a client trait that can be automatically implemented
520/// by specifying [`HandleClient`] for `ActorHandle<Actor>` and [`RefClient`]
521/// for `ActorRef<Actor>` accordingly. We require two implementations because
522/// not `ActorRef`s require that its message type is serializable.
523///
524/// The associated [`hyperactor_macros::handler`] macro can be used to add
525/// a dispatching handler directly to an [`hyperactor::Actor`].
526///
527/// # Example
528///
529/// The following example creates a "shopping list" actor responsible for
530/// maintaining a shopping list.
531///
532/// ```
533/// use std::collections::HashSet;
534/// use std::time::Duration;
535///
536/// use async_trait::async_trait;
537/// use hyperactor::Actor;
538/// use hyperactor::HandleClient;
539/// use hyperactor::Handler;
540/// use hyperactor::Instance;
541/// use hyperactor::OncePortRef;
542/// use hyperactor::RefClient;
543/// use hyperactor::proc::Proc;
544/// use serde::Deserialize;
545/// use serde::Serialize;
546/// use typeuri::Named;
547///
548/// #[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
549/// enum ShoppingList {
550///     // Oneway messages dispatch messages asynchronously, with no reply.
551///     Add(String),
552///     Remove(String),
553///
554///     // Call messages dispatch a request, expecting a reply to the
555///     // provided port, which must be in the last position.
556///     Exists(String, #[reply] OncePortRef<bool>),
557///
558///     List(#[reply] OncePortRef<Vec<String>>),
559/// }
560///
561/// // Define an actor.
562/// #[derive(Debug)]
563/// #[hyperactor::export(
564///     spawn = true,
565///     handlers = [
566///         ShoppingList,
567///     ],
568/// )]
569/// struct ShoppingListActor(HashSet<String>);
570///
571/// #[async_trait]
572/// impl Actor for ShoppingListActor {
573///     type Params = ();
574///
575///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
576///         Ok(Self(HashSet::new()))
577///     }
578/// }
579///
580/// // ShoppingListHandler is the trait generated by derive(Handler) above.
581/// // We implement the trait here for the actor, defining a handler for
582/// // each ShoppingList message.
583/// //
584/// // The `forward` attribute installs a handler that forwards messages
585/// // to the `ShoppingListHandler` implementation directly. This can also
586/// // be done manually:
587/// //
588/// // ```ignore
589/// //<ShoppingListActor as ShoppingListHandler>
590/// //     ::handle(self, comm, message).await
591/// // ```
592/// #[async_trait]
593/// #[hyperactor::forward(ShoppingList)]
594/// impl ShoppingListHandler for ShoppingListActor {
595///     async fn add(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
596///         eprintln!("insert {}", item);
597///         self.0.insert(item);
598///         Ok(())
599///     }
600///
601///     async fn remove(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
602///         eprintln!("remove {}", item);
603///         self.0.remove(&item);
604///         Ok(())
605///     }
606///
607///     async fn exists(
608///         &mut self,
609///         _cx: &Context<Self>,
610///         item: String,
611///     ) -> Result<bool, anyhow::Error> {
612///         Ok(self.0.contains(&item))
613///     }
614///
615///     async fn list(&mut self, _cx: &Context<Self>) -> Result<Vec<String>, anyhow::Error> {
616///         Ok(self.0.iter().cloned().collect())
617///     }
618/// }
619///
620/// #[tokio::main]
621/// async fn main() -> Result<(), anyhow::Error> {
622///     let mut proc = Proc::local();
623///
624///     // Spawn our actor, and get a handle for rank 0.
625///     let shopping_list_actor: hyperactor::ActorHandle<ShoppingListActor> =
626///         proc.spawn("shopping", ()).await?;
627///
628///     // We join the system, so that we can send messages to actors.
629///     let client = proc.attach("client").unwrap();
630///
631///     // todo: consider making this a macro to remove the magic names
632///
633///     // Derive(Handler) generates client methods, which call the
634///     // remote handler provided an actor instance,
635///     // the destination actor, and the method arguments.
636///
637///     shopping_list_actor.add(&client, "milk".into()).await?;
638///     shopping_list_actor.add(&client, "eggs".into()).await?;
639///
640///     println!(
641///         "got milk? {}",
642///         shopping_list_actor.exists(&client, "milk".into()).await?
643///     );
644///     println!(
645///         "got yoghurt? {}",
646///         shopping_list_actor
647///             .exists(&client, "yoghurt".into())
648///             .await?
649///     );
650///
651///     shopping_list_actor.remove(&client, "milk".into()).await?;
652///     println!(
653///         "got milk now? {}",
654///         shopping_list_actor.exists(&client, "milk".into()).await?
655///     );
656///
657///     println!(
658///         "shopping list: {:?}",
659///         shopping_list_actor.list(&client).await?
660///     );
661///
662///     let _ = proc
663///         .destroy_and_wait::<()>(Duration::from_secs(1), None)
664///         .await?;
665///     Ok(())
666/// }
667/// ```
668#[proc_macro_derive(Handler, attributes(reply))]
669pub fn derive_handler(input: TokenStream) -> TokenStream {
670    let input = parse_macro_input!(input as DeriveInput);
671    let name: Ident = input.ident.clone();
672    let (_, ty_generics, _) = input.generics.split_for_impl();
673
674    let messages = match parse_messages(input.clone()) {
675        Ok(messages) => messages,
676        Err(err) => return TokenStream::from(err.to_compile_error()),
677    };
678
679    // Trait definition methods for the handler.
680    let mut handler_trait_methods = Vec::new();
681
682    // The arms of the match used in the message dispatcher.
683    let mut match_arms = Vec::new();
684
685    // Trait implemented by clients.
686    let mut client_trait_methods = Vec::new();
687
688    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
689
690    for message in &messages {
691        match message {
692            Message::Call {
693                variant,
694                reply_port,
695                return_type,
696                log_level,
697            } => {
698                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
699                let variant_name_snake = variant.snake_name();
700                let variant_name_snake_deprecated =
701                    format_ident!("{}_deprecated", variant_name_snake);
702                let enum_name = variant.enum_name();
703                let _variant_qualified_name = variant.qualified_name();
704                let log_level = match (&global_log_level, log_level) {
705                    (_, Some(local)) => local.clone(),
706                    (Some(global), None) => global.clone(),
707                    _ => Ident::new("DEBUG", Span::call_site()),
708                };
709                let _log_level = if reply_port.is_handle {
710                    quote! {
711                        tracing::Level::#log_level
712                    }
713                } else {
714                    quote! {
715                        tracing::Level::TRACE
716                    }
717                };
718                let log_message = quote! {
719                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
720                            "rpc" => "call",
721                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
722                            "message_type" => stringify!(#enum_name),
723                            "variant" => stringify!(#variant_name_snake),
724                        ));
725                };
726
727                handler_trait_methods.push(quote! {
728                    #[doc = "The generated handler method for this enum variant."]
729                    async fn #variant_name_snake(
730                        &mut self,
731                        cx: &hyperactor::Context<Self>,
732                        #(#arg_names: #arg_types),*)
733                        -> Result<#return_type, hyperactor::anyhow::Error>;
734                });
735
736                client_trait_methods.push(quote! {
737                    #[doc = "The generated client method for this enum variant."]
738                    async fn #variant_name_snake(
739                        &self,
740                        cx: &impl hyperactor::context::Actor,
741                        #(#arg_names: #arg_types),*)
742                        -> Result<#return_type, hyperactor::anyhow::Error>;
743
744                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
745                    async fn #variant_name_snake_deprecated(
746                        &self,
747                        cx: &impl hyperactor::context::Actor,
748                        #(#arg_names: #arg_types),*)
749                        -> Result<#return_type, hyperactor::anyhow::Error>;
750                });
751
752                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
753                let constructor = variant.constructor();
754                let result_ident = Ident::new("result", Span::mixed_site());
755                let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
756                if reply_port.is_handle {
757                    match_arms.push(quote! {
758                        #constructor => {
759                            #log_message
760                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
761                            // This would require Result<Result<..., in order to handle RPC errors.
762                            #construct_result_future
763                            #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
764                        }
765                    });
766                } else {
767                    match_arms.push(quote! {
768                        #constructor => {
769                            #log_message
770                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
771                            // This would require Result<Result<..., in order to handle RPC errors.
772                            #construct_result_future
773                            #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
774                        }
775                    });
776                }
777            }
778            Message::OneWay { variant, log_level } => {
779                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
780                let variant_name_snake = variant.snake_name();
781                let variant_name_snake_deprecated =
782                    format_ident!("{}_deprecated", variant_name_snake);
783                let enum_name = variant.enum_name();
784                let log_level = match (&global_log_level, log_level) {
785                    (_, Some(local)) => local.clone(),
786                    (Some(global), None) => global.clone(),
787                    _ => Ident::new("TRACE", Span::call_site()),
788                };
789                let _log_level = quote! {
790                    tracing::Level::#log_level
791                };
792                let log_message = quote! {
793                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
794                            "rpc" => "call",
795                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
796                            "message_type" => stringify!(#enum_name),
797                            "variant" => stringify!(#variant_name_snake),
798                        ));
799                };
800
801                handler_trait_methods.push(quote! {
802                    #[doc = "The generated handler method for this enum variant."]
803                    async fn #variant_name_snake(
804                        &mut self,
805                        cx: &hyperactor::Context<Self>,
806                        #(#arg_names: #arg_types),*)
807                        -> Result<(), hyperactor::anyhow::Error>;
808                });
809
810                client_trait_methods.push(quote! {
811                    #[doc = "The generated client method for this enum variant."]
812                    async fn #variant_name_snake(
813                        &self,
814                        cx: &impl hyperactor::context::Actor,
815                        #(#arg_names: #arg_types),*)
816                        -> Result<(), hyperactor::anyhow::Error>;
817
818                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
819                    async fn #variant_name_snake_deprecated(
820                        &self,
821                        cx: &impl hyperactor::context::Actor,
822                        #(#arg_names: #arg_types),*)
823                        -> Result<(), hyperactor::anyhow::Error>;
824                });
825
826                let constructor = variant.constructor();
827
828                match_arms.push(quote! {
829                    #constructor => {
830                        #log_message
831                        self.#variant_name_snake(cx, #(#arg_names),*).await
832                    },
833                });
834            }
835        }
836    }
837
838    let handler_trait_name = format_ident!("{}Handler", name);
839    let client_trait_name = format_ident!("{}Client", name);
840
841    // We impose additional constraints on the generics in the implementation;
842    // but the trait itself should not impose additional constraints:
843
844    let mut handler_generics = input.generics.clone();
845    for param in handler_generics.type_params_mut() {
846        param.bounds.push(syn::parse_quote!(serde::Serialize));
847        param
848            .bounds
849            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
850        param.bounds.push(syn::parse_quote!(Send));
851        param.bounds.push(syn::parse_quote!(Sync));
852        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
853        param.bounds.push(syn::parse_quote!(typeuri::Named));
854    }
855    let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
856    let (client_impl_generics, _, _) = input.generics.split_for_impl();
857
858    let expanded = quote! {
859        #[doc = "The custom handler trait for this message type."]
860        #[hyperactor::async_trait::async_trait]
861        pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync  {
862            #(#handler_trait_methods)*
863
864            #[doc = "Handle the next message."]
865            async fn handle(
866                &mut self,
867                cx: &hyperactor::Context<Self>,
868                message: #name #ty_generics,
869            ) -> hyperactor::anyhow::Result<()>  {
870                 // Dispatch based on message type.
871                 match message {
872                     #(#match_arms)*
873                }
874            }
875        }
876
877        #[doc = "The custom client trait for this message type."]
878        #[hyperactor::async_trait::async_trait]
879        pub trait #client_trait_name #client_impl_generics: Send + Sync  {
880            #(#client_trait_methods)*
881        }
882    };
883
884    TokenStream::from(expanded)
885}
886
887/// Derives a client implementation on `ActorHandle<Actor>`.
888/// See [`Handler`] documentation for details.
889#[proc_macro_derive(HandleClient, attributes(log_level))]
890pub fn derive_handle_client(input: TokenStream) -> TokenStream {
891    derive_client(input, true)
892}
893
894/// Derives a client implementation on `ActorRef<Actor>`.
895/// See [`Handler`] documentation for details.
896#[proc_macro_derive(RefClient, attributes(log_level))]
897pub fn derive_ref_client(input: TokenStream) -> TokenStream {
898    derive_client(input, false)
899}
900
901fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
902    let input = parse_macro_input!(input as DeriveInput);
903    let name = input.ident.clone();
904
905    let messages = match parse_messages(input.clone()) {
906        Ok(messages) => messages,
907        Err(err) => return TokenStream::from(err.to_compile_error()),
908    };
909
910    // The client implementation methods.
911    let mut impl_methods = Vec::new();
912
913    let send_message = if is_handle {
914        quote! { self.send(message)? }
915    } else {
916        quote! { self.send(cx, message)? }
917    };
918    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
919
920    for message in &messages {
921        match message {
922            Message::Call {
923                variant,
924                reply_port,
925                return_type,
926                log_level,
927            } => {
928                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
929                let variant_name_snake = variant.snake_name();
930                let variant_name_snake_deprecated =
931                    format_ident!("{}_deprecated", variant_name_snake);
932                let enum_name = variant.enum_name();
933
934                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
935                let constructor = variant.constructor();
936                let log_level = match (&global_log_level, log_level) {
937                    (_, Some(local)) => local.clone(),
938                    (Some(global), None) => global.clone(),
939                    _ => Ident::new("DEBUG", Span::call_site()),
940                };
941                let log_level = if is_handle {
942                    quote! {
943                        tracing::Level::#log_level
944                    }
945                } else {
946                    quote! {
947                        tracing::Level::TRACE
948                    }
949                };
950                let log_message = quote! {
951                        hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
952                            "rpc" => "call",
953                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
954                            "message_type" => stringify!(#enum_name),
955                            "variant" => stringify!(#variant_name_snake),
956                        ));
957
958                };
959                let open_port = reply_port.open_op();
960                let rx_mod = reply_port.rx_modifier();
961                if reply_port.is_handle {
962                    impl_methods.push(quote! {
963                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
964                        async fn #variant_name_snake(
965                            &self,
966                            cx: &impl hyperactor::context::Actor,
967                            #(#arg_names: #arg_types),*)
968                            -> Result<#return_type, hyperactor::anyhow::Error> {
969                            let (#reply_port_arg, #rx_mod reply_receiver) =
970                                #open_port::<#return_type>(cx);
971                            let message = #constructor;
972                            #log_message;
973                            #send_message;
974                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
975                        }
976
977                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
978                        async fn #variant_name_snake_deprecated(
979                            &self,
980                            cx: &impl hyperactor::context::Actor,
981                            #(#arg_names: #arg_types),*)
982                            -> Result<#return_type, hyperactor::anyhow::Error> {
983                            let (#reply_port_arg, #rx_mod reply_receiver) =
984                                #open_port::<#return_type>(cx);
985                            let message = #constructor;
986                            #log_message;
987                            #send_message;
988                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
989                        }
990                    });
991                } else {
992                    impl_methods.push(quote! {
993                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
994                        async fn #variant_name_snake(
995                            &self,
996                            cx: &impl hyperactor::context::Actor,
997                            #(#arg_names: #arg_types),*)
998                            -> Result<#return_type, hyperactor::anyhow::Error> {
999                            let (#reply_port_arg, #rx_mod reply_receiver) =
1000                                #open_port::<#return_type>(cx);
1001                            let #reply_port_arg = #reply_port_arg.bind();
1002                            let message = #constructor;
1003                            #log_message;
1004                            #send_message;
1005                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1006                        }
1007
1008                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
1009                        async fn #variant_name_snake_deprecated(
1010                            &self,
1011                            cx: &impl hyperactor::context::Actor,
1012                            #(#arg_names: #arg_types),*)
1013                            -> Result<#return_type, hyperactor::anyhow::Error> {
1014                            let (#reply_port_arg, #rx_mod reply_receiver) =
1015                                #open_port::<#return_type>(cx);
1016                            let #reply_port_arg = #reply_port_arg.bind();
1017                            let message = #constructor;
1018                            #log_message;
1019                            #send_message;
1020                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1021                        }
1022                    });
1023                }
1024            }
1025            Message::OneWay { variant, log_level } => {
1026                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1027                let variant_name_snake = variant.snake_name();
1028                let variant_name_snake_deprecated =
1029                    format_ident!("{}_deprecated", variant_name_snake);
1030                let enum_name = variant.enum_name();
1031                let constructor = variant.constructor();
1032                let log_level = match (&global_log_level, log_level) {
1033                    (_, Some(local)) => local.clone(),
1034                    (Some(global), None) => global.clone(),
1035                    _ => Ident::new("DEBUG", Span::call_site()),
1036                };
1037                let _log_level = if is_handle {
1038                    quote! {
1039                        tracing::Level::TRACE
1040                    }
1041                } else {
1042                    quote! {
1043                        tracing::Level::#log_level
1044                    }
1045                };
1046                let log_message = quote! {
1047                    hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1048                        "rpc" => "oneway",
1049                        "actor_id" => self.actor_id().to_string(),
1050                        "message_type" => stringify!(#enum_name),
1051                        "variant" => stringify!(#variant_name_snake),
1052                    ));
1053                };
1054                impl_methods.push(quote! {
1055                    async fn #variant_name_snake(
1056                        &self,
1057                        cx: &impl hyperactor::context::Actor,
1058                        #(#arg_names: #arg_types),*)
1059                        -> Result<(), hyperactor::anyhow::Error> {
1060                        let message = #constructor;
1061                        #log_message;
1062                        #send_message;
1063                        Ok(())
1064                    }
1065
1066                    async fn #variant_name_snake_deprecated(
1067                        &self,
1068                        cx: &impl hyperactor::context::Actor,
1069                        #(#arg_names: #arg_types),*)
1070                        -> Result<(), hyperactor::anyhow::Error> {
1071                        let message = #constructor;
1072                        #log_message;
1073                        #send_message;
1074                        Ok(())
1075                    }
1076                });
1077            }
1078        }
1079    }
1080
1081    let trait_name = format_ident!("{}Client", name);
1082
1083    let (_, ty_generics, _) = input.generics.split_for_impl();
1084
1085    // Add a new generic parameter 'A'
1086    let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1087    let mut trait_generics = input.generics.clone();
1088    trait_generics.params.insert(
1089        0,
1090        syn::GenericParam::Type(syn::TypeParam {
1091            ident: actor_ident.clone(),
1092            attrs: vec![],
1093            colon_token: None,
1094            bounds: Punctuated::new(),
1095            eq_token: None,
1096            default: None,
1097        }),
1098    );
1099
1100    for param in trait_generics.type_params_mut() {
1101        if param.ident == actor_ident {
1102            continue;
1103        }
1104        param.bounds.push(syn::parse_quote!(serde::Serialize));
1105        param
1106            .bounds
1107            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1108        param.bounds.push(syn::parse_quote!(Send));
1109        param.bounds.push(syn::parse_quote!(Sync));
1110        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1111        param.bounds.push(syn::parse_quote!(typeuri::Named));
1112    }
1113
1114    let (impl_generics, _, _) = trait_generics.split_for_impl();
1115
1116    let expanded = if is_handle {
1117        quote! {
1118            #[hyperactor::async_trait::async_trait]
1119            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1120              where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1121                #(#impl_methods)*
1122            }
1123        }
1124    } else {
1125        quote! {
1126            #[hyperactor::async_trait::async_trait]
1127            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1128              where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1129                #(#impl_methods)*
1130            }
1131        }
1132    };
1133
1134    TokenStream::from(expanded)
1135}
1136
1137const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
1138`forward` expects the message type that is being forwarded
1139
1140= help: use `#[forward(MessageType)]`
1141"#};
1142
1143/// Forward messages of the provided type to this handler implementation.
1144#[proc_macro_attribute]
1145pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
1146    let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1147    if attr_args.len() != 1 {
1148        return TokenStream::from(
1149            syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
1150        );
1151    }
1152
1153    let message_type = attr_args.first().unwrap();
1154    let input = parse_macro_input!(item as ItemImpl);
1155
1156    let self_type = match *input.self_ty {
1157        syn::Type::Path(ref type_path) => {
1158            let segment = type_path.path.segments.last().unwrap();
1159            segment.clone() //ident.clone()
1160        }
1161        _ => {
1162            return TokenStream::from(
1163                syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
1164                    .to_compile_error(),
1165            );
1166        }
1167    };
1168
1169    let trait_name = match input.trait_ {
1170        Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1171        None => {
1172            return TokenStream::from(
1173                syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1174                    .to_compile_error(),
1175            );
1176        }
1177    };
1178
1179    let expanded = quote! {
1180        #input
1181
1182        #[hyperactor::async_trait::async_trait]
1183        impl hyperactor::Handler<#message_type> for #self_type {
1184            async fn handle(
1185                &mut self,
1186                cx: &hyperactor::Context<Self>,
1187                message: #message_type,
1188            ) -> hyperactor::anyhow::Result<()> {
1189                <Self as #trait_name>::handle(self, cx, message).await
1190            }
1191        }
1192    };
1193
1194    TokenStream::from(expanded)
1195}
1196
1197/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1198/// We set a default level of INFO while always setting ERROR if the function returns Result::Err giving us
1199/// consistent and high quality structured logs. Because this wraps around tracing::instrument, all parameters
1200/// mentioned in https://fburl.com/9jlkb5q4 should be valid. For functions that don't return a [`Result`] type, use
1201/// [`instrument_infallible`]
1202///
1203/// ```
1204/// #[telemetry::instrument]
1205/// async fn yolo() -> anyhow::Result<i32> {
1206///     Ok(420)
1207/// }
1208/// ```
1209#[proc_macro_attribute]
1210pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1211    let args =
1212        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1213    let input = parse_macro_input!(input as ItemFn);
1214    let output = quote! {
1215        #[hyperactor::tracing::instrument(err, skip_all, #args)]
1216        #input
1217    };
1218
1219    TokenStream::from(output)
1220}
1221
1222/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1223/// Because this wraps around tracing::instrument, all parameters mentioned in
1224/// https://fburl.com/9jlkb5q4 should be valid.
1225///
1226/// ```
1227/// #[telemetry::instrument]
1228/// async fn yolo() -> i32 {
1229///     420
1230/// }
1231/// ```
1232#[proc_macro_attribute]
1233pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1234    let args =
1235        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1236    let input = parse_macro_input!(input as ItemFn);
1237
1238    let output = quote! {
1239        #[hyperactor::tracing::instrument(skip_all, #args)]
1240        #input
1241    };
1242
1243    TokenStream::from(output)
1244}
1245
1246struct HandlerSpec {
1247    ty: Type,
1248    cast: bool,
1249}
1250
1251impl Parse for HandlerSpec {
1252    fn parse(input: ParseStream) -> syn::Result<Self> {
1253        let ty: Type = input.parse()?;
1254
1255        if input.peek(syn::token::Brace) {
1256            let content;
1257            syn::braced!(content in input);
1258            let key: Ident = content.parse()?;
1259            content.parse::<Token![=]>()?;
1260            let expr: Expr = content.parse()?;
1261
1262            let cast = if key == "cast" {
1263                if let Expr::Lit(ExprLit {
1264                    lit: Lit::Bool(b), ..
1265                }) = expr
1266                {
1267                    b.value
1268                } else {
1269                    return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1270                }
1271            } else {
1272                return Err(syn::Error::new_spanned(
1273                    key,
1274                    "unsupported field (expected `cast`)",
1275                ));
1276            };
1277
1278            Ok(HandlerSpec { ty, cast })
1279        } else if input.is_empty() || input.peek(Token![,]) {
1280            Ok(HandlerSpec { ty, cast: false })
1281        } else {
1282            // Something unexpected follows the type
1283            let unexpected: proc_macro2::TokenTree = input.parse()?;
1284            Err(syn::Error::new_spanned(
1285                unexpected,
1286                "unexpected token after type — expected `{ ... }` or nothing",
1287            ))
1288        }
1289    }
1290}
1291
1292impl HandlerSpec {
1293    fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1294        let mut tys = Vec::new();
1295        for HandlerSpec { ty, cast } in handlers {
1296            if cast {
1297                let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1298                let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1299                tys.push(wrapped_ty);
1300            }
1301            tys.push(ty);
1302        }
1303        tys
1304    }
1305}
1306
1307/// Attribute Struct for [`fn export`] macro.
1308struct ExportAttr {
1309    spawn: bool,
1310    handlers: Vec<HandlerSpec>,
1311}
1312
1313impl Parse for ExportAttr {
1314    fn parse(input: ParseStream) -> syn::Result<Self> {
1315        let mut spawn = false;
1316        let mut handlers: Vec<HandlerSpec> = vec![];
1317
1318        while !input.is_empty() {
1319            let key: Ident = input.parse()?;
1320            input.parse::<Token![=]>()?;
1321
1322            if key == "spawn" {
1323                let expr: Expr = input.parse()?;
1324                if let Expr::Lit(ExprLit {
1325                    lit: Lit::Bool(b), ..
1326                }) = expr
1327                {
1328                    spawn = b.value;
1329                } else {
1330                    return Err(syn::Error::new_spanned(
1331                        expr,
1332                        "expected boolean for `spawn`",
1333                    ));
1334                }
1335            } else if key == "handlers" {
1336                let content;
1337                bracketed!(content in input);
1338                let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1339                handlers = raw_handlers.into_iter().collect();
1340            } else {
1341                return Err(syn::Error::new_spanned(
1342                    key,
1343                    "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1344                ));
1345            }
1346
1347            // optional trailing comma
1348            let _ = input.parse::<Token![,]>();
1349        }
1350
1351        Ok(ExportAttr { spawn, handlers })
1352    }
1353}
1354
1355/// Exports handlers for this actor. The set of exported handlers
1356/// determine the messages that may be sent to remote references of
1357/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
1358/// [`hyperactor::RemoteMessage`] may be exported.
1359///
1360/// Additionally, an exported actor may be remotely spawned,
1361/// indicated by `spawn = true`. Such actors must also ensure that
1362/// their parameter type implements [`hyperactor::RemoteMessage`].
1363///
1364/// # Example
1365///
1366/// In the following example, `MyActor` can be spawned remotely. It also has
1367/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`.
1368/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these
1369/// types.
1370///
1371/// ```ignore
1372/// #[export(
1373///     spawn = true,
1374///     handlers = [
1375///         MyMessage,
1376///         MyOtherMessage,
1377///     ],
1378/// )]
1379/// struct MyActor {}
1380/// ```
1381#[proc_macro_attribute]
1382pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1383    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1384    let data_type_name = &input.ident;
1385    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1386
1387    let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1388    let tys = HandlerSpec::add_indexed(handlers);
1389
1390    let mut handles = Vec::new();
1391    let mut bindings = Vec::new();
1392    let mut type_registrations = Vec::new();
1393
1394    for ty in &tys {
1395        handles.push(quote! {
1396            impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
1397        });
1398        bindings.push(quote! {
1399            ports.bind::<#ty>();
1400        });
1401        type_registrations.push(quote! {
1402            wirevalue::register_type!(#ty);
1403        });
1404    }
1405
1406    let mut expanded = quote! {
1407        #input
1408
1409        impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
1410
1411        #(#handles)*
1412
1413        #(#type_registrations)*
1414
1415        // Always export the `Signal` type.
1416        impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1417
1418        impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
1419            fn bind(ports: &hyperactor::proc::Ports<Self>) {
1420                #(#bindings)*
1421            }
1422        }
1423
1424        // TODO: just use Named derive directly here.
1425        impl #impl_generics typeuri::Named for #data_type_name #ty_generics #where_clause {
1426            fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
1427        }
1428    };
1429
1430    if spawn {
1431        expanded.extend(quote! {
1432            hyperactor::remote!(#data_type_name);
1433        });
1434    }
1435
1436    TokenStream::from(expanded)
1437}
1438
1439/// Represents the full input to [`fn behavior`].
1440struct BehaviorInput {
1441    behavior: Ident,
1442    generics: syn::Generics,
1443    handlers: Vec<HandlerSpec>,
1444}
1445
1446impl syn::parse::Parse for BehaviorInput {
1447    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1448        let behavior: Ident = input.parse()?;
1449        let generics: syn::Generics = input.parse()?;
1450        let _: Token![,] = input.parse()?;
1451        let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1452        let handlers = raw_handlers.into_iter().collect();
1453        Ok(BehaviorInput {
1454            behavior,
1455            generics,
1456            handlers,
1457        })
1458    }
1459}
1460
1461/// Create a [`Referable`] definition, handling a specific set of message types.
1462/// Behaviors are used to create an [`ActorRef`] without having to depend on the
1463/// actor's implementation. If the message type need to be cast, add `castable`
1464/// flag to those types. e.g. the following example creates a behavior with 5
1465/// message types, and 4 of which need to be cast.
1466///
1467/// ```
1468/// hyperactor::behavior!(
1469///     TestActorBehavior,
1470///     TestMessage { castable = true },
1471///     () {castable = true },
1472///     MyGeneric<()> {castable = true },
1473///     u64,
1474/// );
1475/// ```
1476///
1477/// This macro also supports generic behaviors:
1478/// ```
1479/// hyperactor::behavior!(
1480///     TestBehavior<T>,
1481///     Message<T> { castable = true },
1482///     u64,
1483/// );
1484/// ```
1485#[proc_macro]
1486pub fn behavior(input: TokenStream) -> TokenStream {
1487    let BehaviorInput {
1488        behavior,
1489        generics,
1490        handlers,
1491    } = parse_macro_input!(input as BehaviorInput);
1492    let tys = HandlerSpec::add_indexed(handlers);
1493
1494    // Add bounds to generics for Named, Serialize, Deserialize
1495    let mut bounded_generics = generics.clone();
1496    for param in bounded_generics.type_params_mut() {
1497        param.bounds.push(syn::parse_quote!(typeuri::Named));
1498        param.bounds.push(syn::parse_quote!(serde::Serialize));
1499        param.bounds.push(syn::parse_quote!(std::marker::Send));
1500        param.bounds.push(syn::parse_quote!(std::marker::Sync));
1501        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1502        // Note: lifetime parameters are not *actually* hygienic.
1503        // https://github.com/rust-lang/rust/issues/54727
1504        let lifetime =
1505            syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1506        param
1507            .bounds
1508            .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1509    }
1510
1511    // Split the generics for use in different contexts
1512    let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1513
1514    // Create a combined generics for the Binds impl that includes both A and the behavior's generics
1515    let mut binds_generics = bounded_generics.clone();
1516    binds_generics.params.insert(
1517        0,
1518        syn::GenericParam::Type(syn::TypeParam {
1519            attrs: vec![],
1520            ident: Ident::new("A", proc_macro2::Span::call_site()),
1521            colon_token: None,
1522            bounds: Punctuated::new(),
1523            eq_token: None,
1524            default: None,
1525        }),
1526    );
1527    let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1528
1529    // Determine typename and typehash implementation based on whether we have generics
1530    let type_params: Vec<_> = bounded_generics.type_params().collect();
1531    let has_generics = !type_params.is_empty();
1532
1533    let (typename_impl, typehash_impl) = if has_generics {
1534        // Create format string with placeholders for each generic parameter
1535        let placeholders = vec!["{}"; type_params.len()].join(", ");
1536        let placeholders_format_string = format!("<{}>", placeholders);
1537        let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1538
1539        let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1540        (
1541            quote! {
1542                typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1543            },
1544            quote! {
1545                typeuri::cityhasher::hash(Self::typename())
1546            },
1547        )
1548    } else {
1549        (
1550            quote! {
1551                concat!(std::module_path!(), "::", stringify!(#behavior))
1552            },
1553            quote! {
1554                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1555                    typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1556                });
1557                *TYPEHASH
1558            },
1559        )
1560    };
1561
1562    let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1563
1564    let expanded = quote! {
1565        #[doc = "The generated behavior struct."]
1566        #[derive(Debug, serde::Serialize, serde::Deserialize)]
1567        pub struct #behavior #impl_generics #where_clause {
1568            _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1569        }
1570
1571        impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1572            fn typename() -> &'static str {
1573                #typename_impl
1574            }
1575
1576            fn typehash() -> u64 {
1577                #typehash_impl
1578            }
1579        }
1580
1581        impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1582
1583        impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1584        where
1585            A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1586            #where_clause
1587        {
1588            fn bind(ports: &hyperactor::proc::Ports<A>) {
1589                #(
1590                    ports.bind::<#tys>();
1591                )*
1592            }
1593        }
1594
1595        #(
1596            impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1597        )*
1598    };
1599
1600    TokenStream::from(expanded)
1601}
1602
1603fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1604    let mut is_included = false;
1605    for attr in &field.attrs {
1606        if attr.path().is_ident("binding") {
1607            // parse #[binding(include)] and look for exactly "include"
1608            attr.parse_nested_meta(|meta| {
1609                if meta.path.is_ident("include") {
1610                    is_included = true;
1611                    Ok(())
1612                } else {
1613                    let path = meta.path.to_token_stream().to_string().replace(' ', "");
1614                    Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1615                }
1616            })?
1617        }
1618    }
1619    Ok(is_included)
1620}
1621
1622/// The field accessor in struct or enum variant.
1623/// e.g.:
1624///   struct NamedStruct { foo: u32 } => FieldAccessor::Named(Ident::new("foo", Span::call_site()))
1625///   struct UnnamedStruct(u32) => FieldAccessor::Unnamed(Index::from(0))
1626enum FieldAccessor {
1627    Named(Ident),
1628    Unnamed(Index),
1629}
1630
1631/// Result of parsing a field in a struct, or a enum variant.
1632struct ParsedField {
1633    accessor: FieldAccessor,
1634    ty: Type,
1635    included: bool,
1636}
1637
1638impl From<&ParsedField> for (Ident, Type) {
1639    fn from(field: &ParsedField) -> Self {
1640        let field_ident = match &field.accessor {
1641            FieldAccessor::Named(ident) => ident.clone(),
1642            FieldAccessor::Unnamed(i) => {
1643                Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1644            }
1645        };
1646        (field_ident, field.ty.clone())
1647    }
1648}
1649
1650fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1651    match fields {
1652        Fields::Named(named) => named
1653            .named
1654            .iter()
1655            .map(|f| {
1656                let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1657                Ok(ParsedField {
1658                    accessor,
1659                    ty: f.ty.clone(),
1660                    included: include_in_bind_unbind(f)?,
1661                })
1662            })
1663            .collect(),
1664        Fields::Unnamed(unnamed) => unnamed
1665            .unnamed
1666            .iter()
1667            .enumerate()
1668            .map(|(i, f)| {
1669                let accessor = FieldAccessor::Unnamed(Index::from(i));
1670                Ok(ParsedField {
1671                    accessor,
1672                    ty: f.ty.clone(),
1673                    included: include_in_bind_unbind(f)?,
1674                })
1675            })
1676            .collect(),
1677        Fields::Unit => Ok(Vec::new()),
1678    }
1679}
1680
1681fn gen_struct_items<F>(
1682    fields: &Fields,
1683    make_item: F,
1684    is_mutable: bool,
1685) -> syn::Result<Vec<proc_macro2::TokenStream>>
1686where
1687    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1688{
1689    let borrow = if is_mutable {
1690        quote! { &mut }
1691    } else {
1692        quote! { & }
1693    };
1694    let items: Vec<_> = collect_all_fields(fields)?
1695        .into_iter()
1696        .filter(|f| f.included)
1697        .map(
1698            |ParsedField {
1699                 accessor,
1700                 ty,
1701                 included,
1702             }| {
1703                assert!(included);
1704                let field_accessor = match accessor {
1705                    FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1706                    FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1707                };
1708                make_item(field_accessor, ty)
1709            },
1710        )
1711        .collect();
1712    Ok(items)
1713}
1714
1715/// Generate the field accessor for a enum variant in pattern matching. e.g.
1716/// the <GENERATED> parts in the following example:
1717///
1718///   match my_enum {
1719///     // e.g. MyEnum::Tuple(_, f1, f2, _)
1720///     MyEnum::Tuple(<GENERATED>) => { ... }
1721///     // e.g. MyEnum::Struct { field0: _, field1 }
1722///     MyEnum::Struct(<GENERATED>) => { ... }
1723///   }
1724fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1725    all_fields
1726        .iter()
1727        .map(
1728            |ParsedField {
1729                 accessor,
1730                 ty: _,
1731                 included,
1732             }| {
1733                match accessor {
1734                    FieldAccessor::Named(ident) => {
1735                        if *included {
1736                            quote! { #ident }
1737                        } else {
1738                            quote! { #ident: _ }
1739                        }
1740                    }
1741                    FieldAccessor::Unnamed(i) => {
1742                        if *included {
1743                            let ident = Ident::new(
1744                                &format!("f{}", i.index),
1745                                proc_macro2::Span::call_site(),
1746                            );
1747                            quote! { #ident }
1748                        } else {
1749                            quote! { _ }
1750                        }
1751                    }
1752                }
1753            },
1754        )
1755        .collect()
1756}
1757
1758/// Generate all the parts for enum variants. e.g. the <GENERATED> part in the
1759/// following example:
1760///
1761///   match my_enum {
1762///      <GENERATED>
1763///   }
1764fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1765where
1766    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1767{
1768    data.variants
1769        .iter()
1770        .map(|variant| {
1771            let name = &variant.ident;
1772            let all_fields = collect_all_fields(&variant.fields)?;
1773            let field_accessors = gen_enum_field_accessors(&all_fields);
1774            let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1775            let items = included_fields
1776                .iter()
1777                .map(|f| {
1778                    let (accessor, ty) = <(Ident, Type)>::from(*f);
1779                    make_item(quote! { #accessor }, ty)
1780                })
1781                .collect::<Vec<_>>();
1782
1783            Ok(match &variant.fields {
1784                Fields::Named(_) => {
1785                    quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1786                }
1787                Fields::Unnamed(_) => {
1788                    quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1789                }
1790                Fields::Unit => quote! { Self::#name => { #(#items)* } },
1791            })
1792        })
1793        .collect()
1794}
1795
1796/// Derive a custom implementation of [`hyperactor::message::Bind`] trait for
1797/// a struct or enum. This macro is normally used in tandem with [`fn derive_unbind`]
1798/// to make the applied struct or enum castable.
1799///
1800/// Specifically, the derived implementation iterates through fields annotated
1801/// with `#[binding(include)]` based on their order of declaration in the struct
1802/// or enum. These fields' types must implement `Bind` trait as well. During the
1803/// iteration, parameters from `bindings` are bound to these fields.
1804///
1805/// # Example
1806///
1807/// This macro supports named and unamed structs and enums. Below are examples
1808/// of the supported types:
1809///
1810/// ```
1811/// #[derive(Bind, Unbind)]
1812/// struct MyNamedStruct {
1813///     field0: u64,
1814///     field1: MyReply,
1815///     #[binding(include)]
1816/// nnnn     field2: PortRef<MyReply>,
1817///     field3: bool,
1818///     #[binding(include)]
1819///     field4: hyperactor::PortRef<u64>,
1820/// }
1821///
1822/// #[derive(Bind, Unbind)]
1823/// struct MyUnamedStruct(
1824///     u64,
1825///     MyReply,
1826///     #[binding(include)] hyperactor::PortRef<MyReply>,
1827///     bool,
1828///     #[binding(include)] PortRef<u64>,
1829/// );
1830///
1831/// #[derive(Bind, Unbind)]
1832/// enum MyEnum {
1833///     Unit,
1834///     NoopTuple(u64, bool),
1835///     NoopStruct {
1836///         field0: u64,
1837///         field1: bool,
1838///     },
1839///     Tuple(
1840///         u64,
1841///         MyReply,
1842///         #[binding(include)] PortRef<MyReply>,
1843///         bool,
1844///         #[binding(include)] hyperactor::PortRef<u64>,
1845///     ),
1846///     Struct {
1847///         field0: u64,
1848///         field1: MyReply,
1849///         #[binding(include)]
1850///         field2: PortRef<MyReply>,
1851///         field3: bool,
1852///         #[binding(include)]
1853///         field4: hyperactor::PortRef<u64>,
1854///     },
1855/// }
1856/// ```
1857///
1858/// The following shows what derived `Bind`` and `Unbind`` implementations for
1859/// `MyNamedStruct` will look like. The implementations of other types are
1860/// similar, and thus are not shown here.
1861/// ```ignore
1862/// impl Bind for MyNamedStruct {
1863/// fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
1864///     Bind::bind(&mut self.field2, bindings)?;
1865///     Bind::bind(&mut self.field4, bindings)?;
1866///     Ok(())
1867/// }
1868///
1869/// impl Unbind for MyNamedStruct {
1870///     fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
1871///         Unbind::unbind(&self.field2, bindings)?;
1872///         Unbind::unbind(&self.field4, bindings)?;
1873///         Ok(())
1874///     }
1875/// }
1876/// ```
1877#[proc_macro_derive(Bind, attributes(binding))]
1878pub fn derive_bind(input: TokenStream) -> TokenStream {
1879    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1880        quote! {
1881            hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1882        }
1883    }
1884
1885    let input = parse_macro_input!(input as DeriveInput);
1886    let name = &input.ident;
1887    let inner = match &input.data {
1888        Data::Struct(DataStruct { fields, .. }) => {
1889            match gen_struct_items(fields, make_item, true) {
1890                Ok(collects) => {
1891                    quote! { #(#collects)* }
1892                }
1893                Err(e) => {
1894                    return TokenStream::from(e.to_compile_error());
1895                }
1896            }
1897        }
1898        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1899            Ok(arms) => {
1900                quote! { match self { #(#arms),* } }
1901            }
1902            Err(e) => {
1903                return TokenStream::from(e.to_compile_error());
1904            }
1905        },
1906        _ => panic!("Bind can only be derived for structs and enums"),
1907    };
1908    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1909    let expand = quote! {
1910        #[automatically_derived]
1911        impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1912            fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1913                #inner
1914                Ok(())
1915            }
1916        }
1917    };
1918    TokenStream::from(expand)
1919}
1920
1921/// Derive a custom implementation of [`hyperactor::message::Unbind`] trait for
1922/// a struct or enum. This macro is normally used in tandem with [`fn derive_bind`]
1923/// to make the applied struct or enum castable.
1924///
1925/// Specifically, the derived implementation iterates through fields annoated
1926/// with `#[binding(include)]` based on their order of declaration in the struct
1927/// or enum. These fields' types must implement `Unbind` trait as well. During
1928/// the iteration, parameters from these fields are extracted and stored in
1929/// `bindings`.
1930///
1931/// # Example
1932///
1933/// See [`fn derive_bind`]'s documentation for examples.
1934#[proc_macro_derive(Unbind, attributes(binding))]
1935pub fn derive_unbind(input: TokenStream) -> TokenStream {
1936    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1937        quote! {
1938            hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1939        }
1940    }
1941
1942    let input = parse_macro_input!(input as DeriveInput);
1943    let name = &input.ident;
1944    let inner = match &input.data {
1945        Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1946        {
1947            Ok(collects) => {
1948                quote! { #(#collects)* }
1949            }
1950            Err(e) => {
1951                return TokenStream::from(e.to_compile_error());
1952            }
1953        },
1954        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1955            Ok(arms) => {
1956                quote! { match self { #(#arms),* } }
1957            }
1958            Err(e) => {
1959                return TokenStream::from(e.to_compile_error());
1960            }
1961        },
1962        _ => panic!("Unbind can only be derived for structs and enums"),
1963    };
1964    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1965    let expand = quote! {
1966        #[automatically_derived]
1967        impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1968            fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1969                #inner
1970                Ok(())
1971            }
1972        }
1973    };
1974    TokenStream::from(expand)
1975}
1976
1977// Helper function for common parsing and validation
1978fn parse_observe_function(
1979    attr: TokenStream,
1980    item: TokenStream,
1981) -> syn::Result<(ItemFn, String, String)> {
1982    let input = syn::parse::<ItemFn>(item)?;
1983
1984    if input.sig.asyncness.is_none() {
1985        return Err(syn::Error::new(
1986            input.sig.span(),
1987            "observe macros can only be applied to async functions",
1988        ));
1989    }
1990
1991    let fn_name_str = input.sig.ident.to_string();
1992    let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
1993
1994    Ok((input, fn_name_str, module_name_str))
1995}
1996
1997// Helper function for creating telemetry identifiers and setup code
1998fn create_telemetry_setup(
1999    module_name_str: &str,
2000    fn_name_str: &str,
2001    include_error: bool,
2002) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2003    let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2004    let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2005
2006    let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2007
2008    let error_ident = if include_error {
2009        Some(Ident::new(
2010            "error",
2011            Span::from(proc_macro::Span::def_site()),
2012        ))
2013    } else {
2014        None
2015    };
2016
2017    let error_declaration = if let Some(ref error_ident) = error_ident {
2018        quote! {
2019            hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2020        }
2021    } else {
2022        quote! {}
2023    };
2024
2025    let setup_code = quote! {
2026        use hyperactor_telemetry;
2027        hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2028        hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2029        #error_declaration
2030    };
2031
2032    (latency_ident, success_ident, error_ident, setup_code)
2033}
2034
2035/// A procedural macro that automatically injects telemetry code into async functions
2036/// that return a Result type.
2037///
2038/// This macro wraps async functions and adds instrumentation to measure:
2039/// 1. Latency - how long the function takes to execute
2040/// 2. Error counter - function error count
2041/// 3. Success counter - function completion count
2042///
2043/// # Example
2044///
2045/// ```rust
2046/// use hyperactor_actor::observe_result;
2047///
2048/// #[observe_result("my_module")]
2049/// async fn process_request(user_id: &str) -> Result<String, Error> {
2050///     // Function implementation
2051///     // Telemetry will be automatically collected
2052/// }
2053/// ```
2054#[proc_macro_attribute]
2055pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2056    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2057        Ok(parsed) => parsed,
2058        Err(err) => return err.to_compile_error().into(),
2059    };
2060
2061    let fn_name = &input.sig.ident;
2062    let vis = &input.vis;
2063    let args = &input.sig.inputs;
2064    let return_type = &input.sig.output;
2065    let body = &input.block;
2066    let attrs = &input.attrs;
2067    let generics = &input.sig.generics;
2068
2069    let (latency_ident, success_ident, error_ident, telemetry_setup) =
2070        create_telemetry_setup(&module_name_str, &fn_name_str, true);
2071    let error_ident = error_ident.unwrap();
2072
2073    let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2074
2075    // Generate the instrumented function
2076    let expanded = quote! {
2077        #(#attrs)*
2078        #vis async fn #fn_name #generics(#args) #return_type {
2079            #telemetry_setup
2080
2081            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2082            let _timer = #latency_ident.start(kv_pairs);
2083
2084            let #result_ident = async #body.await;
2085
2086            match &#result_ident {
2087                Ok(_) => {
2088                    #success_ident.add(
2089                        1,
2090                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2091                    );
2092                }
2093                Err(_) => {
2094                    #error_ident.add(
2095                        1,
2096                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2097                    );
2098                }
2099            }
2100
2101            #result_ident
2102        }
2103    };
2104
2105    expanded.into()
2106}
2107
2108/// A procedural macro that automatically injects telemetry code into async functions
2109/// that do not return a Result type.
2110///
2111/// This macro wraps async functions and adds instrumentation to measure:
2112/// 1. Latency - how long the function takes to execute
2113/// 2. Success counter - function completion count
2114///
2115/// # Example
2116///
2117/// ```rust
2118/// use hyperactor_actor::observe_async;
2119///
2120/// #[observe_async("my_module")]
2121/// async fn process_data(data: &str) -> String {
2122///     // Function implementation
2123///     // Telemetry will be automatically collected
2124/// }
2125/// ```
2126#[proc_macro_attribute]
2127pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2128    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2129        Ok(parsed) => parsed,
2130        Err(err) => return err.to_compile_error().into(),
2131    };
2132
2133    let fn_name = &input.sig.ident;
2134    let vis = &input.vis;
2135    let args = &input.sig.inputs;
2136    let return_type = &input.sig.output;
2137    let body = &input.block;
2138    let attrs = &input.attrs;
2139    let generics = &input.sig.generics;
2140
2141    let (latency_ident, success_ident, _, telemetry_setup) =
2142        create_telemetry_setup(&module_name_str, &fn_name_str, false);
2143
2144    let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2145
2146    // Generate the instrumented function
2147    let expanded = quote! {
2148        #(#attrs)*
2149        #vis async fn #fn_name #generics(#args) #return_type {
2150            #telemetry_setup
2151
2152            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2153            let _timer = #latency_ident.start(kv_pairs);
2154
2155            let #return_ident = async #body.await;
2156
2157            #success_ident.add(
2158                1,
2159                hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2160            );
2161            #return_ident
2162        }
2163    };
2164
2165    expanded.into()
2166}