monarch_hyperactor/
local_state_broker.rs1use std::collections::HashMap;
10
11use async_trait::async_trait;
12use hyperactor::Actor;
13use hyperactor::ActorHandle;
14use hyperactor::Context;
15use hyperactor::Endpoint as _;
16use hyperactor::Handler;
17use hyperactor::OncePortHandle;
18use pyo3::prelude::*;
19
20#[derive(Debug)]
21pub struct LocalState {
22 pub response_port: OncePortHandle<Result<Py<PyAny>, Py<PyAny>>>,
23 pub state: Vec<Py<PyAny>>,
24}
25
26#[derive(Debug)]
27pub enum LocalStateBrokerMessage {
28 Set(usize, LocalState),
29 Get(usize, OncePortHandle<LocalState>),
30}
31
32#[derive(Debug, Default)]
33#[hyperactor::export]
34#[hyperactor::spawnable]
35pub struct LocalStateBrokerActor {
36 states: HashMap<usize, LocalState>,
37 ports: HashMap<usize, OncePortHandle<LocalState>>,
38}
39
40impl Actor for LocalStateBrokerActor {}
41
42#[async_trait]
43impl Handler<LocalStateBrokerMessage> for LocalStateBrokerActor {
44 async fn handle(
45 &mut self,
46 cx: &Context<Self>,
47 message: LocalStateBrokerMessage,
48 ) -> anyhow::Result<()> {
49 match message {
50 LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) {
51 Some((_, port)) => {
52 port.post(cx, state);
53 }
54 None => {
55 self.states.insert(id, state);
56 }
57 },
58 LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) {
59 Some((_, state)) => {
60 port.post(cx, state);
61 }
62 None => {
63 self.ports.insert(id, port);
64 }
65 },
66 }
67 Ok(())
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct BrokerId(String);
73
74impl BrokerId {
75 pub fn new(broker_id: (String, usize)) -> Self {
76 BrokerId(broker_id.0)
77 }
78
79 pub async fn resolve<A: Actor>(
85 self,
86 cx: &Context<'_, A>,
87 ) -> ActorHandle<LocalStateBrokerActor> {
88 use std::time::Duration;
89
90 let broker_name = format!("{:?}", self);
91 let actor_id = cx.proc().proc_addr().actor_addr(&self.0);
92 let actor_ref: hyperactor::ActorRef<LocalStateBrokerActor> =
93 hyperactor::ActorRef::attest(actor_id);
94
95 let mut delay_ms = 1;
96 loop {
97 if let Some(handle) = actor_ref.downcast_handle(cx) {
98 return handle;
99 }
100
101 if delay_ms > 8192 {
102 panic!("Failed to resolve broker {} after retries", broker_name);
103 }
104
105 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
106 delay_ms *= 2;
107 }
108 }
109}