hyperactor_macros/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines macros used by the [`hyperactor`] crate.
10
11#![feature(proc_macro_def_site)]
12#![deny(missing_docs)]
13
14extern crate proc_macro;
15
16use convert_case::Case;
17use convert_case::Casing;
18use indoc::indoc;
19use proc_macro::TokenStream;
20use proc_macro2::Span;
21use quote::ToTokens;
22use quote::format_ident;
23use quote::quote;
24use syn::Attribute;
25use syn::Data;
26use syn::DataEnum;
27use syn::DataStruct;
28use syn::DeriveInput;
29use syn::Expr;
30use syn::ExprLit;
31use syn::Field;
32use syn::Fields;
33use syn::Ident;
34use syn::Index;
35use syn::ItemFn;
36use syn::ItemImpl;
37use syn::Lit;
38use syn::Meta;
39use syn::MetaNameValue;
40use syn::Token;
41use syn::Type;
42use syn::bracketed;
43use syn::parse::Parse;
44use syn::parse::ParseStream;
45use syn::parse_macro_input;
46use syn::punctuated::Punctuated;
47use syn::spanned::Spanned;
48
49const REPLY_VARIANT_ERROR: &str = indoc! {r#"
50`call` message expects a typed `OncePortRef` or `OncePortHandle` argument in the last position
51
52= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortRef<ReplyType>)`
53= help: use `MyCall(Arg1Type, Arg2Type, .., OncePortHandle<ReplyType>)`
54"#};
55
56const REPLY_USAGE_ERROR: &str = indoc! {r#"
57`call` message expects at most one `reply` argument
58
59= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
60= help: use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
61"#};
62
63enum FieldFlag {
64    None,
65    Reply,
66}
67
68/// Represents a variant of an enum.
69enum Variant {
70    /// A named variant (i.e., `MyVariant { .. }`).
71    Named {
72        enum_name: Ident,
73        name: Ident,
74        field_names: Vec<Ident>,
75        field_types: Vec<Type>,
76        field_flags: Vec<FieldFlag>,
77    },
78    /// An anonymous variant (i.e., `MyVariant(..)`).
79    Anon {
80        enum_name: Ident,
81        name: Ident,
82        field_types: Vec<Type>,
83        field_flags: Vec<FieldFlag>,
84    },
85}
86
87impl Variant {
88    /// The number of fields in the variant.
89    fn len(&self) -> usize {
90        self.field_types().len()
91    }
92
93    /// The name of the enum containing the variant.
94    fn enum_name(&self) -> &Ident {
95        match self {
96            Variant::Named { enum_name, .. } => enum_name,
97            Variant::Anon { enum_name, .. } => enum_name,
98        }
99    }
100
101    /// The name of the variant itself.
102    fn name(&self) -> &Ident {
103        match self {
104            Variant::Named { name, .. } => name,
105            Variant::Anon { name, .. } => name,
106        }
107    }
108
109    /// The snake_name of the variant itself.
110    fn snake_name(&self) -> Ident {
111        Ident::new(
112            &self.name().to_string().to_case(Case::Snake),
113            self.name().span(),
114        )
115    }
116
117    /// The variant's qualified name.
118    fn qualified_name(&self) -> proc_macro2::TokenStream {
119        let enum_name = self.enum_name();
120        let name = self.name();
121        quote! { #enum_name::#name }
122    }
123
124    /// Names of the fields in the variant. Anonymous variants are named
125    /// according to their position in the argument list.
126    fn field_names(&self) -> Vec<Ident> {
127        match self {
128            Variant::Named { field_names, .. } => field_names.clone(),
129            Variant::Anon { field_types, .. } => (0usize..field_types.len())
130                .map(|idx| format_ident!("arg{}", idx))
131                .collect(),
132        }
133    }
134
135    /// The types of the fields int the variant.
136    fn field_types(&self) -> &Vec<Type> {
137        match self {
138            Variant::Named { field_types, .. } => field_types,
139            Variant::Anon { field_types, .. } => field_types,
140        }
141    }
142
143    /// Return the field flags for this variant.
144    fn field_flags(&self) -> &Vec<FieldFlag> {
145        match self {
146            Variant::Named { field_flags, .. } => field_flags,
147            Variant::Anon { field_flags, .. } => field_flags,
148        }
149    }
150
151    /// The constructor for the variant, using the field names directly.
152    fn constructor(&self) -> proc_macro2::TokenStream {
153        let qualified_name = self.qualified_name();
154        let field_names = self.field_names();
155        match self {
156            Variant::Named { .. } => quote! { #qualified_name { #(#field_names),* } },
157            Variant::Anon { .. } => quote! { #qualified_name(#(#field_names),*) },
158        }
159    }
160}
161
162/// Represents a message that can be sent to a handler, each message is associated with
163/// a variant.
164#[allow(clippy::large_enum_variant)]
165enum Message {
166    /// A call message is a request-response message, the last argument is
167    /// a [`hyperactor::OncePortRef`] or [`hyperactor::OncePortHandle`].
168    Call {
169        variant: Variant,
170        /// Tells whether the reply argument is a handle.
171        reply_port_is_handle: bool,
172        /// The underlying return type (i.e., the type of the reply port).
173        return_type: Type,
174        /// the log level for generated instrumentation for handlers of this message.
175        log_level: Option<Ident>,
176    },
177    OneWay {
178        variant: Variant,
179        /// the log level for generated instrumentation for handlers of this message.
180        log_level: Option<Ident>,
181    },
182}
183
184impl Message {
185    fn new(span: Span, variant: Variant, log_level: Option<Ident>) -> Result<Self, syn::Error> {
186        match &variant
187            .field_flags()
188            .iter()
189            .zip(variant.field_types())
190            .filter_map(|(flag, ty)| match flag {
191                FieldFlag::Reply => Some(ty),
192                FieldFlag::None => None,
193            })
194            .collect::<Vec<&Type>>()[..]
195        {
196            [] => Ok(Self::OneWay { variant, log_level }),
197            [reply_port_ty] => {
198                let syn::Type::Path(type_path) = reply_port_ty else {
199                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
200                };
201                let Some(last_segment) = type_path.path.segments.last() else {
202                    return Err(syn::Error::new(span, REPLY_VARIANT_ERROR));
203                };
204                if last_segment.ident != "OncePortRef" && last_segment.ident != "OncePortHandle" {
205                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
206                }
207                let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments else {
208                    return Err(syn::Error::new_spanned(last_segment, REPLY_VARIANT_ERROR));
209                };
210                let Some(syn::GenericArgument::Type(return_ty)) = args.args.first() else {
211                    return Err(syn::Error::new_spanned(&args.args, REPLY_VARIANT_ERROR));
212                };
213                let reply_port_is_handle = last_segment.ident == "OncePortHandle";
214                let return_type = return_ty.clone();
215                Ok(Self::Call {
216                    variant,
217                    reply_port_is_handle,
218                    return_type,
219                    log_level,
220                })
221            }
222            _ => Err(syn::Error::new(span, REPLY_USAGE_ERROR)),
223        }
224    }
225
226    /// The arguments of this message.
227    fn args(&self) -> Vec<(Ident, Type)> {
228        match self {
229            Message::Call { variant, .. } => variant
230                .field_names()
231                .into_iter()
232                .zip(variant.field_types().clone())
233                .take(variant.len() - 1)
234                .collect(),
235            Message::OneWay { variant, .. } => variant
236                .field_names()
237                .into_iter()
238                .zip(variant.field_types().clone())
239                .collect(),
240        }
241    }
242
243    fn variant(&self) -> &Variant {
244        match self {
245            Message::Call { variant, .. } => variant,
246            Message::OneWay { variant, .. } => variant,
247        }
248    }
249
250    fn reply_port_position(&self) -> Option<usize> {
251        self.variant()
252            .field_flags()
253            .iter()
254            .position(|flag| matches!(flag, FieldFlag::Reply))
255    }
256
257    /// The reply port argument of this message.
258    fn reply_port_arg(&self) -> Option<(Ident, Type)> {
259        match self {
260            Message::Call { variant, .. } => {
261                let pos = self.reply_port_position()?;
262                Some((
263                    variant.field_names()[pos].clone(),
264                    variant.field_types()[pos].clone(),
265                ))
266            }
267            Message::OneWay { .. } => None,
268        }
269    }
270}
271
272fn parse_log_level(attrs: &[Attribute]) -> Result<Option<Ident>, syn::Error> {
273    let level: Option<String> = match attrs.iter().find(|attr| attr.path().is_ident("log_level")) {
274        Some(attr) => {
275            let Ok(meta) = attr.meta.require_list() else {
276                return Err(syn::Error::new(
277                    Span::call_site(),
278                    indoc! {"
279                            `log_level` attribute must specify level. Supported levels = error, warn, info, debug, trace
280
281                            = help use `#[log_level(info)]` or `#[log_level(error)]`
282                        "},
283                ));
284            };
285            let parsed = meta.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?;
286            if parsed.len() != 1 {
287                return Err(syn::Error::new(
288                    Span::call_site(),
289                    indoc! {"
290                            `log_level` attribute must specify exactly one level
291
292                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
293                        "},
294                ));
295            };
296            Some(parsed.first().unwrap().to_string())
297        }
298        None => None,
299    };
300
301    if level.is_none() {
302        return Ok(None);
303    }
304    let level = level.unwrap();
305
306    match level.as_str() {
307        "error" | "warn" | "info" | "debug" | "trace" => {}
308        _ => {
309            return Err(syn::Error::new(
310                Span::call_site(),
311                indoc! {"
312                            `log_level` attribute must be one of 'error, warn, info, debug, trace'
313
314                            = help use `#[log_level(warn)]` or `#[log_level(info)]`
315                        "},
316            ));
317        }
318    }
319
320    Ok(Some(Ident::new(
321        level.to_ascii_uppercase().as_str(),
322        Span::call_site(),
323    )))
324}
325
326fn parse_field_flag(field: &Field) -> FieldFlag {
327    for attr in field.attrs.iter() {
328        match &attr.meta {
329            syn::Meta::Path(path) if path.is_ident("reply") => return FieldFlag::Reply,
330            _ => {}
331        }
332    }
333    FieldFlag::None
334}
335
336/// Parse a message enum into its constituent messages.
337fn parse_message_enum(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
338    let variants = if let Data::Enum(data_enum) = &input.data {
339        &data_enum.variants
340    } else {
341        return Err(syn::Error::new_spanned(
342            input,
343            "handlers can only be derived for enums",
344        ));
345    };
346
347    let mut messages = Vec::new();
348
349    for variant in variants {
350        let name = variant.ident.clone();
351        let attrs = &variant.attrs;
352
353        let message_variant = match &variant.fields {
354            syn::Fields::Unnamed(fields_) => Variant::Anon {
355                enum_name: input.ident.clone(),
356                name,
357                field_types: fields_
358                    .unnamed
359                    .iter()
360                    .map(|field| field.ty.clone())
361                    .collect(),
362                field_flags: fields_.unnamed.iter().map(parse_field_flag).collect(),
363            },
364            syn::Fields::Named(fields_) => Variant::Named {
365                enum_name: input.ident.clone(),
366                name,
367                field_names: fields_
368                    .named
369                    .iter()
370                    .map(|field| field.ident.clone().unwrap())
371                    .collect(),
372                field_types: fields_.named.iter().map(|field| field.ty.clone()).collect(),
373                field_flags: fields_.named.iter().map(parse_field_flag).collect(),
374            },
375            _ => {
376                return Err(syn::Error::new_spanned(
377                    variant,
378                    indoc! {r#"
379                      `Handler` currently only supports named or tuple struct variants
380
381                      = help use `MyCall(Arg1Type, Arg2Type, ..)`,
382                      = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .. }`,
383                      = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortRef<ReplyType>)`
384                      = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortRef<ReplyType>)`
385                      = help use `MyCall(Arg1Type, Arg2Type, .., #[reply] OncePortHandle<ReplyType>)`
386                      = help use `MyCall { arg1: Arg1Type, arg2: Arg2Type, .., reply: #[reply] OncePortHandle<ReplyType>)`
387                    "#},
388                ));
389            }
390        };
391        let log_level = parse_log_level(attrs)?;
392
393        messages.push(Message::new(
394            variant.fields.span(),
395            message_variant,
396            log_level,
397        )?);
398    }
399
400    Ok(messages)
401}
402
403/// Derive a custom handler trait for given an enum containing tuple
404/// structs.  The handler trait defines a method corresponding
405/// to each of the enum's variants, and a `handle` function
406/// that dispatches messages to the correct method.  The macro
407/// supports two messaging patterns: "call" and "oneway". A call is a
408/// request-response message; a [`hyperactor::mailbox::OncePortRef`] or
409/// [`hyperactor::mailbox::OncePortHandle`] in the last position is used
410/// to send the return value.
411///
412/// The macro also derives a client trait that can be automatically implemented
413/// by specifying [`HandleClient`] for `ActorHandle<Actor>` and [`RefClient`]
414/// for `ActorRef<Actor>` accordingly. We require two implementations because
415/// not `ActorRef`s require that its message type is serializable.
416///
417/// The associated [`hyperactor_macros::handler`] macro can be used to add
418/// a dispatching handler directly to an [`hyperactor::Actor`].
419///
420/// # Example
421///
422/// The following example creates a "shopping list" actor responsible for
423/// maintaining a shopping list.
424///
425/// ```
426/// use std::collections::HashSet;
427/// use std::time::Duration;
428///
429/// use async_trait::async_trait;
430/// use hyperactor::Actor;
431/// use hyperactor::HandleClient;
432/// use hyperactor::Handler;
433/// use hyperactor::Instance;
434/// use hyperactor::Named;
435/// use hyperactor::OncePortRef;
436/// use hyperactor::RefClient;
437/// use hyperactor::proc::Proc;
438/// use serde::Deserialize;
439/// use serde::Serialize;
440///
441/// #[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
442/// enum ShoppingList {
443///     // Oneway messages dispatch messages asynchronously, with no reply.
444///     Add(String),
445///     Remove(String),
446///
447///     // Call messages dispatch a request, expecting a reply to the
448///     // provided port, which must be in the last position.
449///     Exists(String, #[reply] OncePortRef<bool>),
450///
451///     List(#[reply] OncePortRef<Vec<String>>),
452/// }
453///
454/// // Define an actor.
455/// #[derive(Debug)]
456/// #[hyperactor::export(
457///     spawn = true,
458///     handlers = [
459///         ShoppingList,
460///     ],
461/// )]
462/// struct ShoppingListActor(HashSet<String>);
463///
464/// #[async_trait]
465/// impl Actor for ShoppingListActor {
466///     type Params = ();
467///
468///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
469///         Ok(Self(HashSet::new()))
470///     }
471/// }
472///
473/// // ShoppingListHandler is the trait generated by derive(Handler) above.
474/// // We implement the trait here for the actor, defining a handler for
475/// // each ShoppingList message.
476/// //
477/// // The `forward` attribute installs a handler that forwards messages
478/// // to the `ShoppingListHandler` implementation directly. This can also
479/// // be done manually:
480/// //
481/// // ```ignore
482/// //<ShoppingListActor as ShoppingListHandler>
483/// //     ::handle(self, comm, message).await
484/// // ```
485/// #[async_trait]
486/// #[hyperactor::forward(ShoppingList)]
487/// impl ShoppingListHandler for ShoppingListActor {
488///     async fn add(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
489///         eprintln!("insert {}", item);
490///         self.0.insert(item);
491///         Ok(())
492///     }
493///
494///     async fn remove(&mut self, _cx: &Context<Self>, item: String) -> Result<(), anyhow::Error> {
495///         eprintln!("remove {}", item);
496///         self.0.remove(&item);
497///         Ok(())
498///     }
499///
500///     async fn exists(
501///         &mut self,
502///         _cx: &Context<Self>,
503///         item: String,
504///     ) -> Result<bool, anyhow::Error> {
505///         Ok(self.0.contains(&item))
506///     }
507///
508///     async fn list(&mut self, _cx: &Context<Self>) -> Result<Vec<String>, anyhow::Error> {
509///         Ok(self.0.iter().cloned().collect())
510///     }
511/// }
512///
513/// #[tokio::main]
514/// async fn main() -> Result<(), anyhow::Error> {
515///     let mut proc = Proc::local();
516///
517///     // Spawn our actor, and get a handle for rank 0.
518///     let shopping_list_actor: hyperactor::ActorHandle<ShoppingListActor> =
519///         proc.spawn("shopping", ()).await?;
520///
521///     // We join the system, so that we can send messages to actors.
522///     let client = proc.attach("client").unwrap();
523///
524///     // todo: consider making this a macro to remove the magic names
525///
526///     // Derive(Handler) generates client methods, which call the
527///     // remote handler provided an instance (send + open capability),
528///     // the destination actor, and the method arguments.
529///
530///     shopping_list_actor.add(&client, "milk".into()).await?;
531///     shopping_list_actor.add(&client, "eggs".into()).await?;
532///
533///     println!(
534///         "got milk? {}",
535///         shopping_list_actor.exists(&client, "milk".into()).await?
536///     );
537///     println!(
538///         "got yoghurt? {}",
539///         shopping_list_actor
540///             .exists(&client, "yoghurt".into())
541///             .await?
542///     );
543///
544///     shopping_list_actor.remove(&client, "milk".into()).await?;
545///     println!(
546///         "got milk now? {}",
547///         shopping_list_actor.exists(&client, "milk".into()).await?
548///     );
549///
550///     println!(
551///         "shopping list: {:?}",
552///         shopping_list_actor.list(&client).await?
553///     );
554///
555///     let _ = proc.destroy_and_wait(Duration::from_secs(1), None).await?;
556///     Ok(())
557/// }
558/// ```
559#[proc_macro_derive(Handler, attributes(reply))]
560pub fn derive_handler(input: TokenStream) -> TokenStream {
561    let input = parse_macro_input!(input as DeriveInput);
562    let name: Ident = input.ident.clone();
563    let (impl_generics, ty_generics, _) = input.generics.split_for_impl();
564
565    let messages = match parse_message_enum(input.clone()) {
566        Ok(messages) => messages,
567        Err(err) => return TokenStream::from(err.to_compile_error()),
568    };
569
570    // Trait definition methods for the handler.
571    let mut handler_trait_methods = Vec::new();
572
573    // The arms of the match used in the message dispatcher.
574    let mut match_arms = Vec::new();
575
576    // Trait implemented by clients.
577    let mut client_trait_methods = Vec::new();
578
579    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
580
581    for message in messages {
582        match message {
583            Message::Call {
584                ref variant,
585                ref reply_port_is_handle,
586                ref return_type,
587                ref log_level,
588            } => {
589                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
590                let variant_name_snake = variant.snake_name();
591                let enum_name = variant.enum_name();
592                let _variant_qualified_name = variant.qualified_name();
593                let log_level = match (&global_log_level, log_level) {
594                    (_, Some(local)) => local.clone(),
595                    (Some(global), None) => global.clone(),
596                    _ => Ident::new("DEBUG", Span::call_site()),
597                };
598                let _log_level = if *reply_port_is_handle {
599                    quote! {
600                        tracing::Level::#log_level
601                    }
602                } else {
603                    quote! {
604                        tracing::Level::TRACE
605                    }
606                };
607                let log_message = quote! {
608                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
609                            "rpc" => "call",
610                            "actor_id" => cx.self_id().to_string(),
611                            "message_type" => stringify!(#enum_name),
612                            "variant" => stringify!(#variant_name_snake),
613                        ));
614                };
615
616                handler_trait_methods.push(quote! {
617                    #[doc = "The generated handler method for this enum variant."]
618                    async fn #variant_name_snake(
619                        &mut self,
620                        cx: &hyperactor::Context<Self>,
621                        #(#arg_names: #arg_types),*)
622                        -> Result<#return_type, hyperactor::anyhow::Error>;
623                });
624
625                client_trait_methods.push(quote! {
626                    #[doc = "The generated client method for this enum variant."]
627                    async fn #variant_name_snake(
628                        &self,
629                        caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
630                        #(#arg_names: #arg_types),*)
631                        -> Result<#return_type, hyperactor::anyhow::Error>;
632                });
633
634                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
635                let constructor = variant.constructor();
636                let result_ident = Ident::new("result", Span::mixed_site());
637                let construct_result_future = quote! { use hyperactor::Message; let #result_ident = self.#variant_name_snake(cx, #(#arg_names),*).await?; };
638                if *reply_port_is_handle {
639                    match_arms.push(quote! {
640                        #constructor => {
641                            #log_message
642                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
643                            // This would require Result<Result<..., in order to handle RPC errors.
644                            #construct_result_future
645                            #reply_port_arg.send(#result_ident).map_err(hyperactor::anyhow::Error::from)
646                        }
647                    });
648                } else {
649                    match_arms.push(quote! {
650                        #constructor => {
651                            #log_message
652                            // TODO: should we propagate this error (to supervision), or send it back as an "RPC error"?
653                            // This would require Result<Result<..., in order to handle RPC errors.
654                            #construct_result_future
655                            #reply_port_arg.send(cx, #result_ident).map_err(hyperactor::anyhow::Error::from)
656                        }
657                    });
658                }
659            }
660            Message::OneWay {
661                ref variant,
662                ref log_level,
663            } => {
664                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
665                let variant_name_snake = variant.snake_name();
666                let enum_name = variant.enum_name();
667                let log_level = match (&global_log_level, log_level) {
668                    (_, Some(local)) => local.clone(),
669                    (Some(global), None) => global.clone(),
670                    _ => Ident::new("TRACE", Span::call_site()),
671                };
672                let _log_level = quote! {
673                    tracing::Level::#log_level
674                };
675                let log_message = quote! {
676                        hyperactor::metrics::ACTOR_MESSAGES_RECEIVED.add(1, hyperactor::kv_pairs!(
677                            "rpc" => "call",
678                            "actor_id" => cx.self_id().to_string(),
679                            "message_type" => stringify!(#enum_name),
680                            "variant" => stringify!(#variant_name_snake),
681                        ));
682                };
683
684                handler_trait_methods.push(quote! {
685                    #[doc = "The generated handler method for this enum variant."]
686                    async fn #variant_name_snake(
687                        &mut self,
688                        cx: &hyperactor::Context<Self>,
689                        #(#arg_names: #arg_types),*)
690                        -> Result<(), hyperactor::anyhow::Error>;
691                });
692
693                client_trait_methods.push(quote! {
694                    #[doc = "The generated client method for this enum variant."]
695                    async fn #variant_name_snake(
696                        &self,
697                        caps: &impl hyperactor::cap::CanSend,
698                        #(#arg_names: #arg_types),*)
699                        -> Result<(), hyperactor::anyhow::Error>;
700                });
701
702                let constructor = variant.constructor();
703
704                match_arms.push(quote! {
705                    #constructor => {
706                        #log_message
707                        self.#variant_name_snake(cx, #(#arg_names),*).await
708                    },
709                });
710            }
711        }
712    }
713
714    let handler_trait_name = format_ident!("{}Handler", name);
715    let client_trait_name = format_ident!("{}Client", name);
716
717    let expanded = quote! {
718        #[doc = "The custom handler trait for this message type."]
719        #[hyperactor::async_trait::async_trait]
720        pub trait #handler_trait_name #impl_generics: hyperactor::Actor + Send + Sync  {
721            #(#handler_trait_methods)*
722
723            #[doc = "Handle the next message."]
724            async fn handle(
725                &mut self,
726                cx: &hyperactor::Context<Self>,
727                message: #name #ty_generics,
728            ) -> hyperactor::anyhow::Result<()>  {
729                 // Dispatch based on message type.
730                 match message {
731                     #(#match_arms)*
732                }
733            }
734        }
735
736        #[doc = "The custom client trait for this message type."]
737        #[hyperactor::async_trait::async_trait]
738        pub trait #client_trait_name #impl_generics: Send + Sync  {
739            #(#client_trait_methods)*
740        }
741    };
742
743    TokenStream::from(expanded)
744}
745
746/// Derives a client implementation on `ActorHandle<Actor>`.
747/// See [`Handler`] documentation for details.
748#[proc_macro_derive(HandleClient, attributes(log_level))]
749pub fn derive_handle_client(input: TokenStream) -> TokenStream {
750    derive_client(input, true)
751}
752
753/// Derives a client implementation on `ActorRef<Actor>`.
754/// See [`Handler`] documentation for details.
755#[proc_macro_derive(RefClient, attributes(log_level))]
756pub fn derive_ref_client(input: TokenStream) -> TokenStream {
757    derive_client(input, false)
758}
759
760fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
761    let input = parse_macro_input!(input as DeriveInput);
762    let name = input.ident.clone();
763
764    let messages = match parse_message_enum(input.clone()) {
765        Ok(messages) => messages,
766        Err(err) => return TokenStream::from(err.to_compile_error()),
767    };
768
769    // The client implementation methods.
770    let mut impl_methods = Vec::new();
771
772    let send_message = if is_handle {
773        quote! { self.send(message)? }
774    } else {
775        quote! { self.send(caps, message)? }
776    };
777    let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None);
778
779    for message in messages {
780        match message {
781            Message::Call {
782                ref variant,
783                ref reply_port_is_handle,
784                ref return_type,
785                ref log_level,
786            } => {
787                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
788                let variant_name_snake = variant.snake_name();
789                let enum_name = variant.enum_name();
790
791                let (reply_port_arg, _) = message.reply_port_arg().unwrap();
792                let constructor = variant.constructor();
793                let log_level = match (&global_log_level, log_level) {
794                    (_, Some(local)) => local.clone(),
795                    (Some(global), None) => global.clone(),
796                    _ => Ident::new("DEBUG", Span::call_site()),
797                };
798                let log_level = if is_handle {
799                    quote! {
800                        tracing::Level::#log_level
801                    }
802                } else {
803                    quote! {
804                        tracing::Level::TRACE
805                    }
806                };
807                let log_message = quote! {
808                        hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
809                            "rpc" => "call",
810                            "actor_id" => self.actor_id().to_string(),
811                            "message_type" => stringify!(#enum_name),
812                            "variant" => stringify!(#variant_name_snake),
813                        ));
814
815                };
816                if *reply_port_is_handle {
817                    impl_methods.push(quote! {
818                        #[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
819                        async fn #variant_name_snake(
820                            &self,
821                            caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
822                            #(#arg_names: #arg_types),*)
823                            -> Result<#return_type, hyperactor::anyhow::Error> {
824                            let (#reply_port_arg, reply_receiver) =
825                                hyperactor::mailbox::open_once_port::<#return_type>(caps);
826                            let message = #constructor;
827                            #log_message;
828                            #send_message;
829                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
830                        }
831                    });
832                } else {
833                    impl_methods.push(quote! {
834                        #[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
835                        async fn #variant_name_snake(
836                            &self,
837                            caps: &(impl hyperactor::cap::CanSend + hyperactor::cap::CanOpenPort),
838                            #(#arg_names: #arg_types),*)
839                            -> Result<#return_type, hyperactor::anyhow::Error> {
840                            let (#reply_port_arg, reply_receiver) =
841                                hyperactor::mailbox::open_once_port::<#return_type>(caps);
842                            let #reply_port_arg = #reply_port_arg.bind();
843                            let message = #constructor;
844                            #log_message;
845                            #send_message;
846                            reply_receiver.recv().await.map_err(hyperactor::anyhow::Error::from)
847                        }
848                    });
849                }
850            }
851            Message::OneWay {
852                ref variant,
853                ref log_level,
854            } => {
855                let (arg_names, arg_types): (Vec<_>, Vec<_>) = message.args().into_iter().unzip();
856                let variant_name_snake = variant.snake_name();
857                let enum_name = variant.enum_name();
858                let constructor = variant.constructor();
859                let log_level = match (&global_log_level, log_level) {
860                    (_, Some(local)) => local.clone(),
861                    (Some(global), None) => global.clone(),
862                    _ => Ident::new("DEBUG", Span::call_site()),
863                };
864                let _log_level = if is_handle {
865                    quote! {
866                        tracing::Level::TRACE
867                    }
868                } else {
869                    quote! {
870                        tracing::Level::#log_level
871                    }
872                };
873                let log_message = quote! {
874                    hyperactor::metrics::ACTOR_MESSAGES_SENT.add(1, hyperactor::kv_pairs!(
875                        "rpc" => "oneway",
876                        "actor_id" => self.actor_id().to_string(),
877                        "message_type" => stringify!(#enum_name),
878                        "variant" => stringify!(#variant_name_snake),
879                    ));
880                };
881                impl_methods.push(quote! {
882                    async fn #variant_name_snake(
883                        &self,
884                        caps: &impl hyperactor::cap::CanSend,
885                        #(#arg_names: #arg_types),*)
886                        -> Result<(), hyperactor::anyhow::Error> {
887                        let message = #constructor;
888                        #log_message;
889                        #send_message;
890                        Ok(())
891                    }
892                });
893            }
894        }
895    }
896
897    let trait_name = format_ident!("{}Client", name);
898
899    let (_, ty_generics, _) = input.generics.split_for_impl();
900
901    // Add a new generic parameter 'A'
902    let a_ident = Ident::new("A", proc_macro2::Span::from(proc_macro::Span::def_site()));
903    let mut trait_generics = input.generics.clone();
904    trait_generics.params.insert(
905        0,
906        syn::GenericParam::Type(syn::TypeParam {
907            ident: a_ident.clone(),
908            attrs: vec![],
909            colon_token: None,
910            bounds: Punctuated::new(),
911            eq_token: None,
912            default: None,
913        }),
914    );
915    let (impl_generics, _, _) = trait_generics.split_for_impl();
916
917    let expanded = if is_handle {
918        quote! {
919            #[hyperactor::async_trait::async_trait]
920            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorHandle<#a_ident>
921              where #a_ident: hyperactor::Handler<#name #ty_generics> {
922                #(#impl_methods)*
923            }
924        }
925    } else {
926        quote! {
927            #[hyperactor::async_trait::async_trait]
928            impl #impl_generics #trait_name #ty_generics for hyperactor::ActorRef<#a_ident>
929              where #a_ident: hyperactor::actor::RemoteHandles<#name #ty_generics> {
930                #(#impl_methods)*
931            }
932        }
933    };
934
935    TokenStream::from(expanded)
936}
937
938const FORWARD_ARGUMENT_ERROR: &str = indoc! {r#"
939`forward` expects the message type that is being forwarded
940
941= help: use `#[forward(MessageType)]`
942"#};
943
944/// Forward messages of the provided type to this handler implementation.
945#[proc_macro_attribute]
946pub fn forward(attr: TokenStream, item: TokenStream) -> TokenStream {
947    let attr_args = parse_macro_input!(attr with Punctuated::<syn::PathSegment, syn::Token![,]>::parse_terminated);
948    if attr_args.len() != 1 {
949        return TokenStream::from(
950            syn::Error::new_spanned(attr_args, FORWARD_ARGUMENT_ERROR).to_compile_error(),
951        );
952    }
953
954    let message_type = attr_args.first().unwrap();
955    let input = parse_macro_input!(item as ItemImpl);
956
957    let self_type = match *input.self_ty {
958        syn::Type::Path(ref type_path) => {
959            let segment = type_path.path.segments.last().unwrap();
960            segment.clone() //ident.clone()
961        }
962        _ => {
963            return TokenStream::from(
964                syn::Error::new_spanned(input.self_ty, "`forward` argument must be a type")
965                    .to_compile_error(),
966            );
967        }
968    };
969
970    let trait_name = match input.trait_ {
971        Some((_, ref trait_path, _)) => trait_path.segments.last().unwrap().clone(),
972        None => {
973            return TokenStream::from(
974                syn::Error::new_spanned(input.self_ty, "no trait in implementation block")
975                    .to_compile_error(),
976            );
977        }
978    };
979
980    let expanded = quote! {
981        #input
982
983        #[hyperactor::async_trait::async_trait]
984        impl hyperactor::Handler<#message_type> for #self_type {
985            async fn handle(
986                &mut self,
987                cx: &hyperactor::Context<Self>,
988                message: #message_type,
989            ) -> hyperactor::anyhow::Result<()> {
990                <Self as #trait_name>::handle(self, cx, message).await
991            }
992        }
993    };
994
995    TokenStream::from(expanded)
996}
997
998/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
999/// We set a default level of INFO while always setting ERROR if the function returns Result::Err giving us
1000/// consistent and high quality structured logs. Because this wraps around tracing::instrument, all parameters
1001/// mentioned in https://fburl.com/9jlkb5q4 should be valid. For functions that don't return a [`Result`] type, use
1002/// [`instrument_infallible`]
1003///
1004/// ```
1005/// #[telemetry::instrument]
1006/// async fn yolo() -> anyhow::Result<i32> {
1007///     Ok(420)
1008/// }
1009/// ```
1010#[proc_macro_attribute]
1011pub fn instrument(args: TokenStream, input: TokenStream) -> TokenStream {
1012    let args =
1013        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1014    let input = parse_macro_input!(input as ItemFn);
1015    let output = quote! {
1016        #[hyperactor::tracing::instrument(err, skip_all, #args)]
1017        #input
1018    };
1019
1020    TokenStream::from(output)
1021}
1022
1023/// Use this macro in place of tracing::instrument to prevent spamming our tracing table.
1024/// Because this wraps around tracing::instrument, all parameters mentioned in
1025/// https://fburl.com/9jlkb5q4 should be valid.
1026///
1027/// ```
1028/// #[telemetry::instrument]
1029/// async fn yolo() -> i32 {
1030///     420
1031/// }
1032/// ```
1033#[proc_macro_attribute]
1034pub fn instrument_infallible(args: TokenStream, input: TokenStream) -> TokenStream {
1035    let args =
1036        parse_macro_input!(args with Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated);
1037    let input = parse_macro_input!(input as ItemFn);
1038
1039    let output = quote! {
1040        #[hyperactor::tracing::instrument(skip_all, #args)]
1041        #input
1042    };
1043
1044    TokenStream::from(output)
1045}
1046
1047/// Derive the [`hyperactor::data::Named`] trait for a struct with the
1048/// provided type URI. The name of the type is its fully-qualified Rust
1049/// path. The name may be overridden by providing a string value for the
1050/// `name` attribute.
1051#[proc_macro_derive(Named, attributes(named))]
1052pub fn named_derive(input: TokenStream) -> TokenStream {
1053    // Parse the input struct or enum
1054    let input = parse_macro_input!(input as DeriveInput);
1055    let struct_name = &input.ident;
1056
1057    let mut typename = quote! {
1058        concat!(std::module_path!(), "::", stringify!(#struct_name))
1059    };
1060
1061    for attr in &input.attrs {
1062        if attr.path().is_ident("named") {
1063            if let Ok(meta) = attr.parse_args_with(
1064                syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
1065            ) {
1066                for item in meta {
1067                    if let Meta::NameValue(MetaNameValue {
1068                        path,
1069                        value: Expr::Lit(expr_lit),
1070                        ..
1071                    }) = item
1072                    {
1073                        if path.is_ident("name") {
1074                            if let Lit::Str(name) = expr_lit.lit {
1075                                typename = quote! { #name };
1076                            }
1077                        } else {
1078                            return TokenStream::from(
1079                                syn::Error::new_spanned(
1080                                    path,
1081                                    "unsupported attribute (only `name` is supported)",
1082                                )
1083                                .to_compile_error(),
1084                            );
1085                        }
1086                    }
1087                }
1088            }
1089        }
1090    }
1091
1092    // Extract type parameters and add Named bounds
1093    let type_params: Vec<_> = input.generics.type_params().collect();
1094    let has_generics = !type_params.is_empty();
1095
1096    // Create a version of generics with Named bounds for the impl block
1097    let mut generics_with_bounds = input.generics.clone();
1098    if has_generics {
1099        for param in generics_with_bounds.type_params_mut() {
1100            param
1101                .bounds
1102                .push(syn::parse_quote!(hyperactor::data::Named));
1103        }
1104    }
1105    let (impl_generics_with_bounds, _, _) = generics_with_bounds.split_for_impl();
1106
1107    // Generate typename implementation based on whether we have generics
1108    let (typename_impl, typehash_impl) = if has_generics {
1109        // Create format string with placeholders for each generic parameter
1110        let placeholders = vec!["{}"; type_params.len()].join(", ");
1111        let placeholders_format_string = format!("<{}>", placeholders);
1112        let format_string = quote! { concat!(std::module_path!(), "::", stringify!(#struct_name), #placeholders_format_string) };
1113
1114        let type_param_idents: Vec<_> = type_params.iter().map(|p| &p.ident).collect();
1115        (
1116            quote! {
1117                hyperactor::data::intern_typename!(Self, #format_string, #(#type_param_idents),*)
1118            },
1119            quote! {
1120                hyperactor::cityhasher::hash(Self::typename())
1121            },
1122        )
1123    } else {
1124        (
1125            typename,
1126            quote! {
1127                static TYPEHASH: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
1128                    hyperactor::cityhasher::hash(<#struct_name as hyperactor::data::Named>::typename())
1129                });
1130                *TYPEHASH
1131            },
1132        )
1133    };
1134
1135    // Generate 'arm' for enums only.
1136    let arm_impl = match &input.data {
1137        Data::Enum(DataEnum { variants, .. }) => {
1138            let match_arms = variants.iter().map(|v| {
1139                let variant_name = &v.ident;
1140                let variant_str = variant_name.to_string();
1141                match &v.fields {
1142                    Fields::Unit => quote! { Self::#variant_name => Some(#variant_str) },
1143                    Fields::Unnamed(_) => quote! { Self::#variant_name(..) => Some(#variant_str) },
1144                    Fields::Named(_) => quote! { Self::#variant_name { .. } => Some(#variant_str) },
1145                }
1146            });
1147            quote! {
1148                fn arm(&self) -> Option<&'static str> {
1149                    match self {
1150                        #(#match_arms,)*
1151                    }
1152                }
1153            }
1154        }
1155        _ => quote! {},
1156    };
1157
1158    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
1159    // Ideally we would compute the has directly in the macro itself, however, we don't
1160    // have access to the fully expanded pathname here as we use the intrinsic std::module_path!() macro.
1161    let expanded = quote! {
1162        impl #impl_generics_with_bounds hyperactor::data::Named for #struct_name #ty_generics #where_clause {
1163            fn typename() -> &'static str { #typename_impl }
1164            fn typehash() -> u64 { #typehash_impl }
1165            #arm_impl
1166        }
1167    };
1168
1169    TokenStream::from(expanded)
1170}
1171
1172struct HandlerSpec {
1173    ty: Type,
1174    cast: bool,
1175}
1176
1177impl Parse for HandlerSpec {
1178    fn parse(input: ParseStream) -> syn::Result<Self> {
1179        let ty: Type = input.parse()?;
1180
1181        if input.peek(syn::token::Brace) {
1182            let content;
1183            syn::braced!(content in input);
1184            let key: Ident = content.parse()?;
1185            content.parse::<Token![=]>()?;
1186            let expr: Expr = content.parse()?;
1187
1188            let cast = if key == "cast" {
1189                if let Expr::Lit(ExprLit {
1190                    lit: Lit::Bool(b), ..
1191                }) = expr
1192                {
1193                    b.value
1194                } else {
1195                    return Err(syn::Error::new_spanned(expr, "expected boolean for `cast`"));
1196                }
1197            } else {
1198                return Err(syn::Error::new_spanned(
1199                    key,
1200                    "unsupported field (expected `cast`)",
1201                ));
1202            };
1203
1204            Ok(HandlerSpec { ty, cast })
1205        } else if input.is_empty() || input.peek(Token![,]) {
1206            Ok(HandlerSpec { ty, cast: false })
1207        } else {
1208            // Something unexpected follows the type
1209            let unexpected: proc_macro2::TokenTree = input.parse()?;
1210            Err(syn::Error::new_spanned(
1211                unexpected,
1212                "unexpected token after type — expected `{ ... }` or nothing",
1213            ))
1214        }
1215    }
1216}
1217
1218impl HandlerSpec {
1219    fn add_indexed(handlers: Vec<HandlerSpec>) -> Vec<Type> {
1220        let mut tys = Vec::new();
1221        for HandlerSpec { ty, cast } in handlers {
1222            if cast {
1223                let wrapped = quote! { hyperactor::message::IndexedErasedUnbound<#ty> };
1224                let wrapped_ty: Type = syn::parse2(wrapped).unwrap();
1225                tys.push(wrapped_ty);
1226            }
1227            tys.push(ty);
1228        }
1229        tys
1230    }
1231}
1232
1233/// Attribute Struct for [`fn export`] macro.
1234struct ExportAttr {
1235    spawn: bool,
1236    handlers: Vec<HandlerSpec>,
1237}
1238
1239impl Parse for ExportAttr {
1240    fn parse(input: ParseStream) -> syn::Result<Self> {
1241        let mut spawn = false;
1242        let mut handlers: Vec<HandlerSpec> = vec![];
1243
1244        while !input.is_empty() {
1245            let key: Ident = input.parse()?;
1246            input.parse::<Token![=]>()?;
1247
1248            if key == "spawn" {
1249                let expr: Expr = input.parse()?;
1250                if let Expr::Lit(ExprLit {
1251                    lit: Lit::Bool(b), ..
1252                }) = expr
1253                {
1254                    spawn = b.value;
1255                } else {
1256                    return Err(syn::Error::new_spanned(
1257                        expr,
1258                        "expected boolean for `spawn`",
1259                    ));
1260                }
1261            } else if key == "handlers" {
1262                let content;
1263                bracketed!(content in input);
1264                let raw_handlers = content.parse_terminated(HandlerSpec::parse, Token![,])?;
1265                handlers = raw_handlers.into_iter().collect();
1266            } else {
1267                return Err(syn::Error::new_spanned(
1268                    key,
1269                    "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
1270                ));
1271            }
1272
1273            // optional trailing comma
1274            let _ = input.parse::<Token![,]>();
1275        }
1276
1277        Ok(ExportAttr { spawn, handlers })
1278    }
1279}
1280
1281/// Exports handlers for this actor. The set of exported handlers
1282/// determine the messages that may be sent to remote references of
1283/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
1284/// [`hyperactor::RemoteMessage`] may be exported.
1285///
1286/// Additionally, an exported actor may be remotely spawned,
1287/// indicated by `spawn = true`. Such actors must also ensure that
1288/// their parameter type implements [`hyperactor::RemoteMessage`].
1289///
1290/// # Example
1291///
1292/// In the following example, `MyActor` can be spawned remotely. It also has
1293/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`.
1294/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these
1295/// types.
1296///
1297/// ```ignore
1298/// #[export(
1299///     spawn = true,
1300///     handlers = [
1301///         MyMessage,
1302///         MyOtherMessage,
1303///     ],
1304/// )]
1305/// struct MyActor {}
1306/// ```
1307#[proc_macro_attribute]
1308pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
1309    let input: DeriveInput = parse_macro_input!(item as DeriveInput);
1310    let data_type_name = &input.ident;
1311
1312    let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
1313    let tys = HandlerSpec::add_indexed(handlers);
1314
1315    let mut handles = Vec::new();
1316    let mut bindings = Vec::new();
1317
1318    for ty in &tys {
1319        handles.push(quote! {
1320            impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
1321        });
1322        bindings.push(quote! {
1323            ports.bind::<#ty>();
1324        });
1325    }
1326
1327    let mut expanded = quote! {
1328        #input
1329
1330        impl hyperactor::actor::RemoteActor for #data_type_name {}
1331
1332        #(#handles)*
1333
1334        // Always export the `Signal` type.
1335        impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
1336
1337        impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
1338            fn bind(ports: &hyperactor::proc::Ports<Self>) {
1339                #(#bindings)*
1340            }
1341        }
1342
1343        // TODO: just use Named derive directly here.
1344        impl hyperactor::data::Named for #data_type_name {
1345            fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
1346        }
1347    };
1348
1349    if spawn {
1350        expanded.extend(quote! {
1351
1352            hyperactor::remote!(#data_type_name);
1353        });
1354    }
1355
1356    TokenStream::from(expanded)
1357}
1358
1359/// Represents the full input to [`fn alias`].
1360struct AliasInput {
1361    alias: Ident,
1362    handlers: Vec<HandlerSpec>,
1363}
1364
1365impl syn::parse::Parse for AliasInput {
1366    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1367        let alias: Ident = input.parse()?;
1368        let _: Token![,] = input.parse()?;
1369        let raw_handlers = input.parse_terminated(HandlerSpec::parse, Token![,])?;
1370        let handlers = raw_handlers.into_iter().collect();
1371        Ok(AliasInput { alias, handlers })
1372    }
1373}
1374
1375/// Create a [`RemoteActor`] handling a specific set of message types.
1376/// This is used to create an [`ActorRef`] without having to depend on the
1377/// actor's implementation. If the message type need to be cast, add `castable`
1378/// flag to those types. e.g. the following example creats an alias with 5
1379/// message types, and 4 of which need to be cast.
1380///
1381/// ```
1382/// hyperactor::alias!(
1383///     TestActorAlias,
1384///     TestMessage { castable = true },
1385///     () {castable = true },
1386///     MyGeneric<()> {castable = true },
1387///     u64,
1388/// );
1389/// ```
1390#[proc_macro]
1391pub fn alias(input: TokenStream) -> TokenStream {
1392    let AliasInput { alias, handlers } = parse_macro_input!(input as AliasInput);
1393    let tys = HandlerSpec::add_indexed(handlers);
1394
1395    let expanded = quote! {
1396        #[doc = "The generated alias struct."]
1397        #[derive(Debug, Named)]
1398        pub struct #alias;
1399        impl hyperactor::actor::RemoteActor for #alias {}
1400
1401        impl<A> hyperactor::actor::Binds<A> for #alias
1402        where
1403            A: hyperactor::Actor #(+ hyperactor::Handler<#tys>)* {
1404            fn bind(ports: &hyperactor::proc::Ports<A>) {
1405                #(
1406                    ports.bind::<#tys>();
1407                )*
1408            }
1409        }
1410
1411        #(
1412            impl hyperactor::actor::RemoteHandles<#tys> for #alias {}
1413        )*
1414    };
1415
1416    TokenStream::from(expanded)
1417}
1418
1419fn include_in_bind_unbind(field: &Field) -> syn::Result<bool> {
1420    let mut is_included = false;
1421    for attr in &field.attrs {
1422        if attr.path().is_ident("binding") {
1423            // parse #[binding(include)] and look for exactly "include"
1424            attr.parse_nested_meta(|meta| {
1425                if meta.path.is_ident("include") {
1426                    is_included = true;
1427                    Ok(())
1428                } else {
1429                    let path = meta.path.to_token_stream().to_string().replace(' ', "");
1430                    Err(meta.error(format_args!("unknown binding variant attribute `{}`", path)))
1431                }
1432            })?
1433        }
1434    }
1435    Ok(is_included)
1436}
1437
1438/// The field accessor in struct or enum variant.
1439/// e.g.:
1440///   struct NamedStruct { foo: u32 } => FieldAccessor::Named(Ident::new("foo", Span::call_site()))
1441///   struct UnnamedStruct(u32) => FieldAccessor::Unnamed(Index::from(0))
1442enum FieldAccessor {
1443    Named(Ident),
1444    Unnamed(Index),
1445}
1446
1447/// Result of parsing a field in a struct, or a enum variant.
1448struct ParsedField {
1449    accessor: FieldAccessor,
1450    ty: Type,
1451    included: bool,
1452}
1453
1454impl From<&ParsedField> for (Ident, Type) {
1455    fn from(field: &ParsedField) -> Self {
1456        let field_ident = match &field.accessor {
1457            FieldAccessor::Named(ident) => ident.clone(),
1458            FieldAccessor::Unnamed(i) => {
1459                Ident::new(&format!("f{}", i.index), proc_macro2::Span::call_site())
1460            }
1461        };
1462        (field_ident, field.ty.clone())
1463    }
1464}
1465
1466fn collect_all_fields(fields: &Fields) -> syn::Result<Vec<ParsedField>> {
1467    match fields {
1468        Fields::Named(named) => named
1469            .named
1470            .iter()
1471            .map(|f| {
1472                let accessor = FieldAccessor::Named(f.ident.clone().unwrap());
1473                Ok(ParsedField {
1474                    accessor,
1475                    ty: f.ty.clone(),
1476                    included: include_in_bind_unbind(f)?,
1477                })
1478            })
1479            .collect(),
1480        Fields::Unnamed(unnamed) => unnamed
1481            .unnamed
1482            .iter()
1483            .enumerate()
1484            .map(|(i, f)| {
1485                let accessor = FieldAccessor::Unnamed(Index::from(i));
1486                Ok(ParsedField {
1487                    accessor,
1488                    ty: f.ty.clone(),
1489                    included: include_in_bind_unbind(f)?,
1490                })
1491            })
1492            .collect(),
1493        Fields::Unit => Ok(Vec::new()),
1494    }
1495}
1496
1497fn gen_struct_items<F>(
1498    fields: &Fields,
1499    make_item: F,
1500    is_mutable: bool,
1501) -> syn::Result<Vec<proc_macro2::TokenStream>>
1502where
1503    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1504{
1505    let borrow = if is_mutable {
1506        quote! { &mut }
1507    } else {
1508        quote! { & }
1509    };
1510    let items: Vec<_> = collect_all_fields(fields)?
1511        .into_iter()
1512        .filter(|f| f.included)
1513        .map(
1514            |ParsedField {
1515                 accessor,
1516                 ty,
1517                 included,
1518             }| {
1519                assert!(included);
1520                let field_accessor = match accessor {
1521                    FieldAccessor::Named(ident) => quote! { #borrow self.#ident },
1522                    FieldAccessor::Unnamed(index) => quote! { #borrow self.#index },
1523                };
1524                make_item(field_accessor, ty)
1525            },
1526        )
1527        .collect();
1528    Ok(items)
1529}
1530
1531/// Generate the field accessor for a enum variant in pattern matching. e.g.
1532/// the <GENERATED> parts in the following example:
1533///
1534///   match my_enum {
1535///     // e.g. MyEnum::Tuple(_, f1, f2, _)
1536///     MyEnum::Tuple(<GENERATED>) => { ... }
1537///     // e.g. MyEnum::Struct { field0: _, field1 }
1538///     MyEnum::Struct(<GENERATED>) => { ... }
1539///   }
1540fn gen_enum_field_accessors(all_fields: &[ParsedField]) -> Vec<proc_macro2::TokenStream> {
1541    all_fields
1542        .iter()
1543        .map(
1544            |ParsedField {
1545                 accessor,
1546                 ty: _,
1547                 included,
1548             }| {
1549                match accessor {
1550                    FieldAccessor::Named(ident) => {
1551                        if *included {
1552                            quote! { #ident }
1553                        } else {
1554                            quote! { #ident: _ }
1555                        }
1556                    }
1557                    FieldAccessor::Unnamed(i) => {
1558                        if *included {
1559                            let ident = Ident::new(
1560                                &format!("f{}", i.index),
1561                                proc_macro2::Span::call_site(),
1562                            );
1563                            quote! { #ident }
1564                        } else {
1565                            quote! { _ }
1566                        }
1567                    }
1568                }
1569            },
1570        )
1571        .collect()
1572}
1573
1574/// Generate all the parts for enum variants. e.g. the <GENERATED> part in the
1575/// following example:
1576///
1577///   match my_enum {
1578///      <GENERATED>
1579///   }
1580fn gen_enum_arms<F>(data: &DataEnum, make_item: F) -> syn::Result<Vec<proc_macro2::TokenStream>>
1581where
1582    F: Fn(proc_macro2::TokenStream, Type) -> proc_macro2::TokenStream,
1583{
1584    data.variants
1585        .iter()
1586        .map(|variant| {
1587            let name = &variant.ident;
1588            let all_fields = collect_all_fields(&variant.fields)?;
1589            let field_accessors = gen_enum_field_accessors(&all_fields);
1590            let included_fields = all_fields.iter().filter(|f| f.included).collect::<Vec<_>>();
1591            let items = included_fields
1592                .iter()
1593                .map(|f| {
1594                    let (accessor, ty) = <(Ident, Type)>::from(*f);
1595                    make_item(quote! { #accessor }, ty)
1596                })
1597                .collect::<Vec<_>>();
1598
1599            Ok(match &variant.fields {
1600                Fields::Named(_) => {
1601                    quote! { Self::#name { #(#field_accessors),* } => { #(#items)* } }
1602                }
1603                Fields::Unnamed(_) => {
1604                    quote! { Self::#name( #(#field_accessors),* ) => { #(#items)* } }
1605                }
1606                Fields::Unit => quote! { Self::#name => { #(#items)* } },
1607            })
1608        })
1609        .collect()
1610}
1611
1612/// Derive a custom implementation of [`hyperactor::message::Bind`] trait for
1613/// a struct or enum. This macro is normally used in tandem with [`fn derive_unbind`]
1614/// to make the applied struct or enum castable.
1615///
1616/// Specifically, the derived implementation iterates through fields annotated
1617/// with `#[binding(include)]` based on their order of declaration in the struct
1618/// or enum. These fields' types must implement `Bind` trait as well. During the
1619/// iteration, parameters from `bindings` are bound to these fields.
1620///
1621/// # Example
1622///
1623/// This macro supports named and unamed structs and enums. Below are examples
1624/// of the supported types:
1625///
1626/// ```
1627/// #[derive(Bind, Unbind)]
1628/// struct MyNamedStruct {
1629///     field0: u64,
1630///     field1: MyReply,
1631///     #[binding(include)]
1632/// nnnn     field2: PortRef<MyReply>,
1633///     field3: bool,
1634///     #[binding(include)]
1635///     field4: hyperactor::PortRef<u64>,
1636/// }
1637///
1638/// #[derive(Bind, Unbind)]
1639/// struct MyUnamedStruct(
1640///     u64,
1641///     MyReply,
1642///     #[binding(include)] hyperactor::PortRef<MyReply>,
1643///     bool,
1644///     #[binding(include)] PortRef<u64>,
1645/// );
1646///
1647/// #[derive(Bind, Unbind)]
1648/// enum MyEnum {
1649///     Unit,
1650///     NoopTuple(u64, bool),
1651///     NoopStruct {
1652///         field0: u64,
1653///         field1: bool,
1654///     },
1655///     Tuple(
1656///         u64,
1657///         MyReply,
1658///         #[binding(include)] PortRef<MyReply>,
1659///         bool,
1660///         #[binding(include)] hyperactor::PortRef<u64>,
1661///     ),
1662///     Struct {
1663///         field0: u64,
1664///         field1: MyReply,
1665///         #[binding(include)]
1666///         field2: PortRef<MyReply>,
1667///         field3: bool,
1668///         #[binding(include)]
1669///         field4: hyperactor::PortRef<u64>,
1670///     },
1671/// }
1672/// ```
1673///
1674/// The following shows what derived `Bind`` and `Unbind`` implementations for
1675/// `MyNamedStruct` will look like. The implementations of other types are
1676/// similar, and thus are not shown here.
1677/// ```ignore
1678/// impl Bind for MyNamedStruct {
1679/// fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
1680///     Bind::bind(&mut self.field2, bindings)?;
1681///     Bind::bind(&mut self.field4, bindings)?;
1682///     Ok(())
1683/// }
1684///
1685/// impl Unbind for MyNamedStruct {
1686///     fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
1687///         Unbind::unbind(&self.field2, bindings)?;
1688///         Unbind::unbind(&self.field4, bindings)?;
1689///         Ok(())
1690///     }
1691/// }
1692/// ```
1693#[proc_macro_derive(Bind, attributes(binding))]
1694pub fn derive_bind(input: TokenStream) -> TokenStream {
1695    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1696        quote! {
1697            hyperactor::message::Bind::bind(#field_accessor, bindings)?;
1698        }
1699    }
1700
1701    let input = parse_macro_input!(input as DeriveInput);
1702    let name = &input.ident;
1703    let inner = match &input.data {
1704        Data::Struct(DataStruct { fields, .. }) => {
1705            match gen_struct_items(fields, make_item, true) {
1706                Ok(collects) => {
1707                    quote! { #(#collects)* }
1708                }
1709                Err(e) => {
1710                    return TokenStream::from(e.to_compile_error());
1711                }
1712            }
1713        }
1714        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1715            Ok(arms) => {
1716                quote! { match self { #(#arms),* } }
1717            }
1718            Err(e) => {
1719                return TokenStream::from(e.to_compile_error());
1720            }
1721        },
1722        _ => panic!("Bind can only be derived for structs and enums"),
1723    };
1724    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1725    let expand = quote! {
1726        #[automatically_derived]
1727        impl #impl_generics hyperactor::message::Bind for #name #ty_generics #where_clause {
1728            fn bind(&mut self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1729                #inner
1730                Ok(())
1731            }
1732        }
1733    };
1734    TokenStream::from(expand)
1735}
1736
1737/// Derive a custom implementation of [`hyperactor::message::Unbind`] trait for
1738/// a struct or enum. This macro is normally used in tandem with [`fn derive_bind`]
1739/// to make the applied struct or enum castable.
1740///
1741/// Specifically, the derived implementation iterates through fields annoated
1742/// with `#[binding(include)]` based on their order of declaration in the struct
1743/// or enum. These fields' types must implement `Unbind` trait as well. During
1744/// the iteration, parameters from these fields are extracted and stored in
1745/// `bindings`.
1746///
1747/// # Example
1748///
1749/// See [`fn derive_bind`]'s documentation for examples.
1750#[proc_macro_derive(Unbind, attributes(binding))]
1751pub fn derive_unbind(input: TokenStream) -> TokenStream {
1752    fn make_item(field_accessor: proc_macro2::TokenStream, _ty: Type) -> proc_macro2::TokenStream {
1753        quote! {
1754            hyperactor::message::Unbind::unbind(#field_accessor, bindings)?;
1755        }
1756    }
1757
1758    let input = parse_macro_input!(input as DeriveInput);
1759    let name = &input.ident;
1760    let inner = match &input.data {
1761        Data::Struct(DataStruct { fields, .. }) => match gen_struct_items(fields, make_item, false)
1762        {
1763            Ok(collects) => {
1764                quote! { #(#collects)* }
1765            }
1766            Err(e) => {
1767                return TokenStream::from(e.to_compile_error());
1768            }
1769        },
1770        Data::Enum(data_enum) => match gen_enum_arms(data_enum, make_item) {
1771            Ok(arms) => {
1772                quote! { match self { #(#arms),* } }
1773            }
1774            Err(e) => {
1775                return TokenStream::from(e.to_compile_error());
1776            }
1777        },
1778        _ => panic!("Unbind can only be derived for structs and enums"),
1779    };
1780    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1781    let expand = quote! {
1782        #[automatically_derived]
1783        impl #impl_generics hyperactor::message::Unbind for #name #ty_generics #where_clause {
1784            fn unbind(&self, bindings: &mut hyperactor::message::Bindings) -> anyhow::Result<()> {
1785                #inner
1786                Ok(())
1787            }
1788        }
1789    };
1790    TokenStream::from(expand)
1791}
1792
1793/// Derives the `Actor` trait for a struct. By default, generates an implementation
1794/// with no params (`type Params = ()` and `async fn new(_params: ()) -> Result<Self, anyhow::Error>`).
1795/// This requires that the Actor implements [`Default`].
1796///
1797/// If the `#[actor(passthrough)]` attribute is specified, generates an implementation
1798/// with where the parameter type is `Self`
1799/// (`type Params = Self` and `async fn new(instance: Self) -> Result<Self, anyhow::Error>`).
1800///
1801/// # Examples
1802///
1803/// Default behavior:
1804/// ```
1805/// #[derive(Actor, Default)]
1806/// struct MyActor(u64);
1807/// ```
1808///
1809/// Generates:
1810/// ```ignore
1811/// #[async_trait]
1812/// impl Actor for MyActor {
1813///     type Params = ();
1814///
1815///     async fn new(_params: ()) -> Result<Self, anyhow::Error> {
1816///         Ok(Default::default())
1817///     }
1818/// }
1819/// ```
1820///
1821/// Passthrough behavior:
1822/// ```
1823/// #[derive(Actor, Default)]
1824/// #[actor(passthrough)]
1825/// struct MyActor(u64);
1826/// ```
1827///
1828/// Generates:
1829/// ```ignore
1830/// #[async_trait]
1831/// impl Actor for MyActor {
1832///     type Params = Self;
1833///
1834///     async fn new(instance: Self) -> Result<Self, anyhow::Error> {
1835///         Ok(instance)
1836///     }
1837/// }
1838/// ```
1839#[proc_macro_derive(Actor, attributes(actor))]
1840pub fn derive_actor(input: TokenStream) -> TokenStream {
1841    let input = parse_macro_input!(input as DeriveInput);
1842    let name = &input.ident;
1843    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1844
1845    let is_passthrough = input.attrs.iter().any(|attr| {
1846        if attr.path().is_ident("actor") {
1847            if let Ok(meta) = attr.parse_args_with(
1848                syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated,
1849            ) {
1850                return meta.iter().any(|ident| ident == "passthrough");
1851            }
1852        }
1853        false
1854    });
1855
1856    let expanded = if is_passthrough {
1857        quote! {
1858            #[hyperactor::async_trait::async_trait]
1859            impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1860                type Params = Self;
1861
1862                async fn new(instance: Self) -> Result<Self, hyperactor::anyhow::Error> {
1863                    Ok(instance)
1864                }
1865            }
1866        }
1867    } else {
1868        quote! {
1869            #[hyperactor::async_trait::async_trait]
1870            impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause {
1871                type Params = ();
1872
1873                async fn new(_params: ()) -> Result<Self, hyperactor::anyhow::Error> {
1874                    Ok(Default::default())
1875                }
1876            }
1877        }
1878    };
1879
1880    TokenStream::from(expanded)
1881}