monarch_hyperactor/
local_state_broker.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use async_trait::async_trait;
12use hyperactor::Actor;
13use hyperactor::ActorHandle;
14use hyperactor::Context;
15use hyperactor::Handler;
16use hyperactor::OncePortHandle;
17use hyperactor::reference;
18use pyo3::prelude::*;
19
20#[derive(Debug)]
21pub struct LocalState {
22    pub response_port: OncePortHandle<Result<Py<PyAny>, Py<PyAny>>>,
23    pub state: Vec<Py<PyAny>>,
24}
25
26#[derive(Debug)]
27pub enum LocalStateBrokerMessage {
28    Set(usize, LocalState),
29    Get(usize, OncePortHandle<LocalState>),
30}
31
32#[derive(Debug, Default)]
33#[hyperactor::export(spawn = true)]
34pub struct LocalStateBrokerActor {
35    states: HashMap<usize, LocalState>,
36    ports: HashMap<usize, OncePortHandle<LocalState>>,
37}
38
39impl Actor for LocalStateBrokerActor {}
40
41#[async_trait]
42impl Handler<LocalStateBrokerMessage> for LocalStateBrokerActor {
43    async fn handle(
44        &mut self,
45        cx: &Context<Self>,
46        message: LocalStateBrokerMessage,
47    ) -> anyhow::Result<()> {
48        match message {
49            LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) {
50                Some((_, port)) => {
51                    port.send(cx, state)?;
52                }
53                None => {
54                    self.states.insert(id, state);
55                }
56            },
57            LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) {
58                Some((_, state)) => {
59                    port.send(cx, state)?;
60                }
61                None => {
62                    self.ports.insert(id, port);
63                }
64            },
65        }
66        Ok(())
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct BrokerId(String, usize);
72
73impl BrokerId {
74    pub fn new(broker_id: (String, usize)) -> Self {
75        BrokerId(broker_id.0, broker_id.1)
76    }
77
78    /// Resolve the broker with exponential backoff retry.
79    /// Broker creation can race with messages that will use the broker,
80    /// so we retry with exponential backoff before panicking.
81    /// A better solution would be to figure out some way to get the real broker reference threaded to the client,  but
82    /// that is more difficult to figure out right now.
83    pub async fn resolve<A: Actor>(
84        self,
85        cx: &Context<'_, A>,
86    ) -> ActorHandle<LocalStateBrokerActor> {
87        use std::time::Duration;
88
89        let broker_name = format!("{:?}", self);
90        let actor_id = reference::ActorId::new(cx.proc().proc_id().clone(), self.0.clone(), self.1);
91        let actor_ref: reference::ActorRef<LocalStateBrokerActor> =
92            reference::ActorRef::attest(actor_id);
93
94        let mut delay_ms = 1;
95        loop {
96            if let Some(handle) = actor_ref.downcast_handle(cx) {
97                return handle;
98            }
99
100            if delay_ms > 8192 {
101                panic!("Failed to resolve broker {} after retries", broker_name);
102            }
103
104            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
105            delay_ms *= 2;
106        }
107    }
108}