hyperactor_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines macros used by the [`hyperactor`] crate.
10
11#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Meta;
39use syn::MetaNameValue;
40use syn::Token;
41use syn::Type;
42use syn::bracketed;
43use syn::parse::Parse;
44use syn::parse::ParseStream;
45use syn::parse_macro_input;
46use syn::punctuated::Punctuated;
47use syn::spanned::Spanned;
48
49const REPLY_VARIANT_ERROR: &str = indoc! {r#"
50`call` message expects a typed port ref (`OncePortRef` or `PortRef`) or handle (`OncePortHandle` or `PortHandle`) argument in the last position
51
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
53= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
54"#};
55
56const REPLY_USAGE_ERROR: &str = indoc! {r#"
57`call` message expects at most one `reply` argument
58
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
60= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
61"#};
62
63enum FieldFlag {
64    None,
65    Reply,
66}
67
68/// Represents a variant of an enum.
69#[allow(dead_code)]
70enum Variant {
71    /// A named variant (i.e., `MyVariant { .. }`).
72    Named {
73        enum_name: Ident,
74        name: Ident,
75        field_names: Vec<Ident>,
76        field_types: Vec<Type>,
77        field_flags: Vec<FieldFlag>,
78        is_struct: bool,
79        generics: syn::Generics,
80    },
81    /// An anonymous variant (i.e., `MyVariant(..)`).
82    Anon {
83        enum_name: Ident,
84        name: Ident,
85        field_types: Vec<Type>,
86        field_flags: Vec<FieldFlag>,
87        is_struct: bool,
88        generics: syn::Generics,
89    },
90}
91
92impl Variant {
93    /// The number of fields in the variant.
94    fn len(&self) -> usize {
95        self.field_types().len()
96    }
97
98    /// Returns whether this variant was defined as a struct.
99    fn is_struct(&self) -> bool {
100        match self {
101            Variant::Named { is_struct, .. } => *is_struct,
102            Variant::Anon { is_struct, .. } => *is_struct,
103        }
104    }
105
106    /// The name of the enum containing the variant.
107    fn enum_name(&self) -> &Ident {
108        match self {
109            Variant::Named { enum_name, .. } => enum_name,
110            Variant::Anon { enum_name, .. } => enum_name,
111        }
112    }
113
114    /// The name of the variant itself.
115    fn name(&self) -> &Ident {
116        match self {
117            Variant::Named { name, .. } => name,
118            Variant::Anon { name, .. } => name,
119        }
120    }
121
122    /// The generics of the variant itself.
123    #[allow(dead_code)]
124    fn generics(&self) -> &syn::Generics {
125        match self {
126            Variant::Named { generics, .. } => generics,
127            Variant::Anon { generics, .. } => generics,
128        }
129    }
130
131    /// The snake_name of the variant itself.
132    fn snake_name(&self) -> Ident {
133        Ident::new(
134            &self.name().to_string().to_case(Case::Snake),
135            self.name().span(),
136        )
137    }
138
139    /// The variant's qualified name.
140    fn qualified_name(&self) -> proc_macro2::TokenStream {
141        let enum_name = self.enum_name();
142        let name = self.name();
143
144        if self.is_struct() {
145            quote! { #enum_name }
146        } else {
147            quote! { #enum_name::#name }
148        }
149    }
150
151    /// Names of the fields in the variant. Anonymous variants are named
152    /// according to their position in the argument list.
153    fn field_names(&self) -> Vec<Ident> {
154        match self {
155            Variant::Named { field_names, .. } => field_names.clone(),
156            Variant::Anon { field_types, .. } => (0usize..field_types.len())
157                .map(|idx| format_ident!("arg{}", idx))
158                .collect(),
159        }
160    }
161
162    /// The types of the fields int the variant.
163    fn field_types(&self) -> &Vec<Type> {
164        match self {
165            Variant::Named { field_types, .. } => field_types,
166            Variant::Anon { field_types, .. } => field_types,
167        }
168    }
169
170    /// Return the field flags for this variant.
171    fn field_flags(&self) -> &Vec<FieldFlag> {
172        match self {
173            Variant::Named { field_flags, .. } => field_flags,
174            Variant::Anon { field_flags, .. } => field_flags,
175        }
176    }
177
178    /// The constructor for the variant, using the field names directly.
179    fn constructor(&self) -> proc_macro2::TokenStream {
180        let qualified_name = self.qualified_name();
181        let field_names = self.field_names();
182        match self {
183            Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
184            Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
185        }
186    }
187}
188
189struct ReplyPort {
190    is_handle: bool,
191    is_once: bool,
192}
193
194impl ReplyPort {
195    fn from_last_segment(last_segment: &proc_macro2::Ident) -> ReplyPort {
196        ReplyPort {
197            is_handle: last_segment == "PortHandle" || last_segment == "OncePortHandle",
198            is_once: last_segment == "OncePortHandle" || last_segment == "OncePortRef",
199        }
200    }
201
202    fn open_op(&self) -> proc_macro2::TokenStream {
203        if self.is_once {
204            quote! { hyperactor::mailbox::open_once_port }
205        } else {
206            quote! { hyperactor::mailbox::open_port }
207        }
208    }
209
210    fn rx_modifier(&self) -> proc_macro2::TokenStream {
211        if self.is_once {
212            quote! {}
213        } else {
214            quote! { mut }
215        }
216    }
217}
218
219/// Represents a message that can be sent to a handler, each message is associated with
220/// a variant.
221#[allow(clippy::large_enum_variant)]
222enum Message {
223    /// A call message is a request-response message, the last argument is
224    /// a [`hyperactor::OncePortRef`] or [`hyperactor::OncePortHandle`].
225    Call {
226        variant: Variant,
227        /// Tells whether the reply argument is a handle.
228        reply_port: ReplyPort,
229        /// The underlying return type (i.e., the type of the reply port).
230        return_type: Type,
231        /// the log level for generated instrumentation for handlers of this message.
232        log_level: Option<Ident>,
233    },
234    OneWay {
235        variant: Variant,
236        /// the log level for generated instrumentation for handlers of this message.
237        log_level: Option<Ident>,
238    },
239}
240
241impl Message {
242    fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
243        match &variant
244            .field_flags()
245            .iter()
246            .zip(variant.field_types())
247            .filter_map(|(flag, ty)| match flag {
248                FieldFlag::Reply => Some(ty),
249                FieldFlag::None => None,
250            })
251            .collect::<Vec<&Type>>()[..]
252        {
253            [] => Ok(Self::OneWay { variant, log_level }),
254            [reply_port_ty] => {
255                let syn::Type::Path(type_path) = reply_port_ty else {
256                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
257                };
258                let Some(last_segment) = type_path.path.segments.last() else {
259                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
260                };
261                if last_segment.ident != "OncePortRef"
262                    && last_segment.ident != "OncePortHandle"
263                    && last_segment.ident != "PortRef"
264                    && last_segment.ident != "PortHandle"
265                {
266                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
267                }
268                let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
269                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
270                };
271                let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
272                    return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
273                };
274                let reply_port = ReplyPort::from_last_segment(&last_segment.ident);
275                let return_type = return_ty.clone();
276                Ok(Self::Call {
277                    variant,
278                    reply_port,
279                    return_type,
280                    log_level,
281                })
282            }
283            _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
284        }
285    }
286
287    /// The arguments of this message.
288    fn args(&self) -> Vec<(Ident, Type)> {
289        match self {
290            Message::Call { variant, .. } => variant
291                .field_names()
292                .into_iter()
293                .zip(variant.field_types().clone())
294                .take(variant.len() - 1)
295                .collect(),
296            Message::OneWay { variant, .. } => variant
297                .field_names()
298                .into_iter()
299                .zip(variant.field_types().clone())
300                .collect(),
301        }
302    }
303
304    fn variant(&self) -> &Variant {
305        match self {
306            Message::Call { variant, .. } => variant,
307            Message::OneWay { variant, .. } => variant,
308        }
309    }
310
311    fn reply_port_position(&self) -> Option<usize> {
312        self.variant()
313            .field_flags()
314            .iter()
315            .position(|flag| matches!(flag, FieldFlag::Reply))
316    }
317
318    /// The reply port argument of this message.
319    fn reply_port_arg(&self) -> Option<(Ident, Type)> {
320        match self {
321            Message::Call { variant, .. } => {
322                let pos = self.reply_port_position()?;
323                Some((
324                    variant.field_names()[pos].clone(),
325                    variant.field_types()[pos].clone(),
326                ))
327            }
328            Message::OneWay { .. } => None,
329        }
330    }
331}
332
333fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
334    let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
335        Some(attr) => {
336            let Ok(meta) = attr.meta.require_list() else {
337                return Err(syn::Error::new(
338                    Span::call_site(),
339                    indoc! {"
340                            `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
341
342                            = help use `#[log_level(info)]` or `#[log_level(error)]`
343                        "},
344                ));
345            };
346            let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
347            if parsed.len() != 1 {
348                return Err(syn::Error::new(
349                    Span::call_site(),
350                    indoc! {"
351                            `log_level` attribute must specify exactly one level
352
353                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
354                        "},
355                ));
356            };
357            Some(parsed.first().unwrap().to_string())
358        }
359        None => None,
360    };
361
362    if level.is_none() {
363        return Ok(None);
364    }
365    let level = level.unwrap();
366
367    match level.as_str() {
368        "error" | "warn" | "info" | "debug" | "trace" => {}
369        _ => {
370            return Err(syn::Error::new(
371                Span::call_site(),
372                indoc! {"
373                            `log_level` attribute must be one of 'error, warn, info, debug, trace'
374
375                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
376                        "},
377            ));
378        }
379    }
380
381    Ok(Some(Ident::new(
382        level.to_ascii_uppercase().as_str(),
383        Span::call_site(),
384    )))
385}
386
387fn parse_field_flag(field: &Field) -> FieldFlag {
388    for attr in field.attrs.iter() {
389        match &attr.meta {
390            syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
391            _ => {}
392        }
393    }
394    FieldFlag::None
395}
396
397/// Parse a message enum or struct into its constituent messages.
398fn parse_messages(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
399    match &input.data {
400        Data::Enum(data_enum) => {
401            let mut messages = Vec::new();
402
403            for variant in &data_enum.variants {
404                let name = variant.ident.clone();
405                let attrs = &variant.attrs;
406
407                let message_variant = match &variant.fields {
408                    syn::Fields::Unnamed(fields_) => Variant::Anon {
409                        enum_name: input.ident.clone(),
410                        name,
411                        field_types: fields_
412                            .unnamed
413                            .iter()
414                            .map(|field| field.ty.clone())
415                            .collect(),
416                        field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
417                        is_struct: false,
418                        generics: input.generics.clone(),
419                    },
420                    syn::Fields::Named(fields_) => Variant::Named {
421                        enum_name: input.ident.clone(),
422                        name,
423                        field_names: fields_
424                            .named
425                            .iter()
426                            .map(|field| field.ident.clone().unwrap())
427                            .collect(),
428                        field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
429                        field_flags: fields_.named.iter().map(parse_field_flag).collect(),
430                        is_struct: false,
431                        generics: input.generics.clone(),
432                    },
433                    _ => {
434                        return Err(syn::Error::new_spanned(
435                            variant,
436                            indoc! {r#"
437                                `Handler` currently only supports named or tuple struct variants
438
439                                = help use `MyCall(Arg1Type, Arg2Type, ..)`,
440                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
441                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
442                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>}`
443                                = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
444                                = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>}`
445                              "#},
446                        ));
447                    }
448                };
449                let log_level = parse_log_level(attrs)?;
450
451                messages.push(Message::new(
452                    variant.fields.span(),
453                    message_variant,
454                    log_level,
455                )?);
456            }
457
458            Ok(messages)
459        }
460        Data::Struct(data_struct) => {
461            let struct_name = input.ident.clone();
462            let attrs = &input.attrs;
463
464            let message_variant = match &data_struct.fields {
465                syn::Fields::Unnamed(fields_) => Variant::Anon {
466                    enum_name: struct_name.clone(),
467                    name: struct_name,
468                    field_types: fields_
469                        .unnamed
470                        .iter()
471                        .map(|field| field.ty.clone())
472                        .collect(),
473                    field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
474                    is_struct: true,
475                    generics: input.generics.clone(),
476                },
477                syn::Fields::Named(fields_) => Variant::Named {
478                    enum_name: struct_name.clone(),
479                    name: struct_name,
480                    field_names: fields_
481                        .named
482                        .iter()
483                        .map(|field| field.ident.clone().unwrap())
484                        .collect(),
485                    field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
486                    field_flags: fields_.named.iter().map(parse_field_flag).collect(),
487                    is_struct: true,
488                    generics: input.generics.clone(),
489                },
490                syn::Fields::Unit => Variant::Anon {
491                    enum_name: struct_name.clone(),
492                    name: struct_name,
493                    field_types: Vec::new(),
494                    field_flags: Vec::new(),
495                    is_struct: true,
496                    generics: input.generics.clone(),
497                },
498            };
499
500            let log_level = parse_log_level(attrs)?;
501            let message = Message::new(data_struct.fields.span(), message_variant, log_level)?;
502
503            Ok(vec![message])
504        }
505        _ => Err(syn::Error::new_spanned(
506            input,
507            "handlers can only be derived for enums and structs",
508        )),
509    }
510}
511
512/// Derive a custom handler trait for given an enum containing tuple
513/// structs.  The handler trait defines a method corresponding
514/// to each of the enum's variants, and a `handle` function
515/// that dispatches messages to the correct method.  The macro
516/// supports two messaging patterns: "call" and "oneway". A call is a
517/// request-response message; a [`hyperactor::mailbox::OncePortRef`] or
518/// [`hyperactor::mailbox::OncePortHandle`] in the last position is used
519/// to send the return value.
520///
521/// The macro also derives a client trait that can be automatically implemented
522/// by specifying [`HandleClient`] for `ActorHandle<Actor>` and [`RefClient`]
523/// for `ActorRef<Actor>` accordingly. We require two implementations because
524/// not `ActorRef`s require that its message type is serializable.
525///
526/// The associated [`hyperactor_macros::handler`] macro can be used to add
527/// a dispatching handler directly to an [`hyperactor::Actor`].
528///
529/// # Example
530///
531/// The following example creates a "shopping list" actor responsible for
532/// maintaining a shopping list.
533///
534/// ```
535/// use std::collections::HashSet;
536/// use std::time::Duration;
537///
538/// use async_trait::async_trait;
539/// use hyperactor::Actor;
540/// use hyperactor::HandleClient;
541/// use hyperactor::Handler;
542/// use hyperactor::Instance;
543/// use hyperactor::Named;
544/// use hyperactor::OncePortRef;
545/// use hyperactor::RefClient;
546/// use hyperactor::proc::Proc;
547/// use serde::Deserialize;
548/// use serde::Serialize;
549///
550/// #[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
551/// enum ShoppingList {
552///     // Oneway messages dispatch messages asynchronously, with no reply.
553///     Add(String),
554///     Remove(String),
555///
556///     // Call messages dispatch a request, expecting a reply to the
557///     // provided port, which must be in the last position.
558///     Exists(String, #[reply] OncePortRef<bool>),
559///
560///     List(#[reply] OncePortRef<Vec<String>>),
561/// }
562///
563/// // Define an actor.
564/// #[derive(Debug)]
565/// #[hyperactor::export(
566///     spawn = true,
567///     handlers = [
568///         ShoppingList,
569///     ],
570/// )]
571/// struct ShoppingListActor(HashSet<String>);
572///
573/// #[async_trait]
574/// impl Actor for ShoppingListActor {
575///     type Params = ();
576///
577///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
578///         Ok(Self(HashSet::new()))
579///     }
580/// }
581///
582/// // ShoppingListHandler is the trait generated by derive(Handler) above.
583/// // We implement the trait here for the actor, defining a handler for
584/// // each ShoppingList message.
585/// //
586/// // The `forward` attribute installs a handler that forwards messages
587/// // to the `ShoppingListHandler` implementation directly. This can also
588/// // be done manually:
589/// //
590/// // ```ignore
591/// //<ShoppingListActor as ShoppingListHandler>
592/// //     ::handle(self, comm, message).await
593/// // ```
594/// #[async_trait]
595/// #[hyperactor::forward(ShoppingList)]
596/// impl ShoppingListHandler for ShoppingListActor {
597///     async fn add(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
598///         eprintln!("insert {}", item);
599///         self.0.insert(item);
600///         Ok(())
601///     }
602///
603///     async fn remove(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
604///         eprintln!("remove {}", item);
605///         self.0.remove(&item);
606///         Ok(())
607///     }
608///
609///     async fn exists(
610///         &mut self,
611///         _cx: &Context<Self>,
612///         item: String,
613///     ) -> Result<bool, anyhow::Error> {
614///         Ok(self.0.contains(&item))
615///     }
616///
617///     async fn list(&mut self, _cx: &Context<Self>) -> Result<Vec<String>, anyhow::Error> {
618///         Ok(self.0.iter().cloned().collect())
619///     }
620/// }
621///
622/// #[tokio::main]
623/// async fn main() -> Result<(), anyhow::Error> {
624///     let mut proc = Proc::local();
625///
626///     // Spawn our actor, and get a handle for rank 0.
627///     let shopping_list_actor: hyperactor::ActorHandle<ShoppingListActor> =
628///         proc.spawn("shopping", ()).await?;
629///
630///     // We join the system, so that we can send messages to actors.
631///     let client = proc.attach("client").unwrap();
632///
633///     // todo: consider making this a macro to remove the magic names
634///
635///     // Derive(Handler) generates client methods, which call the
636///     // remote handler provided an actor instance,
637///     // the destination actor, and the method arguments.
638///
639///     shopping_list_actor.add(&client, "milk".into()).await?;
640///     shopping_list_actor.add(&client, "eggs".into()).await?;
641///
642///     println!(
643///         "got milk? {}",
644///         shopping_list_actor.exists(&client, "milk".into()).await?
645///     );
646///     println!(
647///         "got yoghurt? {}",
648///         shopping_list_actor
649///             .exists(&client, "yoghurt".into())
650///             .await?
651///     );
652///
653///     shopping_list_actor.remove(&client, "milk".into()).await?;
654///     println!(
655///         "got milk now? {}",
656///         shopping_list_actor.exists(&client, "milk".into()).await?
657///     );
658///
659///     println!(
660///         "shopping list: {:?}",
661///         shopping_list_actor.list(&client).await?
662///     );
663///
664///     let _ = proc
665///         .destroy_and_wait::<()>(Duration::from_secs(1), None)
666///         .await?;
667///     Ok(())
668/// }
669/// ```
670#[proc_macro_derive(Handler, attributes(reply))]
671pub fn derive_handler(input: TokenStream) -> TokenStream {
672    let input = parse_macro_input!(input as DeriveInput);
673    let name: Ident = input.ident.clone();
674    let (_, ty_generics, _) = input.generics.split_for_impl();
675
676    let messages = match parse_messages(input.clone()) {
677        Ok(messages) => messages,
678        Err(err) => return TokenStream::from(err.to_compile_error()),
679    };
680
681    // Trait definition methods for the handler.
682    let mut handler_trait_methods = Vec::new();
683
684    // The arms of the match used in the message dispatcher.
685    let mut match_arms = Vec::new();
686
687    // Trait implemented by clients.
688    let mut client_trait_methods = Vec::new();
689
690    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
691
692    for message in &messages {
693        match message {
694            Message::Call {
695                variant,
696                reply_port,
697                return_type,
698                log_level,
699            } => {
700                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
701                let variant_name_snake = variant.snake_name();
702                let variant_name_snake_deprecated =
703                    format_ident!("{}_deprecated", variant_name_snake);
704                let enum_name = variant.enum_name();
705                let _variant_qualified_name = variant.qualified_name();
706                let log_level = match (&global_log_level, log_level) {
707                    (_, Some(local)) => local.clone(),
708                    (Some(global), None) => global.clone(),
709                    _ => Ident::new("DEBUG", Span::call_site()),
710                };
711                let _log_level = if reply_port.is_handle {
712                    quote! {
713                        tracing::Level::#log_level
714                    }
715                } else {
716                    quote! {
717                        tracing::Level::TRACE
718                    }
719                };
720                let log_message = quote! {
721                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
722                            "rpc" => "call",
723                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
724                            "message_type" => stringify!(#enum_name),
725                            "variant" => stringify!(#variant_name_snake),
726                        ));
727                };
728
729                handler_trait_methods.push(quote! {
730                    #[doc = "The generated handler method for this enum variant."]
731                    async fn #variant_name_snake(
732                        &mut self,
733                        cx: &hyperactor::Context<Self>,
734                        #(#arg_names: #arg_types),*)
735                        -> Result<#return_type, hyperactor::anyhow::Error>;
736                });
737
738                client_trait_methods.push(quote! {
739                    #[doc = "The generated client method for this enum variant."]
740                    async fn #variant_name_snake(
741                        &self,
742                        cx: &impl hyperactor::context::Actor,
743                        #(#arg_names: #arg_types),*)
744                        -> Result<#return_type, hyperactor::anyhow::Error>;
745
746                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
747                    async fn #variant_name_snake_deprecated(
748                        &self,
749                        cx: &impl hyperactor::context::Actor,
750                        #(#arg_names: #arg_types),*)
751                        -> Result<#return_type, hyperactor::anyhow::Error>;
752                });
753
754                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
755                let constructor = variant.constructor();
756                let result_ident = Ident::new("result", Span::mixed_site());
757                let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
758                if reply_port.is_handle {
759                    match_arms.push(quote! {
760                        #constructor => {
761                            #log_message
762                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
763                            // This would require Result<Result<..., in order to handle RPC errors.
764                            #construct_result_future
765                            #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
766                        }
767                    });
768                } else {
769                    match_arms.push(quote! {
770                        #constructor => {
771                            #log_message
772                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
773                            // This would require Result<Result<..., in order to handle RPC errors.
774                            #construct_result_future
775                            #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
776                        }
777                    });
778                }
779            }
780            Message::OneWay { variant, log_level } => {
781                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
782                let variant_name_snake = variant.snake_name();
783                let variant_name_snake_deprecated =
784                    format_ident!("{}_deprecated", variant_name_snake);
785                let enum_name = variant.enum_name();
786                let log_level = match (&global_log_level, log_level) {
787                    (_, Some(local)) => local.clone(),
788                    (Some(global), None) => global.clone(),
789                    _ => Ident::new("TRACE", Span::call_site()),
790                };
791                let _log_level = quote! {
792                    tracing::Level::#log_level
793                };
794                let log_message = quote! {
795                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
796                            "rpc" => "call",
797                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
798                            "message_type" => stringify!(#enum_name),
799                            "variant" => stringify!(#variant_name_snake),
800                        ));
801                };
802
803                handler_trait_methods.push(quote! {
804                    #[doc = "The generated handler method for this enum variant."]
805                    async fn #variant_name_snake(
806                        &mut self,
807                        cx: &hyperactor::Context<Self>,
808                        #(#arg_names: #arg_types),*)
809                        -> Result<(), hyperactor::anyhow::Error>;
810                });
811
812                client_trait_methods.push(quote! {
813                    #[doc = "The generated client method for this enum variant."]
814                    async fn #variant_name_snake(
815                        &self,
816                        cx: &impl hyperactor::context::Actor,
817                        #(#arg_names: #arg_types),*)
818                        -> Result<(), hyperactor::anyhow::Error>;
819
820                    #[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
821                    async fn #variant_name_snake_deprecated(
822                        &self,
823                        cx: &impl hyperactor::context::Actor,
824                        #(#arg_names: #arg_types),*)
825                        -> Result<(), hyperactor::anyhow::Error>;
826                });
827
828                let constructor = variant.constructor();
829
830                match_arms.push(quote! {
831                    #constructor => {
832                        #log_message
833                        self.#variant_name_snake(cx, #(#arg_names),*).await
834                    },
835                });
836            }
837        }
838    }
839
840    let handler_trait_name = format_ident!("{}Handler", name);
841    let client_trait_name = format_ident!("{}Client", name);
842
843    // We impose additional constraints on the generics in the implementation;
844    // but the trait itself should not impose additional constraints:
845
846    let mut handler_generics = input.generics.clone();
847    for param in handler_generics.type_params_mut() {
848        param.bounds.push(syn::parse_quote!(serde::Serialize));
849        param
850            .bounds
851            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
852        param.bounds.push(syn::parse_quote!(Send));
853        param.bounds.push(syn::parse_quote!(Sync));
854        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
855        param.bounds.push(syn::parse_quote!(hyperactor::Named));
856    }
857    let (handler_impl_generics, _, _) = handler_generics.split_for_impl();
858    let (client_impl_generics, _, _) = input.generics.split_for_impl();
859
860    let expanded = quote! {
861        #[doc = "The custom handler trait for this message type."]
862        #[hyperactor::async_trait::async_trait]
863        pub trait #handler_trait_name #handler_impl_generics: hyperactor::Actor + Send + Sync  {
864            #(#handler_trait_methods)*
865
866            #[doc = "Handle the next message."]
867            async fn handle(
868                &mut self,
869                cx: &hyperactor::Context<Self>,
870                message: #name #ty_generics,
871            ) -> hyperactor::anyhow::Result<()>  {
872                 // Dispatch based on message type.
873                 match message {
874                     #(#match_arms)*
875                }
876            }
877        }
878
879        #[doc = "The custom client trait for this message type."]
880        #[hyperactor::async_trait::async_trait]
881        pub trait #client_trait_name #client_impl_generics: Send + Sync  {
882            #(#client_trait_methods)*
883        }
884    };
885
886    TokenStream::from(expanded)
887}
888
889/// Derives a client implementation on `ActorHandle<Actor>`.
890/// See [`Handler`] documentation for details.
891#[proc_macro_derive(HandleClient, attributes(log_level))]
892pub fn derive_handle_client(input: TokenStream) -> TokenStream {
893    derive_client(input, true)
894}
895
896/// Derives a client implementation on `ActorRef<Actor>`.
897/// See [`Handler`] documentation for details.
898#[proc_macro_derive(RefClient, attributes(log_level))]
899pub fn derive_ref_client(input: TokenStream) -> TokenStream {
900    derive_client(input, false)
901}
902
903fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
904    let input = parse_macro_input!(input as DeriveInput);
905    let name = input.ident.clone();
906
907    let messages = match parse_messages(input.clone()) {
908        Ok(messages) => messages,
909        Err(err) => return TokenStream::from(err.to_compile_error()),
910    };
911
912    // The client implementation methods.
913    let mut impl_methods = Vec::new();
914
915    let send_message = if is_handle {
916        quote! { self.send(message)? }
917    } else {
918        quote! { self.send(cx, message)? }
919    };
920    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
921
922    for message in &messages {
923        match message {
924            Message::Call {
925                variant,
926                reply_port,
927                return_type,
928                log_level,
929            } => {
930                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
931                let variant_name_snake = variant.snake_name();
932                let variant_name_snake_deprecated =
933                    format_ident!("{}_deprecated", variant_name_snake);
934                let enum_name = variant.enum_name();
935
936                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
937                let constructor = variant.constructor();
938                let log_level = match (&global_log_level, log_level) {
939                    (_, Some(local)) => local.clone(),
940                    (Some(global), None) => global.clone(),
941                    _ => Ident::new("DEBUG", Span::call_site()),
942                };
943                let log_level = if is_handle {
944                    quote! {
945                        tracing::Level::#log_level
946                    }
947                } else {
948                    quote! {
949                        tracing::Level::TRACE
950                    }
951                };
952                let log_message = quote! {
953                        hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
954                            "rpc" => "call",
955                            "actor_id" => hyperactor::context::Mailbox::mailbox(cx).actor_id().to_string(),
956                            "message_type" => stringify!(#enum_name),
957                            "variant" => stringify!(#variant_name_snake),
958                        ));
959
960                };
961                let open_port = reply_port.open_op();
962                let rx_mod = reply_port.rx_modifier();
963                if reply_port.is_handle {
964                    impl_methods.push(quote! {
965                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
966                        async fn #variant_name_snake(
967                            &self,
968                            cx: &impl hyperactor::context::Actor,
969                            #(#arg_names: #arg_types),*)
970                            -> Result<#return_type, hyperactor::anyhow::Error> {
971                            let (#reply_port_arg, #rx_mod reply_receiver) =
972                                #open_port::<#return_type>(cx);
973                            let message = #constructor;
974                            #log_message;
975                            #send_message;
976                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
977                        }
978
979                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
980                        async fn #variant_name_snake_deprecated(
981                            &self,
982                            cx: &impl hyperactor::context::Actor,
983                            #(#arg_names: #arg_types),*)
984                            -> Result<#return_type, hyperactor::anyhow::Error> {
985                            let (#reply_port_arg, #rx_mod reply_receiver) =
986                                #open_port::<#return_type>(cx);
987                            let message = #constructor;
988                            #log_message;
989                            #send_message;
990                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
991                        }
992                    });
993                } else {
994                    impl_methods.push(quote! {
995                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
996                        async fn #variant_name_snake(
997                            &self,
998                            cx: &impl hyperactor::context::Actor,
999                            #(#arg_names: #arg_types),*)
1000                            -> Result<#return_type, hyperactor::anyhow::Error> {
1001                            let (#reply_port_arg, #rx_mod reply_receiver) =
1002                                #open_port::<#return_type>(cx);
1003                            let #reply_port_arg = #reply_port_arg.bind();
1004                            let message = #constructor;
1005                            #log_message;
1006                            #send_message;
1007                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1008                        }
1009
1010                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
1011                        async fn #variant_name_snake_deprecated(
1012                            &self,
1013                            cx: &impl hyperactor::context::Actor,
1014                            #(#arg_names: #arg_types),*)
1015                            -> Result<#return_type, hyperactor::anyhow::Error> {
1016                            let (#reply_port_arg, #rx_mod reply_receiver) =
1017                                #open_port::<#return_type>(cx);
1018                            let #reply_port_arg = #reply_port_arg.bind();
1019                            let message = #constructor;
1020                            #log_message;
1021                            #send_message;
1022                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
1023                        }
1024                    });
1025                }
1026            }
1027            Message::OneWay { variant, log_level } => {
1028                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
1029                let variant_name_snake = variant.snake_name();
1030                let variant_name_snake_deprecated =
1031                    format_ident!("{}_deprecated", variant_name_snake);
1032                let enum_name = variant.enum_name();
1033                let constructor = variant.constructor();
1034                let log_level = match (&global_log_level, log_level) {
1035                    (_, Some(local)) => local.clone(),
1036                    (Some(global), None) => global.clone(),
1037                    _ => Ident::new("DEBUG", Span::call_site()),
1038                };
1039                let _log_level = if is_handle {
1040                    quote! {
1041                        tracing::Level::TRACE
1042                    }
1043                } else {
1044                    quote! {
1045                        tracing::Level::#log_level
1046                    }
1047                };
1048                let log_message = quote! {
1049                    hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
1050                        "rpc" => "oneway",
1051                        "actor_id" => self.actor_id().to_string(),
1052                        "message_type" => stringify!(#enum_name),
1053                        "variant" => stringify!(#variant_name_snake),
1054                    ));
1055                };
1056                impl_methods.push(quote! {
1057                    async fn #variant_name_snake(
1058                        &self,
1059                        cx: &impl hyperactor::context::Actor,
1060                        #(#arg_names: #arg_types),*)
1061                        -> Result<(), hyperactor::anyhow::Error> {
1062                        let message = #constructor;
1063                        #log_message;
1064                        #send_message;
1065                        Ok(())
1066                    }
1067
1068                    async fn #variant_name_snake_deprecated(
1069                        &self,
1070                        cx: &impl hyperactor::context::Actor,
1071                        #(#arg_names: #arg_types),*)
1072                        -> Result<(), hyperactor::anyhow::Error> {
1073                        let message = #constructor;
1074                        #log_message;
1075                        #send_message;
1076                        Ok(())
1077                    }
1078                });
1079            }
1080        }
1081    }
1082
1083    let trait_name = format_ident!("{}Client", name);
1084
1085    let (_, ty_generics, _) = input.generics.split_for_impl();
1086
1087    // Add a new generic parameter 'A'
1088    let actor_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
1089    let mut trait_generics = input.generics.clone();
1090    trait_generics.params.insert(
1091        0,
1092        syn::GenericParam::Type(syn::TypeParam {
1093            ident: actor_ident.clone(),
1094            attrs: vec![],
1095            colon_token: None,
1096            bounds: Punctuated::new(),
1097            eq_token: None,
1098            default: None,
1099        }),
1100    );
1101
1102    for param in trait_generics.type_params_mut() {
1103        if param.ident == actor_ident {
1104            continue;
1105        }
1106        param.bounds.push(syn::parse_quote!(serde::Serialize));
1107        param
1108            .bounds
1109            .push(syn::parse_quote!(for<'de> serde::Deserialize<'de>));
1110        param.bounds.push(syn::parse_quote!(Send));
1111        param.bounds.push(syn::parse_quote!(Sync));
1112        param.bounds.push(syn::parse_quote!(std::fmt::Debug));
1113        param.bounds.push(syn::parse_quote!(hyperactor::Named));
1114    }
1115
1116    let (impl_generics, _, _) = trait_generics.split_for_impl();
1117
1118    let expanded = if is_handle {
1119        quote! {
1120            #[hyperactor::async_trait::async_trait]
1121            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#actor_ident>
1122              where #actor_ident: hyperactor::Handler<#name #ty_generics> {
1123                #(#impl_methods)*
1124            }
1125        }
1126    } else {
1127        quote! {
1128            #[hyperactor::async_trait::async_trait]
1129            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#actor_ident>
1130              where #actor_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
1131                #(#impl_methods)*
1132            }
1133        }
1134    };
1135
1136    TokenStream::from(expanded)
1137}
1138
1139const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
1140`forward` expects the message type that is being forwarded
1141
1142= help: use `#[forward(MessageType)]`
1143"#};
1144
1145/// Forward messages of the provided type to this handler implementation.
1146#[proc_macro_attribute]
1147pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
1148    let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
1149    if attr_args.len() != 1 {
1150        return TokenStream::from(
1151            syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
1152        );
1153    }
1154
1155    let message_type = attr_args.first().unwrap();
1156    let input = parse_macro_input!(item as ItemImpl);
1157
1158    let self_type = match *input.self_ty {
1159        syn::Type::Path(ref type_path) => {
1160            let segment = type_path.path.segments.last().unwrap();
1161            segment.clone() //ident.clone()
1162        }
1163        _ => {
1164            return TokenStream::from(
1165                syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
1166                    .to_compile_error(),
1167            );
1168        }
1169    };
1170
1171    let trait_name = match input.trait_ {
1172        Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
1173        None => {
1174            return TokenStream::from(
1175                syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
1176                    .to_compile_error(),
1177            );
1178        }
1179    };
1180
1181    let expanded = quote! {
1182        #input
1183
1184        #[hyperactor::async_trait::async_trait]
1185        impl hyperactor::Handler<#message_type> for #self_type {
1186            async fn handle(
1187                &mut self,
1188                cx: &hyperactor::Context<Self>,
1189                message: #message_type,
1190            ) -> hyperactor::anyhow::Result<()> {
1191                <Self as #trait_name>::handle(self, cx, message).await
1192            }
1193        }
1194    };
1195
1196    TokenStream::from(expanded)
1197}
1198
1199/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1200/// We set a default level of INFO while always setting ERROR if the function returns Result::Err giving us
1201/// consistent and high quality structured logs. Because this wraps around tracing::instrument, all parameters
1202/// mentioned in https://fburl.com/9jlkb5q4 should be valid. For functions that don't return a [`Result`] type, use
1203/// [`instrument_infallible`]
1204///
1205/// ```
1206/// #[telemetry::instrument]
1207/// async fn yolo() -> anyhow::Result<i32> {
1208///     Ok(420)
1209/// }
1210/// ```
1211#[proc_macro_attribute]
1212pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1213    let args =
1214        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1215    let input = parse_macro_input!(input as ItemFn);
1216    let output = quote! {
1217        #[hyperactor::tracing::instrument(err, skip_all, #args)]
1218        #input
1219    };
1220
1221    TokenStream::from(output)
1222}
1223
1224/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1225/// Because this wraps around tracing::instrument, all parameters mentioned in
1226/// https://fburl.com/9jlkb5q4 should be valid.
1227///
1228/// ```
1229/// #[telemetry::instrument]
1230/// async fn yolo() -> i32 {
1231///     420
1232/// }
1233/// ```
1234#[proc_macro_attribute]
1235pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1236    let args =
1237        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1238    let input = parse_macro_input!(input as ItemFn);
1239
1240    let output = quote! {
1241        #[hyperactor::tracing::instrument(skip_all, #args)]
1242        #input
1243    };
1244
1245    TokenStream::from(output)
1246}
1247
1248/// Derive the [`hyperactor::data::Named`] trait for a struct with the
1249/// provided type URI. The name of the type is its fully-qualified Rust
1250/// path. The name may be overridden by providing a string value for the
1251/// `name` attribute.
1252///
1253/// In addition to deriving [`hyperactor::data::Named`], this macro will
1254/// register the type using the [`hyperactor::register_type`] macro for
1255/// concrete types. This behavior can be overridden by providing a literal
1256/// booolean for the `register` attribute.
1257///
1258/// This also requires the type to implement [`serde::Serialize`]
1259/// and [`serde::Deserialize`].
1260#[proc_macro_derive(Named, attributes(named))]
1261pub fn derive_named(input: TokenStream) -> TokenStream {
1262    // Parse the input struct or enum
1263    let input = parse_macro_input!(input as DeriveInput);
1264    let struct_name = &input.ident;
1265
1266    let mut typename = quote! {
1267        concat!(std::module_path!(), "::", stringify!(#struct_name))
1268    };
1269
1270    let type_params: Vec<_> = input.generics.type_params().collect();
1271    let has_generics = !type_params.is_empty();
1272    // By default, register concrete types.
1273    let mut register = !has_generics;
1274
1275    for attr in &input.attrs {
1276        if attr.path().is_ident("named") {
1277            if let Ok(meta) = attr.parse_args_with(
1278                syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
1279            ) {
1280                for item in meta {
1281                    if let Meta::NameValue(MetaNameValue {
1282                        path,
1283                        value: Expr::Lit(expr_lit),
1284                        ..
1285                    }) = item
1286                    {
1287                        if path.is_ident("name") {
1288                            if let Lit::Str(name) = expr_lit.lit {
1289                                typename = quote! { #name };
1290                            } else {
1291                                return TokenStream::from(
1292                                    syn::Error::new_spanned(path, "invalid name")
1293                                        .to_compile_error(),
1294                                );
1295                            }
1296                        } else if path.is_ident("register") {
1297                            if let Lit::Bool(flag) = expr_lit.lit {
1298                                register = flag.value;
1299                            } else {
1300                                return TokenStream::from(
1301                                    syn::Error::new_spanned(path, "invalid registration flag")
1302                                        .to_compile_error(),
1303                                );
1304                            }
1305                        } else {
1306                            return TokenStream::from(
1307                                syn::Error::new_spanned(
1308                                    path,
1309                                    "unsupported attribute (only `name` or `register` is supported)",
1310                                )
1311                                .to_compile_error(),
1312                            );
1313                        }
1314                    }
1315                }
1316            }
1317        }
1318    }
1319
1320    // Create a version of generics with Named bounds for the impl block
1321    let mut generics_with_bounds = input.generics.clone();
1322    if has_generics {
1323        for param in generics_with_bounds.type_params_mut() {
1324            param
1325                .bounds
1326                .push(syn::parse_quote!(hyperactor::data::Named));
1327        }
1328    }
1329    let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1330
1331    // Generate typename implementation based on whether we have generics
1332    let (typename_impl, typehash_impl) = if has_generics {
1333        // Create format string with placeholders for each generic parameter
1334        let placeholders = vec!["{}"; type_params.len()].join(", ");
1335        let placeholders_format_string = format!("<{}>", placeholders);
1336        let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#struct_name), #placeholders_format_string) };
1337
1338        let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1339        (
1340            quote! {
1341                hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1342            },
1343            quote! {
1344                hyperactor::cityhasher::hash(Self::typename())
1345            },
1346        )
1347    } else {
1348        (
1349            typename,
1350            quote! {
1351                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1352                    hyperactor::cityhasher::hash(<#struct_name as hyperactor::data::Named>::typename())
1353                });
1354                *TYPEHASH
1355            },
1356        )
1357    };
1358
1359    // Generate 'arm' for enums only.
1360    let arm_impl = match &input.data {
1361        Data::Enum(DataEnum { variants, .. }) => {
1362            let match_arms = variants.iter().map(|v| {
1363                let variant_name = &v.ident;
1364                let variant_str = variant_name.to_string();
1365                match &v.fields {
1366                    Fields::Unit => quote! { Self::#variant_name => Some(#variant_str) },
1367                    Fields::Unnamed(_) => quote! { Self::#variant_name(..) => Some(#variant_str) },
1368                    Fields::Named(_) => quote! { Self::#variant_name { .. } => Some(#variant_str) },
1369                }
1370            });
1371            quote! {
1372                fn arm(&self) -> Option<&'static str> {
1373                    match self {
1374                        #(#match_arms,)*
1375                    }
1376                }
1377            }
1378        }
1379        _ => quote! {},
1380    };
1381
1382    // Try to register the type so we can get runtime TypeInfo.
1383    // We can only do this for concrete types.
1384    //
1385    // TODO: explore making type hashes "structural", so that we
1386    // can derive generic type hashes and reconstruct their runtime
1387    // TypeInfos.
1388    let registration = if register {
1389        quote! {
1390            hyperactor::register_type!(#struct_name);
1391        }
1392    } else {
1393        quote! {
1394            // Registration not requested
1395        }
1396    };
1397
1398    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
1399    // Ideally we would compute the has directly in the macro itself, however, we don't
1400    // have access to the fully expanded pathname here as we use the intrinsic std::module_path!() macro.
1401    let expanded = quote! {
1402        impl #impl_generics_with_bounds hyperactor::data::Named for #struct_name #ty_generics #where_clause {
1403            fn typename() -> &'static str { #typename_impl }
1404            fn typehash() -> u64 { #typehash_impl }
1405            #arm_impl
1406        }
1407
1408        #registration
1409    };
1410
1411    TokenStream::from(expanded)
1412}
1413
1414struct HandlerSpec {
1415    ty: Type,
1416    cast: bool,
1417}
1418
1419impl Parse for HandlerSpec {
1420    fn parse(input: ParseStream) -> syn::Result<Self> {
1421        let ty: Type = input.parse()?;
1422
1423        if input.peek(syn::token::Brace) {
1424            let content;
1425            syn::braced!(content in input);
1426            let key: Ident = content.parse()?;
1427            content.parse::<Token![=]>()?;
1428            let expr: Expr = content.parse()?;
1429
1430            let cast = if key == "cast" {
1431                if let Expr::Lit(ExprLit {
1432                    lit: Lit::Bool(b), ..
1433                }) = expr
1434                {
1435                    b.value
1436                } else {
1437                    return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1438                }
1439            } else {
1440                return Err(syn::Error::new_spanned(
1441                    key,
1442                    "unsupported field (expected `cast`)",
1443                ));
1444            };
1445
1446            Ok(HandlerSpec { ty, cast })
1447        } else if input.is_empty() || input.peek(Token![,]) {
1448            Ok(HandlerSpec { ty, cast: false })
1449        } else {
1450            // Something unexpected follows the type
1451            let unexpected: proc_macro2::TokenTree = input.parse()?;
1452            Err(syn::Error::new_spanned(
1453                unexpected,
1454                "unexpected token after type — expected `{ ... }` or nothing",
1455            ))
1456        }
1457    }
1458}
1459
1460impl HandlerSpec {
1461    fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1462        let mut tys = Vec::new();
1463        for HandlerSpec { ty, cast } in handlers {
1464            if cast {
1465                let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1466                let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1467                tys.push(wrapped_ty);
1468            }
1469            tys.push(ty);
1470        }
1471        tys
1472    }
1473}
1474
1475/// Attribute Struct for [`fn export`] macro.
1476struct ExportAttr {
1477    spawn: bool,
1478    handlers: Vec<HandlerSpec>,
1479}
1480
1481impl Parse for ExportAttr {
1482    fn parse(input: ParseStream) -> syn::Result<Self> {
1483        let mut spawn = false;
1484        let mut handlers: Vec<HandlerSpec> = vec![];
1485
1486        while !input.is_empty() {
1487            let key: Ident = input.parse()?;
1488            input.parse::<Token![=]>()?;
1489
1490            if key == "spawn" {
1491                let expr: Expr = input.parse()?;
1492                if let Expr::Lit(ExprLit {
1493                    lit: Lit::Bool(b), ..
1494                }) = expr
1495                {
1496                    spawn = b.value;
1497                } else {
1498                    return Err(syn::Error::new_spanned(
1499                        expr,
1500                        "expected boolean for `spawn`",
1501                    ));
1502                }
1503            } else if key == "handlers" {
1504                let content;
1505                bracketed!(content in input);
1506                let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1507                handlers = raw_handlers.into_iter().collect();
1508            } else {
1509                return Err(syn::Error::new_spanned(
1510                    key,
1511                    "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1512                ));
1513            }
1514
1515            // optional trailing comma
1516            let _ = input.parse::<Token![,]>();
1517        }
1518
1519        Ok(ExportAttr { spawn, handlers })
1520    }
1521}
1522
1523/// Exports handlers for this actor. The set of exported handlers
1524/// determine the messages that may be sent to remote references of
1525/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
1526/// [`hyperactor::RemoteMessage`] may be exported.
1527///
1528/// Additionally, an exported actor may be remotely spawned,
1529/// indicated by `spawn = true`. Such actors must also ensure that
1530/// their parameter type implements [`hyperactor::RemoteMessage`].
1531///
1532/// # Example
1533///
1534/// In the following example, `MyActor` can be spawned remotely. It also has
1535/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`.
1536/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these
1537/// types.
1538///
1539/// ```ignore
1540/// #[export(
1541///     spawn = true,
1542///     handlers = [
1543///         MyMessage,
1544///         MyOtherMessage,
1545///     ],
1546/// )]
1547/// struct MyActor {}
1548/// ```
1549#[proc_macro_attribute]
1550pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1551    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1552    let data_type_name = &input.ident;
1553
1554    let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1555    let tys = HandlerSpec::add_indexed(handlers);
1556
1557    let mut handles = Vec::new();
1558    let mut bindings = Vec::new();
1559    let mut type_registrations = Vec::new();
1560
1561    for ty in &tys {
1562        handles.push(quote! {
1563            impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
1564        });
1565        bindings.push(quote! {
1566            ports.bind::<#ty>();
1567        });
1568        type_registrations.push(quote! {
1569            hyperactor::register_type!(#ty);
1570        });
1571    }
1572
1573    let mut expanded = quote! {
1574        #input
1575
1576        impl hyperactor::actor::Referable for #data_type_name {}
1577
1578        #(#handles)*
1579
1580        #(#type_registrations)*
1581
1582        // Always export the `Signal` type.
1583        impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
1584
1585        impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
1586            fn bind(ports: &hyperactor::proc::Ports<Self>) {
1587                #(#bindings)*
1588            }
1589        }
1590
1591        // TODO: just use Named derive directly here.
1592        impl hyperactor::data::Named for #data_type_name {
1593            fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
1594        }
1595    };
1596
1597    if spawn {
1598        expanded.extend(quote! {
1599            hyperactor::remote!(#data_type_name);
1600        });
1601    }
1602
1603    TokenStream::from(expanded)
1604}
1605
1606/// Represents the full input to [`fn alias`].
1607struct AliasInput {
1608    alias: Ident,
1609    handlers: Vec<HandlerSpec>,
1610}
1611
1612impl syn::parse::Parse for AliasInput {
1613    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1614        let alias: Ident = input.parse()?;
1615        let _: Token![,] = input.parse()?;
1616        let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1617        let handlers = raw_handlers.into_iter().collect();
1618        Ok(AliasInput { alias, handlers })
1619    }
1620}
1621
1622/// Create a [`Referable`] handling a specific set of message types.
1623/// This is used to create an [`ActorRef`] without having to depend on the
1624/// actor's implementation. If the message type need to be cast, add `castable`
1625/// flag to those types. e.g. the following example creats an alias with 5
1626/// message types, and 4 of which need to be cast.
1627///
1628/// ```
1629/// hyperactor::alias!(
1630///     TestActorAlias,
1631///     TestMessage { castable = true },
1632///     () {castable = true },
1633///     MyGeneric<()> {castable = true },
1634///     u64,
1635/// );
1636/// ```
1637#[proc_macro]
1638pub fn alias(input: TokenStream) -> TokenStream {
1639    let AliasInput { alias, handlers } = parse_macro_input!(input as AliasInput);
1640    let tys = HandlerSpec::add_indexed(handlers);
1641
1642    let expanded = quote! {
1643        #[doc = "The generated alias struct."]
1644        #[derive(Debug, hyperactor::Named, serde::Serialize, serde::Deserialize)]
1645        pub struct #alias;
1646        impl hyperactor::actor::Referable for #alias {}
1647
1648        impl<A> hyperactor::actor::Binds<A> for #alias
1649        where
1650            A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)* {
1651            fn bind(ports: &hyperactor::proc::Ports<A>) {
1652                #(
1653                    ports.bind::<#tys>();
1654                )*
1655            }
1656        }
1657
1658        #(
1659            impl hyperactor::actor::RemoteHandles<#tys> for #alias {}
1660        )*
1661    };
1662
1663    TokenStream::from(expanded)
1664}
1665
1666fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1667    let mut is_included = false;
1668    for attr in &field.attrs {
1669        if attr.path().is_ident("binding") {
1670            // parse #[binding(include)] and look for exactly "include"
1671            attr.parse_nested_meta(|meta| {
1672                if meta.path.is_ident("include") {
1673                    is_included = true;
1674                    Ok(())
1675                } else {
1676                    let path = meta.path.to_token_stream().to_string().replace(' ', "");
1677                    Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1678                }
1679            })?
1680        }
1681    }
1682    Ok(is_included)
1683}
1684
1685/// The field accessor in struct or enum variant.
1686/// e.g.:
1687///   struct NamedStruct { foo: u32 } => FieldAccessor::Named(Ident::new("foo", Span::call_site()))
1688///   struct UnnamedStruct(u32) => FieldAccessor::Unnamed(Index::from(0))
1689enum FieldAccessor {
1690    Named(Ident),
1691    Unnamed(Index),
1692}
1693
1694/// Result of parsing a field in a struct, or a enum variant.
1695struct ParsedField {
1696    accessor: FieldAccessor,
1697    ty: Type,
1698    included: bool,
1699}
1700
1701impl From<&ParsedField> for (Ident, Type) {
1702    fn from(field: &ParsedField) -> Self {
1703        let field_ident = match &field.accessor {
1704            FieldAccessor::Named(ident) => ident.clone(),
1705            FieldAccessor::Unnamed(i) => {
1706                Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1707            }
1708        };
1709        (field_ident, field.ty.clone())
1710    }
1711}
1712
1713fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1714    match fields {
1715        Fields::Named(named) => named
1716            .named
1717            .iter()
1718            .map(|f| {
1719                let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1720                Ok(ParsedField {
1721                    accessor,
1722                    ty: f.ty.clone(),
1723                    included: include_in_bind_unbind(f)?,
1724                })
1725            })
1726            .collect(),
1727        Fields::Unnamed(unnamed) => unnamed
1728            .unnamed
1729            .iter()
1730            .enumerate()
1731            .map(|(i, f)| {
1732                let accessor = FieldAccessor::Unnamed(Index::from(i));
1733                Ok(ParsedField {
1734                    accessor,
1735                    ty: f.ty.clone(),
1736                    included: include_in_bind_unbind(f)?,
1737                })
1738            })
1739            .collect(),
1740        Fields::Unit => Ok(Vec::new()),
1741    }
1742}
1743
1744fn gen_struct_items<F>(
1745    fields: &Fields,
1746    make_item: F,
1747    is_mutable: bool,
1748) -> syn::Result<Vec<proc_macro2::TokenStream>>
1749where
1750    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1751{
1752    let borrow = if is_mutable {
1753        quote! { &mut }
1754    } else {
1755        quote! { & }
1756    };
1757    let items: Vec<_> = collect_all_fields(fields)?
1758        .into_iter()
1759        .filter(|f| f.included)
1760        .map(
1761            |ParsedField {
1762                 accessor,
1763                 ty,
1764                 included,
1765             }| {
1766                assert!(included);
1767                let field_accessor = match accessor {
1768                    FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1769                    FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1770                };
1771                make_item(field_accessor, ty)
1772            },
1773        )
1774        .collect();
1775    Ok(items)
1776}
1777
1778/// Generate the field accessor for a enum variant in pattern matching. e.g.
1779/// the <GENERATED> parts in the following example:
1780///
1781///   match my_enum {
1782///     // e.g. MyEnum::Tuple(_, f1, f2, _)
1783///     MyEnum::Tuple(<GENERATED>) => { ... }
1784///     // e.g. MyEnum::Struct { field0: _, field1 }
1785///     MyEnum::Struct(<GENERATED>) => { ... }
1786///   }
1787fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1788    all_fields
1789        .iter()
1790        .map(
1791            |ParsedField {
1792                 accessor,
1793                 ty: _,
1794                 included,
1795             }| {
1796                match accessor {
1797                    FieldAccessor::Named(ident) => {
1798                        if *included {
1799                            quote! { #ident }
1800                        } else {
1801                            quote! { #ident: _ }
1802                        }
1803                    }
1804                    FieldAccessor::Unnamed(i) => {
1805                        if *included {
1806                            let ident = Ident::new(
1807                                &format!("f{}", i.index),
1808                                proc_macro2::Span::call_site(),
1809                            );
1810                            quote! { #ident }
1811                        } else {
1812                            quote! { _ }
1813                        }
1814                    }
1815                }
1816            },
1817        )
1818        .collect()
1819}
1820
1821/// Generate all the parts for enum variants. e.g. the <GENERATED> part in the
1822/// following example:
1823///
1824///   match my_enum {
1825///      <GENERATED>
1826///   }
1827fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1828where
1829    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1830{
1831    data.variants
1832        .iter()
1833        .map(|variant| {
1834            let name = &variant.ident;
1835            let all_fields = collect_all_fields(&variant.fields)?;
1836            let field_accessors = gen_enum_field_accessors(&all_fields);
1837            let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1838            let items = included_fields
1839                .iter()
1840                .map(|f| {
1841                    let (accessor, ty) = <(Ident, Type)>::from(*f);
1842                    make_item(quote! { #accessor }, ty)
1843                })
1844                .collect::<Vec<_>>();
1845
1846            Ok(match &variant.fields {
1847                Fields::Named(_) => {
1848                    quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1849                }
1850                Fields::Unnamed(_) => {
1851                    quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1852                }
1853                Fields::Unit => quote! { Self::#name => { #(#items)* } },
1854            })
1855        })
1856        .collect()
1857}
1858
1859/// Derive a custom implementation of [`hyperactor::message::Bind`] trait for
1860/// a struct or enum. This macro is normally used in tandem with [`fn derive_unbind`]
1861/// to make the applied struct or enum castable.
1862///
1863/// Specifically, the derived implementation iterates through fields annotated
1864/// with `#[binding(include)]` based on their order of declaration in the struct
1865/// or enum. These fields' types must implement `Bind` trait as well. During the
1866/// iteration, parameters from `bindings` are bound to these fields.
1867///
1868/// # Example
1869///
1870/// This macro supports named and unamed structs and enums. Below are examples
1871/// of the supported types:
1872///
1873/// ```
1874/// #[derive(Bind, Unbind)]
1875/// struct MyNamedStruct {
1876///     field0: u64,
1877///     field1: MyReply,
1878///     #[binding(include)]
1879/// nnnn     field2: PortRef<MyReply>,
1880///     field3: bool,
1881///     #[binding(include)]
1882///     field4: hyperactor::PortRef<u64>,
1883/// }
1884///
1885/// #[derive(Bind, Unbind)]
1886/// struct MyUnamedStruct(
1887///     u64,
1888///     MyReply,
1889///     #[binding(include)] hyperactor::PortRef<MyReply>,
1890///     bool,
1891///     #[binding(include)] PortRef<u64>,
1892/// );
1893///
1894/// #[derive(Bind, Unbind)]
1895/// enum MyEnum {
1896///     Unit,
1897///     NoopTuple(u64, bool),
1898///     NoopStruct {
1899///         field0: u64,
1900///         field1: bool,
1901///     },
1902///     Tuple(
1903///         u64,
1904///         MyReply,
1905///         #[binding(include)] PortRef<MyReply>,
1906///         bool,
1907///         #[binding(include)] hyperactor::PortRef<u64>,
1908///     ),
1909///     Struct {
1910///         field0: u64,
1911///         field1: MyReply,
1912///         #[binding(include)]
1913///         field2: PortRef<MyReply>,
1914///         field3: bool,
1915///         #[binding(include)]
1916///         field4: hyperactor::PortRef<u64>,
1917///     },
1918/// }
1919/// ```
1920///
1921/// The following shows what derived `Bind`` and `Unbind`` implementations for
1922/// `MyNamedStruct` will look like. The implementations of other types are
1923/// similar, and thus are not shown here.
1924/// ```ignore
1925/// impl Bind for MyNamedStruct {
1926/// fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
1927///     Bind::bind(&mut self.field2, bindings)?;
1928///     Bind::bind(&mut self.field4, bindings)?;
1929///     Ok(())
1930/// }
1931///
1932/// impl Unbind for MyNamedStruct {
1933///     fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
1934///         Unbind::unbind(&self.field2, bindings)?;
1935///         Unbind::unbind(&self.field4, bindings)?;
1936///         Ok(())
1937///     }
1938/// }
1939/// ```
1940#[proc_macro_derive(Bind, attributes(binding))]
1941pub fn derive_bind(input: TokenStream) -> TokenStream {
1942    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1943        quote! {
1944            hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1945        }
1946    }
1947
1948    let input = parse_macro_input!(input as DeriveInput);
1949    let name = &input.ident;
1950    let inner = match &input.data {
1951        Data::Struct(DataStruct { fields, .. }) => {
1952            match gen_struct_items(fields, make_item, true) {
1953                Ok(collects) => {
1954                    quote! { #(#collects)* }
1955                }
1956                Err(e) => {
1957                    return TokenStream::from(e.to_compile_error());
1958                }
1959            }
1960        }
1961        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1962            Ok(arms) => {
1963                quote! { match self { #(#arms),* } }
1964            }
1965            Err(e) => {
1966                return TokenStream::from(e.to_compile_error());
1967            }
1968        },
1969        _ => panic!("Bind can only be derived for structs and enums"),
1970    };
1971    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1972    let expand = quote! {
1973        #[automatically_derived]
1974        impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1975            fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1976                #inner
1977                Ok(())
1978            }
1979        }
1980    };
1981    TokenStream::from(expand)
1982}
1983
1984/// Derive a custom implementation of [`hyperactor::message::Unbind`] trait for
1985/// a struct or enum. This macro is normally used in tandem with [`fn derive_bind`]
1986/// to make the applied struct or enum castable.
1987///
1988/// Specifically, the derived implementation iterates through fields annoated
1989/// with `#[binding(include)]` based on their order of declaration in the struct
1990/// or enum. These fields' types must implement `Unbind` trait as well. During
1991/// the iteration, parameters from these fields are extracted and stored in
1992/// `bindings`.
1993///
1994/// # Example
1995///
1996/// See [`fn derive_bind`]'s documentation for examples.
1997#[proc_macro_derive(Unbind, attributes(binding))]
1998pub fn derive_unbind(input: TokenStream) -> TokenStream {
1999    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
2000        quote! {
2001            hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
2002        }
2003    }
2004
2005    let input = parse_macro_input!(input as DeriveInput);
2006    let name = &input.ident;
2007    let inner = match &input.data {
2008        Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
2009        {
2010            Ok(collects) => {
2011                quote! { #(#collects)* }
2012            }
2013            Err(e) => {
2014                return TokenStream::from(e.to_compile_error());
2015            }
2016        },
2017        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
2018            Ok(arms) => {
2019                quote! { match self { #(#arms),* } }
2020            }
2021            Err(e) => {
2022                return TokenStream::from(e.to_compile_error());
2023            }
2024        },
2025        _ => panic!("Unbind can only be derived for structs and enums"),
2026    };
2027    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2028    let expand = quote! {
2029        #[automatically_derived]
2030        impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
2031            fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
2032                #inner
2033                Ok(())
2034            }
2035        }
2036    };
2037    TokenStream::from(expand)
2038}
2039
2040/// Derives the `Actor` trait for a struct. By default, generates an implementation
2041/// with no params (`type Params = ()` and `async fn new(_params: ()) -> Result<Self, anyhow::Error>`).
2042/// This requires that the Actor implements [`Default`].
2043///
2044/// If the `#[actor(passthrough)]` attribute is specified, generates an implementation
2045/// with where the parameter type is `Self`
2046/// (`type Params = Self` and `async fn new(instance: Self) -> Result<Self, anyhow::Error>`).
2047///
2048/// # Examples
2049///
2050/// Default behavior:
2051/// ```
2052/// #[derive(Actor, Default)]
2053/// struct MyActor(u64);
2054/// ```
2055///
2056/// Generates:
2057/// ```ignore
2058/// #[async_trait]
2059/// impl Actor for MyActor {
2060///     type Params = ();
2061///
2062///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
2063///         Ok(Default::default())
2064///     }
2065/// }
2066/// ```
2067///
2068/// Passthrough behavior:
2069/// ```
2070/// #[derive(Actor, Default)]
2071/// #[actor(passthrough)]
2072/// struct MyActor(u64);
2073/// ```
2074///
2075/// Generates:
2076/// ```ignore
2077/// #[async_trait]
2078/// impl Actor for MyActor {
2079///     type Params = Self;
2080///
2081///     async fn new(instance: Self) -> Result<Self, anyhow::Error> {
2082///         Ok(instance)
2083///     }
2084/// }
2085/// ```
2086#[proc_macro_derive(Actor, attributes(actor))]
2087pub fn derive_actor(input: TokenStream) -> TokenStream {
2088    let input = parse_macro_input!(input as DeriveInput);
2089    let name = &input.ident;
2090    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2091
2092    let is_passthrough = input.attrs.iter().any(|attr| {
2093        if attr.path().is_ident("actor") {
2094            if let Ok(meta) = attr.parse_args_with(
2095                syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
2096            ) {
2097                return meta.iter().any(|ident| ident == "passthrough");
2098            }
2099        }
2100        false
2101    });
2102
2103    let expanded = if is_passthrough {
2104        quote! {
2105            #[hyperactor::async_trait::async_trait]
2106            impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2107                type Params = Self;
2108
2109                async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
2110                    Ok(instance)
2111                }
2112            }
2113        }
2114    } else {
2115        quote! {
2116            #[hyperactor::async_trait::async_trait]
2117            impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
2118                type Params = ();
2119
2120                async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
2121                    Ok(Default::default())
2122                }
2123            }
2124        }
2125    };
2126
2127    TokenStream::from(expanded)
2128}
2129
2130// Helper function for common parsing and validation
2131fn parse_observe_function(
2132    attr: TokenStream,
2133    item: TokenStream,
2134) -> syn::Result<(ItemFn, String, String)> {
2135    let input = syn::parse::<ItemFn>(item)?;
2136
2137    if input.sig.asyncness.is_none() {
2138        return Err(syn::Error::new(
2139            input.sig.span(),
2140            "observe macros can only be applied to async functions",
2141        ));
2142    }
2143
2144    let fn_name_str = input.sig.ident.to_string();
2145    let module_name_str = syn::parse::<syn::LitStr>(attr)?.value();
2146
2147    Ok((input, fn_name_str, module_name_str))
2148}
2149
2150// Helper function for creating telemetry identifiers and setup code
2151fn create_telemetry_setup(
2152    module_name_str: &str,
2153    fn_name_str: &str,
2154    include_error: bool,
2155) -> (Ident, Ident, Option<Ident>, proc_macro2::TokenStream) {
2156    let module_and_fn = format!("{}_{}", module_name_str, fn_name_str);
2157    let latency_ident = Ident::new("latency", Span::from(proc_macro::Span::def_site()));
2158
2159    let success_ident = Ident::new("success", Span::from(proc_macro::Span::def_site()));
2160
2161    let error_ident = if include_error {
2162        Some(Ident::new(
2163            "error",
2164            Span::from(proc_macro::Span::def_site()),
2165        ))
2166    } else {
2167        None
2168    };
2169
2170    let error_declaration = if let Some(ref error_ident) = error_ident {
2171        quote! {
2172            hyperactor_telemetry::declare_static_counter!(#error_ident, concat!(#module_and_fn, ".error"));
2173        }
2174    } else {
2175        quote! {}
2176    };
2177
2178    let setup_code = quote! {
2179        use hyperactor_telemetry;
2180        hyperactor_telemetry::declare_static_timer!(#latency_ident, concat!(#module_and_fn, ".latency"), hyperactor_telemetry::TimeUnit::Micros);
2181        hyperactor_telemetry::declare_static_counter!(#success_ident, concat!(#module_and_fn, ".success"));
2182        #error_declaration
2183    };
2184
2185    (latency_ident, success_ident, error_ident, setup_code)
2186}
2187
2188/// A procedural macro that automatically injects telemetry code into async functions
2189/// that return a Result type.
2190///
2191/// This macro wraps async functions and adds instrumentation to measure:
2192/// 1. Latency - how long the function takes to execute
2193/// 2. Error counter - function error count
2194/// 3. Success counter - function completion count
2195///
2196/// # Example
2197///
2198/// ```rust
2199/// use hyperactor_actor::observe_result;
2200///
2201/// #[observe_result("my_module")]
2202/// async fn process_request(user_id: &str) -> Result<String, Error> {
2203///     // Function implementation
2204///     // Telemetry will be automatically collected
2205/// }
2206/// ```
2207#[proc_macro_attribute]
2208pub fn observe_result(attr: TokenStream, item: TokenStream) -> TokenStream {
2209    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2210        Ok(parsed) => parsed,
2211        Err(err) => return err.to_compile_error().into(),
2212    };
2213
2214    let fn_name = &input.sig.ident;
2215    let vis = &input.vis;
2216    let args = &input.sig.inputs;
2217    let return_type = &input.sig.output;
2218    let body = &input.block;
2219    let attrs = &input.attrs;
2220    let generics = &input.sig.generics;
2221
2222    let (latency_ident, success_ident, error_ident, telemetry_setup) =
2223        create_telemetry_setup(&module_name_str, &fn_name_str, true);
2224    let error_ident = error_ident.unwrap();
2225
2226    let result_ident = Ident::new("result", Span::from(proc_macro::Span::def_site()));
2227
2228    // Generate the instrumented function
2229    let expanded = quote! {
2230        #(#attrs)*
2231        #vis async fn #fn_name #generics(#args) #return_type {
2232            #telemetry_setup
2233
2234            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2235            let _timer = #latency_ident.start(kv_pairs);
2236
2237            let #result_ident = async #body.await;
2238
2239            match &#result_ident {
2240                Ok(_) => {
2241                    #success_ident.add(
2242                        1,
2243                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2244                    );
2245                }
2246                Err(_) => {
2247                    #error_ident.add(
2248                        1,
2249                        hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2250                    );
2251                }
2252            }
2253
2254            #result_ident
2255        }
2256    };
2257
2258    expanded.into()
2259}
2260
2261/// A procedural macro that automatically injects telemetry code into async functions
2262/// that do not return a Result type.
2263///
2264/// This macro wraps async functions and adds instrumentation to measure:
2265/// 1. Latency - how long the function takes to execute
2266/// 2. Success counter - function completion count
2267///
2268/// # Example
2269///
2270/// ```rust
2271/// use hyperactor_actor::observe_async;
2272///
2273/// #[observe_async("my_module")]
2274/// async fn process_data(data: &str) -> String {
2275///     // Function implementation
2276///     // Telemetry will be automatically collected
2277/// }
2278/// ```
2279#[proc_macro_attribute]
2280pub fn observe_async(attr: TokenStream, item: TokenStream) -> TokenStream {
2281    let (input, fn_name_str, module_name_str) = match parse_observe_function(attr, item) {
2282        Ok(parsed) => parsed,
2283        Err(err) => return err.to_compile_error().into(),
2284    };
2285
2286    let fn_name = &input.sig.ident;
2287    let vis = &input.vis;
2288    let args = &input.sig.inputs;
2289    let return_type = &input.sig.output;
2290    let body = &input.block;
2291    let attrs = &input.attrs;
2292    let generics = &input.sig.generics;
2293
2294    let (latency_ident, success_ident, _, telemetry_setup) =
2295        create_telemetry_setup(&module_name_str, &fn_name_str, false);
2296
2297    let return_ident = Ident::new("ret", Span::from(proc_macro::Span::def_site()));
2298
2299    // Generate the instrumented function
2300    let expanded = quote! {
2301        #(#attrs)*
2302        #vis async fn #fn_name #generics(#args) #return_type {
2303            #telemetry_setup
2304
2305            let kv_pairs = hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone());
2306            let _timer = #latency_ident.start(kv_pairs);
2307
2308            let #return_ident = async #body.await;
2309
2310            #success_ident.add(
2311                1,
2312                hyperactor_telemetry::kv_pairs!("function" => #fn_name_str.clone())
2313            );
2314            #return_ident
2315        }
2316    };
2317
2318    expanded.into()
2319}
2320
2321/// Derive the [`hyperactor::attrs::AttrValue`] trait for a struct or enum.
2322///
2323/// This macro generates an implementation that uses the type's `ToString` and `FromStr`
2324/// implementations for the `display` and `parse` methods respectively.
2325///
2326/// The type must already implement the required super-traits:
2327/// `Named + Sized + Serialize + DeserializeOwned + Send + Sync + Clone + 'static`
2328/// as well as `ToString` and `FromStr`.
2329///
2330/// # Example
2331///
2332/// ```
2333/// #[derive(AttrValue, Named, Serialize, Deserialize, Clone)]
2334/// struct MyCustomType {
2335///     value: String,
2336/// }
2337///
2338/// impl std::fmt::Display for MyCustomType {
2339///     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2340///         write!(f, "{}", self.value)
2341///     }
2342/// }
2343///
2344/// impl std::str::FromStr for MyCustomType {
2345///     type Err = std::io::Error;
2346///
2347///     fn from_str(s: &str) -> Result<Self, Self::Err> {
2348///         Ok(MyCustomType {
2349///             value: s.to_string(),
2350///         })
2351///     }
2352/// }
2353/// ```
2354#[proc_macro_derive(AttrValue)]
2355pub fn derive_attr_value(input: TokenStream) -> TokenStream {
2356    let input = parse_macro_input!(input as DeriveInput);
2357    let name = &input.ident;
2358
2359    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2360
2361    TokenStream::from(quote! {
2362        impl #impl_generics hyperactor::attrs::AttrValue for #name #ty_generics #where_clause {
2363            fn display(&self) -> String {
2364                self.to_string()
2365            }
2366
2367            fn parse(value: &str) -> Result<Self, anyhow::Error> {
2368                value.parse().map_err(|e| anyhow::anyhow!("failed to parse {}: {}", stringify!(#name), e))
2369            }
2370        }
2371    })
2372}