hyperactor_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines macros used by the [`hyperactor`] crate.
10
11#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Token;
39use syn::Type;
40use syn::bracketed;
41use syn::parse::Parse;
42use syn::parse::ParseStream;
43use syn::parse_macro_input;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46
47const REPLY_VARIANT_ERROR: &str = indoc! {r#"
48`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
49
50= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
51= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
52"#};
53
54const REPLY_USAGE_ERROR: &str = indoc! {r#"
55`call` message expects at most one `reply` argument
56
57= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
58= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
59"#};
60
61enum FieldFlag {
62    None,
63    Reply,
64}
65
66/// Represents a variant of an enum.
67#[allow(dead_code)]
68enum Variant {
69    /// A named variant (i.e., `MyVariant { .. }`).
70    Named {
71        enum_name: Ident,
72        name: Ident,
73        field_names: Vec<Ident>,
74        field_types: Vec<Type>,
75        field_flags: Vec<FieldFlag>,
76        is_struct: bool,
77        generics: syn::Generics,
78    },
79    /// An anonymous variant (i.e., `MyVariant(..)`).
80    Anon {
81        enum_name: Ident,
82        name: Ident,
83        field_types: Vec<Type>,
84        field_flags: Vec<FieldFlag>,
85        is_struct: bool,
86        generics: syn::Generics,
87    },
88}
89
90impl Variant {
91    /// The number of fields in the variant.
92    fn len(&self) -> usize {
93        self.field_types().len()
94    }
95
96    /// Returns whether this variant was defined as a struct.
97    fn is_struct(&self) -> bool {
98        match self {
99            Variant::Named { is_struct, .. } => *is_struct,
100            Variant::Anon { is_struct, .. } => *is_struct,
101        }
102    }
103
104    /// The name of the enum containing the variant.
105    fn enum_name(&self) -> &Ident {
106        match self {
107            Variant::Named { enum_name, .. } => enum_name,
108            Variant::Anon { enum_name, .. } => enum_name,
109        }
110    }
111
112    /// The name of the variant itself.
113    fn name(&self) -> &Ident {
114        match self {
115            Variant::Named { name, .. } => name,
116            Variant::Anon { name, .. } => name,
117        }
118    }
119
120    /// The generics of the variant itself.
121    #[allow(dead_code)]
122    fn generics(&self) -> &syn::Generics {
123        match self {
124            Variant::Named { generics, .. } => generics,
125            Variant::Anon { generics, .. } => generics,
126        }
127    }
128
129    /// The snake_name of the variant itself.
130    fn snake_name(&self) -> Ident {
131        Ident::new(
132            &self.name().to_string().to_case(Case::Snake),
133            self.name().span(),
134        )
135    }
136
137    /// The variant's qualified name.
138    fn qualified_name(&self) -> proc_macro2::TokenStream {
139        let enum_name = self.enum_name();
140        let name = self.name();
141
142        if self.is_struct() {
143            quote! { #enum_name }
144        } else {
145            quote! { #enum_name::#name }
146        }
147    }
148
149    /// Names of the fields in the variant. Anonymous variants are named
150    /// according to their position in the argument list.
151    fn field_names(&self) -> Vec<Ident> {
152        match self {
153            Variant::Named { field_names, .. } => field_names.clone(),
154            Variant::Anon { field_types, .. } => (0usize..field_types.len())
155                .map(|idx| format_ident!("arg{}", idx))
156                .collect(),
157        }
158    }
159
160    /// The types of the fields int the variant.
161    fn field_types(&self) -> &Vec<Type> {
162        match self {
163            Variant::Named { field_types, .. } => field_types,
164            Variant::Anon { field_types, .. } => field_types,
165        }
166    }
167
168    /// Return the field flags for this variant.
169    fn field_flags(&self) -> &Vec<FieldFlag> {
170        match self {
171            Variant::Named { field_flags, .. } => field_flags,
172            Variant::Anon { field_flags, .. } => field_flags,
173        }
174    }
175
176    /// The constructor for the variant, using the field names directly.
177    fn constructor(&self) -> proc_macro2::TokenStream {
178        let qualified_name = self.qualified_name();
179        let field_names = self.field_names();
180        match self {
181            Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
182            Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
183        }
184    }
185}
186
187struct ReplyPort {
188    is_handle: bool,
189    is_once: bool,
190}
191
192impl ReplyPort {
193    fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
194        ReplyPort {
195            is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
196            is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
197        }
198    }
199
200    fn open_op(&self) -> proc_macro2::TokenStream {
201        if self.is_once {
202            quote! { hyperactor::mailbox::open_once_port }
203        } else {
204            quote! { hyperactor::mailbox::open_port }
205        }
206    }
207
208    fn rx_modifier(&self) -> proc_macro2::TokenStream {
209        if self.is_once {
210            quote! {}
211        } else {
212            quote! { mut }
213        }
214    }
215}
216
217/// Represents a message that can be sent to a handler, each message is associated with
218/// a variant.
219#[allow(clippy::large_enum_variant)]
220enum Message {
221    /// A call message is a request-response message, the last argument is
222    /// a [`hyperactor::OncePortRef`] or [`hyperactor::OncePortHandle`].
223    Call {
224        variant: Variant,
225        /// Tells whether the reply argument is a handle.
226        reply_port: ReplyPort,
227        /// The underlying return type (i.e., the type of the reply port).
228        return_type: Type,
229        /// the log level for generated instrumentation for handlers of this message.
230        log_level: Option<Ident>,
231    },
232    OneWay {
233        variant: Variant,
234        /// the log level for generated instrumentation for handlers of this message.
235        log_level: Option<Ident>,
236    },
237}
238
239impl Message {
240    fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
241        match &variant
242            .field_flags()
243            .iter()
244            .zip(variant.field_types())
245            .filter_map(|(flag, ty)| match flag {
246                FieldFlag::Reply => Some(ty),
247                FieldFlag::None => None,
248            })
249            .collect::<Vec<&Type>>()[..]
250        {
251            [] => Ok(Self::OneWay { variant, log_level }),
252            [reply_port_ty] => {
253                let syn::Type::Path(type_path) = reply_port_ty else {
254                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
255                };
256                let Some(last_segment) = type_path.path.segments.last() else {
257                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
258                };
259                if last_segment.ident != "OncePortRef"
260                    && last_segment.ident != "OncePortHandle"
261                    && last_segment.ident != "PortRef"
262                    && last_segment.ident != "PortHandle"
263                {
264                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
265                }
266                let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
267                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
268                };
269                let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
270                    return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
271                };
272                let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
273                let return_type = return_ty.clone();
274                Ok(Self::Call {
275                    variant,
276                    reply_port,
277                    return_type,
278                    log_level,
279                })
280            }
281            _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
282        }
283    }
284
285    /// The arguments of this message.
286    fn args(&self) -> Vec<(Ident, Type)> {
287        match self {
288            Message::Call { variant, .. } => variant
289                .field_names()
290                .into_iter()
291                .zip(variant.field_types().clone())
292                .take(variant.len() - 1)
293                .collect(),
294            Message::OneWay { variant, .. } => variant
295                .field_names()
296                .into_iter()
297                .zip(variant.field_types().clone())
298                .collect(),
299        }
300    }
301
302    fn variant(&self) -> &Variant {
303        match self {
304            Message::Call { variant, .. } => variant,
305            Message::OneWay { variant, .. } => variant,
306        }
307    }
308
309    fn reply_port_position(&self) -> Option<usize> {
310        self.variant()
311            .field_flags()
312            .iter()
313            .position(|flag| matches!(flag, FieldFlag::Reply))
314    }
315
316    /// The reply port argument of this message.
317    fn reply_port_arg(&self) -> Option<(Ident, Type)> {
318        match self {
319            Message::Call { variant, .. } => {
320                let pos = self.reply_port_position()?;
321                Some((
322                    variant.field_names()[pos].clone(),
323                    variant.field_types()[pos].clone(),
324                ))
325            }
326            Message::OneWay { .. } => None,
327        }
328    }
329}
330
331fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
332    let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
333        Some(attr) => {
334            let Ok(meta) = attr.meta.require_list() else {
335                return Err(syn::Error::new(
336                    Span::call_site(),
337                    indoc! {"
338                            `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
339
340                            = help use `#[log_level(info)]` or `#[log_level(error)]`
341                        "},
342                ));
343            };
344            let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
345            if parsed.len() != 1 {
346                return Err(syn::Error::new(
347                    Span::call_site(),
348                    indoc! {"
349                            `log_level` attribute must specify exactly one level
350
351                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
352                        "},
353                ));
354            };
355            Some(parsed.first().unwrap().to_string())
356        }
357        None => None,
358    };
359
360    if level.is_none() {
361        return Ok(None);
362    }
363    let level = level.unwrap();
364
365    match level.as_str() {
366        "error" | "warn" | "info" | "debug" | "trace" => {}
367        _ => {
368            return Err(syn::Error::new(
369                Span::call_site(),
370                indoc! {"
371                            `log_level` attribute must be one of 'error, warn, info, debug, trace'
372
373                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
374                        "},
375            ));
376        }
377    }
378
379    Ok(Some(Ident::new(
380        level.to_ascii_uppercase().as_str(),
381        Span::call_site(),
382    )))
383}
384
385fn parse_field_flag(field: &Field) -> FieldFlag {
386    for attr in field.attrs.iter() {
387        match &attr.meta {
388            syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
389            _ => {}
390        }
391    }
392    FieldFlag::None
393}
394
395/// Parse a message enum or struct into its constituent messages.
396fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
397    match &input.data {
398        Data::Enum(data_enum) => {
399            let mut messages = Vec::new();
400
401            for variant in &data_enum.variants {
402                let name = variant.ident.clone();
403                let attrs = &variant.attrs;
404
405                let message_variant = match &variant.fields {
406                    syn::Fields::Unnamed(fields_) => Variant::Anon {
407                        enum_name: input.ident.clone(),
408                        name,
409                        field_types: fields_
410                            .unnamed
411                            .iter()
412                            .map(|field| field.ty.clone())
413                            .collect(),
414                        field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
415                        is_struct: false,
416                        generics: input.generics.clone(),
417                    },
418                    syn::Fields::Named(fields_) => Variant::Named {
419                        enum_name: input.ident.clone(),
420                        name,
421                        field_names: fields_
422                            .named
423                            .iter()
424                            .map(|field| field.ident.clone().unwrap())
425                            .collect(),
426                        field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
427                        field_flags: fields_.named.iter().map(parse_field_flag).collect(),
428                        is_struct: false,
429                        generics: input.generics.clone(),
430                    },
431                    _ => {
432                        return Err(syn::Error::new_spanned(
433                            variant,
434                            indoc! {r#"
435                                `Handler` currently only supports named or tuple struct variants
436
437                                = help use `MyCall(Arg1Type, Arg2Type, ..)`,
438                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
439                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
440                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
441                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
442                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
443                              "#},
444                        ));
445                    }
446                };
447                let log_level = parse_log_level(attrs)?;
448
449                messages.push(Message::new(
450                    variant.fields.span(),
451                    message_variant,
452                    log_level,
453                )?);
454            }
455
456            Ok(messages)
457        }
458        Data::Struct(data_struct) => {
459            let struct_name = input.ident.clone();
460            let attrs = &input.attrs;
461
462            let message_variant = match &data_struct.fields {
463                syn::Fields::Unnamed(fields_) => Variant::Anon {
464                    enum_name: struct_name.clone(),
465                    name: struct_name,
466                    field_types: fields_
467                        .unnamed
468                        .iter()
469                        .map(|field| field.ty.clone())
470                        .collect(),
471                    field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
472                    is_struct: true,
473                    generics: input.generics.clone(),
474                },
475                syn::Fields::Named(fields_) => Variant::Named {
476                    enum_name: struct_name.clone(),
477                    name: struct_name,
478                    field_names: fields_
479                        .named
480                        .iter()
481                        .map(|field| field.ident.clone().unwrap())
482                        .collect(),
483                    field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
484                    field_flags: fields_.named.iter().map(parse_field_flag).collect(),
485                    is_struct: true,
486                    generics: input.generics.clone(),
487                },
488                syn::Fields::Unit => Variant::Anon {
489                    enum_name: struct_name.clone(),
490                    name: struct_name,
491                    field_types: Vec::new(),
492                    field_flags: Vec::new(),
493                    is_struct: true,
494                    generics: input.generics.clone(),
495                },
496            };
497
498            let log_level = parse_log_level(attrs)?;
499            let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
500
501            Ok(vec![message])
502        }
503        _ => Err(syn::Error::new_spanned(
504            input,
505            "handlers can only be derived for enums and structs",
506        )),
507    }
508}
509
510/// Derive a custom handler trait for given an enum containing tuple
511/// structs.  The handler trait defines a method corresponding
512/// to each of the enum's variants, and a `handle` function
513/// that dispatches messages to the correct method.  The macro
514/// supports two messaging patterns: "call" and "oneway". A call is a
515/// request-response message; a [`hyperactor::mailbox::OncePortRef`] or
516/// [`hyperactor::mailbox::OncePortHandle`] in the last position is used
517/// to send the return value.
518///
519/// The macro also derives a client trait that can be automatically implemented
520/// by specifying [`HandleClient`] for `ActorHandle<Actor>` and [`RefClient`]
521/// for `ActorRef<Actor>` accordingly. We require two implementations because
522/// not `ActorRef`s require that its message type is serializable.
523///
524/// The associated [`hyperactor_macros::handler`] macro can be used to add
525/// a dispatching handler directly to an [`hyperactor::Actor`].
526///
527/// # Example
528///
529/// The following example creates a "shopping list" actor responsible for
530/// maintaining a shopping list.
531///
532/// ```
533/// use std::collections::HashSet;
534/// use std::time::Duration;
535///
536/// use async_trait::async_trait;
537/// use hyperactor::Actor;
538/// use hyperactor::HandleClient;
539/// use hyperactor::Handler;
540/// use hyperactor::Instance;
541/// use hyperactor::RefClient;
542/// use hyperactor::proc::Proc;
543/// use hyperactor::reference;
544/// use serde::Deserialize;
545/// use serde::Serialize;
546/// use typeuri::Named;
547///
548/// #[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
549/// enum ShoppingList {
550///     // Oneway messages dispatch messages asynchronously, with no reply.
551///     Add(String),
552///     Remove(String),
553///
554///     // Call messages dispatch a request, expecting a reply to the
555///     // provided port, which must be in the last position.
556///     Exists(String, #[reply] reference::OncePortRef<bool>),
557///
558///     List(#[reply] reference::OncePortRef<Vec<String>>),
559/// }
560///
561/// // Define an actor.
562/// #[derive(Debug)]
563/// #[hyperactor::export(
564///     spawn = true,
565///     handlers = [
566///         ShoppingList,
567///     ],
568/// )]
569/// struct ShoppingListActor(HashSet<String>);
570///
571/// #[async_trait]
572/// impl Actor for ShoppingListActor {
573///     type Params = ();
574///
575///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
576///         Ok(Self(HashSet::new()))
577///     }
578/// }
579///
580/// // ShoppingListHandler is the trait generated by derive(Handler) above.
581/// // We implement the trait here for the actor, defining a handler for
582/// // each ShoppingList message.
583/// //
584/// // The `handle` attribute installs a handler that routes messages
585/// // to the `ShoppingListHandler` implementation directly. This can also
586/// // be done manually:
587/// //
588/// // ```ignore
589/// //<ShoppingListActor as ShoppingListHandler>
590/// //     ::handle(self, comm, message).await
591/// // ```
592/// #[async_trait]
593/// #[hyperactor::handle(ShoppingList)]
594/// impl ShoppingListHandler for ShoppingListActor {
595///     async fn add(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
596///         eprintln!("insert {}", item);
597///         self.0.insert(item);
598///         Ok(())
599///     }
600///
601///     async fn remove(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
602///         eprintln!("remove {}", item);
603///         self.0.remove(&item);
604///         Ok(())
605///     }
606///
607///     async fn exists(
608///         &mut self,
609///         _cx: &Context<Self>,
610///         item: String,
611///     ) -> Result<bool, anyhow::Error> {
612///         Ok(self.0.contains(&item))
613///     }
614///
615///     async fn list(&mut self, _cx: &Context<Self>) -> Result<Vec<String>, anyhow::Error> {
616///         Ok(self.0.iter().cloned().collect())
617///     }
618/// }
619///
620/// #[tokio::main]
621/// async fn main() -> Result<(), anyhow::Error> {
622///     let mut proc = Proc::local();
623///
624///     // Spawn our actor, and get a handle for rank 0.
625///     let shopping_list_actor: hyperactor::ActorHandle<ShoppingListActor> =
626///         proc.spawn("shopping", ()).await?;
627///
628///     // We join the system, so that we can send messages to actors.
629///     let client = proc.attach("client").unwrap();
630///
631///     // todo: consider making this a macro to remove the magic names
632///
633///     // Derive(Handler) generates client methods, which call the
634///     // remote handler provided an actor instance,
635///     // the destination actor, and the method arguments.
636///
637///     shopping_list_actor.add(&client, "milk".into()).await?;
638///     shopping_list_actor.add(&client, "eggs".into()).await?;
639///
640///     println!(
641///         "got milk? {}",
642///         shopping_list_actor.exists(&client, "milk".into()).await?
643///     );
644///     println!(
645///         "got yoghurt? {}",
646///         shopping_list_actor
647///             .exists(&client, "yoghurt".into())
648///             .await?
649///     );
650///
651///     shopping_list_actor.remove(&client, "milk".into()).await?;
652///     println!(
653///         "got milk now? {}",
654///         shopping_list_actor.exists(&client, "milk".into()).await?
655///     );
656///
657///     println!(
658///         "shopping list: {:?}",
659///         shopping_list_actor.list(&client).await?
660///     );
661///
662///     let _ = proc
663///         .destroy_and_wait::<()>(Duration::from_secs(1), None)
664///         .await?;
665///     Ok(())
666/// }
667/// ```
668#[proc_macro_derive(Handler, attributes(reply))]
669pub fn derive_handler(input: TokenStream) -> TokenStream {
670    let input = parse_macro_input!(input as DeriveInput);
671    let name: Ident = input.ident.clone();
672    let (_, ty_generics, _) = input.generics.split_for_impl();
673
674    let messages = match parse_messages(input.clone()) {
675        Ok(messages) => messages,
676        Err(err) => return TokenStream::from(err.to_compile_error()),
677    };
678
679    // Trait definition methods for the handler.
680    let mut handler_trait_methods = Vec::new();
681
682    // The arms of the match used in the message dispatcher.
683    let mut match_arms = Vec::new();
684
685    // Trait implemented by clients.
686    let mut client_trait_methods = Vec::new();
687
688    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
689
690    for message in &messages {
691        match message {
692            Message::Call {
693                variant,
694                reply_port,
695                return_type,
696                log_level,
697            } => {
698                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
699                let variant_name_snake = variant.snake_name();
700                let variant_name_snake_deprecated =
701                    format_ident!("{}_deprecated", variant_name_snake);
702                let enum_name = variant.enum_name();
703                let _variant_qualified_name = variant.qualified_name();
704                let log_level = match (&global_log_level, log_level) {
705                    (_, Some(local)) => local.clone(),
706                    (Some(global), None) => global.clone(),
707                    _ => Ident::new("DEBUG", Span::call_site()),
708                };
709                let _log_level = if reply_port.is_handle {
710                    quote! {
711                        tracing::Level::#log_level
712                    }
713                } else {
714                    quote! {
715                        tracing::Level::TRACE
716                    }
717                };
718                let log_message = quote! {
719                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
720                            "rpc" => "call",
721                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
722                            "message_type" => stringify!(#enum_name),
723                            "variant" => stringify!(#variant_name_snake),
724                        ));
725                };
726
727                handler_trait_methods.push(quote! {
728                    #[doc = "The generated handler method for this enum variant."]
729                    async fn #variant_name_snake(
730                        &mut self,
731                        cx: &hyperactor::Context<Self>,
732                        #(#arg_names: #arg_types),*)
733                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
734                });
735
736                client_trait_methods.push(quote! {
737                    #[doc = "The generated client method for this enum variant."]
738                    async fn #variant_name_snake(
739                        &self,
740                        cx: &impl hyperactor::context::Actor,
741                        #(#arg_names: #arg_types),*)
742                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
743
744                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
745                    async fn #variant_name_snake_deprecated(
746                        &self,
747                        cx: &impl hyperactor::context::Actor,
748                        #(#arg_names: #arg_types),*)
749                        -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error>;
750                });
751
752                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
753                let constructor = variant.constructor();
754                let result_ident = Ident::new("result", Span::mixed_site());
755                let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
756                match_arms.push(quote! {
757                    #constructor => {
758                        #log_message
759                        // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
760                        // This would require Result<Result<..., in order to handle RPC errors.
761                        #construct_result_future
762                        #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::internal_macro_support::anyhow::Error::from)
763                    }
764                });
765            }
766            Message::OneWay { variant, log_level } => {
767                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
768                let variant_name_snake = variant.snake_name();
769                let variant_name_snake_deprecated =
770                    format_ident!("{}_deprecated", variant_name_snake);
771                let enum_name = variant.enum_name();
772                let log_level = match (&global_log_level, log_level) {
773                    (_, Some(local)) => local.clone(),
774                    (Some(global), None) => global.clone(),
775                    _ => Ident::new("TRACE", Span::call_site()),
776                };
777                let _log_level = quote! {
778                    tracing::Level::#log_level
779                };
780                let log_message = quote! {
781                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
782                            "rpc" => "call",
783                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
784                            "message_type" => stringify!(#enum_name),
785                            "variant" => stringify!(#variant_name_snake),
786                        ));
787                };
788
789                handler_trait_methods.push(quote! {
790                    #[doc = "The generated handler method for this enum variant."]
791                    async fn #variant_name_snake(
792                        &mut self,
793                        cx: &hyperactor::Context<Self>,
794                        #(#arg_names: #arg_types),*)
795                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
796                });
797
798                client_trait_methods.push(quote! {
799                    #[doc = "The generated client method for this enum variant."]
800                    async fn #variant_name_snake(
801                        &self,
802                        cx: &impl hyperactor::context::Actor,
803                        #(#arg_names: #arg_types),*)
804                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
805
806                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
807                    async fn #variant_name_snake_deprecated(
808                        &self,
809                        cx: &impl hyperactor::context::Actor,
810                        #(#arg_names: #arg_types),*)
811                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error>;
812                });
813
814                let constructor = variant.constructor();
815
816                match_arms.push(quote! {
817                    #constructor => {
818                        #log_message
819                        self.#variant_name_snake(cx, #(#arg_names),*).await
820                    },
821                });
822            }
823        }
824    }
825
826    let handler_trait_name = format_ident!("{}Handler", name);
827    let client_trait_name = format_ident!("{}Client", name);
828
829    // We impose additional constraints on the generics in the implementation;
830    // but the trait itself should not impose additional constraints:
831
832    let mut handler_generics = input.generics.clone();
833    for param in handler_generics.type_params_mut() {
834        param.bounds.push(syn::parse_quote!(serde::Serialize));
835        param
836            .bounds
837            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
838        param.bounds.push(syn::parse_quote!(Send));
839        param.bounds.push(syn::parse_quote!(Sync));
840        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
841        param.bounds.push(syn::parse_quote!(typeuri::Named));
842    }
843    let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
844    let (client_impl_generics, _, _) = input.generics.split_for_impl();
845
846    let expanded = quote! {
847        #[doc = "The custom handler trait for this message type."]
848        #[hyperactor::internal_macro_support::async_trait::async_trait]
849        pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync  {
850            #(#handler_trait_methods)*
851
852            #[doc = "Handle the next message."]
853            async fn handle(
854                &mut self,
855                cx: &hyperactor::Context<Self>,
856                message: #name #ty_generics,
857            ) -> hyperactor::internal_macro_support::anyhow::Result<()>  {
858                 // Dispatch based on message type.
859                 match message {
860                     #(#match_arms)*
861                }
862            }
863        }
864
865        #[doc = "The custom client trait for this message type."]
866        #[hyperactor::internal_macro_support::async_trait::async_trait]
867        pub trait #client_trait_name #client_impl_generics: Send + Sync  {
868            #(#client_trait_methods)*
869        }
870    };
871
872    TokenStream::from(expanded)
873}
874
875/// Derives a client implementation on `ActorHandle<Actor>`.
876/// See [`Handler`] documentation for details.
877#[proc_macro_derive(HandleClient, attributes(log_level))]
878pub fn derive_handle_client(input: TokenStream) -> TokenStream {
879    derive_client(input, true)
880}
881
882/// Derives a client implementation on `ActorRef<Actor>`.
883/// See [`Handler`] documentation for details.
884#[proc_macro_derive(RefClient, attributes(log_level))]
885pub fn derive_ref_client(input: TokenStream) -> TokenStream {
886    derive_client(input, false)
887}
888
889fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
890    let input = parse_macro_input!(input as DeriveInput);
891    let name = input.ident.clone();
892
893    let messages = match parse_messages(input.clone()) {
894        Ok(messages) => messages,
895        Err(err) => return TokenStream::from(err.to_compile_error()),
896    };
897
898    // The client implementation methods.
899    let mut impl_methods = Vec::new();
900
901    let send_message = quote! { self.send(cx, message)? };
902    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
903
904    for message in &messages {
905        match message {
906            Message::Call {
907                variant,
908                reply_port,
909                return_type,
910                log_level,
911            } => {
912                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
913                let variant_name_snake = variant.snake_name();
914                let variant_name_snake_deprecated =
915                    format_ident!("{}_deprecated", variant_name_snake);
916                let enum_name = variant.enum_name();
917
918                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
919                let constructor = variant.constructor();
920                let log_level = match (&global_log_level, log_level) {
921                    (_, Some(local)) => local.clone(),
922                    (Some(global), None) => global.clone(),
923                    _ => Ident::new("DEBUG", Span::call_site()),
924                };
925                let log_level = if is_handle {
926                    quote! {
927                        tracing::Level::#log_level
928                    }
929                } else {
930                    quote! {
931                        tracing::Level::TRACE
932                    }
933                };
934                let log_message = quote! {
935                        hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
936                            "rpc" => "call",
937                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
938                            "message_type" => stringify!(#enum_name),
939                            "variant" => stringify!(#variant_name_snake),
940                        ));
941
942                };
943                let open_port = reply_port.open_op();
944                let rx_mod = reply_port.rx_modifier();
945                if reply_port.is_handle {
946                    impl_methods.push(quote! {
947                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
948                        async fn #variant_name_snake(
949                            &self,
950                            cx: &impl hyperactor::context::Actor,
951                            #(#arg_names: #arg_types),*)
952                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
953                            let (#reply_port_arg, #rx_mod reply_receiver) =
954                                #open_port::<#return_type>(cx);
955                            let message = #constructor;
956                            #log_message;
957                            #send_message;
958                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
959                        }
960
961                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
962                        async fn #variant_name_snake_deprecated(
963                            &self,
964                            cx: &impl hyperactor::context::Actor,
965                            #(#arg_names: #arg_types),*)
966                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
967                            let (#reply_port_arg, #rx_mod reply_receiver) =
968                                #open_port::<#return_type>(cx);
969                            let message = #constructor;
970                            #log_message;
971                            #send_message;
972                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
973                        }
974                    });
975                } else {
976                    impl_methods.push(quote! {
977                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
978                        async fn #variant_name_snake(
979                            &self,
980                            cx: &impl hyperactor::context::Actor,
981                            #(#arg_names: #arg_types),*)
982                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
983                            let (#reply_port_arg, #rx_mod reply_receiver) =
984                                #open_port::<#return_type>(cx);
985                            let #reply_port_arg = #reply_port_arg.bind();
986                            let message = #constructor;
987                            #log_message;
988                            #send_message;
989                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
990                        }
991
992                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
993                        async fn #variant_name_snake_deprecated(
994                            &self,
995                            cx: &impl hyperactor::context::Actor,
996                            #(#arg_names: #arg_types),*)
997                            -> Result<#return_type, hyperactor::internal_macro_support::anyhow::Error> {
998                            let (#reply_port_arg, #rx_mod reply_receiver) =
999                                #open_port::<#return_type>(cx);
1000                            let #reply_port_arg = #reply_port_arg.bind();
1001                            let message = #constructor;
1002                            #log_message;
1003                            #send_message;
1004                            reply_receiver.recv().await.map_err(hyperactor::internal_macro_support::anyhow::Error::from)
1005                        }
1006                    });
1007                }
1008            }
1009            Message::OneWay { variant, log_level } => {
1010                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1011                let variant_name_snake = variant.snake_name();
1012                let variant_name_snake_deprecated =
1013                    format_ident!("{}_deprecated", variant_name_snake);
1014                let enum_name = variant.enum_name();
1015                let constructor = variant.constructor();
1016                let log_level = match (&global_log_level, log_level) {
1017                    (_, Some(local)) => local.clone(),
1018                    (Some(global), None) => global.clone(),
1019                    _ => Ident::new("DEBUG", Span::call_site()),
1020                };
1021                let _log_level = if is_handle {
1022                    quote! {
1023                        tracing::Level::TRACE
1024                    }
1025                } else {
1026                    quote! {
1027                        tracing::Level::#log_level
1028                    }
1029                };
1030                let log_message = quote! {
1031                    hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1032                        "rpc" => "oneway",
1033                        "actor_id" => self.actor_id().to_string(),
1034                        "message_type" => stringify!(#enum_name),
1035                        "variant" => stringify!(#variant_name_snake),
1036                    ));
1037                };
1038                impl_methods.push(quote! {
1039                    async fn #variant_name_snake(
1040                        &self,
1041                        cx: &impl hyperactor::context::Actor,
1042                        #(#arg_names: #arg_types),*)
1043                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1044                        let message = #constructor;
1045                        #log_message;
1046                        #send_message;
1047                        Ok(())
1048                    }
1049
1050                    async fn #variant_name_snake_deprecated(
1051                        &self,
1052                        cx: &impl hyperactor::context::Actor,
1053                        #(#arg_names: #arg_types),*)
1054                        -> Result<(), hyperactor::internal_macro_support::anyhow::Error> {
1055                        let message = #constructor;
1056                        #log_message;
1057                        #send_message;
1058                        Ok(())
1059                    }
1060                });
1061            }
1062        }
1063    }
1064
1065    let trait_name = format_ident!("{}Client", name);
1066
1067    let (_, ty_generics, _) = input.generics.split_for_impl();
1068
1069    // Add a new generic parameter 'A'
1070    let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1071    let mut trait_generics = input.generics.clone();
1072    trait_generics.params.insert(
1073        0,
1074        syn::GenericParam::Type(syn::TypeParam {
1075            ident: actor_ident.clone(),
1076            attrs: vec![],
1077            colon_token: None,
1078            bounds: Punctuated::new(),
1079            eq_token: None,
1080            default: None,
1081        }),
1082    );
1083
1084    for param in trait_generics.type_params_mut() {
1085        if param.ident == actor_ident {
1086            continue;
1087        }
1088        param.bounds.push(syn::parse_quote!(serde::Serialize));
1089        param
1090            .bounds
1091            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1092        param.bounds.push(syn::parse_quote!(Send));
1093        param.bounds.push(syn::parse_quote!(Sync));
1094        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1095        param.bounds.push(syn::parse_quote!(typeuri::Named));
1096    }
1097
1098    let (impl_generics, _, _) = trait_generics.split_for_impl();
1099
1100    let expanded = if is_handle {
1101        quote! {
1102            #[hyperactor::internal_macro_support::async_trait::async_trait]
1103            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1104              where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1105                #(#impl_methods)*
1106            }
1107        }
1108    } else {
1109        quote! {
1110            #[hyperactor::internal_macro_support::async_trait::async_trait]
1111            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1112              where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1113                #(#impl_methods)*
1114            }
1115        }
1116    };
1117
1118    TokenStream::from(expanded)
1119}
1120
1121const HANDLE_ARGUMENT_ERROR: &str = indoc! {r#"
1122`handle` expects the message type that is being handled
1123
1124= help: use `#[handle(MessageType)]`
1125"#};
1126
1127/// Install a [`Handler`] that routes messages of the provided type to this handler trait implementation.
1128#[proc_macro_attribute]
1129pub fn handle(attr: TokenStream, item: TokenStream) -> TokenStream {
1130    let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1131    if attr_args.len() != 1 {
1132        return TokenStream::from(
1133            syn::Error::new_spanned(attr_args, HANDLE_ARGUMENT_ERROR).to_compile_error(),
1134        );
1135    }
1136
1137    let message_type = attr_args.first().unwrap();
1138    let input = parse_macro_input!(item as ItemImpl);
1139
1140    let self_type = match *input.self_ty {
1141        syn::Type::Path(ref type_path) => {
1142            let segment = type_path.path.segments.last().unwrap();
1143            segment.clone() //ident.clone()
1144        }
1145        _ => {
1146            return TokenStream::from(
1147                syn::Error::new_spanned(input.self_ty, "`handle` argument must be a type")
1148                    .to_compile_error(),
1149            );
1150        }
1151    };
1152
1153    let trait_name = match input.trait_ {
1154        Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1155        None => {
1156            return TokenStream::from(
1157                syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1158                    .to_compile_error(),
1159            );
1160        }
1161    };
1162
1163    let expanded = quote! {
1164        #input
1165
1166        #[hyperactor::internal_macro_support::async_trait::async_trait]
1167        impl hyperactor::Handler<#message_type> for #self_type {
1168            async fn handle(
1169                &mut self,
1170                cx: &hyperactor::Context<Self>,
1171                message: #message_type,
1172            ) -> hyperactor::internal_macro_support::anyhow::Result<()> {
1173                <Self as #trait_name>::handle(self, cx, message).await
1174            }
1175        }
1176    };
1177
1178    TokenStream::from(expanded)
1179}
1180
1181/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1182/// We set a default level of INFO while always setting ERROR if the function returns Result::Err giving us
1183/// consistent and high quality structured logs. Because this wraps around tracing::instrument, all parameters
1184/// mentioned in https://fburl.com/9jlkb5q4 should be valid. For functions that don't return a [`Result`] type, use
1185/// [`instrument_infallible`]
1186///
1187/// ```
1188/// #[telemetry::instrument]
1189/// async fn yolo() -> anyhow::Result<i32> {
1190///     Ok(420)
1191/// }
1192/// ```
1193#[proc_macro_attribute]
1194pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1195    let args =
1196        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1197    let input = parse_macro_input!(input as ItemFn);
1198    let output = quote! {
1199        #[hyperactor::internal_macro_support::tracing::instrument(err, skip_all, #args)]
1200        #input
1201    };
1202
1203    TokenStream::from(output)
1204}
1205
1206/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1207/// Because this wraps around tracing::instrument, all parameters mentioned in
1208/// https://fburl.com/9jlkb5q4 should be valid.
1209///
1210/// ```
1211/// #[telemetry::instrument]
1212/// async fn yolo() -> i32 {
1213///     420
1214/// }
1215/// ```
1216#[proc_macro_attribute]
1217pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1218    let args =
1219        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1220    let input = parse_macro_input!(input as ItemFn);
1221
1222    let output = quote! {
1223        #[hyperactor::internal_macro_support::tracing::instrument(skip_all, #args)]
1224        #input
1225    };
1226
1227    TokenStream::from(output)
1228}
1229
1230struct HandlerSpec {
1231    ty: Type,
1232    cast: bool,
1233}
1234
1235impl Parse for HandlerSpec {
1236    fn parse(input: ParseStream) -> syn::Result<Self> {
1237        let ty: Type = input.parse()?;
1238
1239        if input.peek(syn::token::Brace) {
1240            let content;
1241            syn::braced!(content in input);
1242            let key: Ident = content.parse()?;
1243            content.parse::<Token![=]>()?;
1244            let expr: Expr = content.parse()?;
1245
1246            let cast = if key == "cast" {
1247                if let Expr::Lit(ExprLit {
1248                    lit: Lit::Bool(b), ..
1249                }) = expr
1250                {
1251                    b.value
1252                } else {
1253                    return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1254                }
1255            } else {
1256                return Err(syn::Error::new_spanned(
1257                    key,
1258                    "unsupported field (expected `cast`)",
1259                ));
1260            };
1261
1262            Ok(HandlerSpec { ty, cast })
1263        } else if input.is_empty() || input.peek(Token![,]) {
1264            Ok(HandlerSpec { ty, cast: false })
1265        } else {
1266            // Something unexpected follows the type
1267            let unexpected: proc_macro2::TokenTree = input.parse()?;
1268            Err(syn::Error::new_spanned(
1269                unexpected,
1270                "unexpected token after type — expected `{ ... }` or nothing",
1271            ))
1272        }
1273    }
1274}
1275
1276impl HandlerSpec {
1277    fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1278        let mut tys = Vec::new();
1279        for HandlerSpec { ty, cast } in handlers {
1280            if cast {
1281                let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1282                let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1283                tys.push(wrapped_ty);
1284            }
1285            tys.push(ty);
1286        }
1287        tys
1288    }
1289}
1290
1291/// Attribute Struct for [`fn export`] macro.
1292struct ExportAttr {
1293    spawn: bool,
1294    handlers: Vec<HandlerSpec>,
1295}
1296
1297impl Parse for ExportAttr {
1298    fn parse(input: ParseStream) -> syn::Result<Self> {
1299        let mut spawn = false;
1300        let mut handlers: Vec<HandlerSpec> = vec![];
1301
1302        while !input.is_empty() {
1303            let key: Ident = input.parse()?;
1304            input.parse::<Token![=]>()?;
1305
1306            if key == "spawn" {
1307                let expr: Expr = input.parse()?;
1308                if let Expr::Lit(ExprLit {
1309                    lit: Lit::Bool(b), ..
1310                }) = expr
1311                {
1312                    spawn = b.value;
1313                } else {
1314                    return Err(syn::Error::new_spanned(
1315                        expr,
1316                        "expected boolean for `spawn`",
1317                    ));
1318                }
1319            } else if key == "handlers" {
1320                let content;
1321                bracketed!(content in input);
1322                let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1323                handlers = raw_handlers.into_iter().collect();
1324            } else {
1325                return Err(syn::Error::new_spanned(
1326                    key,
1327                    "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1328                ));
1329            }
1330
1331            // optional trailing comma
1332            let _ = input.parse::<Token![,]>();
1333        }
1334
1335        Ok(ExportAttr { spawn, handlers })
1336    }
1337}
1338
1339/// Exports handlers for this actor. The set of exported handlers
1340/// determine the messages that may be sent to remote references of
1341/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
1342/// [`hyperactor::RemoteMessage`] may be exported.
1343///
1344/// Additionally, an exported actor may be remotely spawned,
1345/// indicated by `spawn = true`. Such actors must also ensure that
1346/// their parameter type implements [`hyperactor::RemoteMessage`].
1347///
1348/// # Example
1349///
1350/// In the following example, `MyActor` can be spawned remotely. It also has
1351/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`.
1352/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these
1353/// types.
1354///
1355/// ```ignore
1356/// #[export(
1357///     spawn = true,
1358///     handlers = [
1359///         MyMessage,
1360///         MyOtherMessage,
1361///     ],
1362/// )]
1363/// struct MyActor {}
1364/// ```
1365#[proc_macro_attribute]
1366pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1367    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1368    let data_type_name = &input.ident;
1369    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1370
1371    let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1372    let tys = HandlerSpec::add_indexed(handlers);
1373
1374    let mut handles = Vec::new();
1375    let mut bindings = Vec::new();
1376    let mut type_registrations = Vec::new();
1377
1378    for ty in &tys {
1379        handles.push(quote! {
1380            impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
1381            impl #impl_generics hyperactor::remote::Accepts<#ty> for #data_type_name #ty_generics #where_clause {}
1382        });
1383        bindings.push(quote! {
1384            ports.bind::<#ty>();
1385        });
1386        type_registrations.push(quote! {
1387            wirevalue::register_type!(#ty);
1388        });
1389    }
1390
1391    let mut expanded = quote! {
1392        #input
1393
1394        impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
1395
1396        #(#handles)*
1397
1398        #(#type_registrations)*
1399
1400        // Always export the `Signal` type.
1401        impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1402        impl #impl_generics hyperactor::remote::Accepts<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
1403
1404        // Always export the `IntrospectMessage` type.
1405        impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::introspect::IntrospectMessage> for #data_type_name #ty_generics #where_clause {}
1406        impl #impl_generics hyperactor::remote::Accepts<hyperactor::introspect::IntrospectMessage> for #data_type_name #ty_generics #where_clause {}
1407
1408        impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
1409            fn bind(ports: &hyperactor::proc::Ports<Self>) {
1410                #(#bindings)*
1411            }
1412        }
1413
1414        // TODO: just use Named derive directly here.
1415        impl #impl_generics typeuri::Named for #data_type_name #ty_generics #where_clause {
1416            fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
1417        }
1418    };
1419
1420    if spawn {
1421        expanded.extend(quote! {
1422            hyperactor::remote!(#data_type_name);
1423        });
1424    }
1425
1426    TokenStream::from(expanded)
1427}
1428
1429/// Represents the full input to [`fn behavior`].
1430struct BehaviorInput {
1431    behavior: Ident,
1432    generics: syn::Generics,
1433    handlers: Vec<HandlerSpec>,
1434}
1435
1436impl syn::parse::Parse for BehaviorInput {
1437    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1438        let behavior: Ident = input.parse()?;
1439        let generics: syn::Generics = input.parse()?;
1440        let _: Token![,] = input.parse()?;
1441        let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1442        let handlers = raw_handlers.into_iter().collect();
1443        Ok(BehaviorInput {
1444            behavior,
1445            generics,
1446            handlers,
1447        })
1448    }
1449}
1450
1451/// Create a [`Referable`] definition, handling a specific set of message types.
1452/// Behaviors are used to create an [`ActorRef`] without having to depend on the
1453/// actor's implementation. If the message type need to be cast, add `castable`
1454/// flag to those types. e.g. the following example creates a behavior with 5
1455/// message types, and 4 of which need to be cast.
1456///
1457/// ```
1458/// hyperactor::behavior!(
1459///     TestActorBehavior,
1460///     TestMessage { castable = true },
1461///     () {castable = true },
1462///     MyGeneric<()> {castable = true },
1463///     u64,
1464/// );
1465/// ```
1466///
1467/// This macro also supports generic behaviors:
1468/// ```
1469/// hyperactor::behavior!(
1470///     TestBehavior<T>,
1471///     Message<T> { castable = true },
1472///     u64,
1473/// );
1474/// ```
1475#[proc_macro]
1476pub fn behavior(input: TokenStream) -> TokenStream {
1477    let BehaviorInput {
1478        behavior,
1479        generics,
1480        handlers,
1481    } = parse_macro_input!(input as BehaviorInput);
1482    let tys = HandlerSpec::add_indexed(handlers);
1483
1484    // Add bounds to generics for Named, Serialize, Deserialize
1485    let mut bounded_generics = generics.clone();
1486    for param in bounded_generics.type_params_mut() {
1487        param.bounds.push(syn::parse_quote!(typeuri::Named));
1488        param.bounds.push(syn::parse_quote!(serde::Serialize));
1489        param.bounds.push(syn::parse_quote!(std::marker::Send));
1490        param.bounds.push(syn::parse_quote!(std::marker::Sync));
1491        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1492        // Note: lifetime parameters are not *actually* hygienic.
1493        // https://github.com/rust-lang/rust/issues/54727
1494        let lifetime =
1495            syn::Lifetime::new("'hyperactor_behavior_de", proc_macro2::Span::mixed_site());
1496        param
1497            .bounds
1498            .push(syn::parse_quote!(for<#lifetime> serde::Deserialize<#lifetime>));
1499    }
1500
1501    // Split the generics for use in different contexts
1502    let (impl_generics, ty_generics, where_clause) = bounded_generics.split_for_impl();
1503
1504    // Create a combined generics for the Binds impl that includes both A and the behavior's generics
1505    let mut binds_generics = bounded_generics.clone();
1506    binds_generics.params.insert(
1507        0,
1508        syn::GenericParam::Type(syn::TypeParam {
1509            attrs: vec![],
1510            ident: Ident::new("A", proc_macro2::Span::call_site()),
1511            colon_token: None,
1512            bounds: Punctuated::new(),
1513            eq_token: None,
1514            default: None,
1515        }),
1516    );
1517    let (binds_impl_generics, _, _) = binds_generics.split_for_impl();
1518
1519    // Determine typename and typehash implementation based on whether we have generics
1520    let type_params: Vec<_> = bounded_generics.type_params().collect();
1521    let has_generics = !type_params.is_empty();
1522
1523    let (typename_impl, typehash_impl) = if has_generics {
1524        // Create format string with placeholders for each generic parameter
1525        let placeholders = vec!["{}"; type_params.len()].join(", ");
1526        let placeholders_format_string = format!("<{}>", placeholders);
1527        let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#behavior), #placeholders_format_string) };
1528
1529        let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1530        (
1531            quote! {
1532                typeuri::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1533            },
1534            quote! {
1535                typeuri::cityhasher::hash(Self::typename())
1536            },
1537        )
1538    } else {
1539        (
1540            quote! {
1541                concat!(std::module_path!(), "::", stringify!(#behavior))
1542            },
1543            quote! {
1544                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1545                    typeuri::cityhasher::hash(<#behavior as typeuri::Named>::typename())
1546                });
1547                *TYPEHASH
1548            },
1549        )
1550    };
1551
1552    let type_param_idents = generics.type_params().map(|p| &p.ident).collect::<Vec<_>>();
1553
1554    let expanded = quote! {
1555        #[doc = "The generated behavior struct."]
1556        #[derive(Debug, serde::Serialize, serde::Deserialize)]
1557        pub struct #behavior #impl_generics #where_clause {
1558            _phantom: std::marker::PhantomData<(#(#type_param_idents),*)>
1559        }
1560
1561        impl #impl_generics typeuri::Named for #behavior #ty_generics #where_clause {
1562            fn typename() -> &'static str {
1563                #typename_impl
1564            }
1565
1566            fn typehash() -> u64 {
1567                #typehash_impl
1568            }
1569        }
1570
1571        impl #impl_generics hyperactor::actor::Referable for #behavior #ty_generics #where_clause {}
1572
1573        impl #binds_impl_generics hyperactor::actor::Binds<A> for #behavior #ty_generics
1574        where
1575            A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)*,
1576            #where_clause
1577        {
1578            fn bind(ports: &hyperactor::proc::Ports<A>) {
1579                #(
1580                    ports.bind::<#tys>();
1581                )*
1582            }
1583        }
1584
1585        #(
1586            impl #impl_generics hyperactor::actor::RemoteHandles<#tys> for #behavior #ty_generics #where_clause {}
1587            impl #impl_generics hyperactor::remote::Accepts<#tys> for #behavior #ty_generics #where_clause {}
1588        )*
1589    };
1590
1591    TokenStream::from(expanded)
1592}
1593
1594fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1595    let mut is_included = false;
1596    for attr in &field.attrs {
1597        if attr.path().is_ident("binding") {
1598            // parse #[binding(include)] and look for exactly "include"
1599            attr.parse_nested_meta(|meta| {
1600                if meta.path.is_ident("include") {
1601                    is_included = true;
1602                    Ok(())
1603                } else {
1604                    let path = meta.path.to_token_stream().to_string().replace(' ', "");
1605                    Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1606                }
1607            })?
1608        }
1609    }
1610    Ok(is_included)
1611}
1612
1613/// The field accessor in struct or enum variant.
1614/// e.g.:
1615///   struct NamedStruct { foo: u32 } => FieldAccessor::Named(Ident::new("foo", Span::call_site()))
1616///   struct UnnamedStruct(u32) => FieldAccessor::Unnamed(Index::from(0))
1617enum FieldAccessor {
1618    Named(Ident),
1619    Unnamed(Index),
1620}
1621
1622/// Result of parsing a field in a struct, or a enum variant.
1623struct ParsedField {
1624    accessor: FieldAccessor,
1625    ty: Type,
1626    included: bool,
1627}
1628
1629impl From<&ParsedField> for (Ident, Type) {
1630    fn from(field: &ParsedField) -> Self {
1631        let field_ident = match &field.accessor {
1632            FieldAccessor::Named(ident) => ident.clone(),
1633            FieldAccessor::Unnamed(i) => {
1634                Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1635            }
1636        };
1637        (field_ident, field.ty.clone())
1638    }
1639}
1640
1641fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1642    match fields {
1643        Fields::Named(named) => named
1644            .named
1645            .iter()
1646            .map(|f| {
1647                let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1648                Ok(ParsedField {
1649                    accessor,
1650                    ty: f.ty.clone(),
1651                    included: include_in_bind_unbind(f)?,
1652                })
1653            })
1654            .collect(),
1655        Fields::Unnamed(unnamed) => unnamed
1656            .unnamed
1657            .iter()
1658            .enumerate()
1659            .map(|(i, f)| {
1660                let accessor = FieldAccessor::Unnamed(Index::from(i));
1661                Ok(ParsedField {
1662                    accessor,
1663                    ty: f.ty.clone(),
1664                    included: include_in_bind_unbind(f)?,
1665                })
1666            })
1667            .collect(),
1668        Fields::Unit => Ok(Vec::new()),
1669    }
1670}
1671
1672fn gen_struct_items<F>(
1673    fields: &Fields,
1674    make_item: F,
1675    is_mutable: bool,
1676) -> syn::Result<Vec<proc_macro2::TokenStream>>
1677where
1678    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1679{
1680    let borrow = if is_mutable {
1681        quote! { &mut }
1682    } else {
1683        quote! { & }
1684    };
1685    let items: Vec<_> = collect_all_fields(fields)?
1686        .into_iter()
1687        .filter(|f| f.included)
1688        .map(
1689            |ParsedField {
1690                 accessor,
1691                 ty,
1692                 included,
1693             }| {
1694                assert!(included);
1695                let field_accessor = match accessor {
1696                    FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1697                    FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1698                };
1699                make_item(field_accessor, ty)
1700            },
1701        )
1702        .collect();
1703    Ok(items)
1704}
1705
1706/// Generate the field accessor for a enum variant in pattern matching. e.g.
1707/// the <GENERATED> parts in the following example:
1708///
1709///   match my_enum {
1710///     // e.g. MyEnum::Tuple(_, f1, f2, _)
1711///     MyEnum::Tuple(<GENERATED>) => { ... }
1712///     // e.g. MyEnum::Struct { field0: _, field1 }
1713///     MyEnum::Struct(<GENERATED>) => { ... }
1714///   }
1715fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1716    all_fields
1717        .iter()
1718        .map(
1719            |ParsedField {
1720                 accessor,
1721                 ty: _,
1722                 included,
1723             }| {
1724                match accessor {
1725                    FieldAccessor::Named(ident) => {
1726                        if *included {
1727                            quote! { #ident }
1728                        } else {
1729                            quote! { #ident: _ }
1730                        }
1731                    }
1732                    FieldAccessor::Unnamed(i) => {
1733                        if *included {
1734                            let ident = Ident::new(
1735                                &format!("f{}", i.index),
1736                                proc_macro2::Span::call_site(),
1737                            );
1738                            quote! { #ident }
1739                        } else {
1740                            quote! { _ }
1741                        }
1742                    }
1743                }
1744            },
1745        )
1746        .collect()
1747}
1748
1749/// Generate all the parts for enum variants. e.g. the <GENERATED> part in the
1750/// following example:
1751///
1752///   match my_enum {
1753///      <GENERATED>
1754///   }
1755fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1756where
1757    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1758{
1759    data.variants
1760        .iter()
1761        .map(|variant| {
1762            let name = &variant.ident;
1763            let all_fields = collect_all_fields(&variant.fields)?;
1764            let field_accessors = gen_enum_field_accessors(&all_fields);
1765            let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1766            let items = included_fields
1767                .iter()
1768                .map(|f| {
1769                    let (accessor, ty) = <(Ident, Type)>::from(*f);
1770                    make_item(quote! { #accessor }, ty)
1771                })
1772                .collect::<Vec<_>>();
1773
1774            Ok(match &variant.fields {
1775                Fields::Named(_) => {
1776                    quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1777                }
1778                Fields::Unnamed(_) => {
1779                    quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1780                }
1781                Fields::Unit => quote! { Self::#name => { #(#items)* } },
1782            })
1783        })
1784        .collect()
1785}
1786
1787/// Derive a custom implementation of [`hyperactor::message::Bind`] trait for
1788/// a struct or enum. This macro is normally used in tandem with [`fn derive_unbind`]
1789/// to make the applied struct or enum castable.
1790///
1791/// Specifically, the derived implementation iterates through fields annotated
1792/// with `#[binding(include)]` based on their order of declaration in the struct
1793/// or enum. These fields' types must implement `Bind` trait as well. During the
1794/// iteration, parameters from `bindings` are bound to these fields.
1795///
1796/// # Example
1797///
1798/// This macro supports named and unamed structs and enums. Below are examples
1799/// of the supported types:
1800///
1801/// ```
1802/// #[derive(Bind, Unbind)]
1803/// struct MyNamedStruct {
1804///     field0: u64,
1805///     field1: MyReply,
1806///     #[binding(include)]
1807/// nnnn     field2: PortRef<MyReply>,
1808///     field3: bool,
1809///     #[binding(include)]
1810///     field4: hyperactor::PortRef<u64>,
1811/// }
1812///
1813/// #[derive(Bind, Unbind)]
1814/// struct MyUnamedStruct(
1815///     u64,
1816///     MyReply,
1817///     #[binding(include)] hyperactor::PortRef<MyReply>,
1818///     bool,
1819///     #[binding(include)] PortRef<u64>,
1820/// );
1821///
1822/// #[derive(Bind, Unbind)]
1823/// enum MyEnum {
1824///     Unit,
1825///     NoopTuple(u64, bool),
1826///     NoopStruct {
1827///         field0: u64,
1828///         field1: bool,
1829///     },
1830///     Tuple(
1831///         u64,
1832///         MyReply,
1833///         #[binding(include)] PortRef<MyReply>,
1834///         bool,
1835///         #[binding(include)] hyperactor::PortRef<u64>,
1836///     ),
1837///     Struct {
1838///         field0: u64,
1839///         field1: MyReply,
1840///         #[binding(include)]
1841///         field2: PortRef<MyReply>,
1842///         field3: bool,
1843///         #[binding(include)]
1844///         field4: hyperactor::PortRef<u64>,
1845///     },
1846/// }
1847/// ```
1848///
1849/// The following shows what derived `Bind`` and `Unbind`` implementations for
1850/// `MyNamedStruct` will look like. The implementations of other types are
1851/// similar, and thus are not shown here.
1852/// ```ignore
1853/// impl Bind for MyNamedStruct {
1854/// fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
1855///     Bind::bind(&mut self.field2, bindings)?;
1856///     Bind::bind(&mut self.field4, bindings)?;
1857///     Ok(())
1858/// }
1859///
1860/// impl Unbind for MyNamedStruct {
1861///     fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
1862///         Unbind::unbind(&self.field2, bindings)?;
1863///         Unbind::unbind(&self.field4, bindings)?;
1864///         Ok(())
1865///     }
1866/// }
1867/// ```
1868#[proc_macro_derive(Bind, attributes(binding))]
1869pub fn derive_bind(input: TokenStream) -> TokenStream {
1870    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1871        quote! {
1872            hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1873        }
1874    }
1875
1876    let input = parse_macro_input!(input as DeriveInput);
1877    let name = &input.ident;
1878    let inner = match &input.data {
1879        Data::Struct(DataStruct { fields, .. }) => {
1880            match gen_struct_items(fields, make_item, true) {
1881                Ok(collects) => {
1882                    quote! { #(#collects)* }
1883                }
1884                Err(e) => {
1885                    return TokenStream::from(e.to_compile_error());
1886                }
1887            }
1888        }
1889        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1890            Ok(arms) => {
1891                quote! { match self { #(#arms),* } }
1892            }
1893            Err(e) => {
1894                return TokenStream::from(e.to_compile_error());
1895            }
1896        },
1897        _ => panic!("Bind can only be derived for structs and enums"),
1898    };
1899    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1900    let expand = quote! {
1901        #[automatically_derived]
1902        impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1903            fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1904                #inner
1905                Ok(())
1906            }
1907        }
1908    };
1909    TokenStream::from(expand)
1910}
1911
1912/// Derive a custom implementation of [`hyperactor::message::Unbind`] trait for
1913/// a struct or enum. This macro is normally used in tandem with [`fn derive_bind`]
1914/// to make the applied struct or enum castable.
1915///
1916/// Specifically, the derived implementation iterates through fields annoated
1917/// with `#[binding(include)]` based on their order of declaration in the struct
1918/// or enum. These fields' types must implement `Unbind` trait as well. During
1919/// the iteration, parameters from these fields are extracted and stored in
1920/// `bindings`.
1921///
1922/// # Example
1923///
1924/// See [`fn derive_bind`]'s documentation for examples.
1925#[proc_macro_derive(Unbind, attributes(binding))]
1926pub fn derive_unbind(input: TokenStream) -> TokenStream {
1927    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1928        quote! {
1929            hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1930        }
1931    }
1932
1933    let input = parse_macro_input!(input as DeriveInput);
1934    let name = &input.ident;
1935    let inner = match &input.data {
1936        Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1937        {
1938            Ok(collects) => {
1939                quote! { #(#collects)* }
1940            }
1941            Err(e) => {
1942                return TokenStream::from(e.to_compile_error());
1943            }
1944        },
1945        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1946            Ok(arms) => {
1947                quote! { match self { #(#arms),* } }
1948            }
1949            Err(e) => {
1950                return TokenStream::from(e.to_compile_error());
1951            }
1952        },
1953        _ => panic!("Unbind can only be derived for structs and enums"),
1954    };
1955    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1956    let expand = quote! {
1957        #[automatically_derived]
1958        impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1959            fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1960                #inner
1961                Ok(())
1962            }
1963        }
1964    };
1965    TokenStream::from(expand)
1966}
1967
1968// Helper function for common parsing and validation
1969fn parse_observe_function(
1970    attr: TokenStream,
1971    item: TokenStream,
1972) -> syn::Result<(ItemFn, String, String)> {
1973    let input = syn::parse::<ItemFn>(item)?;
1974
1975    if input.sig.asyncness.is_none() {
1976        return Err(syn::Error::new(
1977            input.sig.span(),
1978            "observe macros can only be applied to async functions",
1979        ));
1980    }
1981
1982    let fn_name_str = input.sig.ident.to_string();
1983    let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
1984
1985    Ok((input, fn_name_str, module_name_str))
1986}
1987
1988// Helper function for creating telemetry identifiers and setup code
1989fn create_telemetry_setup(
1990    module_name_str: &str,
1991    fn_name_str: &str,
1992    include_error: bool,
1993) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
1994    let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
1995    let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
1996
1997    let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
1998
1999    let error_ident = if include_error {
2000        Some(Ident::new(
2001            "error",
2002            Span::from(proc_macro::Span::def_site()),
2003        ))
2004    } else {
2005        None
2006    };
2007
2008    let error_declaration = if let Some(ref error_ident) = error_ident {
2009        quote! {
2010            hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2011        }
2012    } else {
2013        quote! {}
2014    };
2015
2016    let setup_code = quote! {
2017        use hyperactor_telemetry;
2018        hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2019        hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2020        #error_declaration
2021    };
2022
2023    (latency_ident, success_ident, error_ident, setup_code)
2024}
2025
2026/// A procedural macro that automatically injects telemetry code into async functions
2027/// that return a Result type.
2028///
2029/// This macro wraps async functions and adds instrumentation to measure:
2030/// 1. Latency - how long the function takes to execute
2031/// 2. Error counter - function error count
2032/// 3. Success counter - function completion count
2033///
2034/// # Example
2035///
2036/// ```rust
2037/// use hyperactor_actor::observe_result;
2038///
2039/// #[observe_result("my_module")]
2040/// async fn process_request(user_id: &str) -> Result<String, Error> {
2041///     // Function implementation
2042///     // Telemetry will be automatically collected
2043/// }
2044/// ```
2045#[proc_macro_attribute]
2046pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2047    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2048        Ok(parsed) => parsed,
2049        Err(err) => return err.to_compile_error().into(),
2050    };
2051
2052    let fn_name = &input.sig.ident;
2053    let vis = &input.vis;
2054    let args = &input.sig.inputs;
2055    let return_type = &input.sig.output;
2056    let body = &input.block;
2057    let attrs = &input.attrs;
2058    let generics = &input.sig.generics;
2059
2060    let (latency_ident, success_ident, error_ident, telemetry_setup) =
2061        create_telemetry_setup(&module_name_str, &fn_name_str, true);
2062    let error_ident = error_ident.unwrap();
2063
2064    let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2065
2066    // Generate the instrumented function
2067    let expanded = quote! {
2068        #(#attrs)*
2069        #vis async fn #fn_name #generics(#args) #return_type {
2070            #telemetry_setup
2071
2072            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2073            let _timer = #latency_ident.start(kv_pairs);
2074
2075            let #result_ident = async #body.await;
2076
2077            match &#result_ident {
2078                Ok(_) => {
2079                    #success_ident.add(
2080                        1,
2081                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2082                    );
2083                }
2084                Err(_) => {
2085                    #error_ident.add(
2086                        1,
2087                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2088                    );
2089                }
2090            }
2091
2092            #result_ident
2093        }
2094    };
2095
2096    expanded.into()
2097}
2098
2099/// A procedural macro that automatically injects telemetry code into async functions
2100/// that do not return a Result type.
2101///
2102/// This macro wraps async functions and adds instrumentation to measure:
2103/// 1. Latency - how long the function takes to execute
2104/// 2. Success counter - function completion count
2105///
2106/// # Example
2107///
2108/// ```rust
2109/// use hyperactor_actor::observe_async;
2110///
2111/// #[observe_async("my_module")]
2112/// async fn process_data(data: &str) -> String {
2113///     // Function implementation
2114///     // Telemetry will be automatically collected
2115/// }
2116/// ```
2117#[proc_macro_attribute]
2118pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2119    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2120        Ok(parsed) => parsed,
2121        Err(err) => return err.to_compile_error().into(),
2122    };
2123
2124    let fn_name = &input.sig.ident;
2125    let vis = &input.vis;
2126    let args = &input.sig.inputs;
2127    let return_type = &input.sig.output;
2128    let body = &input.block;
2129    let attrs = &input.attrs;
2130    let generics = &input.sig.generics;
2131
2132    let (latency_ident, success_ident, _, telemetry_setup) =
2133        create_telemetry_setup(&module_name_str, &fn_name_str, false);
2134
2135    let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2136
2137    // Generate the instrumented function
2138    let expanded = quote! {
2139        #(#attrs)*
2140        #vis async fn #fn_name #generics(#args) #return_type {
2141            #telemetry_setup
2142
2143            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2144            let _timer = #latency_ident.start(kv_pairs);
2145
2146            let #return_ident = async #body.await;
2147
2148            #success_ident.add(
2149                1,
2150                hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2151            );
2152            #return_ident
2153        }
2154    };
2155
2156    expanded.into()
2157}
2158
2159fn validate_label(s: &str) -> Result<(), String> {
2160    if s.is_empty() {
2161        return Err("label must not be empty".to_string());
2162    }
2163    if s.len() > 63 {
2164        return Err("label exceeds 63 characters".to_string());
2165    }
2166    let first = s.as_bytes()[0];
2167    if !first.is_ascii_lowercase() {
2168        return Err("label must start with a lowercase letter".to_string());
2169    }
2170    let last = s.as_bytes()[s.len() - 1];
2171    if !last.is_ascii_lowercase() && !last.is_ascii_digit() {
2172        return Err("label must end with a lowercase letter or digit".to_string());
2173    }
2174    for ch in s.chars() {
2175        if !ch.is_ascii_lowercase() && !ch.is_ascii_digit() && ch != '-' {
2176            return Err(format!("label contains invalid character '{ch}'"));
2177        }
2178    }
2179    Ok(())
2180}
2181
2182fn validate_hex_uid(s: &str) -> Result<u64, String> {
2183    if s.is_empty() || s.len() > 16 {
2184        return Err(format!("hex uid must be 1-16 hex characters, got '{s}'"));
2185    }
2186    for ch in s.chars() {
2187        if !ch.is_ascii_hexdigit() {
2188            return Err(format!("hex uid contains invalid character '{ch}'"));
2189        }
2190    }
2191    u64::from_str_radix(s, 16).map_err(|e| format!("invalid hex uid '{s}': {e}"))
2192}
2193
2194/// Compile-time validated [`hyperactor::id::Uid`] construction.
2195///
2196/// Accepts two forms:
2197/// - `uid!(_my-singleton)` — a singleton Uid
2198/// - `uid!(d5d54d7201103869)` — an instance Uid
2199#[proc_macro]
2200pub fn uid(input: TokenStream) -> TokenStream {
2201    let input2: proc_macro2::TokenStream = input.into();
2202    let combined: String = input2.into_iter().map(|tt| tt.to_string()).collect();
2203
2204    if combined.is_empty() {
2205        return TokenStream::from(quote! { compile_error!("uid! macro requires an argument") });
2206    }
2207
2208    // Singleton: starts with '_'
2209    if let Some(rest) = combined.strip_prefix('_') {
2210        return match validate_label(rest) {
2211            Ok(()) => TokenStream::from(quote! {
2212                hyperactor::id::Uid::Singleton(
2213                    hyperactor::id::Label::new(#rest).unwrap()
2214                )
2215            }),
2216            Err(e) => {
2217                let msg = format!("invalid singleton uid: {e}");
2218                TokenStream::from(quote! { compile_error!(#msg) })
2219            }
2220        };
2221    }
2222
2223    // Instance: bare hex
2224    match validate_hex_uid(&combined) {
2225        Ok(uid_val) => TokenStream::from(quote! {
2226            hyperactor::id::Uid::Instance(#uid_val)
2227        }),
2228        Err(e) => {
2229            let msg = format!("invalid uid: {e}");
2230            TokenStream::from(quote! { compile_error!(#msg) })
2231        }
2232    }
2233}