1use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::path::PathBuf;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::ensure;
16use async_once_cell::OnceCell;
17use async_trait::async_trait;
18use futures::FutureExt;
19use futures::StreamExt;
20use futures::TryFutureExt;
21use futures::TryStreamExt;
22use futures::try_join;
23use hyperactor as reference;
24use hyperactor::Actor;
25use hyperactor::ActorHandle;
26use hyperactor::Bind;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::RemoteSpawn;
31use hyperactor::Unbind;
32use hyperactor::context;
33use hyperactor::handle;
34use hyperactor_config::Flattrs;
35use hyperactor_mesh::connect::Connect;
36use hyperactor_mesh::connect::accept;
37use lazy_errors::ErrorStash;
38use lazy_errors::TryCollectOrStash;
39use monarch_conda::sync::sender;
40use ndslice::Shape;
41use ndslice::ShapeError;
42use ndslice::view::Ranked;
43use ndslice::view::RankedSliceable;
44use ndslice::view::ViewExt;
45use serde::Deserialize;
46use serde::Serialize;
47use tokio::io::AsyncReadExt;
48use tokio::io::AsyncWriteExt;
49use tokio::net::TcpListener;
50use tokio::net::TcpStream;
51use typeuri::Named;
52
53use crate::code_sync::WorkspaceLocation;
54use crate::code_sync::auto_reload::AutoReloadActor;
55use crate::code_sync::auto_reload::AutoReloadMessage;
56use crate::code_sync::conda_sync::CondaSyncActor;
57use crate::code_sync::conda_sync::CondaSyncMessage;
58use crate::code_sync::conda_sync::CondaSyncResult;
59use crate::code_sync::rsync::RsyncActor;
60use crate::code_sync::rsync::RsyncDaemon;
61use crate::code_sync::rsync::RsyncMessage;
62use crate::code_sync::rsync::RsyncResult;
63
64#[derive(Clone, Serialize, Deserialize, Debug)]
65pub enum Method {
66 Rsync {
67 connect: reference::PortRef<Connect>,
68 },
69 CondaSync {
70 connect: reference::PortRef<Connect>,
71 path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
72 },
73}
74
75#[derive(Clone, Serialize, Deserialize, Debug)]
77pub struct WorkspaceShape {
78 pub shape: Shape,
80 pub dimension: Option<String>,
82}
83
84impl WorkspaceShape {
85 pub fn owners(&self) -> Result<Shape, ShapeError> {
91 let mut new_shape = self.shape.clone();
92 for label in self
93 .shape
94 .labels()
95 .iter()
96 .skip_while(|l| Some(*l) != self.dimension.as_ref())
97 {
98 new_shape = new_shape.select(label, 0)?;
99 }
101 Ok(new_shape)
102 }
103
104 pub fn downstream(&self, rank: usize) -> Result<Shape> {
112 let coords = self.shape.coordinates(rank)?;
113
114 for (label, value) in coords
115 .iter()
116 .skip_while(|(l, _)| Some(l) != self.dimension.as_ref())
117 {
118 ensure!(
119 *value == 0,
120 "Coordinate for dimension '{}' must be 0 for rank {}",
121 label,
122 rank
123 );
124 }
125
126 Ok(self.shape.index(
127 coords
128 .into_iter()
129 .take_while(|(l, _)| Some(l) != self.dimension.as_ref())
130 .collect::<Vec<_>>(),
131 )?)
132 }
133
134 fn downstream_mesh(
135 &self,
136 mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
137 rank: usize,
138 ) -> Result<hyperactor_mesh::ActorMeshRef<CodeSyncManager>> {
139 let shape = self.downstream(rank)?;
140 Ok(mesh.sliced(shape.region()))
141 }
142}
143
144#[derive(Clone, Serialize, Deserialize, Debug)]
145pub struct WorkspaceConfig {
146 pub location: WorkspaceLocation,
147 pub shape: WorkspaceShape,
148}
149
150#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
151#[expect(
152 clippy::large_enum_variant,
153 reason = "actor message enum with Handler/Bind/Unbind derives; boxing fields ripples into handler call sites and may require derive-macro changes — separate diff"
154)]
155pub enum CodeSyncMessage {
156 Sync {
157 workspace: WorkspaceLocation,
158 method: Method,
160 reload: Option<WorkspaceShape>,
162 result: reference::PortRef<Result<(), String>>,
164 },
165 Reload {
166 sender_rank: Option<usize>,
167 result: reference::PortRef<Result<(), String>>,
168 },
169}
170wirevalue::register_type!(CodeSyncMessage);
171
172#[derive(Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
173pub struct SetActorMeshMessage {
174 pub actor_mesh: hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
175}
176wirevalue::register_type!(SetActorMeshMessage);
177
178#[derive(Debug, Named, Serialize, Deserialize)]
179pub struct CodeSyncManagerParams {}
180wirevalue::register_type!(CodeSyncManagerParams);
181
182#[derive(Debug)]
183#[hyperactor::export(
184 handlers = [
185 CodeSyncMessage { cast = true },
186 SetActorMeshMessage { cast = true }
187 ],
188)]
189#[hyperactor::spawnable]
190pub struct CodeSyncManager {
191 rsync: OnceCell<ActorHandle<RsyncActor>>,
192 auto_reload: OnceCell<ActorHandle<AutoReloadActor>>,
193 conda_sync: OnceCell<ActorHandle<CondaSyncActor>>,
194 self_mesh: once_cell::sync::OnceCell<hyperactor_mesh::ActorMeshRef<CodeSyncManager>>,
195 rank: once_cell::sync::OnceCell<usize>,
196}
197
198impl Actor for CodeSyncManager {}
199
200#[async_trait]
201impl RemoteSpawn for CodeSyncManager {
202 type Params = CodeSyncManagerParams;
203
204 async fn new(CodeSyncManagerParams {}: Self::Params, _environment: Flattrs) -> Result<Self> {
205 Ok(Self {
206 rsync: OnceCell::new(),
207 auto_reload: OnceCell::new(),
208 conda_sync: OnceCell::new(),
209 self_mesh: once_cell::sync::OnceCell::new(),
210 rank: once_cell::sync::OnceCell::new(),
211 })
212 }
213}
214
215impl CodeSyncManager {
216 async fn get_rsync_actor<'a>(
217 &'a mut self,
218 cx: &Context<'a, Self>,
219 ) -> Result<&'a ActorHandle<RsyncActor>> {
220 self.rsync
221 .get_or_try_init(async move { RsyncActor::default().spawn(cx) })
222 .await
223 }
224
225 async fn get_auto_reload_actor<'a>(
226 &'a mut self,
227 cx: &Context<'a, Self>,
228 ) -> Result<&'a ActorHandle<AutoReloadActor>> {
229 self.auto_reload
230 .get_or_try_init(async move { AutoReloadActor::new().await?.spawn(cx) })
231 .await
232 }
233
234 async fn get_conda_sync_actor<'a>(
235 &'a mut self,
236 cx: &Context<'a, Self>,
237 ) -> Result<&'a ActorHandle<CondaSyncActor>> {
238 self.conda_sync
239 .get_or_try_init(async move { CondaSyncActor::default().spawn(cx) })
240 .await
241 }
242}
243
244#[async_trait]
245#[handle(CodeSyncMessage)]
246impl CodeSyncMessageHandler for CodeSyncManager {
247 async fn sync(
248 &mut self,
249 cx: &Context<Self>,
250 workspace: WorkspaceLocation,
251 method: Method,
252 reload: Option<WorkspaceShape>,
253 result: reference::PortRef<Result<(), String>>,
254 ) -> Result<()> {
255 let res = async move {
256 match method {
257 Method::Rsync { connect } => {
258 let (tx, mut rx) = cx.open_port::<Result<RsyncResult, String>>();
261 self.get_rsync_actor(cx).await?.post(
262 cx,
263 RsyncMessage {
264 connect,
265 result: tx.bind(),
266 workspace,
267 },
268 );
269 let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
271 }
272 Method::CondaSync {
273 connect,
274 path_prefix_replacements,
275 } => {
276 let (tx, mut rx) = cx.open_port::<Result<CondaSyncResult, String>>();
279 self.get_conda_sync_actor(cx).await?.post(
280 cx,
281 CondaSyncMessage {
282 connect,
283 result: tx.bind(),
284 workspace,
285 path_prefix_replacements,
286 },
287 );
288 let _ = rx.recv().await?.map_err(anyhow::Error::msg)?;
290 }
291 }
292
293 if let Some(workspace_shape) = reload {
295 let (tx, rx) = cx.open_port::<Result<(), String>>();
296 let tx = tx.bind();
297 let rank = self
298 .rank
299 .get()
300 .ok_or_else(|| anyhow::anyhow!("missing rank"))?;
301 let mesh = self
302 .self_mesh
303 .get()
304 .ok_or_else(|| anyhow::anyhow!("missing self mesh"))?;
305 let mesh = workspace_shape.downstream_mesh(mesh, *rank)?;
306 mesh.cast(
307 cx,
308 CodeSyncMessage::Reload {
309 sender_rank: Some(*rank),
310 result: tx.clone(),
311 },
312 )?;
313 let len = Ranked::region(&mesh).num_ranks() - 1;
315 let _: ((), Vec<()>) = try_join!(
316 self.reload(cx, self.rank.get().cloned(), tx),
318 rx.take(len)
319 .map(|res| res?.map_err(anyhow::Error::msg))
320 .try_collect(),
321 )?;
322 }
323
324 anyhow::Ok(())
325 }
326 .await;
327 result.post(
328 cx,
329 res.map_err(|e| {
330 format!(
331 "{:#?}",
332 Err::<(), _>(e)
333 .with_context(|| format!("code sync from {}", cx.self_addr()))
334 .unwrap_err()
335 )
336 }),
337 );
338 Ok(())
339 }
340
341 async fn reload(
342 &mut self,
343 cx: &Context<Self>,
344 sender_rank: Option<usize>,
345 result: reference::PortRef<Result<(), String>>,
346 ) -> Result<()> {
347 if self
348 .rank
349 .get()
350 .is_some_and(|rank| sender_rank.is_some_and(|sender_rank| *rank == sender_rank))
351 {
352 return Ok(());
353 }
354 let res = async move {
355 let (tx, mut rx) = cx.open_port::<Result<(), String>>();
356 self.get_auto_reload_actor(cx)
357 .await?
358 .post(cx, AutoReloadMessage { result: tx.bind() });
359 rx.recv().await?.map_err(anyhow::Error::msg)?;
360 anyhow::Ok(())
361 }
362 .await;
363 result.post(
364 cx,
365 res.map_err(|e| {
366 format!(
367 "{:#?}",
368 Err::<(), _>(e)
369 .with_context(|| format!("module reload from {}", cx.self_addr()))
370 .unwrap_err()
371 )
372 }),
373 );
374 Ok(())
375 }
376}
377
378#[async_trait]
379impl Handler<SetActorMeshMessage> for CodeSyncManager {
380 async fn handle(&mut self, cx: &Context<Self>, msg: SetActorMeshMessage) -> Result<()> {
381 let mesh = self.self_mesh.get_or_init(|| msg.actor_mesh);
382 self.rank.get_or_init(|| {
383 mesh.iter()
384 .find(|(_, actor)| *actor.actor_addr() == *cx.self_addr())
385 .unwrap()
386 .0
387 .rank()
388 });
389 Ok(())
390 }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub enum CodeSyncMethod {
395 Rsync,
396 CondaSync {
397 path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
398 },
399}
400
401pub async fn code_sync_mesh(
402 cx: &impl context::Actor,
403 actor_mesh: &hyperactor_mesh::ActorMeshRef<CodeSyncManager>,
404 local_workspace: PathBuf,
405 remote_workspace: WorkspaceConfig,
406 method: CodeSyncMethod,
407 auto_reload: bool,
408) -> Result<()> {
409 let instance = cx.instance();
410
411 let owner_shape = remote_workspace.shape.owners()?;
414 let actor_mesh = actor_mesh.sliced(owner_shape.region());
415 let num_ranks = Ranked::region(&actor_mesh).num_ranks();
416
417 let (method, method_fut) = match method {
418 CodeSyncMethod::Rsync => {
419 let ipv6_lo: SocketAddr = "[::1]:0".parse()?;
422 let ipv4_lo: SocketAddr = "127.0.0.1:0".parse()?;
423 let addrs: [SocketAddr; 2] = [ipv6_lo, ipv4_lo];
424 let daemon =
425 RsyncDaemon::spawn(TcpListener::bind(&addrs[..]).await?, &local_workspace).await?;
426
427 let daemon_addr = *daemon.addr();
428 let (rsync_conns_tx, rsync_conns_rx) = instance.open_port::<Connect>();
429 (
430 Method::Rsync {
431 connect: rsync_conns_tx.bind(),
432 },
433 async move {
436 let res = rsync_conns_rx
437 .take(num_ranks)
438 .err_into::<anyhow::Error>()
439 .try_for_each_concurrent(None, |connect| async move {
440 let (mut local, mut stream) = try_join!(
441 TcpStream::connect(daemon_addr).err_into(),
442 accept(instance, instance.self_addr().clone(), connect),
443 )?;
444 tokio::io::copy_bidirectional(&mut local, &mut stream).await?;
445 Ok(())
446 })
447 .await;
448 daemon.shutdown().await?;
449 res?;
450 anyhow::Ok(())
451 }
452 .boxed(),
453 )
454 }
455 CodeSyncMethod::CondaSync {
456 path_prefix_replacements,
457 } => {
458 let (conns_tx, conns_rx) = instance.open_port::<Connect>();
459 (
460 Method::CondaSync {
461 connect: conns_tx.bind(),
462 path_prefix_replacements,
463 },
464 async move {
465 conns_rx
466 .take(num_ranks)
467 .err_into::<anyhow::Error>()
468 .try_for_each_concurrent(None, |connect| async {
469 let (mut read, mut write) =
470 accept(instance, instance.self_addr().clone(), connect)
471 .await?
472 .into_split();
473 let res = sender(&local_workspace, &mut read, &mut write).await;
474
475 write.shutdown().await?;
478 let mut buf = vec![];
479 read.read_to_end(&mut buf).await?;
480
481 res
482 })
483 .await
484 }
485 .boxed(),
486 )
487 }
488 };
489
490 let ((), ()) = try_join!(
491 method_fut,
492 async move {
494 let (result_tx, result_rx) = instance.open_port::<Result<(), String>>();
495 actor_mesh.cast(
496 instance,
497 CodeSyncMessage::Sync {
498 method,
499 workspace: remote_workspace.location.clone(),
500 reload: if auto_reload {
501 Some(remote_workspace.shape)
502 } else {
503 None
504 },
505 result: result_tx.bind(),
506 },
507 )?;
508
509 let results = result_rx.take(num_ranks).try_collect::<Vec<_>>().await?;
511
512 let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
514 results
515 .into_iter()
516 .map(|res| res.map_err(anyhow::Error::msg))
517 .try_collect_or_stash::<()>(&mut errs);
518 Ok(errs.into_result()?)
519 },
520 )?;
521
522 Ok(())
523}
524
525#[cfg(test)]
526mod tests {
527 use anyhow::anyhow;
528 use hyperactor_mesh::context;
529 use hyperactor_mesh::test_utils;
530 use ndslice::shape;
531 use tempfile::TempDir;
532 use tokio::fs;
533
534 use super::*;
535
536 #[test]
537 fn test_workspace_shape_owners() {
538 let shape = shape! { host = 2, replica = 3 };
540
541 let ws_shape = WorkspaceShape {
543 shape: shape.clone(),
544 dimension: None,
545 };
546 let owners = ws_shape.owners().unwrap();
547 assert_eq!(owners.slice().len(), 6); let ws_shape = WorkspaceShape {
551 shape: shape.clone(),
552 dimension: Some("host".to_string()),
553 };
554 let owners = ws_shape.owners().unwrap();
555 assert_eq!(owners.slice().len(), 1); let ws_shape = WorkspaceShape {
559 shape: shape.clone(),
560 dimension: Some("replica".to_string()),
561 };
562 let owners = ws_shape.owners().unwrap();
563 assert_eq!(owners.slice().len(), 2); }
565
566 #[test]
567 fn test_workspace_shape_downstream() -> Result<()> {
568 let shape = shape! { host = 2, replica = 3 };
570
571 let ws_shape = WorkspaceShape {
573 shape: shape.clone(),
574 dimension: None,
575 };
576 let downstream = ws_shape.downstream(0)?;
577 assert_eq!(downstream.slice().len(), 1); let ws_shape = WorkspaceShape {
581 shape: shape.clone(),
582 dimension: Some("host".to_string()),
583 };
584 let downstream = ws_shape.downstream(0)?;
585 assert_eq!(downstream.slice().len(), 6); assert!(ws_shape.downstream(3).is_err());
587
588 let ws_shape = WorkspaceShape {
590 shape: shape.clone(),
591 dimension: Some("replica".to_string()),
592 };
593 let downstream = ws_shape.downstream(0)?;
594 assert_eq!(downstream.slice().len(), 3);
595 let downstream = ws_shape.downstream(3)?;
596 assert_eq!(downstream.slice().len(), 3);
597
598 Ok(())
599 }
600
601 #[cfg_attr(not(target_os = "linux"), ignore = "linux-only")]
602 #[tokio::test]
603 async fn test_code_sync_manager_and_mesh() -> Result<()> {
604 let source_workspace = TempDir::new()?;
606 fs::write(source_workspace.path().join("test1.txt"), "content1").await?;
607 fs::write(source_workspace.path().join("test2.txt"), "content2").await?;
608 fs::create_dir(source_workspace.path().join("subdir")).await?;
609 fs::write(source_workspace.path().join("subdir/test3.txt"), "content3").await?;
610
611 let target_workspace = TempDir::new()?;
613 fs::create_dir(target_workspace.path().join("subdir5")).await?;
614 fs::write(target_workspace.path().join("foo.txt"), "something").await?;
615
616 let cx = context().await;
619 let instance = cx.actor_instance;
620 let mut host_mesh = test_utils::local_host_mesh(2).await;
622 let proc_mesh = host_mesh
623 .spawn(
624 instance,
625 "code_sync_test",
626 ndslice::Extent::unity(),
627 None,
628 None,
629 )
630 .await
631 .unwrap();
632
633 let params = CodeSyncManagerParams {};
635
636 let actor_mesh = proc_mesh
638 .spawn_service(&instance, "code_sync_test", ¶ms)
639 .await?;
640
641 actor_mesh.cast(
643 &instance,
644 SetActorMeshMessage {
645 actor_mesh: (*actor_mesh).clone(),
646 },
647 )?;
648
649 let remote_workspace_config = WorkspaceConfig {
651 location: WorkspaceLocation::Constant(target_workspace.path().to_path_buf()),
652 shape: WorkspaceShape {
653 shape: shape! { replica = 2 },
654 dimension: Some("replica".to_string()),
655 },
656 };
657
658 code_sync_mesh(
661 instance,
662 &actor_mesh,
663 source_workspace.path().to_path_buf(),
664 remote_workspace_config.clone(),
665 CodeSyncMethod::Rsync,
666 false, )
668 .await?;
669
670 assert!(
672 !dir_diff::is_different(&source_workspace, &target_workspace)
673 .map_err(|e| anyhow!("{:?}", e))?,
674 "Source and target workspaces should be identical after sync"
675 );
676
677 let _ = host_mesh.shutdown(instance).await;
678 Ok(())
679 }
680}