Skip to main content

hyperactor/
ref_.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Typed capability references for Hyperactor actors and ports.
10
11use std::cmp::Ordering;
12use std::fmt;
13use std::hash::Hash;
14use std::hash::Hasher;
15use std::marker::PhantomData;
16
17use derivative::Derivative;
18use hyperactor_config::Flattrs;
19use serde::Deserialize;
20use serde::Deserializer;
21use serde::Serialize;
22use serde::Serializer;
23use typeuri::Named;
24
25use crate::Actor;
26use crate::ActorAddr;
27use crate::ActorHandle;
28use crate::Endpoint;
29use crate::EndpointLocation;
30use crate::PortAddr;
31use crate::RemoteEndpoint;
32use crate::RemoteHandles;
33use crate::RemoteMessage;
34use crate::accum::ReducerSpec;
35use crate::accum::StreamingReducerOpts;
36use crate::actor::Referable;
37use crate::context;
38use crate::context::MailboxExt;
39use crate::mailbox::LostMessage;
40use crate::mailbox::MailboxSenderError;
41use crate::mailbox::MailboxSenderErrorKind;
42use crate::mailbox::PortSink;
43use crate::message::Bind;
44use crate::message::Bindings;
45use crate::message::Unbind;
46use crate::port::Port;
47
48/// ActorRefs are typed references to actors.
49#[derive(typeuri::Named)]
50pub struct ActorRef<A: Referable> {
51    pub(crate) actor_addr: ActorAddr,
52    // fn() -> A so that the struct remains Send
53    phantom: PhantomData<fn() -> A>,
54}
55
56impl<A: Referable> ActorRef<A> {
57    /// Get the remote port for message type [`M`] for the referenced actor.
58    pub fn port<M: RemoteMessage>(&self) -> PortRef<M>
59    where
60        A: RemoteHandles<M>,
61    {
62        PortRef::attest(self.actor_addr.port_addr(Port::from(<M as Named>::port())))
63    }
64
65    /// The caller guarantees that the provided actor ID is also a valid,
66    /// typed reference.  This is usually invoked to provide a guarantee
67    /// that an externally-provided actor ID (e.g., through a command
68    /// line argument) is a valid reference.
69    pub fn attest(actor_addr: ActorAddr) -> Self {
70        Self {
71            actor_addr,
72            phantom: PhantomData,
73        }
74    }
75
76    /// The actor address corresponding with this reference.
77    pub fn actor_addr(&self) -> &ActorAddr {
78        &self.actor_addr
79    }
80
81    /// Convert this actor reference into its corresponding actor address.
82    pub fn into_actor_addr(self) -> ActorAddr {
83        self.actor_addr
84    }
85
86    /// Attempt to downcast this reference into a (local) actor handle.
87    /// This will only succeed when the referenced actor is in the same
88    /// proc as the caller.
89    pub fn downcast_handle(&self, cx: &impl context::Actor) -> Option<ActorHandle<A>>
90    where
91        A: Actor,
92    {
93        cx.instance().proc().resolve_actor_ref(self)
94    }
95}
96
97impl<A, M> Endpoint<M> for &ActorRef<A>
98where
99    A: Referable + RemoteHandles<M>,
100    M: RemoteMessage,
101{
102    fn endpoint_location(&self) -> EndpointLocation {
103        EndpointLocation::Actor(self.actor_addr.clone())
104    }
105
106    fn post<C>(self, cx: &C, message: M)
107    where
108        C: context::Actor,
109    {
110        RemoteEndpoint::post_with_headers(self, cx, Flattrs::new(), message)
111    }
112}
113
114impl<A, M> RemoteEndpoint<M> for &ActorRef<A>
115where
116    A: Referable + RemoteHandles<M>,
117    M: RemoteMessage,
118{
119    fn post_with_headers<C>(self, cx: &C, headers: Flattrs, message: M)
120    where
121        C: context::Actor,
122    {
123        RemoteEndpoint::post_with_headers(&self.port(), cx, headers, message)
124    }
125}
126
127// Implement Serialize manually, without requiring A: Serialize
128impl<A: Referable> Serialize for ActorRef<A> {
129    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
130    where
131        S: Serializer,
132    {
133        // Serialize only the fields that don't depend on A
134        self.actor_addr.serialize(serializer)
135    }
136}
137
138// Implement Deserialize manually, without requiring A: Deserialize
139impl<'de, A: Referable> Deserialize<'de> for ActorRef<A> {
140    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141    where
142        D: Deserializer<'de>,
143    {
144        let actor_addr = <ActorAddr>::deserialize(deserializer)?;
145        Ok(ActorRef {
146            actor_addr,
147            phantom: PhantomData,
148        })
149    }
150}
151
152// Implement Debug manually, without requiring A: Debug
153impl<A: Referable> fmt::Debug for ActorRef<A> {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        f.debug_struct("ActorRef")
156            .field("actor_addr", &self.actor_addr)
157            .field("type", &std::any::type_name::<A>())
158            .finish()
159    }
160}
161
162impl<A: Referable> fmt::Display for ActorRef<A> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        fmt::Display::fmt(&self.actor_addr, f)?;
165        write!(f, "<{}>", std::any::type_name::<A>())
166    }
167}
168
169// We implement Clone manually to avoid imposing A: Clone.
170impl<A: Referable> Clone for ActorRef<A> {
171    fn clone(&self) -> Self {
172        Self {
173            actor_addr: self.actor_addr.clone(),
174            phantom: PhantomData,
175        }
176    }
177}
178
179impl<A: Referable> PartialEq for ActorRef<A> {
180    fn eq(&self, other: &Self) -> bool {
181        self.actor_addr == other.actor_addr
182    }
183}
184
185impl<A: Referable> Eq for ActorRef<A> {}
186
187impl<A: Referable> PartialOrd for ActorRef<A> {
188    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
189        Some(self.cmp(other))
190    }
191}
192
193impl<A: Referable> Ord for ActorRef<A> {
194    fn cmp(&self, other: &Self) -> Ordering {
195        self.actor_addr.cmp(&other.actor_addr)
196    }
197}
198
199impl<A: Referable> Hash for ActorRef<A> {
200    fn hash<H: Hasher>(&self, state: &mut H) {
201        self.actor_addr.hash(state);
202    }
203}
204
205/// A reference to a remote port. All messages passed through
206/// PortRefs will be serialized. PortRefs are always streaming.
207#[derive(Debug, Serialize, Deserialize, Derivative, typeuri::Named)]
208#[derivative(PartialEq, Eq, PartialOrd, Hash, Ord)]
209pub struct PortRef<M> {
210    port_addr: PortAddr,
211    #[derivative(
212        PartialEq = "ignore",
213        PartialOrd = "ignore",
214        Ord = "ignore",
215        Hash = "ignore"
216    )]
217    reducer_spec: Option<ReducerSpec>,
218    #[derivative(
219        PartialEq = "ignore",
220        PartialOrd = "ignore",
221        Ord = "ignore",
222        Hash = "ignore"
223    )]
224    streaming_opts: StreamingReducerOpts,
225    phantom: PhantomData<M>,
226    return_undeliverable: bool,
227    #[derivative(
228        PartialEq = "ignore",
229        PartialOrd = "ignore",
230        Ord = "ignore",
231        Hash = "ignore"
232    )]
233    unsplit: bool,
234}
235
236impl<M: RemoteMessage> PortRef<M> {
237    /// The caller attests that the provided port address identifies a
238    /// reachable typed port for message type `M`.
239    pub fn attest(port_addr: PortAddr) -> Self {
240        Self {
241            port_addr,
242            reducer_spec: None,
243            streaming_opts: StreamingReducerOpts::default(),
244            phantom: PhantomData,
245            return_undeliverable: true,
246            unsplit: false,
247        }
248    }
249
250    /// The caller attests that the provided port address identifies a
251    /// reachable typed port for message type `M`.
252    pub fn attest_reducible(
253        port_addr: PortAddr,
254        reducer_spec: Option<ReducerSpec>,
255        streaming_opts: StreamingReducerOpts,
256    ) -> Self {
257        Self {
258            port_addr,
259            reducer_spec,
260            streaming_opts,
261            phantom: PhantomData,
262            return_undeliverable: true,
263            unsplit: false,
264        }
265    }
266
267    /// Prevents the port from being split.
268    pub fn unsplit(mut self) -> Self {
269        self.unsplit = true;
270        self
271    }
272
273    /// The caller attests that the provided actor exposes a reachable handler
274    /// port for message type `M`.
275    pub fn attest_handler_port(actor: &ActorAddr) -> Self {
276        PortRef::<M>::attest(actor.port_addr(Port::from(<M as Named>::port())))
277    }
278
279    /// The typehash of this port's reducer, if any. Reducers
280    /// may be used to coalesce messages sent to a port.
281    pub fn reducer_spec(&self) -> &Option<ReducerSpec> {
282        &self.reducer_spec
283    }
284
285    /// This port's address.
286    pub fn port_addr(&self) -> &PortAddr {
287        &self.port_addr
288    }
289
290    /// Convert this PortRef into its corresponding port address.
291    pub fn into_port_addr(self) -> PortAddr {
292        self.port_addr
293    }
294
295    /// coerce it into OncePortRef so we can send messages to this port from
296    /// APIs requires OncePortRef.
297    pub fn into_once(self) -> OncePortRef<M> {
298        let return_undeliverable = self.return_undeliverable;
299        let unsplit = self.unsplit;
300        let mut once = OncePortRef::attest(self.into_port_addr());
301        once.return_undeliverable = return_undeliverable;
302        once.unsplit = unsplit;
303        once
304    }
305
306    /// Post a serialized message to this port, provided a sending capability, such as
307    /// [`crate::actor::Instance`].
308    pub fn post_serialized(
309        &self,
310        cx: &impl context::Actor,
311        mut headers: Flattrs,
312        message: wirevalue::Any,
313    ) {
314        crate::mailbox::headers::set_send_timestamp(&mut headers);
315        crate::mailbox::headers::set_rust_message_type::<M>(&mut headers);
316        cx.post(
317            self.port_addr.clone(),
318            headers,
319            message,
320            self.return_undeliverable,
321            context::SeqInfoPolicy::AssignNew,
322        );
323    }
324
325    /// Convert this port into a sink that can be used to send messages using the given capability.
326    pub fn into_sink<C: context::Actor>(self, cx: C) -> PortSink<C, M> {
327        PortSink::new(cx, self)
328    }
329
330    /// Get whether or not messages sent to this port that are undeliverable should
331    /// be returned to the sender.
332    pub fn get_return_undeliverable(&self) -> bool {
333        self.return_undeliverable
334    }
335
336    /// Set whether or not messages sent to this port that are undeliverable
337    /// should be returned to the sender.
338    pub fn return_undeliverable(&mut self, return_undeliverable: bool) {
339        self.return_undeliverable = return_undeliverable;
340    }
341}
342
343impl<M> Endpoint<M> for &PortRef<M>
344where
345    M: RemoteMessage,
346{
347    fn endpoint_location(&self) -> EndpointLocation {
348        EndpointLocation::Port(self.port_addr.clone())
349    }
350
351    fn post<C>(self, cx: &C, message: M)
352    where
353        C: context::Actor,
354    {
355        RemoteEndpoint::post_with_headers(self, cx, Flattrs::new(), message)
356    }
357}
358
359impl<M> RemoteEndpoint<M> for &PortRef<M>
360where
361    M: RemoteMessage,
362{
363    fn post_with_headers<C>(self, cx: &C, headers: Flattrs, message: M)
364    where
365        C: context::Actor,
366    {
367        let serialized = match wirevalue::Any::serialize(&message).map_err(|err| {
368            MailboxSenderError::new_bound(
369                self.port_addr.clone(),
370                MailboxSenderErrorKind::Serialize(err.into()),
371            )
372        }) {
373            Ok(serialized) => serialized,
374            Err(err) => {
375                cx.instance()
376                    .report_lost_message(LostMessage::from_send_error::<M>(
377                        cx.mailbox().actor_addr().clone(),
378                        self.endpoint_location(),
379                        &err,
380                    ));
381                return;
382            }
383        };
384        self.post_serialized(cx, headers, serialized);
385    }
386}
387
388impl<M: RemoteMessage> Clone for PortRef<M> {
389    fn clone(&self) -> Self {
390        Self {
391            port_addr: self.port_addr.clone(),
392            reducer_spec: self.reducer_spec.clone(),
393            streaming_opts: self.streaming_opts.clone(),
394            phantom: PhantomData,
395            return_undeliverable: self.return_undeliverable,
396            unsplit: self.unsplit,
397        }
398    }
399}
400
401impl<M: RemoteMessage> fmt::Display for PortRef<M> {
402    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403        fmt::Display::fmt(&self.port_addr, f)
404    }
405}
406
407/// The kind of unbound port.
408#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
409pub enum UnboundPortKind {
410    /// A streaming port, which should be reduced with the provided options.
411    Streaming(Option<StreamingReducerOpts>),
412    /// A OncePort, which must be one-shot aggregated.
413    Once,
414}
415
416/// The parameters extracted from [`PortRef`] to [`Bindings`].
417#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, typeuri::Named)]
418pub struct UnboundPort(
419    pub PortAddr,
420    pub Option<ReducerSpec>,
421    pub bool, // return_undeliverable
422    pub UnboundPortKind,
423    pub bool, // unsplit
424);
425wirevalue::register_type!(UnboundPort);
426
427impl UnboundPort {
428    /// Update the port id of this binding.
429    pub fn update(&mut self, port_addr: PortAddr) {
430        self.0 = port_addr;
431    }
432}
433
434impl<M: RemoteMessage> From<&PortRef<M>> for UnboundPort {
435    fn from(port_ref: &PortRef<M>) -> Self {
436        UnboundPort(
437            port_ref.port_addr.clone(),
438            port_ref.reducer_spec.clone(),
439            port_ref.return_undeliverable,
440            UnboundPortKind::Streaming(Some(port_ref.streaming_opts.clone())),
441            port_ref.unsplit,
442        )
443    }
444}
445
446impl<M: RemoteMessage> Unbind for PortRef<M> {
447    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
448        bindings.push_back(&UnboundPort::from(self))
449    }
450}
451
452impl<M: RemoteMessage> Bind for PortRef<M> {
453    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
454        let UnboundPort(port_addr, reducer_spec, return_undeliverable, port_kind, unsplit) =
455            bindings.try_pop_front::<UnboundPort>()?;
456        self.port_addr = port_addr;
457        self.reducer_spec = reducer_spec;
458        self.return_undeliverable = return_undeliverable;
459        self.unsplit = unsplit;
460        self.streaming_opts = match port_kind {
461            UnboundPortKind::Streaming(opts) => opts.unwrap_or_default(),
462            UnboundPortKind::Once => {
463                anyhow::bail!("OncePortRef cannot be bound to PortRef")
464            }
465        };
466        Ok(())
467    }
468}
469
470/// A remote reference to a [`OncePort`]. References are serializable
471/// and may be passed to remote actors, which can then use it to send
472/// a message to this port.
473#[derive(Debug, Serialize, Deserialize, PartialEq)]
474pub struct OncePortRef<M> {
475    port_addr: PortAddr,
476    reducer_spec: Option<ReducerSpec>,
477    return_undeliverable: bool,
478    unsplit: bool,
479    phantom: PhantomData<M>,
480}
481
482impl<M: RemoteMessage> OncePortRef<M> {
483    pub(crate) fn attest(port_addr: PortAddr) -> Self {
484        Self {
485            port_addr,
486            reducer_spec: None,
487            return_undeliverable: true,
488            unsplit: false,
489            phantom: PhantomData,
490        }
491    }
492
493    /// The caller attests that the provided PortId can be
494    /// converted to a reachable, typed once port reference.
495    pub fn attest_reducible(port_addr: PortAddr, reducer_spec: Option<ReducerSpec>) -> Self {
496        Self {
497            port_addr,
498            reducer_spec,
499            return_undeliverable: true,
500            unsplit: false,
501            phantom: PhantomData,
502        }
503    }
504
505    /// Prevents the port from being split.
506    pub fn unsplit(mut self) -> Self {
507        self.unsplit = true;
508        self
509    }
510
511    /// The typehash of this port's reducer, if any.
512    pub fn reducer_spec(&self) -> &Option<ReducerSpec> {
513        &self.reducer_spec
514    }
515
516    /// This port's address.
517    pub fn port_addr(&self) -> &PortAddr {
518        &self.port_addr
519    }
520
521    /// Convert this OncePortRef into its corresponding port address.
522    pub fn into_port_addr(self) -> PortAddr {
523        self.port_addr
524    }
525
526    /// Get whether or not messages sent to this port that are undeliverable should
527    /// be returned to the sender.
528    pub fn get_return_undeliverable(&self) -> bool {
529        self.return_undeliverable
530    }
531
532    /// Set whether or not messages sent to this port that are undeliverable
533    /// should be returned to the sender.
534    pub fn return_undeliverable(&mut self, return_undeliverable: bool) {
535        self.return_undeliverable = return_undeliverable;
536    }
537}
538
539impl<M> Endpoint<M> for OncePortRef<M>
540where
541    M: RemoteMessage,
542{
543    fn endpoint_location(&self) -> EndpointLocation {
544        EndpointLocation::Port(self.port_addr.clone())
545    }
546
547    fn post<C>(self, cx: &C, message: M)
548    where
549        C: context::Actor,
550    {
551        RemoteEndpoint::post_with_headers(self, cx, Flattrs::new(), message)
552    }
553}
554
555impl<M> RemoteEndpoint<M> for OncePortRef<M>
556where
557    M: RemoteMessage,
558{
559    fn post_with_headers<C>(self, cx: &C, mut headers: Flattrs, message: M)
560    where
561        C: context::Actor,
562    {
563        crate::mailbox::headers::set_send_timestamp(&mut headers);
564        let serialized = match wirevalue::Any::serialize(&message).map_err(|err| {
565            MailboxSenderError::new_bound(
566                self.port_addr.clone(),
567                MailboxSenderErrorKind::Serialize(err.into()),
568            )
569        }) {
570            Ok(serialized) => serialized,
571            Err(err) => {
572                cx.instance()
573                    .report_lost_message(LostMessage::from_send_error::<M>(
574                        cx.mailbox().actor_addr().clone(),
575                        self.endpoint_location(),
576                        &err,
577                    ));
578                return;
579            }
580        };
581        cx.post(
582            self.port_addr.clone(),
583            headers,
584            serialized,
585            self.return_undeliverable,
586            context::SeqInfoPolicy::AssignNew,
587        );
588    }
589}
590
591impl<M: RemoteMessage> Clone for OncePortRef<M> {
592    fn clone(&self) -> Self {
593        Self {
594            port_addr: self.port_addr.clone(),
595            reducer_spec: self.reducer_spec.clone(),
596            return_undeliverable: self.return_undeliverable,
597            unsplit: self.unsplit,
598            phantom: PhantomData,
599        }
600    }
601}
602
603impl<M: RemoteMessage> fmt::Display for OncePortRef<M> {
604    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
605        fmt::Display::fmt(&self.port_addr, f)
606    }
607}
608
609impl<M: RemoteMessage> Named for OncePortRef<M> {
610    fn typename() -> &'static str {
611        wirevalue::intern_typename!(Self, "hyperactor::mailbox::OncePortRef<{}>", M)
612    }
613}
614
615impl<M: RemoteMessage> From<&OncePortRef<M>> for UnboundPort {
616    fn from(port_ref: &OncePortRef<M>) -> Self {
617        UnboundPort(
618            port_ref.port_addr.clone(),
619            port_ref.reducer_spec.clone(),
620            true, // return_undeliverable
621            UnboundPortKind::Once,
622            port_ref.unsplit,
623        )
624    }
625}
626
627impl<M: RemoteMessage> Unbind for OncePortRef<M> {
628    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
629        bindings.push_back(&UnboundPort::from(self))
630    }
631}
632
633impl<M: RemoteMessage> Bind for OncePortRef<M> {
634    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
635        let UnboundPort(port_addr, reducer_spec, _return_undeliverable, port_kind, unsplit) =
636            bindings.try_pop_front::<UnboundPort>()?;
637        match port_kind {
638            UnboundPortKind::Once => {
639                self.port_addr = port_addr;
640                self.reducer_spec = reducer_spec;
641                self.unsplit = unsplit;
642                Ok(())
643            }
644            UnboundPortKind::Streaming(_) => {
645                anyhow::bail!("PortRef cannot be bound to OncePortRef")
646            }
647        }
648    }
649}