monarch_hyperactor/code_sync/
manager.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::path::PathBuf;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::ensure;
16use async_once_cell::OnceCell;
17use async_trait::async_trait;
18use futures::FutureExt;
19use futures::StreamExt;
20use futures::TryFutureExt;
21use futures::TryStreamExt;
22use futures::try_join;
23use hyperactor::Actor;
24use hyperactor::ActorHandle;
25use hyperactor::Bind;
26use hyperactor::Context;
27use hyperactor::Handler;
28use hyperactor::RemoteSpawn;
29use hyperactor::Unbind;
30use hyperactor::context;
31use hyperactor::handle;
32use hyperactor::reference;
33use hyperactor_config::Flattrs;
34use hyperactor_mesh::connect::Connect;
35use hyperactor_mesh::connect::accept;
36use lazy_errors::ErrorStash;
37use lazy_errors::TryCollectOrStash;
38use monarch_conda::sync::sender;
39use ndslice::Shape;
40use ndslice::ShapeError;
41use ndslice::view::Ranked;
42use ndslice::view::RankedSliceable;
43use ndslice::view::ViewExt;
44use serde::Deserialize;
45use serde::Serialize;
46use tokio::io::AsyncReadExt;
47use tokio::io::AsyncWriteExt;
48use tokio::net::TcpListener;
49use tokio::net::TcpStream;
50use typeuri::Named;
51
52use crate::code_sync::WorkspaceLocation;
53use crate::code_sync::auto_reload::AutoReloadActor;
54use crate::code_sync::auto_reload::AutoReloadMessage;
55use crate::code_sync::conda_sync::CondaSyncActor;
56use crate::code_sync::conda_sync::CondaSyncMessage;
57use crate::code_sync::conda_sync::CondaSyncResult;
58use crate::code_sync::rsync::RsyncActor;
59use crate::code_sync::rsync::RsyncDaemon;
60use crate::code_sync::rsync::RsyncMessage;
61use crate::code_sync::rsync::RsyncResult;
62
63#[derive(Clone, Serialize, Deserialize, Debug)]
64pub enum Method {
65    Rsync {
66        connect: reference::PortRef<Connect>,
67    },
68    CondaSync {
69        connect: reference::PortRef<Connect>,
70        path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
71    },
72}
73
74/// Describe the shape of the workspace.
75#[derive(Clone, Serialize, Deserialize, Debug)]
76pub struct WorkspaceShape {
77    /// All actors accessing the workspace.
78    pub shape: Shape,
79    /// Starting dimension in the shape denoting all ranks that share the same workspace.
80    pub dimension: Option<String>,
81}
82
83impl WorkspaceShape {
84    /// Reduce the shape to contain only the "owners" of the remote workspace
85    ///
86    /// This is relevant when e.g. multiple worker on the same host share a workspace, in which case,
87    /// we'll reduce the share to only contain one worker per workspace, so that we don't have multiple
88    /// workers trying to sync to the same workspace at the same time.
89    pub fn owners(&self) -> Result<Shape, ShapeError> {
90        let mut new_shape = self.shape.clone();
91        for label in self
92            .shape
93            .labels()
94            .iter()
95            .skip_while(|l| Some(*l) != self.dimension.as_ref())
96        {
97            new_shape = new_shape.select(label, 0)?;
98            //new_shape = new_shape.slice(label, 0..1)?;
99        }
100        Ok(new_shape)
101    }
102
103    /// Return a new shape that contains all ranks that share the same workspace with the given "owning" rank.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the given rank's coordinates aren't all zero starting at the specified dimension
108    /// and continuing until the end. For example, if the dimension is Some("host"), then the coordinates
109    /// for the rank must be 0 in the host dimension and all subsequent dimensions.
110    pub fn downstream(&self, rank: usize) -> Result<Shape> {
111        let coords = self.shape.coordinates(rank)?;
112
113        for (label, value) in coords
114            .iter()
115            .skip_while(|(l, _)| Some(l) != self.dimension.as_ref())
116        {
117            ensure!(
118                *value == 0,
119                "Coordinate for dimension '{}' must be 0 for rank {}",
120                label,
121                rank
122            );
123        }
124
125        Ok(self.shape.index(
126            coords
127                .into_iter()
128                .take_while(|(l, _)| Some(l) != self.dimension.as_ref())
129                .collect::<Vec<_>>(),
130        )?)
131    }
132
133    fn downstream_mesh(
134        &self,
135        mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
136        rank: usize,
137    ) -> Result<hyperactor_mesh::ActorMeshRef<CodeSyncManager>> {
138        let shape = self.downstream(rank)?;
139        Ok(mesh.sliced(shape.region()))
140    }
141}
142
143#[derive(Clone, Serialize, Deserialize, Debug)]
144pub struct WorkspaceConfig {
145    pub location: WorkspaceLocation,
146    pub shape: WorkspaceShape,
147}
148
149#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
150pub enum CodeSyncMessage {
151    Sync {
152        workspace: WorkspaceLocation,
153        /// The method to use for syncing.
154        method: Method,
155        /// Whether to hot-reload code after syncing.
156        reload: Option<WorkspaceShape>,
157        /// A port to send back the result of the sync operation.
158        result: reference::PortRef<Result<(), String>>,
159    },
160    Reload {
161        sender_rank: Option<usize>,
162        result: reference::PortRef<Result<(), String>>,
163    },
164}
165wirevalue::register_type!(CodeSyncMessage);
166
167#[derive(Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
168pub struct SetActorMeshMessage {
169    pub actor_mesh: hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
170}
171wirevalue::register_type!(SetActorMeshMessage);
172
173#[derive(Debug, Named, Serialize, Deserialize)]
174pub struct CodeSyncManagerParams {}
175wirevalue::register_type!(CodeSyncManagerParams);
176
177#[derive(Debug)]
178#[hyperactor::export(
179    spawn = true,
180    handlers = [
181        CodeSyncMessage { cast = true },
182        SetActorMeshMessage { cast = true }
183    ],
184)]
185pub struct CodeSyncManager {
186    rsync: OnceCell<ActorHandle<RsyncActor>>,
187    auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
188    conda_sync: OnceCell<ActorHandle<CondaSyncActor>>,
189    self_mesh: once_cell::sync::OnceCell<hyperactor_mesh::ActorMeshRef<CodeSyncManager>>,
190    rank: once_cell::sync::OnceCell<usize>,
191}
192
193impl Actor for CodeSyncManager {}
194
195#[async_trait]
196impl RemoteSpawn for CodeSyncManager {
197    type Params = CodeSyncManagerParams;
198
199    async fn new(CodeSyncManagerParams {}: Self::Params, _environment: Flattrs) -> Result<Self> {
200        Ok(Self {
201            rsync: OnceCell::new(),
202            auto_reload: OnceCell::new(),
203            conda_sync: OnceCell::new(),
204            self_mesh: once_cell::sync::OnceCell::new(),
205            rank: once_cell::sync::OnceCell::new(),
206        })
207    }
208}
209
210impl CodeSyncManager {
211    async fn get_rsync_actor<'a>(
212        &'a mut self,
213        cx: &Context<'a, Self>,
214    ) -> Result<&'a ActorHandle<RsyncActor>> {
215        self.rsync
216            .get_or_try_init(async move { RsyncActor::default().spawn(cx) })
217            .await
218    }
219
220    async fn get_auto_reload_actor<'a>(
221        &'a mut self,
222        cx: &Context<'a, Self>,
223    ) -> Result<&'a ActorHandle<AutoReloadActor>> {
224        self.auto_reload
225            .get_or_try_init(async move { AutoReloadActor::new().await?.spawn(cx) })
226            .await
227    }
228
229    async fn get_conda_sync_actor<'a>(
230        &'a mut self,
231        cx: &Context<'a, Self>,
232    ) -> Result<&'a ActorHandle<CondaSyncActor>> {
233        self.conda_sync
234            .get_or_try_init(async move { CondaSyncActor::default().spawn(cx) })
235            .await
236    }
237}
238
239#[async_trait]
240#[handle(CodeSyncMessage)]
241impl CodeSyncMessageHandler for CodeSyncManager {
242    async fn sync(
243        &mut self,
244        cx: &Context<Self>,
245        workspace: WorkspaceLocation,
246        method: Method,
247        reload: Option<WorkspaceShape>,
248        result: reference::PortRef<Result<(), String>>,
249    ) -> Result<()> {
250        let res = async move {
251            match method {
252                Method::Rsync { connect } => {
253                    // Forward rsync connection port to the RsyncActor, which will do the actual
254                    // connection and run the client.
255                    let (tx, mut rx) = cx.open_port::<Result<RsyncResult, String>>();
256                    self.get_rsync_actor(cx).await?.send(
257                        cx,
258                        RsyncMessage {
259                            connect,
260                            result: tx.bind(),
261                            workspace,
262                        },
263                    )?;
264                    // Observe any errors.
265                    let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
266                }
267                Method::CondaSync {
268                    connect,
269                    path_prefix_replacements,
270                } => {
271                    // Forward rsync connection port to the RsyncActor, which will do the actual
272                    // connection and run the client.
273                    let (tx, mut rx) = cx.open_port::<Result<CondaSyncResult, String>>();
274                    self.get_conda_sync_actor(cx).await?.send(
275                        cx,
276                        CondaSyncMessage {
277                            connect,
278                            result: tx.bind(),
279                            workspace,
280                            path_prefix_replacements,
281                        },
282                    )?;
283                    // Observe any errors.
284                    let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
285                }
286            }
287
288            // Trigger hot reload on all ranks that use/share this workspace.
289            if let Some(workspace_shape) = reload {
290                let (tx, rx) = cx.open_port::<Result<(), String>>();
291                let tx = tx.bind();
292                let rank = self
293                    .rank
294                    .get()
295                    .ok_or_else(|| anyhow::anyhow!("missing rank"))?;
296                let mesh = self
297                    .self_mesh
298                    .get()
299                    .ok_or_else(|| anyhow::anyhow!("missing self mesh"))?;
300                let mesh = workspace_shape.downstream_mesh(mesh, *rank)?;
301                mesh.cast(
302                    cx,
303                    CodeSyncMessage::Reload {
304                        sender_rank: Some(*rank),
305                        result: tx.clone(),
306                    },
307                )?;
308                // Exclude self from the sync.
309                let len = Ranked::region(&mesh).num_ranks() - 1;
310                let _: ((), Vec<()>) = try_join!(
311                    // Run reload for this rank.
312                    self.reload(cx, self.rank.get().cloned(), tx),
313                    rx.take(len)
314                        .map(|res| res?.map_err(anyhow::Error::msg))
315                        .try_collect(),
316                )?;
317            }
318
319            anyhow::Ok(())
320        }
321        .await;
322        result.send(
323            cx,
324            res.map_err(|e| {
325                format!(
326                    "{:#?}",
327                    Err::<(), _>(e)
328                        .with_context(|| format!("code sync from {}", cx.self_id()))
329                        .unwrap_err()
330                )
331            }),
332        )?;
333        Ok(())
334    }
335
336    async fn reload(
337        &mut self,
338        cx: &Context<Self>,
339        sender_rank: Option<usize>,
340        result: reference::PortRef<Result<(), String>>,
341    ) -> Result<()> {
342        if self
343            .rank
344            .get()
345            .is_some_and(|rank| sender_rank.is_some_and(|sender_rank| *rank == sender_rank))
346        {
347            return Ok(());
348        }
349        let res = async move {
350            let (tx, mut rx) = cx.open_port();
351            self.get_auto_reload_actor(cx)
352                .await?
353                .send(cx, AutoReloadMessage { result: tx.bind() })?;
354            rx.recv().await?.map_err(anyhow::Error::msg)?;
355            anyhow::Ok(())
356        }
357        .await;
358        result.send(
359            cx,
360            res.map_err(|e| {
361                format!(
362                    "{:#?}",
363                    Err::<(), _>(e)
364                        .with_context(|| format!("module reload from {}", cx.self_id()))
365                        .unwrap_err()
366                )
367            }),
368        )?;
369        Ok(())
370    }
371}
372
373#[async_trait]
374impl Handler<SetActorMeshMessage> for CodeSyncManager {
375    async fn handle(&mut self, cx: &Context<Self>, msg: SetActorMeshMessage) -> Result<()> {
376        let mesh = self.self_mesh.get_or_init(|| msg.actor_mesh);
377        self.rank.get_or_init(|| {
378            mesh.iter()
379                .find(|(_, actor)| actor.actor_id() == cx.self_id())
380                .unwrap()
381                .0
382                .rank()
383        });
384        Ok(())
385    }
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub enum CodeSyncMethod {
390    Rsync,
391    CondaSync {
392        path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
393    },
394}
395
396pub async fn code_sync_mesh(
397    cx: &impl context::Actor,
398    actor_mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
399    local_workspace: PathBuf,
400    remote_workspace: WorkspaceConfig,
401    method: CodeSyncMethod,
402    auto_reload: bool,
403) -> Result<()> {
404    let instance = cx.instance();
405
406    // Create a slice of the actor mesh that only includes workspace "owners" (e.g. on multi-GPU hosts,
407    // only one of the ranks on that host will participate in the code sync).
408    let owner_shape = remote_workspace.shape.owners()?;
409    let actor_mesh = actor_mesh.sliced(owner_shape.region());
410    let num_ranks = Ranked::region(&actor_mesh).num_ranks();
411
412    let (method, method_fut) = match method {
413        CodeSyncMethod::Rsync => {
414            // Spawn a rsync daemon to accept incoming connections from actors.
415            // some machines (e.g. github CI) do not have ipv6, so try ipv6 then fallback to ipv4
416            let ipv6_lo: SocketAddr = "[::1]:0".parse()?;
417            let ipv4_lo: SocketAddr = "127.0.0.1:0".parse()?;
418            let addrs: [SocketAddr; 2] = [ipv6_lo, ipv4_lo];
419            let daemon =
420                RsyncDaemon::spawn(TcpListener::bind(&addrs[..]).await?, &local_workspace).await?;
421
422            let daemon_addr = daemon.addr().clone();
423            let (rsync_conns_tx, rsync_conns_rx) = instance.open_port::<Connect>();
424            (
425                Method::Rsync {
426                    connect: rsync_conns_tx.bind(),
427                },
428                // This async task will process rsync connection attempts concurrently, forwarding
429                // them to the rsync daemon above.
430                async move {
431                    let res = rsync_conns_rx
432                        .take(num_ranks)
433                        .err_into::<anyhow::Error>()
434                        .try_for_each_concurrent(None, |connect| async move {
435                            let (mut local, mut stream) = try_join!(
436                                TcpStream::connect(daemon_addr.clone()).err_into(),
437                                accept(instance, instance.self_id().clone(), connect),
438                            )?;
439                            tokio::io::copy_bidirectional(&mut local, &mut stream).await?;
440                            Ok(())
441                        })
442                        .await;
443                    daemon.shutdown().await?;
444                    res?;
445                    anyhow::Ok(())
446                }
447                .boxed(),
448            )
449        }
450        CodeSyncMethod::CondaSync {
451            path_prefix_replacements,
452        } => {
453            let (conns_tx, conns_rx) = instance.open_port::<Connect>();
454            (
455                Method::CondaSync {
456                    connect: conns_tx.bind(),
457                    path_prefix_replacements,
458                },
459                async move {
460                    conns_rx
461                        .take(num_ranks)
462                        .err_into::<anyhow::Error>()
463                        .try_for_each_concurrent(None, |connect| async {
464                            let (mut read, mut write) =
465                                accept(instance, instance.self_id().clone(), connect)
466                                    .await?
467                                    .into_split();
468                            let res = sender(&local_workspace, &mut read, &mut write).await;
469
470                            // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
471                            // message spam.
472                            write.shutdown().await?;
473                            let mut buf = vec![];
474                            read.read_to_end(&mut buf).await?;
475
476                            res
477                        })
478                        .await
479                }
480                .boxed(),
481            )
482        }
483    };
484
485    let ((), ()) = try_join!(
486        method_fut,
487        // This async task will cast the code sync message to workspace owners, and process any errors.
488        async move {
489            let (result_tx, result_rx) = instance.open_port::<Result<(), String>>();
490            actor_mesh.cast(
491                instance,
492                CodeSyncMessage::Sync {
493                    method,
494                    workspace: remote_workspace.location.clone(),
495                    reload: if auto_reload {
496                        Some(remote_workspace.shape)
497                    } else {
498                        None
499                    },
500                    result: result_tx.bind(),
501                },
502            )?;
503
504            // Wait for all actors to report result.
505            let results = result_rx.take(num_ranks).try_collect::<Vec<_>>().await?;
506
507            // Combine all errors into one.
508            let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
509            results
510                .into_iter()
511                .map(|res| res.map_err(anyhow::Error::msg))
512                .try_collect_or_stash::<()>(&mut errs);
513            Ok(errs.into_result()?)
514        },
515    )?;
516
517    Ok(())
518}
519
520#[cfg(test)]
521mod tests {
522    use anyhow::anyhow;
523    use hyperactor_mesh::context;
524    use hyperactor_mesh::test_utils;
525    use ndslice::shape;
526    use tempfile::TempDir;
527    use tokio::fs;
528
529    use super::*;
530
531    #[test]
532    fn test_workspace_shape_owners() {
533        // Create a shape with multiple dimensions
534        let shape = shape! { host = 2, replica = 3 };
535
536        // Test case 1: dimension is None (should return the original shape)
537        let ws_shape = WorkspaceShape {
538            shape: shape.clone(),
539            dimension: None,
540        };
541        let owners = ws_shape.owners().unwrap();
542        assert_eq!(owners.slice().len(), 6); // 2 hosts * 3 replicas = 6 ranks
543
544        // Test case 2: dimension is "host" (should return a shape with only one rank per host)
545        let ws_shape = WorkspaceShape {
546            shape: shape.clone(),
547            dimension: Some("host".to_string()),
548        };
549        let owners = ws_shape.owners().unwrap();
550        assert_eq!(owners.slice().len(), 1); // 2 hosts, 1 rank per host
551
552        // Test case 3: dimension is "replica" (should return a shape with only one rank per replica)
553        let ws_shape = WorkspaceShape {
554            shape: shape.clone(),
555            dimension: Some("replica".to_string()),
556        };
557        let owners = ws_shape.owners().unwrap();
558        assert_eq!(owners.slice().len(), 2); // 3 replicas, 1 rank per replica
559    }
560
561    #[test]
562    fn test_workspace_shape_downstream() -> Result<()> {
563        // Create a shape with multiple dimensions
564        let shape = shape! { host = 2, replica = 3 };
565
566        // Test case 1: dimension is None (should return a shape with just the specified rank)
567        let ws_shape = WorkspaceShape {
568            shape: shape.clone(),
569            dimension: None,
570        };
571        let downstream = ws_shape.downstream(0)?;
572        assert_eq!(downstream.slice().len(), 1); // Just rank 0
573
574        // Test case 2: dimension is "host" (should return a shape with all ranks on the same host)
575        let ws_shape = WorkspaceShape {
576            shape: shape.clone(),
577            dimension: Some("host".to_string()),
578        };
579        let downstream = ws_shape.downstream(0)?;
580        assert_eq!(downstream.slice().len(), 6); // All ranks in the shape
581        assert!(ws_shape.downstream(3).is_err());
582
583        // Test case 3: dimension is "e (should return a shape with all ranks on the same host)
584        let ws_shape = WorkspaceShape {
585            shape: shape.clone(),
586            dimension: Some("replica".to_string()),
587        };
588        let downstream = ws_shape.downstream(0)?;
589        assert_eq!(downstream.slice().len(), 3);
590        let downstream = ws_shape.downstream(3)?;
591        assert_eq!(downstream.slice().len(), 3);
592
593        Ok(())
594    }
595
596    #[tokio::test]
597    async fn test_code_sync_manager_and_mesh() -> Result<()> {
598        // Create source workspace with test files
599        let source_workspace = TempDir::new()?;
600        fs::write(source_workspace.path().join("test1.txt"), "content1").await?;
601        fs::write(source_workspace.path().join("test2.txt"), "content2").await?;
602        fs::create_dir(source_workspace.path().join("subdir")).await?;
603        fs::write(source_workspace.path().join("subdir/test3.txt"), "content3").await?;
604
605        // Create target workspace for the actors
606        let target_workspace = TempDir::new()?;
607        fs::create_dir(target_workspace.path().join("subdir5")).await?;
608        fs::write(target_workspace.path().join("foo.txt"), "something").await?;
609
610        // TODO: thread through context, or access the actual python context;
611        // for now this is basically equivalent (arguably better) to using the proc mesh client.
612        let cx = context().await;
613        let instance = cx.actor_instance;
614        // Set up actor mesh with CodeSyncManager actors
615        let mut host_mesh = test_utils::local_host_mesh(2).await;
616        let proc_mesh = host_mesh
617            .spawn(instance, "code_sync_test", ndslice::Extent::unity())
618            .await
619            .unwrap();
620
621        // Create CodeSyncManagerParams
622        let params = CodeSyncManagerParams {};
623
624        // Spawn actor mesh with CodeSyncManager actors
625        let actor_mesh = proc_mesh
626            .spawn_service(&instance, "code_sync_test", &params)
627            .await?;
628
629        // Set up the mesh reference on each actor
630        actor_mesh.cast(
631            &instance,
632            SetActorMeshMessage {
633                actor_mesh: (*actor_mesh).clone(),
634            },
635        )?;
636
637        // Create workspace configuration
638        let remote_workspace_config = WorkspaceConfig {
639            location: WorkspaceLocation::Constant(target_workspace.path().to_path_buf()),
640            shape: WorkspaceShape {
641                shape: shape! { replica = 2 },
642                dimension: Some("replica".to_string()),
643            },
644        };
645
646        // Test code_sync_mesh function - this coordinates sync operations across the mesh
647        // Test without auto-reload first
648        code_sync_mesh(
649            instance,
650            &actor_mesh,
651            source_workspace.path().to_path_buf(),
652            remote_workspace_config.clone(),
653            CodeSyncMethod::Rsync,
654            false, // no auto-reload
655        )
656        .await?;
657
658        // Verify that files were synchronized correctly
659        assert!(
660            !dir_diff::is_different(&source_workspace, &target_workspace)
661                .map_err(|e| anyhow!("{:?}", e))?,
662            "Source and target workspaces should be identical after sync"
663        );
664
665        let _ = host_mesh.shutdown(instance).await;
666        Ok(())
667    }
668}