Skip to main content

monarch_hyperactor/code_sync/
auto_reload.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Arc;
10
11use anyhow::Result;
12use async_trait::async_trait;
13use hyperactor as reference;
14use hyperactor::Actor;
15use hyperactor::Context;
16use hyperactor::Endpoint as _;
17use hyperactor::Handler;
18use hyperactor::RemoteSpawn;
19use hyperactor_config::Flattrs;
20use monarch_types::SerializablePyErr;
21use pyo3::prelude::*;
22use serde::Deserialize;
23use serde::Serialize;
24use typeuri::Named;
25
26use crate::runtime::monarch_with_gil_blocking;
27
28/// Message to trigger module reloading
29#[derive(Debug, Clone, Named, Serialize, Deserialize)]
30pub struct AutoReloadMessage {
31    pub result: reference::PortRef<Result<(), String>>,
32}
33wirevalue::register_type!(AutoReloadMessage);
34
35/// Parameters for creating an AutoReloadActor
36#[derive(Debug, Clone, Named, Serialize, Deserialize)]
37pub struct AutoReloadParams {}
38wirevalue::register_type!(AutoReloadParams);
39
40/// Simple Rust Actor that wraps the Python AutoReloader class via pyo3
41#[derive(Debug)]
42#[hyperactor::export(handlers = [AutoReloadMessage])]
43#[hyperactor::spawnable]
44pub struct AutoReloadActor {
45    state: Result<(Arc<Py<PyAny>>, Py<PyAny>), SerializablePyErr>,
46}
47
48impl Actor for AutoReloadActor {}
49
50#[async_trait]
51impl RemoteSpawn for AutoReloadActor {
52    type Params = AutoReloadParams;
53
54    async fn new(Self::Params {}: Self::Params, _environment: Flattrs) -> Result<Self> {
55        AutoReloadActor::new().await
56    }
57}
58
59impl AutoReloadActor {
60    pub(crate) async fn new() -> Result<Self, anyhow::Error> {
61        Ok(Self {
62            state: tokio::task::spawn_blocking(move || {
63                monarch_with_gil_blocking(|py| {
64                    Self::create_state(py).map_err(SerializablePyErr::from_fn(py))
65                })
66            })
67            .await?,
68        })
69    }
70
71    fn create_state(py: Python) -> PyResult<(Arc<Py<PyAny>>, Py<PyAny>)> {
72        // Import the Python AutoReloader class
73        let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?;
74        let auto_reloader_class = auto_reload_module.getattr("AutoReloader")?;
75
76        let reloader = auto_reloader_class.call0()?;
77
78        // Install the audit import hook: SysAuditImportHook.install(reloader.import_callback)
79        let sys_audit_import_hook_class = auto_reload_module.getattr("SysAuditImportHook")?;
80        let import_callback = reloader.getattr("import_callback")?;
81        let hook_guard = sys_audit_import_hook_class.call_method1("install", (import_callback,))?;
82
83        Ok((Arc::new(reloader.into()), hook_guard.into()))
84    }
85
86    fn reload(py: Python, py_reloader: &Py<PyAny>) -> PyResult<()> {
87        let reloader = py_reloader.bind(py);
88        let changed_modules: Vec<String> = reloader.call_method0("reload_changes")?.extract()?;
89        if !changed_modules.is_empty() {
90            eprintln!("reloaded modules: {:?}", changed_modules);
91        }
92        Ok(())
93    }
94}
95
96#[async_trait]
97impl Handler<AutoReloadMessage> for AutoReloadActor {
98    async fn handle(
99        &mut self,
100        cx: &Context<Self>,
101        AutoReloadMessage { result }: AutoReloadMessage,
102    ) -> Result<()> {
103        // Call the Python reloader's reload_changes method
104        let res = async {
105            let py_reloader: Arc<_> = self.state.as_ref().map_err(Clone::clone)?.0.clone();
106            tokio::task::spawn_blocking(move || {
107                monarch_with_gil_blocking(|py| {
108                    Self::reload(py, py_reloader.as_ref()).map_err(SerializablePyErr::from_fn(py))
109                })
110            })
111            .await??;
112            anyhow::Ok(())
113        }
114        .await;
115        result.post(cx, res.map_err(|e| format!("{:#?}", e)));
116        Ok(())
117    }
118}