1use std::mem::take;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::OnceLock;
21
22use async_trait::async_trait;
23use backoff::ExponentialBackoffBuilder;
24use backoff::backoff::Backoff;
25use dashmap::DashSet;
26use hyperactor_config::Flattrs;
27
28use crate::Instance;
29use crate::accum;
30use crate::accum::ErasedCommReducer;
31use crate::accum::ReducerMode;
32use crate::accum::ReducerSpec;
33use crate::config;
34use crate::mailbox;
35use crate::mailbox::MailboxSender;
36use crate::mailbox::MessageEnvelope;
37use crate::ordering::SEQ_INFO;
38use crate::reference;
39use crate::time::Alarm;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub(crate) enum SeqInfoPolicy {
44 AssignNew,
46 AllowExternal,
48}
49
50pub trait Mailbox: crate::private::Sealed + Send + Sync {
52 fn mailbox(&self) -> &crate::Mailbox;
54}
55
56#[async_trait]
61pub trait Actor: Mailbox {
62 type A: crate::Actor;
64
65 fn instance(&self) -> &Instance<Self::A>;
67}
68
69pub(crate) trait MailboxExt: Mailbox {
72 fn post(
75 &self,
76 dest: reference::PortId,
77 headers: Flattrs,
78 data: wirevalue::Any,
79 return_undeliverable: bool,
80 seq_info_policy: SeqInfoPolicy,
81 );
82
83 fn split(
85 &self,
86 port_id: reference::PortId,
87 reducer_spec: Option<ReducerSpec>,
88 reducer_mode: ReducerMode,
89 return_undeliverable: bool,
90 ) -> anyhow::Result<reference::PortId>;
91}
92
93static CAN_SEND_WARNED_MAILBOXES: OnceLock<DashSet<reference::ActorId>> = OnceLock::new();
98
99impl<T: Actor + Send + Sync> MailboxExt for T {
101 fn post(
102 &self,
103 dest: reference::PortId,
104 mut headers: Flattrs,
105 data: wirevalue::Any,
106 return_undeliverable: bool,
107 seq_info_policy: SeqInfoPolicy,
108 ) {
109 let return_handle = self.mailbox().bound_return_handle().unwrap_or_else(|| {
110 let actor_id = self.mailbox().actor_id();
111 if CAN_SEND_WARNED_MAILBOXES
112 .get_or_init(DashSet::new)
113 .insert(actor_id.clone())
114 {
115 let bt = std::backtrace::Backtrace::force_capture();
116 tracing::warn!(
117 actor_id = ?actor_id,
118 backtrace = ?bt,
119 "mailbox attempted to post a message without binding Undeliverable<MessageEnvelope>"
120 );
121 }
122 mailbox::monitored_return_handle()
123 });
124
125 assert!(
126 !headers.contains_key(SEQ_INFO) || seq_info_policy == SeqInfoPolicy::AllowExternal,
127 "SEQ_INFO must not be set on headers outside of fn post unless explicitly allowed"
128 );
129
130 if !headers.contains_key(SEQ_INFO) {
131 let sequencer = self.instance().sequencer();
134 let seq_info = sequencer.assign_seq(&dest);
135 headers.set(SEQ_INFO, seq_info);
136 }
137
138 let mut envelope =
139 MessageEnvelope::new(self.mailbox().actor_id().clone(), dest, data, headers);
140 envelope.set_return_undeliverable(return_undeliverable);
141 MailboxSender::post(self.mailbox(), envelope, return_handle);
142 }
143
144 fn split(
145 &self,
146 port_id: reference::PortId,
147 reducer_spec: Option<ReducerSpec>,
148 reducer_mode: ReducerMode,
149 return_undeliverable: bool,
150 ) -> anyhow::Result<reference::PortId> {
151 fn post(
152 mailbox: &mailbox::Mailbox,
153 port_id: reference::PortId,
154 msg: wirevalue::Any,
155 return_undeliverable: bool,
156 ) {
157 let mut envelope =
158 MessageEnvelope::new(mailbox.actor_id().clone(), port_id, msg, Flattrs::new());
159 envelope.set_return_undeliverable(return_undeliverable);
160 mailbox::MailboxSender::post(
161 mailbox,
162 envelope,
163 mailbox::monitored_return_handle(),
168 );
169 }
170
171 let port_index = self.mailbox().allocate_port();
172 let split_port = self.mailbox().actor_id().port_id(port_index);
173 let mailbox = self.mailbox().clone();
174 let reducer = reducer_spec
175 .map(
176 |ReducerSpec {
177 typehash,
178 builder_params,
179 }| { accum::resolve_reducer(typehash, builder_params) },
180 )
181 .transpose()?
182 .flatten();
183 let enqueue: Box<
184 dyn Fn(wirevalue::Any) -> Result<bool, (wirevalue::Any, anyhow::Error)> + Send + Sync,
185 > = match reducer {
186 None => Box::new(move |serialized: wirevalue::Any| {
187 post(&mailbox, port_id.clone(), serialized, return_undeliverable);
188 Ok(true)
189 }),
190 Some(reducer) => match reducer_mode {
191 ReducerMode::Streaming(_) => {
192 let buffer: Arc<Mutex<UpdateBuffer>> =
193 Arc::new(Mutex::new(UpdateBuffer::new(reducer)));
194
195 let alarm = Alarm::new();
196
197 {
198 let mut sleeper = alarm.sleeper();
199 let buffer = Arc::clone(&buffer);
200 let port_id = port_id.clone();
201 let mailbox = mailbox.clone();
202 tokio::spawn(async move {
203 while sleeper.sleep().await {
204 let mut buf = buffer.lock().unwrap();
205 match buf.reduce() {
206 None => (),
207 Some(Ok(reduced)) => post(
208 &mailbox,
209 port_id.clone(),
210 reduced,
211 return_undeliverable,
212 ),
213 Some(Err(e)) => tracing::error!(
220 "error while reducing update: {}; waiting until the next send to propagate",
221 e
222 ),
223 }
224 }
225 });
226 }
227
228 let alarm = Mutex::new(alarm);
232
233 let max_interval = reducer_mode.max_update_interval();
234 let initial_interval = reducer_mode.initial_update_interval();
235
236 let backoff = Mutex::new(
239 ExponentialBackoffBuilder::new()
240 .with_initial_interval(initial_interval)
241 .with_multiplier(2.0)
242 .with_max_interval(max_interval)
243 .with_max_elapsed_time(None)
244 .build(),
245 );
246
247 Box::new(move |update: wirevalue::Any| {
248 let mut buf = buffer.lock().unwrap();
254 match buf.push(update) {
255 None => {
256 let interval = backoff.lock().unwrap().next_backoff().unwrap();
257 alarm.lock().unwrap().rearm(interval);
258 Ok(true)
259 }
260 Some(Ok(reduced)) => {
261 alarm.lock().unwrap().disarm();
262 post(&mailbox, port_id.clone(), reduced, return_undeliverable);
263 Ok(true)
264 }
265 Some(Err(e)) => Err((buf.pop().unwrap(), e)),
266 }
267 })
268 }
269 ReducerMode::Once(0) => Box::new(move |update: wirevalue::Any| {
270 Err((
271 update,
272 anyhow::anyhow!(
273 "invalid ReducerMode: Once must specify at least one update"
274 ),
275 ))
276 }),
277 ReducerMode::Once(expected) => {
278 let buffer: Arc<Mutex<OnceBuffer>> =
279 Arc::new(Mutex::new(OnceBuffer::new(reducer, expected)));
280
281 Box::new(move |update: wirevalue::Any| {
282 let mut buf = buffer.lock().unwrap();
283 if buf.done {
284 return Err((
285 update,
286 anyhow::anyhow!("OnceReducer has already emitted"),
287 ));
288 }
289 match buf.push(update) {
290 Ok(Some(reduced)) => {
291 post(&mailbox, port_id.clone(), reduced, return_undeliverable);
292 Ok(false) }
294 Ok(None) => Ok(true),
295 Err(e) => Err(e),
296 }
297 })
298 }
299 },
300 };
301 self.mailbox().bind_untyped(
302 &split_port,
303 mailbox::UntypedUnboundedSender {
304 sender: enqueue,
305 port_id: split_port.clone(),
306 },
307 );
308 Ok(split_port)
309 }
310}
311
312struct UpdateBuffer {
313 buffered: Vec<wirevalue::Any>,
314 reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
315}
316
317impl UpdateBuffer {
318 fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>) -> Self {
319 Self {
320 buffered: Vec::new(),
321 reducer,
322 }
323 }
324
325 fn pop(&mut self) -> Option<wirevalue::Any> {
326 self.buffered.pop()
327 }
328
329 fn push(&mut self, serialized: wirevalue::Any) -> Option<anyhow::Result<wirevalue::Any>> {
332 let limit = hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_SIZE);
333
334 self.buffered.push(serialized);
335 if self.buffered.len() >= limit {
336 self.reduce()
337 } else {
338 None
339 }
340 }
341
342 fn reduce(&mut self) -> Option<anyhow::Result<wirevalue::Any>> {
343 if self.buffered.is_empty() {
344 None
345 } else {
346 match self.reducer.reduce_updates(take(&mut self.buffered)) {
347 Ok(reduced) => Some(Ok(reduced)),
348 Err((e, b)) => {
349 self.buffered = b;
350 Some(Err(e))
351 }
352 }
353 }
354 }
355}
356
357struct OnceBuffer {
358 accumulated: Option<wirevalue::Any>,
359 reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
360 expected: usize,
361 count: usize,
362 done: bool,
363}
364
365impl OnceBuffer {
366 fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>, expected: usize) -> Self {
367 Self {
368 accumulated: None,
369 reducer,
370 expected,
371 count: 0,
372 done: false,
373 }
374 }
375
376 fn push(
380 &mut self,
381 value: wirevalue::Any,
382 ) -> Result<Option<wirevalue::Any>, (wirevalue::Any, anyhow::Error)> {
383 self.count += 1;
384 self.accumulated = match self.accumulated.take() {
385 None => Some(value),
386 Some(acc) => match self.reducer.reduce_updates(vec![acc, value]) {
387 Ok(reduced) => Some(reduced),
388 Err((e, mut rejected)) => {
389 return Err((
390 rejected
391 .pop()
392 .unwrap_or_else(|| wirevalue::Any::serialize(&()).unwrap()),
393 e,
394 ));
395 }
396 },
397 };
398 if self.count >= self.expected {
399 self.done = true;
400 Ok(self.accumulated.take())
401 } else {
402 Ok(None)
403 }
404 }
405}