1use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::path::PathBuf;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::ensure;
16use async_once_cell::OnceCell;
17use async_trait::async_trait;
18use futures::FutureExt;
19use futures::StreamExt;
20use futures::TryFutureExt;
21use futures::TryStreamExt;
22use futures::try_join;
23use hyperactor::Actor;
24use hyperactor::ActorHandle;
25use hyperactor::Bind;
26use hyperactor::Context;
27use hyperactor::Handler;
28use hyperactor::RemoteSpawn;
29use hyperactor::Unbind;
30use hyperactor::context;
31use hyperactor::handle;
32use hyperactor::reference;
33use hyperactor_config::Flattrs;
34use hyperactor_mesh::connect::Connect;
35use hyperactor_mesh::connect::accept;
36use lazy_errors::ErrorStash;
37use lazy_errors::TryCollectOrStash;
38use monarch_conda::sync::sender;
39use ndslice::Shape;
40use ndslice::ShapeError;
41use ndslice::view::Ranked;
42use ndslice::view::RankedSliceable;
43use ndslice::view::ViewExt;
44use serde::Deserialize;
45use serde::Serialize;
46use tokio::io::AsyncReadExt;
47use tokio::io::AsyncWriteExt;
48use tokio::net::TcpListener;
49use tokio::net::TcpStream;
50use typeuri::Named;
51
52use crate::code_sync::WorkspaceLocation;
53use crate::code_sync::auto_reload::AutoReloadActor;
54use crate::code_sync::auto_reload::AutoReloadMessage;
55use crate::code_sync::conda_sync::CondaSyncActor;
56use crate::code_sync::conda_sync::CondaSyncMessage;
57use crate::code_sync::conda_sync::CondaSyncResult;
58use crate::code_sync::rsync::RsyncActor;
59use crate::code_sync::rsync::RsyncDaemon;
60use crate::code_sync::rsync::RsyncMessage;
61use crate::code_sync::rsync::RsyncResult;
62
63#[derive(Clone, Serialize, Deserialize, Debug)]
64pub enum Method {
65 Rsync {
66 connect: reference::PortRef<Connect>,
67 },
68 CondaSync {
69 connect: reference::PortRef<Connect>,
70 path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
71 },
72}
73
74#[derive(Clone, Serialize, Deserialize, Debug)]
76pub struct WorkspaceShape {
77 pub shape: Shape,
79 pub dimension: Option<String>,
81}
82
83impl WorkspaceShape {
84 pub fn owners(&self) -> Result<Shape, ShapeError> {
90 let mut new_shape = self.shape.clone();
91 for label in self
92 .shape
93 .labels()
94 .iter()
95 .skip_while(|l| Some(*l) != self.dimension.as_ref())
96 {
97 new_shape = new_shape.select(label, 0)?;
98 }
100 Ok(new_shape)
101 }
102
103 pub fn downstream(&self, rank: usize) -> Result<Shape> {
111 let coords = self.shape.coordinates(rank)?;
112
113 for (label, value) in coords
114 .iter()
115 .skip_while(|(l, _)| Some(l) != self.dimension.as_ref())
116 {
117 ensure!(
118 *value == 0,
119 "Coordinate for dimension '{}' must be 0 for rank {}",
120 label,
121 rank
122 );
123 }
124
125 Ok(self.shape.index(
126 coords
127 .into_iter()
128 .take_while(|(l, _)| Some(l) != self.dimension.as_ref())
129 .collect::<Vec<_>>(),
130 )?)
131 }
132
133 fn downstream_mesh(
134 &self,
135 mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
136 rank: usize,
137 ) -> Result<hyperactor_mesh::ActorMeshRef<CodeSyncManager>> {
138 let shape = self.downstream(rank)?;
139 Ok(mesh.sliced(shape.region()))
140 }
141}
142
143#[derive(Clone, Serialize, Deserialize, Debug)]
144pub struct WorkspaceConfig {
145 pub location: WorkspaceLocation,
146 pub shape: WorkspaceShape,
147}
148
149#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
150pub enum CodeSyncMessage {
151 Sync {
152 workspace: WorkspaceLocation,
153 method: Method,
155 reload: Option<WorkspaceShape>,
157 result: reference::PortRef<Result<(), String>>,
159 },
160 Reload {
161 sender_rank: Option<usize>,
162 result: reference::PortRef<Result<(), String>>,
163 },
164}
165wirevalue::register_type!(CodeSyncMessage);
166
167#[derive(Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
168pub struct SetActorMeshMessage {
169 pub actor_mesh: hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
170}
171wirevalue::register_type!(SetActorMeshMessage);
172
173#[derive(Debug, Named, Serialize, Deserialize)]
174pub struct CodeSyncManagerParams {}
175wirevalue::register_type!(CodeSyncManagerParams);
176
177#[derive(Debug)]
178#[hyperactor::export(
179 spawn = true,
180 handlers = [
181 CodeSyncMessage { cast = true },
182 SetActorMeshMessage { cast = true }
183 ],
184)]
185pub struct CodeSyncManager {
186 rsync: OnceCell<ActorHandle<RsyncActor>>,
187 auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
188 conda_sync: OnceCell<ActorHandle<CondaSyncActor>>,
189 self_mesh: once_cell::sync::OnceCell<hyperactor_mesh::ActorMeshRef<CodeSyncManager>>,
190 rank: once_cell::sync::OnceCell<usize>,
191}
192
193impl Actor for CodeSyncManager {}
194
195#[async_trait]
196impl RemoteSpawn for CodeSyncManager {
197 type Params = CodeSyncManagerParams;
198
199 async fn new(CodeSyncManagerParams {}: Self::Params, _environment: Flattrs) -> Result<Self> {
200 Ok(Self {
201 rsync: OnceCell::new(),
202 auto_reload: OnceCell::new(),
203 conda_sync: OnceCell::new(),
204 self_mesh: once_cell::sync::OnceCell::new(),
205 rank: once_cell::sync::OnceCell::new(),
206 })
207 }
208}
209
210impl CodeSyncManager {
211 async fn get_rsync_actor<'a>(
212 &'a mut self,
213 cx: &Context<'a, Self>,
214 ) -> Result<&'a ActorHandle<RsyncActor>> {
215 self.rsync
216 .get_or_try_init(async move { RsyncActor::default().spawn(cx) })
217 .await
218 }
219
220 async fn get_auto_reload_actor<'a>(
221 &'a mut self,
222 cx: &Context<'a, Self>,
223 ) -> Result<&'a ActorHandle<AutoReloadActor>> {
224 self.auto_reload
225 .get_or_try_init(async move { AutoReloadActor::new().await?.spawn(cx) })
226 .await
227 }
228
229 async fn get_conda_sync_actor<'a>(
230 &'a mut self,
231 cx: &Context<'a, Self>,
232 ) -> Result<&'a ActorHandle<CondaSyncActor>> {
233 self.conda_sync
234 .get_or_try_init(async move { CondaSyncActor::default().spawn(cx) })
235 .await
236 }
237}
238
239#[async_trait]
240#[handle(CodeSyncMessage)]
241impl CodeSyncMessageHandler for CodeSyncManager {
242 async fn sync(
243 &mut self,
244 cx: &Context<Self>,
245 workspace: WorkspaceLocation,
246 method: Method,
247 reload: Option<WorkspaceShape>,
248 result: reference::PortRef<Result<(), String>>,
249 ) -> Result<()> {
250 let res = async move {
251 match method {
252 Method::Rsync { connect } => {
253 let (tx, mut rx) = cx.open_port::<Result<RsyncResult, String>>();
256 self.get_rsync_actor(cx).await?.send(
257 cx,
258 RsyncMessage {
259 connect,
260 result: tx.bind(),
261 workspace,
262 },
263 )?;
264 let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
266 }
267 Method::CondaSync {
268 connect,
269 path_prefix_replacements,
270 } => {
271 let (tx, mut rx) = cx.open_port::<Result<CondaSyncResult, String>>();
274 self.get_conda_sync_actor(cx).await?.send(
275 cx,
276 CondaSyncMessage {
277 connect,
278 result: tx.bind(),
279 workspace,
280 path_prefix_replacements,
281 },
282 )?;
283 let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
285 }
286 }
287
288 if let Some(workspace_shape) = reload {
290 let (tx, rx) = cx.open_port::<Result<(), String>>();
291 let tx = tx.bind();
292 let rank = self
293 .rank
294 .get()
295 .ok_or_else(|| anyhow::anyhow!("missing rank"))?;
296 let mesh = self
297 .self_mesh
298 .get()
299 .ok_or_else(|| anyhow::anyhow!("missing self mesh"))?;
300 let mesh = workspace_shape.downstream_mesh(mesh, *rank)?;
301 mesh.cast(
302 cx,
303 CodeSyncMessage::Reload {
304 sender_rank: Some(*rank),
305 result: tx.clone(),
306 },
307 )?;
308 let len = Ranked::region(&mesh).num_ranks() - 1;
310 let _: ((), Vec<()>) = try_join!(
311 self.reload(cx, self.rank.get().cloned(), tx),
313 rx.take(len)
314 .map(|res| res?.map_err(anyhow::Error::msg))
315 .try_collect(),
316 )?;
317 }
318
319 anyhow::Ok(())
320 }
321 .await;
322 result.send(
323 cx,
324 res.map_err(|e| {
325 format!(
326 "{:#?}",
327 Err::<(), _>(e)
328 .with_context(|| format!("code sync from {}", cx.self_id()))
329 .unwrap_err()
330 )
331 }),
332 )?;
333 Ok(())
334 }
335
336 async fn reload(
337 &mut self,
338 cx: &Context<Self>,
339 sender_rank: Option<usize>,
340 result: reference::PortRef<Result<(), String>>,
341 ) -> Result<()> {
342 if self
343 .rank
344 .get()
345 .is_some_and(|rank| sender_rank.is_some_and(|sender_rank| *rank == sender_rank))
346 {
347 return Ok(());
348 }
349 let res = async move {
350 let (tx, mut rx) = cx.open_port();
351 self.get_auto_reload_actor(cx)
352 .await?
353 .send(cx, AutoReloadMessage { result: tx.bind() })?;
354 rx.recv().await?.map_err(anyhow::Error::msg)?;
355 anyhow::Ok(())
356 }
357 .await;
358 result.send(
359 cx,
360 res.map_err(|e| {
361 format!(
362 "{:#?}",
363 Err::<(), _>(e)
364 .with_context(|| format!("module reload from {}", cx.self_id()))
365 .unwrap_err()
366 )
367 }),
368 )?;
369 Ok(())
370 }
371}
372
373#[async_trait]
374impl Handler<SetActorMeshMessage> for CodeSyncManager {
375 async fn handle(&mut self, cx: &Context<Self>, msg: SetActorMeshMessage) -> Result<()> {
376 let mesh = self.self_mesh.get_or_init(|| msg.actor_mesh);
377 self.rank.get_or_init(|| {
378 mesh.iter()
379 .find(|(_, actor)| actor.actor_id() == cx.self_id())
380 .unwrap()
381 .0
382 .rank()
383 });
384 Ok(())
385 }
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub enum CodeSyncMethod {
390 Rsync,
391 CondaSync {
392 path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
393 },
394}
395
396pub async fn code_sync_mesh(
397 cx: &impl context::Actor,
398 actor_mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
399 local_workspace: PathBuf,
400 remote_workspace: WorkspaceConfig,
401 method: CodeSyncMethod,
402 auto_reload: bool,
403) -> Result<()> {
404 let instance = cx.instance();
405
406 let owner_shape = remote_workspace.shape.owners()?;
409 let actor_mesh = actor_mesh.sliced(owner_shape.region());
410 let num_ranks = Ranked::region(&actor_mesh).num_ranks();
411
412 let (method, method_fut) = match method {
413 CodeSyncMethod::Rsync => {
414 let ipv6_lo: SocketAddr = "[::1]:0".parse()?;
417 let ipv4_lo: SocketAddr = "127.0.0.1:0".parse()?;
418 let addrs: [SocketAddr; 2] = [ipv6_lo, ipv4_lo];
419 let daemon =
420 RsyncDaemon::spawn(TcpListener::bind(&addrs[..]).await?, &local_workspace).await?;
421
422 let daemon_addr = daemon.addr().clone();
423 let (rsync_conns_tx, rsync_conns_rx) = instance.open_port::<Connect>();
424 (
425 Method::Rsync {
426 connect: rsync_conns_tx.bind(),
427 },
428 async move {
431 let res = rsync_conns_rx
432 .take(num_ranks)
433 .err_into::<anyhow::Error>()
434 .try_for_each_concurrent(None, |connect| async move {
435 let (mut local, mut stream) = try_join!(
436 TcpStream::connect(daemon_addr.clone()).err_into(),
437 accept(instance, instance.self_id().clone(), connect),
438 )?;
439 tokio::io::copy_bidirectional(&mut local, &mut stream).await?;
440 Ok(())
441 })
442 .await;
443 daemon.shutdown().await?;
444 res?;
445 anyhow::Ok(())
446 }
447 .boxed(),
448 )
449 }
450 CodeSyncMethod::CondaSync {
451 path_prefix_replacements,
452 } => {
453 let (conns_tx, conns_rx) = instance.open_port::<Connect>();
454 (
455 Method::CondaSync {
456 connect: conns_tx.bind(),
457 path_prefix_replacements,
458 },
459 async move {
460 conns_rx
461 .take(num_ranks)
462 .err_into::<anyhow::Error>()
463 .try_for_each_concurrent(None, |connect| async {
464 let (mut read, mut write) =
465 accept(instance, instance.self_id().clone(), connect)
466 .await?
467 .into_split();
468 let res = sender(&local_workspace, &mut read, &mut write).await;
469
470 write.shutdown().await?;
473 let mut buf = vec![];
474 read.read_to_end(&mut buf).await?;
475
476 res
477 })
478 .await
479 }
480 .boxed(),
481 )
482 }
483 };
484
485 let ((), ()) = try_join!(
486 method_fut,
487 async move {
489 let (result_tx, result_rx) = instance.open_port::<Result<(), String>>();
490 actor_mesh.cast(
491 instance,
492 CodeSyncMessage::Sync {
493 method,
494 workspace: remote_workspace.location.clone(),
495 reload: if auto_reload {
496 Some(remote_workspace.shape)
497 } else {
498 None
499 },
500 result: result_tx.bind(),
501 },
502 )?;
503
504 let results = result_rx.take(num_ranks).try_collect::<Vec<_>>().await?;
506
507 let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
509 results
510 .into_iter()
511 .map(|res| res.map_err(anyhow::Error::msg))
512 .try_collect_or_stash::<()>(&mut errs);
513 Ok(errs.into_result()?)
514 },
515 )?;
516
517 Ok(())
518}
519
520#[cfg(test)]
521mod tests {
522 use anyhow::anyhow;
523 use hyperactor_mesh::context;
524 use hyperactor_mesh::test_utils;
525 use ndslice::shape;
526 use tempfile::TempDir;
527 use tokio::fs;
528
529 use super::*;
530
531 #[test]
532 fn test_workspace_shape_owners() {
533 let shape = shape! { host = 2, replica = 3 };
535
536 let ws_shape = WorkspaceShape {
538 shape: shape.clone(),
539 dimension: None,
540 };
541 let owners = ws_shape.owners().unwrap();
542 assert_eq!(owners.slice().len(), 6); let ws_shape = WorkspaceShape {
546 shape: shape.clone(),
547 dimension: Some("host".to_string()),
548 };
549 let owners = ws_shape.owners().unwrap();
550 assert_eq!(owners.slice().len(), 1); let ws_shape = WorkspaceShape {
554 shape: shape.clone(),
555 dimension: Some("replica".to_string()),
556 };
557 let owners = ws_shape.owners().unwrap();
558 assert_eq!(owners.slice().len(), 2); }
560
561 #[test]
562 fn test_workspace_shape_downstream() -> Result<()> {
563 let shape = shape! { host = 2, replica = 3 };
565
566 let ws_shape = WorkspaceShape {
568 shape: shape.clone(),
569 dimension: None,
570 };
571 let downstream = ws_shape.downstream(0)?;
572 assert_eq!(downstream.slice().len(), 1); let ws_shape = WorkspaceShape {
576 shape: shape.clone(),
577 dimension: Some("host".to_string()),
578 };
579 let downstream = ws_shape.downstream(0)?;
580 assert_eq!(downstream.slice().len(), 6); assert!(ws_shape.downstream(3).is_err());
582
583 let ws_shape = WorkspaceShape {
585 shape: shape.clone(),
586 dimension: Some("replica".to_string()),
587 };
588 let downstream = ws_shape.downstream(0)?;
589 assert_eq!(downstream.slice().len(), 3);
590 let downstream = ws_shape.downstream(3)?;
591 assert_eq!(downstream.slice().len(), 3);
592
593 Ok(())
594 }
595
596 #[tokio::test]
597 async fn test_code_sync_manager_and_mesh() -> Result<()> {
598 let source_workspace = TempDir::new()?;
600 fs::write(source_workspace.path().join("test1.txt"), "content1").await?;
601 fs::write(source_workspace.path().join("test2.txt"), "content2").await?;
602 fs::create_dir(source_workspace.path().join("subdir")).await?;
603 fs::write(source_workspace.path().join("subdir/test3.txt"), "content3").await?;
604
605 let target_workspace = TempDir::new()?;
607 fs::create_dir(target_workspace.path().join("subdir5")).await?;
608 fs::write(target_workspace.path().join("foo.txt"), "something").await?;
609
610 let cx = context().await;
613 let instance = cx.actor_instance;
614 let mut host_mesh = test_utils::local_host_mesh(2).await;
616 let proc_mesh = host_mesh
617 .spawn(instance, "code_sync_test", ndslice::Extent::unity())
618 .await
619 .unwrap();
620
621 let params = CodeSyncManagerParams {};
623
624 let actor_mesh = proc_mesh
626 .spawn_service(&instance, "code_sync_test", ¶ms)
627 .await?;
628
629 actor_mesh.cast(
631 &instance,
632 SetActorMeshMessage {
633 actor_mesh: (*actor_mesh).clone(),
634 },
635 )?;
636
637 let remote_workspace_config = WorkspaceConfig {
639 location: WorkspaceLocation::Constant(target_workspace.path().to_path_buf()),
640 shape: WorkspaceShape {
641 shape: shape! { replica = 2 },
642 dimension: Some("replica".to_string()),
643 },
644 };
645
646 code_sync_mesh(
649 instance,
650 &actor_mesh,
651 source_workspace.path().to_path_buf(),
652 remote_workspace_config.clone(),
653 CodeSyncMethod::Rsync,
654 false, )
656 .await?;
657
658 assert!(
660 !dir_diff::is_different(&source_workspace, &target_workspace)
661 .map_err(|e| anyhow!("{:?}", e))?,
662 "Source and target workspaces should be identical after sync"
663 );
664
665 let _ = host_mesh.shutdown(instance).await;
666 Ok(())
667 }
668}