monarch_hyperactor/
local_state_broker.rs1use std::collections::HashMap;
10
11use async_trait::async_trait;
12use hyperactor::Actor;
13use hyperactor::ActorHandle;
14use hyperactor::Context;
15use hyperactor::Handler;
16use hyperactor::OncePortHandle;
17use hyperactor::reference;
18use pyo3::prelude::*;
19
20#[derive(Debug)]
21pub struct LocalState {
22 pub response_port: OncePortHandle<Result<Py<PyAny>, Py<PyAny>>>,
23 pub state: Vec<Py<PyAny>>,
24}
25
26#[derive(Debug)]
27pub enum LocalStateBrokerMessage {
28 Set(usize, LocalState),
29 Get(usize, OncePortHandle<LocalState>),
30}
31
32#[derive(Debug, Default)]
33#[hyperactor::export(spawn = true)]
34pub struct LocalStateBrokerActor {
35 states: HashMap<usize, LocalState>,
36 ports: HashMap<usize, OncePortHandle<LocalState>>,
37}
38
39impl Actor for LocalStateBrokerActor {}
40
41#[async_trait]
42impl Handler<LocalStateBrokerMessage> for LocalStateBrokerActor {
43 async fn handle(
44 &mut self,
45 cx: &Context<Self>,
46 message: LocalStateBrokerMessage,
47 ) -> anyhow::Result<()> {
48 match message {
49 LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) {
50 Some((_, port)) => {
51 port.send(cx, state)?;
52 }
53 None => {
54 self.states.insert(id, state);
55 }
56 },
57 LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) {
58 Some((_, state)) => {
59 port.send(cx, state)?;
60 }
61 None => {
62 self.ports.insert(id, port);
63 }
64 },
65 }
66 Ok(())
67 }
68}
69
70#[derive(Debug, Clone)]
71pub struct BrokerId(String, usize);
72
73impl BrokerId {
74 pub fn new(broker_id: (String, usize)) -> Self {
75 BrokerId(broker_id.0, broker_id.1)
76 }
77
78 pub async fn resolve<A: Actor>(
84 self,
85 cx: &Context<'_, A>,
86 ) -> ActorHandle<LocalStateBrokerActor> {
87 use std::time::Duration;
88
89 let broker_name = format!("{:?}", self);
90 let actor_id = reference::ActorId::new(cx.proc().proc_id().clone(), self.0.clone(), self.1);
91 let actor_ref: reference::ActorRef<LocalStateBrokerActor> =
92 reference::ActorRef::attest(actor_id);
93
94 let mut delay_ms = 1;
95 loop {
96 if let Some(handle) = actor_ref.downcast_handle(cx) {
97 return handle;
98 }
99
100 if delay_ms > 8192 {
101 panic!("Failed to resolve broker {} after retries", broker_name);
102 }
103
104 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
105 delay_ms *= 2;
106 }
107 }
108}