Skip to main content

monarch_hyperactor/code_sync/
manager.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::path::PathBuf;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::ensure;
16use async_once_cell::OnceCell;
17use async_trait::async_trait;
18use futures::FutureExt;
19use futures::StreamExt;
20use futures::TryFutureExt;
21use futures::TryStreamExt;
22use futures::try_join;
23use hyperactor as reference;
24use hyperactor::Actor;
25use hyperactor::ActorHandle;
26use hyperactor::Bind;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::RemoteSpawn;
31use hyperactor::Unbind;
32use hyperactor::context;
33use hyperactor::handle;
34use hyperactor_config::Flattrs;
35use hyperactor_mesh::connect::Connect;
36use hyperactor_mesh::connect::accept;
37use lazy_errors::ErrorStash;
38use lazy_errors::TryCollectOrStash;
39use monarch_conda::sync::sender;
40use ndslice::Shape;
41use ndslice::ShapeError;
42use ndslice::view::Ranked;
43use ndslice::view::RankedSliceable;
44use ndslice::view::ViewExt;
45use serde::Deserialize;
46use serde::Serialize;
47use tokio::io::AsyncReadExt;
48use tokio::io::AsyncWriteExt;
49use tokio::net::TcpListener;
50use tokio::net::TcpStream;
51use typeuri::Named;
52
53use crate::code_sync::WorkspaceLocation;
54use crate::code_sync::auto_reload::AutoReloadActor;
55use crate::code_sync::auto_reload::AutoReloadMessage;
56use crate::code_sync::conda_sync::CondaSyncActor;
57use crate::code_sync::conda_sync::CondaSyncMessage;
58use crate::code_sync::conda_sync::CondaSyncResult;
59use crate::code_sync::rsync::RsyncActor;
60use crate::code_sync::rsync::RsyncDaemon;
61use crate::code_sync::rsync::RsyncMessage;
62use crate::code_sync::rsync::RsyncResult;
63
64#[derive(Clone, Serialize, Deserialize, Debug)]
65pub enum Method {
66    Rsync {
67        connect: reference::PortRef<Connect>,
68    },
69    CondaSync {
70        connect: reference::PortRef<Connect>,
71        path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
72    },
73}
74
75/// Describe the shape of the workspace.
76#[derive(Clone, Serialize, Deserialize, Debug)]
77pub struct WorkspaceShape {
78    /// All actors accessing the workspace.
79    pub shape: Shape,
80    /// Starting dimension in the shape denoting all ranks that share the same workspace.
81    pub dimension: Option<String>,
82}
83
84impl WorkspaceShape {
85    /// Reduce the shape to contain only the "owners" of the remote workspace
86    ///
87    /// This is relevant when e.g. multiple worker on the same host share a workspace, in which case,
88    /// we'll reduce the share to only contain one worker per workspace, so that we don't have multiple
89    /// workers trying to sync to the same workspace at the same time.
90    pub fn owners(&self) -> Result<Shape, ShapeError> {
91        let mut new_shape = self.shape.clone();
92        for label in self
93            .shape
94            .labels()
95            .iter()
96            .skip_while(|l| Some(*l) != self.dimension.as_ref())
97        {
98            new_shape = new_shape.select(label, 0)?;
99            //new_shape = new_shape.slice(label, 0..1)?;
100        }
101        Ok(new_shape)
102    }
103
104    /// Return a new shape that contains all ranks that share the same workspace with the given "owning" rank.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the given rank's coordinates aren't all zero starting at the specified dimension
109    /// and continuing until the end. For example, if the dimension is Some("host"), then the coordinates
110    /// for the rank must be 0 in the host dimension and all subsequent dimensions.
111    pub fn downstream(&self, rank: usize) -> Result<Shape> {
112        let coords = self.shape.coordinates(rank)?;
113
114        for (label, value) in coords
115            .iter()
116            .skip_while(|(l, _)| Some(l) != self.dimension.as_ref())
117        {
118            ensure!(
119                *value == 0,
120                "Coordinate for dimension '{}' must be 0 for rank {}",
121                label,
122                rank
123            );
124        }
125
126        Ok(self.shape.index(
127            coords
128                .into_iter()
129                .take_while(|(l, _)| Some(l) != self.dimension.as_ref())
130                .collect::<Vec<_>>(),
131        )?)
132    }
133
134    fn downstream_mesh(
135        &self,
136        mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
137        rank: usize,
138    ) -> Result<hyperactor_mesh::ActorMeshRef<CodeSyncManager>> {
139        let shape = self.downstream(rank)?;
140        Ok(mesh.sliced(shape.region()))
141    }
142}
143
144#[derive(Clone, Serialize, Deserialize, Debug)]
145pub struct WorkspaceConfig {
146    pub location: WorkspaceLocation,
147    pub shape: WorkspaceShape,
148}
149
150#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
151#[expect(
152    clippy::large_enum_variant,
153    reason = "actor message enum with Handler/Bind/Unbind derives; boxing fields ripples into handler call sites and may require derive-macro changes — separate diff"
154)]
155pub enum CodeSyncMessage {
156    Sync {
157        workspace: WorkspaceLocation,
158        /// The method to use for syncing.
159        method: Method,
160        /// Whether to hot-reload code after syncing.
161        reload: Option<WorkspaceShape>,
162        /// A port to send back the result of the sync operation.
163        result: reference::PortRef<Result<(), String>>,
164    },
165    Reload {
166        sender_rank: Option<usize>,
167        result: reference::PortRef<Result<(), String>>,
168    },
169}
170wirevalue::register_type!(CodeSyncMessage);
171
172#[derive(Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
173pub struct SetActorMeshMessage {
174    pub actor_mesh: hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
175}
176wirevalue::register_type!(SetActorMeshMessage);
177
178#[derive(Debug, Named, Serialize, Deserialize)]
179pub struct CodeSyncManagerParams {}
180wirevalue::register_type!(CodeSyncManagerParams);
181
182#[derive(Debug)]
183#[hyperactor::export(
184    handlers = [
185        CodeSyncMessage { cast = true },
186        SetActorMeshMessage { cast = true }
187    ],
188)]
189#[hyperactor::spawnable]
190pub struct CodeSyncManager {
191    rsync: OnceCell<ActorHandle<RsyncActor>>,
192    auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
193    conda_sync: OnceCell<ActorHandle<CondaSyncActor>>,
194    self_mesh: once_cell::sync::OnceCell<hyperactor_mesh::ActorMeshRef<CodeSyncManager>>,
195    rank: once_cell::sync::OnceCell<usize>,
196}
197
198impl Actor for CodeSyncManager {}
199
200#[async_trait]
201impl RemoteSpawn for CodeSyncManager {
202    type Params = CodeSyncManagerParams;
203
204    async fn new(CodeSyncManagerParams {}: Self::Params, _environment: Flattrs) -> Result<Self> {
205        Ok(Self {
206            rsync: OnceCell::new(),
207            auto_reload: OnceCell::new(),
208            conda_sync: OnceCell::new(),
209            self_mesh: once_cell::sync::OnceCell::new(),
210            rank: once_cell::sync::OnceCell::new(),
211        })
212    }
213}
214
215impl CodeSyncManager {
216    async fn get_rsync_actor<'a>(
217        &'a mut self,
218        cx: &Context<'a, Self>,
219    ) -> Result<&'a ActorHandle<RsyncActor>> {
220        self.rsync
221            .get_or_try_init(async move { RsyncActor::default().spawn(cx) })
222            .await
223    }
224
225    async fn get_auto_reload_actor<'a>(
226        &'a mut self,
227        cx: &Context<'a, Self>,
228    ) -> Result<&'a ActorHandle<AutoReloadActor>> {
229        self.auto_reload
230            .get_or_try_init(async move { AutoReloadActor::new().await?.spawn(cx) })
231            .await
232    }
233
234    async fn get_conda_sync_actor<'a>(
235        &'a mut self,
236        cx: &Context<'a, Self>,
237    ) -> Result<&'a ActorHandle<CondaSyncActor>> {
238        self.conda_sync
239            .get_or_try_init(async move { CondaSyncActor::default().spawn(cx) })
240            .await
241    }
242}
243
244#[async_trait]
245#[handle(CodeSyncMessage)]
246impl CodeSyncMessageHandler for CodeSyncManager {
247    async fn sync(
248        &mut self,
249        cx: &Context<Self>,
250        workspace: WorkspaceLocation,
251        method: Method,
252        reload: Option<WorkspaceShape>,
253        result: reference::PortRef<Result<(), String>>,
254    ) -> Result<()> {
255        let res = async move {
256            match method {
257                Method::Rsync { connect } => {
258                    // Forward rsync connection port to the RsyncActor, which will do the actual
259                    // connection and run the client.
260                    let (tx, mut rx) = cx.open_port::<Result<RsyncResult, String>>();
261                    self.get_rsync_actor(cx).await?.post(
262                        cx,
263                        RsyncMessage {
264                            connect,
265                            result: tx.bind(),
266                            workspace,
267                        },
268                    );
269                    // Observe any errors.
270                    let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
271                }
272                Method::CondaSync {
273                    connect,
274                    path_prefix_replacements,
275                } => {
276                    // Forward rsync connection port to the RsyncActor, which will do the actual
277                    // connection and run the client.
278                    let (tx, mut rx) = cx.open_port::<Result<CondaSyncResult, String>>();
279                    self.get_conda_sync_actor(cx).await?.post(
280                        cx,
281                        CondaSyncMessage {
282                            connect,
283                            result: tx.bind(),
284                            workspace,
285                            path_prefix_replacements,
286                        },
287                    );
288                    // Observe any errors.
289                    let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
290                }
291            }
292
293            // Trigger hot reload on all ranks that use/share this workspace.
294            if let Some(workspace_shape) = reload {
295                let (tx, rx) = cx.open_port::<Result<(), String>>();
296                let tx = tx.bind();
297                let rank = self
298                    .rank
299                    .get()
300                    .ok_or_else(|| anyhow::anyhow!("missing rank"))?;
301                let mesh = self
302                    .self_mesh
303                    .get()
304                    .ok_or_else(|| anyhow::anyhow!("missing self mesh"))?;
305                let mesh = workspace_shape.downstream_mesh(mesh, *rank)?;
306                mesh.cast(
307                    cx,
308                    CodeSyncMessage::Reload {
309                        sender_rank: Some(*rank),
310                        result: tx.clone(),
311                    },
312                )?;
313                // Exclude self from the sync.
314                let len = Ranked::region(&mesh).num_ranks() - 1;
315                let _: ((), Vec<()>) = try_join!(
316                    // Run reload for this rank.
317                    self.reload(cx, self.rank.get().cloned(), tx),
318                    rx.take(len)
319                        .map(|res| res?.map_err(anyhow::Error::msg))
320                        .try_collect(),
321                )?;
322            }
323
324            anyhow::Ok(())
325        }
326        .await;
327        result.post(
328            cx,
329            res.map_err(|e| {
330                format!(
331                    "{:#?}",
332                    Err::<(), _>(e)
333                        .with_context(|| format!("code sync from {}", cx.self_addr()))
334                        .unwrap_err()
335                )
336            }),
337        );
338        Ok(())
339    }
340
341    async fn reload(
342        &mut self,
343        cx: &Context<Self>,
344        sender_rank: Option<usize>,
345        result: reference::PortRef<Result<(), String>>,
346    ) -> Result<()> {
347        if self
348            .rank
349            .get()
350            .is_some_and(|rank| sender_rank.is_some_and(|sender_rank| *rank == sender_rank))
351        {
352            return Ok(());
353        }
354        let res = async move {
355            let (tx, mut rx) = cx.open_port::<Result<(), String>>();
356            self.get_auto_reload_actor(cx)
357                .await?
358                .post(cx, AutoReloadMessage { result: tx.bind() });
359            rx.recv().await?.map_err(anyhow::Error::msg)?;
360            anyhow::Ok(())
361        }
362        .await;
363        result.post(
364            cx,
365            res.map_err(|e| {
366                format!(
367                    "{:#?}",
368                    Err::<(), _>(e)
369                        .with_context(|| format!("module reload from {}", cx.self_addr()))
370                        .unwrap_err()
371                )
372            }),
373        );
374        Ok(())
375    }
376}
377
378#[async_trait]
379impl Handler<SetActorMeshMessage> for CodeSyncManager {
380    async fn handle(&mut self, cx: &Context<Self>, msg: SetActorMeshMessage) -> Result<()> {
381        let mesh = self.self_mesh.get_or_init(|| msg.actor_mesh);
382        self.rank.get_or_init(|| {
383            mesh.iter()
384                .find(|(_, actor)| *actor.actor_addr() == *cx.self_addr())
385                .unwrap()
386                .0
387                .rank()
388        });
389        Ok(())
390    }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub enum CodeSyncMethod {
395    Rsync,
396    CondaSync {
397        path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
398    },
399}
400
401pub async fn code_sync_mesh(
402    cx: &impl context::Actor,
403    actor_mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
404    local_workspace: PathBuf,
405    remote_workspace: WorkspaceConfig,
406    method: CodeSyncMethod,
407    auto_reload: bool,
408) -> Result<()> {
409    let instance = cx.instance();
410
411    // Create a slice of the actor mesh that only includes workspace "owners" (e.g. on multi-GPU hosts,
412    // only one of the ranks on that host will participate in the code sync).
413    let owner_shape = remote_workspace.shape.owners()?;
414    let actor_mesh = actor_mesh.sliced(owner_shape.region());
415    let num_ranks = Ranked::region(&actor_mesh).num_ranks();
416
417    let (method, method_fut) = match method {
418        CodeSyncMethod::Rsync => {
419            // Spawn a rsync daemon to accept incoming connections from actors.
420            // some machines (e.g. github CI) do not have ipv6, so try ipv6 then fallback to ipv4
421            let ipv6_lo: SocketAddr = "[::1]:0".parse()?;
422            let ipv4_lo: SocketAddr = "127.0.0.1:0".parse()?;
423            let addrs: [SocketAddr; 2] = [ipv6_lo, ipv4_lo];
424            let daemon =
425                RsyncDaemon::spawn(TcpListener::bind(&addrs[..]).await?, &local_workspace).await?;
426
427            let daemon_addr = *daemon.addr();
428            let (rsync_conns_tx, rsync_conns_rx) = instance.open_port::<Connect>();
429            (
430                Method::Rsync {
431                    connect: rsync_conns_tx.bind(),
432                },
433                // This async task will process rsync connection attempts concurrently, forwarding
434                // them to the rsync daemon above.
435                async move {
436                    let res = rsync_conns_rx
437                        .take(num_ranks)
438                        .err_into::<anyhow::Error>()
439                        .try_for_each_concurrent(None, |connect| async move {
440                            let (mut local, mut stream) = try_join!(
441                                TcpStream::connect(daemon_addr).err_into(),
442                                accept(instance, instance.self_addr().clone(), connect),
443                            )?;
444                            tokio::io::copy_bidirectional(&mut local, &mut stream).await?;
445                            Ok(())
446                        })
447                        .await;
448                    daemon.shutdown().await?;
449                    res?;
450                    anyhow::Ok(())
451                }
452                .boxed(),
453            )
454        }
455        CodeSyncMethod::CondaSync {
456            path_prefix_replacements,
457        } => {
458            let (conns_tx, conns_rx) = instance.open_port::<Connect>();
459            (
460                Method::CondaSync {
461                    connect: conns_tx.bind(),
462                    path_prefix_replacements,
463                },
464                async move {
465                    conns_rx
466                        .take(num_ranks)
467                        .err_into::<anyhow::Error>()
468                        .try_for_each_concurrent(None, |connect| async {
469                            let (mut read, mut write) =
470                                accept(instance, instance.self_addr().clone(), connect)
471                                    .await?
472                                    .into_split();
473                            let res = sender(&local_workspace, &mut read, &mut write).await;
474
475                            // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
476                            // message spam.
477                            write.shutdown().await?;
478                            let mut buf = vec![];
479                            read.read_to_end(&mut buf).await?;
480
481                            res
482                        })
483                        .await
484                }
485                .boxed(),
486            )
487        }
488    };
489
490    let ((), ()) = try_join!(
491        method_fut,
492        // This async task will cast the code sync message to workspace owners, and process any errors.
493        async move {
494            let (result_tx, result_rx) = instance.open_port::<Result<(), String>>();
495            actor_mesh.cast(
496                instance,
497                CodeSyncMessage::Sync {
498                    method,
499                    workspace: remote_workspace.location.clone(),
500                    reload: if auto_reload {
501                        Some(remote_workspace.shape)
502                    } else {
503                        None
504                    },
505                    result: result_tx.bind(),
506                },
507            )?;
508
509            // Wait for all actors to report result.
510            let results = result_rx.take(num_ranks).try_collect::<Vec<_>>().await?;
511
512            // Combine all errors into one.
513            let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
514            results
515                .into_iter()
516                .map(|res| res.map_err(anyhow::Error::msg))
517                .try_collect_or_stash::<()>(&mut errs);
518            Ok(errs.into_result()?)
519        },
520    )?;
521
522    Ok(())
523}
524
525#[cfg(test)]
526mod tests {
527    use anyhow::anyhow;
528    use hyperactor_mesh::context;
529    use hyperactor_mesh::test_utils;
530    use ndslice::shape;
531    use tempfile::TempDir;
532    use tokio::fs;
533
534    use super::*;
535
536    #[test]
537    fn test_workspace_shape_owners() {
538        // Create a shape with multiple dimensions
539        let shape = shape! { host = 2, replica = 3 };
540
541        // Test case 1: dimension is None (should return the original shape)
542        let ws_shape = WorkspaceShape {
543            shape: shape.clone(),
544            dimension: None,
545        };
546        let owners = ws_shape.owners().unwrap();
547        assert_eq!(owners.slice().len(), 6); // 2 hosts * 3 replicas = 6 ranks
548
549        // Test case 2: dimension is "host" (should return a shape with only one rank per host)
550        let ws_shape = WorkspaceShape {
551            shape: shape.clone(),
552            dimension: Some("host".to_string()),
553        };
554        let owners = ws_shape.owners().unwrap();
555        assert_eq!(owners.slice().len(), 1); // 2 hosts, 1 rank per host
556
557        // Test case 3: dimension is "replica" (should return a shape with only one rank per replica)
558        let ws_shape = WorkspaceShape {
559            shape: shape.clone(),
560            dimension: Some("replica".to_string()),
561        };
562        let owners = ws_shape.owners().unwrap();
563        assert_eq!(owners.slice().len(), 2); // 3 replicas, 1 rank per replica
564    }
565
566    #[test]
567    fn test_workspace_shape_downstream() -> Result<()> {
568        // Create a shape with multiple dimensions
569        let shape = shape! { host = 2, replica = 3 };
570
571        // Test case 1: dimension is None (should return a shape with just the specified rank)
572        let ws_shape = WorkspaceShape {
573            shape: shape.clone(),
574            dimension: None,
575        };
576        let downstream = ws_shape.downstream(0)?;
577        assert_eq!(downstream.slice().len(), 1); // Just rank 0
578
579        // Test case 2: dimension is "host" (should return a shape with all ranks on the same host)
580        let ws_shape = WorkspaceShape {
581            shape: shape.clone(),
582            dimension: Some("host".to_string()),
583        };
584        let downstream = ws_shape.downstream(0)?;
585        assert_eq!(downstream.slice().len(), 6); // All ranks in the shape
586        assert!(ws_shape.downstream(3).is_err());
587
588        // Test case 3: dimension is "e (should return a shape with all ranks on the same host)
589        let ws_shape = WorkspaceShape {
590            shape: shape.clone(),
591            dimension: Some("replica".to_string()),
592        };
593        let downstream = ws_shape.downstream(0)?;
594        assert_eq!(downstream.slice().len(), 3);
595        let downstream = ws_shape.downstream(3)?;
596        assert_eq!(downstream.slice().len(), 3);
597
598        Ok(())
599    }
600
601    #[cfg_attr(not(target_os = "linux"), ignore = "linux-only")]
602    #[tokio::test]
603    async fn test_code_sync_manager_and_mesh() -> Result<()> {
604        // Create source workspace with test files
605        let source_workspace = TempDir::new()?;
606        fs::write(source_workspace.path().join("test1.txt"), "content1").await?;
607        fs::write(source_workspace.path().join("test2.txt"), "content2").await?;
608        fs::create_dir(source_workspace.path().join("subdir")).await?;
609        fs::write(source_workspace.path().join("subdir/test3.txt"), "content3").await?;
610
611        // Create target workspace for the actors
612        let target_workspace = TempDir::new()?;
613        fs::create_dir(target_workspace.path().join("subdir5")).await?;
614        fs::write(target_workspace.path().join("foo.txt"), "something").await?;
615
616        // TODO: thread through context, or access the actual python context;
617        // for now this is basically equivalent (arguably better) to using the proc mesh client.
618        let cx = context().await;
619        let instance = cx.actor_instance;
620        // Set up actor mesh with CodeSyncManager actors
621        let mut host_mesh = test_utils::local_host_mesh(2).await;
622        let proc_mesh = host_mesh
623            .spawn(
624                instance,
625                "code_sync_test",
626                ndslice::Extent::unity(),
627                None,
628                None,
629            )
630            .await
631            .unwrap();
632
633        // Create CodeSyncManagerParams
634        let params = CodeSyncManagerParams {};
635
636        // Spawn actor mesh with CodeSyncManager actors
637        let actor_mesh = proc_mesh
638            .spawn_service(&instance, "code_sync_test", &params)
639            .await?;
640
641        // Set up the mesh reference on each actor
642        actor_mesh.cast(
643            &instance,
644            SetActorMeshMessage {
645                actor_mesh: (*actor_mesh).clone(),
646            },
647        )?;
648
649        // Create workspace configuration
650        let remote_workspace_config = WorkspaceConfig {
651            location: WorkspaceLocation::Constant(target_workspace.path().to_path_buf()),
652            shape: WorkspaceShape {
653                shape: shape! { replica = 2 },
654                dimension: Some("replica".to_string()),
655            },
656        };
657
658        // Test code_sync_mesh function - this coordinates sync operations across the mesh
659        // Test without auto-reload first
660        code_sync_mesh(
661            instance,
662            &actor_mesh,
663            source_workspace.path().to_path_buf(),
664            remote_workspace_config.clone(),
665            CodeSyncMethod::Rsync,
666            false, // no auto-reload
667        )
668        .await?;
669
670        // Verify that files were synchronized correctly
671        assert!(
672            !dir_diff::is_different(&source_workspace, &target_workspace)
673                .map_err(|e| anyhow!("{:?}", e))?,
674            "Source and target workspaces should be identical after sync"
675        );
676
677        let _ = host_mesh.shutdown(instance).await;
678        Ok(())
679    }
680}