hyperactor/
context.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines traits that are used as context arguments to various
10//! hyperactor APIs; usually [`crate::context::Actor`], implemented by
11//! [`crate::proc::Context`] (provided to actor handlers) and [`crate::proc::Instance`],
12//! representing a running actor instance.
13//!
14//! Context traits are sealed, and thus can only be implemented by data types in the
15//! core hyperactor crate.
16
17use std::mem::take;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::OnceLock;
21
22use async_trait::async_trait;
23use backoff::ExponentialBackoffBuilder;
24use backoff::backoff::Backoff;
25use dashmap::DashSet;
26use hyperactor_config::Flattrs;
27
28use crate::Instance;
29use crate::accum;
30use crate::accum::ErasedCommReducer;
31use crate::accum::ReducerMode;
32use crate::accum::ReducerSpec;
33use crate::config;
34use crate::mailbox;
35use crate::mailbox::MailboxSender;
36use crate::mailbox::MessageEnvelope;
37use crate::ordering::SEQ_INFO;
38use crate::reference;
39use crate::time::Alarm;
40
41/// Policy for handling SEQ_INFO in message headers.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub(crate) enum SeqInfoPolicy {
44    /// Assign a new sequence number. Panics if SEQ_INFO is already set.
45    AssignNew,
46    /// Allow externally-set SEQ_INFO. Used only by CommActor for mesh routing.
47    AllowExternal,
48}
49
50/// A mailbox context provides a mailbox.
51pub trait Mailbox: crate::private::Sealed + Send + Sync {
52    /// The mailbox associated with this context
53    fn mailbox(&self) -> &crate::Mailbox;
54}
55
56/// A typed actor context, providing both a [`Mailbox`] and an [`Instance`].
57///
58/// Note: Send and Sync markers are here only temporarily in order to bridge
59/// the transition to the context types, away from the [`crate::cap`] module.
60#[async_trait]
61pub trait Actor: Mailbox {
62    /// The type of actor associated with this context.
63    type A: crate::Actor;
64
65    /// The instance associated with this context.
66    fn instance(&self) -> &Instance<Self::A>;
67}
68
69/// An internal extension trait for Mailbox contexts.
70/// TODO: consider moving this to another module.
71pub(crate) trait MailboxExt: Mailbox {
72    /// Post a message to the provided destination with the provided headers, and data.
73    /// All messages posted from actors should use this implementation.
74    fn post(
75        &self,
76        dest: reference::PortId,
77        headers: Flattrs,
78        data: wirevalue::Any,
79        return_undeliverable: bool,
80        seq_info_policy: SeqInfoPolicy,
81    );
82
83    /// Split a port, using a provided reducer spec, if provided.
84    fn split(
85        &self,
86        port_id: reference::PortId,
87        reducer_spec: Option<ReducerSpec>,
88        reducer_mode: ReducerMode,
89        return_undeliverable: bool,
90    ) -> anyhow::Result<reference::PortId>;
91}
92
93// Tracks mailboxes that have emitted a `CanSend::post` warning due to
94// missing an `Undeliverable<MessageEnvelope>` binding. In this
95// context, mailboxes are few and long-lived; unbounded growth is not
96// a realistic concern.
97static CAN_SEND_WARNED_MAILBOXES: OnceLock<DashSet<reference::ActorId>> = OnceLock::new();
98
99/// Only actors CanSend because they need a return port.
100impl<T: Actor + Send + Sync> MailboxExt for T {
101    fn post(
102        &self,
103        dest: reference::PortId,
104        mut headers: Flattrs,
105        data: wirevalue::Any,
106        return_undeliverable: bool,
107        seq_info_policy: SeqInfoPolicy,
108    ) {
109        let return_handle = self.mailbox().bound_return_handle().unwrap_or_else(|| {
110            let actor_id = self.mailbox().actor_id();
111            if CAN_SEND_WARNED_MAILBOXES
112                .get_or_init(DashSet::new)
113                .insert(actor_id.clone())
114            {
115                let bt = std::backtrace::Backtrace::force_capture();
116                tracing::warn!(
117                    actor_id = ?actor_id,
118                    backtrace = ?bt,
119                    "mailbox attempted to post a message without binding Undeliverable<MessageEnvelope>"
120                );
121            }
122            mailbox::monitored_return_handle()
123        });
124
125        assert!(
126            !headers.contains_key(SEQ_INFO) || seq_info_policy == SeqInfoPolicy::AllowExternal,
127            "SEQ_INFO must not be set on headers outside of fn post unless explicitly allowed"
128        );
129
130        if !headers.contains_key(SEQ_INFO) {
131            // This method is infallible so is okay to assign the sequence number
132            // without worrying about rollback.
133            let sequencer = self.instance().sequencer();
134            let seq_info = sequencer.assign_seq(&dest);
135            headers.set(SEQ_INFO, seq_info);
136        }
137
138        let mut envelope =
139            MessageEnvelope::new(self.mailbox().actor_id().clone(), dest, data, headers);
140        envelope.set_return_undeliverable(return_undeliverable);
141        MailboxSender::post(self.mailbox(), envelope, return_handle);
142    }
143
144    fn split(
145        &self,
146        port_id: reference::PortId,
147        reducer_spec: Option<ReducerSpec>,
148        reducer_mode: ReducerMode,
149        return_undeliverable: bool,
150    ) -> anyhow::Result<reference::PortId> {
151        fn post(
152            mailbox: &mailbox::Mailbox,
153            port_id: reference::PortId,
154            msg: wirevalue::Any,
155            return_undeliverable: bool,
156        ) {
157            let mut envelope =
158                MessageEnvelope::new(mailbox.actor_id().clone(), port_id, msg, Flattrs::new());
159            envelope.set_return_undeliverable(return_undeliverable);
160            mailbox::MailboxSender::post(
161                mailbox,
162                envelope,
163                // TODO(pzhang) figure out how to use upstream's return handle,
164                // instead of getting a new one like this.
165                // This is okay for now because upstream is currently also using
166                // the same handle singleton, but that could change in the future.
167                mailbox::monitored_return_handle(),
168            );
169        }
170
171        let port_index = self.mailbox().allocate_port();
172        let split_port = self.mailbox().actor_id().port_id(port_index);
173        let mailbox = self.mailbox().clone();
174        let reducer = reducer_spec
175            .map(
176                |ReducerSpec {
177                     typehash,
178                     builder_params,
179                 }| { accum::resolve_reducer(typehash, builder_params) },
180            )
181            .transpose()?
182            .flatten();
183        let enqueue: Box<
184            dyn Fn(wirevalue::Any) -> Result<bool, (wirevalue::Any, anyhow::Error)> + Send + Sync,
185        > = match reducer {
186            None => Box::new(move |serialized: wirevalue::Any| {
187                post(&mailbox, port_id.clone(), serialized, return_undeliverable);
188                Ok(true)
189            }),
190            Some(reducer) => match reducer_mode {
191                ReducerMode::Streaming(_) => {
192                    let buffer: Arc<Mutex<UpdateBuffer>> =
193                        Arc::new(Mutex::new(UpdateBuffer::new(reducer)));
194
195                    let alarm = Alarm::new();
196
197                    {
198                        let mut sleeper = alarm.sleeper();
199                        let buffer = Arc::clone(&buffer);
200                        let port_id = port_id.clone();
201                        let mailbox = mailbox.clone();
202                        tokio::spawn(async move {
203                            while sleeper.sleep().await {
204                                let mut buf = buffer.lock().unwrap();
205                                match buf.reduce() {
206                                    None => (),
207                                    Some(Ok(reduced)) => post(
208                                        &mailbox,
209                                        port_id.clone(),
210                                        reduced,
211                                        return_undeliverable,
212                                    ),
213                                    // We simply ignore errors here, and let them be propagated
214                                    // later in the enqueueing function.
215                                    //
216                                    // If this is the last update, then this strategy will cause a hang.
217                                    // We should obtain a supervisor here from our send context and notify
218                                    // it.
219                                    Some(Err(e)) => tracing::error!(
220                                        "error while reducing update: {}; waiting until the next send to propagate",
221                                        e
222                                    ),
223                                }
224                            }
225                        });
226                    }
227
228                    // Note: alarm is held in the closure while the port is active;
229                    // when it is dropped, the alarm terminates, and so does the sleeper
230                    // task.
231                    let alarm = Mutex::new(alarm);
232
233                    let max_interval = reducer_mode.max_update_interval();
234                    let initial_interval = reducer_mode.initial_update_interval();
235
236                    // Create exponential backoff for buffer flush interval, starting at
237                    // initial_interval and growing to max_interval
238                    let backoff = Mutex::new(
239                        ExponentialBackoffBuilder::new()
240                            .with_initial_interval(initial_interval)
241                            .with_multiplier(2.0)
242                            .with_max_interval(max_interval)
243                            .with_max_elapsed_time(None)
244                            .build(),
245                    );
246
247                    Box::new(move |update: wirevalue::Any| {
248                        // Hold the lock until messages are sent. This is to avoid another
249                        // invocation of this method trying to send message concurrently and
250                        // cause messages delivered out of order.
251                        //
252                        // We also always acquire alarm *after* the buffer, to avoid deadlocks.
253                        let mut buf = buffer.lock().unwrap();
254                        match buf.push(update) {
255                            None => {
256                                let interval = backoff.lock().unwrap().next_backoff().unwrap();
257                                alarm.lock().unwrap().rearm(interval);
258                                Ok(true)
259                            }
260                            Some(Ok(reduced)) => {
261                                alarm.lock().unwrap().disarm();
262                                post(&mailbox, port_id.clone(), reduced, return_undeliverable);
263                                Ok(true)
264                            }
265                            Some(Err(e)) => Err((buf.pop().unwrap(), e)),
266                        }
267                    })
268                }
269                ReducerMode::Once(0) => Box::new(move |update: wirevalue::Any| {
270                    Err((
271                        update,
272                        anyhow::anyhow!(
273                            "invalid ReducerMode: Once must specify at least one update"
274                        ),
275                    ))
276                }),
277                ReducerMode::Once(expected) => {
278                    let buffer: Arc<Mutex<OnceBuffer>> =
279                        Arc::new(Mutex::new(OnceBuffer::new(reducer, expected)));
280
281                    Box::new(move |update: wirevalue::Any| {
282                        let mut buf = buffer.lock().unwrap();
283                        if buf.done {
284                            return Err((
285                                update,
286                                anyhow::anyhow!("OnceReducer has already emitted"),
287                            ));
288                        }
289                        match buf.push(update) {
290                            Ok(Some(reduced)) => {
291                                post(&mailbox, port_id.clone(), reduced, return_undeliverable);
292                                Ok(false) // Done, tear down the port
293                            }
294                            Ok(None) => Ok(true),
295                            Err(e) => Err(e),
296                        }
297                    })
298                }
299            },
300        };
301        self.mailbox().bind_untyped(
302            &split_port,
303            mailbox::UntypedUnboundedSender {
304                sender: enqueue,
305                port_id: split_port.clone(),
306            },
307        );
308        Ok(split_port)
309    }
310}
311
312struct UpdateBuffer {
313    buffered: Vec<wirevalue::Any>,
314    reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
315}
316
317impl UpdateBuffer {
318    fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>) -> Self {
319        Self {
320            buffered: Vec::new(),
321            reducer,
322        }
323    }
324
325    fn pop(&mut self) -> Option<wirevalue::Any> {
326        self.buffered.pop()
327    }
328
329    /// Push a new item to the buffer, and optionally return any items that should
330    /// be flushed.
331    fn push(&mut self, serialized: wirevalue::Any) -> Option<anyhow::Result<wirevalue::Any>> {
332        let limit = hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_SIZE);
333
334        self.buffered.push(serialized);
335        if self.buffered.len() >= limit {
336            self.reduce()
337        } else {
338            None
339        }
340    }
341
342    fn reduce(&mut self) -> Option<anyhow::Result<wirevalue::Any>> {
343        if self.buffered.is_empty() {
344            None
345        } else {
346            match self.reducer.reduce_updates(take(&mut self.buffered)) {
347                Ok(reduced) => Some(Ok(reduced)),
348                Err((e, b)) => {
349                    self.buffered = b;
350                    Some(Err(e))
351                }
352            }
353        }
354    }
355}
356
357struct OnceBuffer {
358    accumulated: Option<wirevalue::Any>,
359    reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
360    expected: usize,
361    count: usize,
362    done: bool,
363}
364
365impl OnceBuffer {
366    fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>, expected: usize) -> Self {
367        Self {
368            accumulated: None,
369            reducer,
370            expected,
371            count: 0,
372            done: false,
373        }
374    }
375
376    /// Push a new value and reduce incrementally. Returns Ok(Some(reduced)) when
377    /// the expected count is reached, Ok(None) while still accumulating. On error,
378    /// the buffer is broken and returns the rejected value.
379    fn push(
380        &mut self,
381        value: wirevalue::Any,
382    ) -> Result<Option<wirevalue::Any>, (wirevalue::Any, anyhow::Error)> {
383        self.count += 1;
384        self.accumulated = match self.accumulated.take() {
385            None => Some(value),
386            Some(acc) => match self.reducer.reduce_updates(vec![acc, value]) {
387                Ok(reduced) => Some(reduced),
388                Err((e, mut rejected)) => {
389                    return Err((
390                        rejected
391                            .pop()
392                            .unwrap_or_else(|| wirevalue::Any::serialize(&()).unwrap()),
393                        e,
394                    ));
395                }
396            },
397        };
398        if self.count >= self.expected {
399            self.done = true;
400            Ok(self.accumulated.take())
401        } else {
402            Ok(None)
403        }
404    }
405}