1use std::mem::take;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::OnceLock;
21
22use async_trait::async_trait;
23use backoff::ExponentialBackoffBuilder;
24use backoff::backoff::Backoff;
25use dashmap::DashSet;
26use hyperactor_config::Flattrs;
27
28use crate::ActorAddr;
29use crate::Instance;
30use crate::PortAddr;
31use crate::Proc;
32use crate::accum;
33use crate::accum::ErasedCommReducer;
34use crate::accum::ReducerMode;
35use crate::accum::ReducerSpec;
36use crate::config;
37use crate::mailbox;
38use crate::mailbox::MailboxSender;
39use crate::mailbox::MessageEnvelope;
40use crate::ordering::SEQ_INFO;
41use crate::port::Port;
42use crate::time::Alarm;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum SeqInfoPolicy {
47 AssignNew,
49 AllowExternal,
51}
52
53pub trait Mailbox: crate::private::Sealed + Send + Sync {
55 fn mailbox(&self) -> &crate::Mailbox;
57}
58
59#[async_trait]
64pub trait Actor: Mailbox {
65 type A: crate::Actor;
67
68 fn instance(&self) -> &Instance<Self::A>;
70}
71
72pub(crate) trait MailboxExt: Mailbox {
75 fn post(
78 &self,
79 dest: PortAddr,
80 headers: Flattrs,
81 data: wirevalue::Any,
82 return_undeliverable: bool,
83 seq_info_policy: SeqInfoPolicy,
84 );
85
86 fn split(
88 &self,
89 port_id: PortAddr,
90 reducer_spec: Option<ReducerSpec>,
91 reducer_mode: ReducerMode,
92 return_undeliverable: bool,
93 ) -> anyhow::Result<PortAddr>;
94}
95
96static CAN_SEND_WARNED_MAILBOXES: OnceLock<DashSet<ActorAddr>> = OnceLock::new();
101
102impl<T: Actor + Send + Sync> MailboxExt for T {
104 fn post(
105 &self,
106 dest: PortAddr,
107 mut headers: Flattrs,
108 data: wirevalue::Any,
109 return_undeliverable: bool,
110 seq_info_policy: SeqInfoPolicy,
111 ) {
112 let return_handle = self.mailbox().bound_return_handle().unwrap_or_else(|| {
113 let actor_id = self.mailbox().actor_addr();
114 if CAN_SEND_WARNED_MAILBOXES
115 .get_or_init(DashSet::new)
116 .insert(actor_id.clone())
117 {
118 let bt = std::backtrace::Backtrace::force_capture();
119 tracing::warn!(
120 actor_id = ?actor_id,
121 backtrace = ?bt,
122 "mailbox attempted to post a message without binding Undeliverable<MessageEnvelope>"
123 );
124 }
125 mailbox::monitored_return_handle()
126 });
127
128 assert!(
129 !headers.contains_key(SEQ_INFO) || seq_info_policy == SeqInfoPolicy::AllowExternal,
130 "SEQ_INFO must not be set on headers outside of fn post unless explicitly allowed"
131 );
132
133 if !headers.contains_key(SEQ_INFO) {
134 let sequencer = self.instance().sequencer();
137 let seq_info = sequencer.assign_seq(&dest);
138 headers.set(SEQ_INFO, seq_info);
139 }
140
141 let mut envelope =
142 MessageEnvelope::new(self.mailbox().actor_addr().clone(), dest, data, headers);
143 envelope.set_return_undeliverable(return_undeliverable);
144 MailboxSender::post(self.instance().proc(), envelope, return_handle);
145 }
146
147 fn split(
148 &self,
149 port_id: PortAddr,
150 reducer_spec: Option<ReducerSpec>,
151 reducer_mode: ReducerMode,
152 return_undeliverable: bool,
153 ) -> anyhow::Result<PortAddr> {
154 fn post(
155 proc: &Proc,
156 sender: &ActorAddr,
157 port_id: PortAddr,
158 msg: wirevalue::Any,
159 return_undeliverable: bool,
160 ) {
161 let mut envelope = MessageEnvelope::new(sender.clone(), port_id, msg, Flattrs::new());
162 envelope.set_return_undeliverable(return_undeliverable);
163 mailbox::MailboxSender::post(
164 proc,
165 envelope,
166 mailbox::monitored_return_handle(),
171 );
172 }
173
174 let port_index = self.mailbox().allocate_port();
175 let split_port = self
176 .mailbox()
177 .actor_addr()
178 .port_addr(Port::from(port_index));
179 let proc = self.instance().proc().clone();
180 let sender = self.mailbox().actor_addr().clone();
181 let reducer = reducer_spec
182 .map(
183 |ReducerSpec {
184 typehash,
185 builder_params,
186 }| { accum::resolve_reducer(typehash, builder_params) },
187 )
188 .transpose()?
189 .flatten();
190 let enqueue: Box<
191 dyn Fn(
192 Flattrs,
193 wirevalue::Any,
194 )
195 -> Result<mailbox::SerializedSendDisposition, mailbox::SerializedSendFailure>
196 + Send
197 + Sync,
198 > = match reducer {
199 None => {
200 let proc = proc.clone();
201 let sender = sender.clone();
202 Box::new(move |_headers: Flattrs, serialized: wirevalue::Any| {
203 post(
204 &proc,
205 &sender,
206 port_id.clone(),
207 serialized,
208 return_undeliverable,
209 );
210 Ok(mailbox::SerializedSendDisposition::Delivered)
211 })
212 }
213 Some(reducer) => match reducer_mode {
214 ReducerMode::Streaming(_) => {
215 let buffer: Arc<Mutex<UpdateBuffer>> =
216 Arc::new(Mutex::new(UpdateBuffer::new(reducer)));
217
218 let alarm = Alarm::new();
219
220 {
221 let mut sleeper = alarm.sleeper();
222 let buffer = Arc::clone(&buffer);
223 let port_id = port_id.clone();
224 let proc = proc.clone();
225 let sender = sender.clone();
226 tokio::spawn(async move {
227 while sleeper.sleep().await {
228 let mut buf = buffer.lock().unwrap();
229 match buf.reduce() {
230 None => (),
231 Some(Ok(reduced)) => post(
232 &proc,
233 &sender,
234 port_id.clone(),
235 reduced,
236 return_undeliverable,
237 ),
238 Some(Err(e)) => tracing::error!(
245 "error while reducing update: {}; waiting until the next send to propagate",
246 e
247 ),
248 }
249 }
250 });
251 }
252
253 let alarm = Mutex::new(alarm);
257
258 let max_interval = reducer_mode.max_update_interval();
259 let initial_interval = reducer_mode.initial_update_interval();
260
261 let backoff = Mutex::new(
264 ExponentialBackoffBuilder::new()
265 .with_initial_interval(initial_interval)
266 .with_multiplier(2.0)
267 .with_max_interval(max_interval)
268 .with_max_elapsed_time(None)
269 .build(),
270 );
271
272 let error_port_id = split_port.clone();
273 Box::new(move |headers: Flattrs, update: wirevalue::Any| {
274 let mut buf = buffer.lock().unwrap();
280 match buf.push(update) {
281 None => {
282 let interval = backoff.lock().unwrap().next_backoff().unwrap();
283 alarm.lock().unwrap().rearm(interval);
284 Ok(mailbox::SerializedSendDisposition::Delivered)
285 }
286 Some(Ok(reduced)) => {
287 alarm.lock().unwrap().disarm();
288 post(
289 &proc,
290 &sender,
291 port_id.clone(),
292 reduced,
293 return_undeliverable,
294 );
295 Ok(mailbox::SerializedSendDisposition::Delivered)
296 }
297 Some(Err(error)) => Err(mailbox::SerializedSendFailure::Error(
298 mailbox::SerializedSendError {
299 data: buf
300 .pop()
301 .expect("reducer error should leave update buffered"),
302 error: crate::mailbox::MailboxSenderError::new_bound(
303 error_port_id.clone(),
304 crate::mailbox::MailboxSenderErrorKind::Other(error),
305 ),
306 headers,
307 },
308 )),
309 }
310 })
311 }
312 ReducerMode::Once(0) => {
313 let error_port_id = split_port.clone();
314 Box::new(move |headers: Flattrs, update: wirevalue::Any| {
315 Err(mailbox::SerializedSendFailure::Error(
316 mailbox::SerializedSendError {
317 data: update,
318 error: crate::mailbox::MailboxSenderError::new_bound(
319 error_port_id.clone(),
320 crate::mailbox::MailboxSenderErrorKind::Other(anyhow::anyhow!(
321 "invalid ReducerMode: Once must specify at least one update"
322 )),
323 ),
324 headers,
325 },
326 ))
327 })
328 }
329 ReducerMode::Once(expected) => {
330 let buffer: Arc<Mutex<OnceBuffer>> =
331 Arc::new(Mutex::new(OnceBuffer::new(reducer, expected)));
332 let error_port_id = split_port.clone();
333 let proc = proc.clone();
334 let sender = sender.clone();
335
336 Box::new(move |headers: Flattrs, update: wirevalue::Any| {
337 let mut buf = buffer.lock().unwrap();
338 if buf.done {
339 return Err(mailbox::SerializedSendFailure::Dead {
340 data: update,
341 headers,
342 });
343 }
344 match buf.push(update) {
345 Ok(Some(reduced)) => {
346 post(
347 &proc,
348 &sender,
349 port_id.clone(),
350 reduced,
351 return_undeliverable,
352 );
353 Ok(mailbox::SerializedSendDisposition::DeliveredAndExhausted)
354 }
355 Ok(None) => Ok(mailbox::SerializedSendDisposition::Delivered),
356 Err((data, error)) => Err(mailbox::SerializedSendFailure::Error(
357 mailbox::SerializedSendError {
358 data,
359 error: crate::mailbox::MailboxSenderError::new_bound(
360 error_port_id.clone(),
361 crate::mailbox::MailboxSenderErrorKind::Other(error),
362 ),
363 headers,
364 },
365 )),
366 }
367 })
368 }
369 },
370 };
371 self.mailbox().bind_untyped(
372 &split_port,
373 mailbox::UntypedUnboundedSender { sender: enqueue },
374 );
375 Ok(split_port)
376 }
377}
378
379struct UpdateBuffer {
380 buffered: Vec<wirevalue::Any>,
381 reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
382}
383
384impl UpdateBuffer {
385 fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>) -> Self {
386 Self {
387 buffered: Vec::new(),
388 reducer,
389 }
390 }
391
392 fn pop(&mut self) -> Option<wirevalue::Any> {
393 self.buffered.pop()
394 }
395
396 fn push(&mut self, serialized: wirevalue::Any) -> Option<anyhow::Result<wirevalue::Any>> {
399 let limit = hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_SIZE);
400
401 self.buffered.push(serialized);
402 if self.buffered.len() >= limit {
403 self.reduce()
404 } else {
405 None
406 }
407 }
408
409 fn reduce(&mut self) -> Option<anyhow::Result<wirevalue::Any>> {
410 if self.buffered.is_empty() {
411 None
412 } else {
413 match self.reducer.reduce_updates(take(&mut self.buffered)) {
414 Ok(reduced) => Some(Ok(reduced)),
415 Err((e, b)) => {
416 self.buffered = b;
417 Some(Err(e))
418 }
419 }
420 }
421 }
422}
423
424struct OnceBuffer {
425 accumulated: Option<wirevalue::Any>,
426 reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>,
427 expected: usize,
428 count: usize,
429 done: bool,
430}
431
432impl OnceBuffer {
433 fn new(reducer: Box<dyn ErasedCommReducer + Send + Sync + 'static>, expected: usize) -> Self {
434 Self {
435 accumulated: None,
436 reducer,
437 expected,
438 count: 0,
439 done: false,
440 }
441 }
442
443 fn push(
447 &mut self,
448 value: wirevalue::Any,
449 ) -> Result<Option<wirevalue::Any>, (wirevalue::Any, anyhow::Error)> {
450 self.count += 1;
451 self.accumulated = match self.accumulated.take() {
452 None => Some(value),
453 Some(acc) => match self.reducer.reduce_updates(vec![acc, value]) {
454 Ok(reduced) => Some(reduced),
455 Err((e, mut rejected)) => {
456 return Err((
457 rejected
458 .pop()
459 .unwrap_or_else(|| wirevalue::Any::serialize(&()).unwrap()),
460 e,
461 ));
462 }
463 },
464 };
465 if self.count >= self.expected {
466 self.done = true;
467 Ok(self.accumulated.take())
468 } else {
469 Ok(None)
470 }
471 }
472}