Skip to main content

monarch_hyperactor/
local_state_broker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use async_trait::async_trait;
12use hyperactor::Actor;
13use hyperactor::ActorHandle;
14use hyperactor::Context;
15use hyperactor::Endpoint as _;
16use hyperactor::Handler;
17use hyperactor::OncePortHandle;
18use pyo3::prelude::*;
19
20#[derive(Debug)]
21pub struct LocalState {
22    pub response_port: OncePortHandle<Result<Py<PyAny>, Py<PyAny>>>,
23    pub state: Vec<Py<PyAny>>,
24}
25
26#[derive(Debug)]
27pub enum LocalStateBrokerMessage {
28    Set(usize, LocalState),
29    Get(usize, OncePortHandle<LocalState>),
30}
31
32#[derive(Debug, Default)]
33#[hyperactor::export]
34#[hyperactor::spawnable]
35pub struct LocalStateBrokerActor {
36    states: HashMap<usize, LocalState>,
37    ports: HashMap<usize, OncePortHandle<LocalState>>,
38}
39
40impl Actor for LocalStateBrokerActor {}
41
42#[async_trait]
43impl Handler<LocalStateBrokerMessage> for LocalStateBrokerActor {
44    async fn handle(
45        &mut self,
46        cx: &Context<Self>,
47        message: LocalStateBrokerMessage,
48    ) -> anyhow::Result<()> {
49        match message {
50            LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) {
51                Some((_, port)) => {
52                    port.post(cx, state);
53                }
54                None => {
55                    self.states.insert(id, state);
56                }
57            },
58            LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) {
59                Some((_, state)) => {
60                    port.post(cx, state);
61                }
62                None => {
63                    self.ports.insert(id, port);
64                }
65            },
66        }
67        Ok(())
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct BrokerId(String);
73
74impl BrokerId {
75    pub fn new(broker_id: (String, usize)) -> Self {
76        BrokerId(broker_id.0)
77    }
78
79    /// Resolve the broker with exponential backoff retry.
80    /// Broker creation can race with messages that will use the broker,
81    /// so we retry with exponential backoff before panicking.
82    /// A better solution would be to figure out some way to get the real broker reference threaded to the client,  but
83    /// that is more difficult to figure out right now.
84    pub async fn resolve<A: Actor>(
85        self,
86        cx: &Context<'_, A>,
87    ) -> ActorHandle<LocalStateBrokerActor> {
88        use std::time::Duration;
89
90        let broker_name = format!("{:?}", self);
91        let actor_id = cx.proc().proc_addr().actor_addr(&self.0);
92        let actor_ref: hyperactor::ActorRef<LocalStateBrokerActor> =
93            hyperactor::ActorRef::attest(actor_id);
94
95        let mut delay_ms = 1;
96        loop {
97            if let Some(handle) = actor_ref.downcast_handle(cx) {
98                return handle;
99            }
100
101            if delay_ms > 8192 {
102                panic!("Failed to resolve broker {} after retries", broker_name);
103            }
104
105            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
106            delay_ms *= 2;
107        }
108    }
109}