monarch_hyperactor/code_sync/
conda_sync.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::path::PathBuf;
11
12use anyhow::Result;
13use async_trait::async_trait;
14use futures::FutureExt;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use hyperactor::Actor;
18use hyperactor::Bind;
19use hyperactor::Handler;
20use hyperactor::Instance;
21use hyperactor::Unbind;
22use hyperactor::context::Mailbox;
23use hyperactor::reference;
24use hyperactor_mesh::ActorMeshRef;
25use hyperactor_mesh::connect::Connect;
26use hyperactor_mesh::connect::accept;
27use lazy_errors::ErrorStash;
28use lazy_errors::OrStash;
29use lazy_errors::StashedResult;
30use lazy_errors::TryCollectOrStash;
31use monarch_conda::sync::Action;
32use monarch_conda::sync::receiver;
33use monarch_conda::sync::sender;
34use ndslice::view::Ranked;
35use serde::Deserialize;
36use serde::Serialize;
37use tokio::io::AsyncReadExt;
38use tokio::io::AsyncWriteExt;
39use typeuri::Named;
40
41use crate::code_sync::WorkspaceLocation;
42
43/// Represents the result of an conda sync operation with details about what was transferred
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
45pub struct CondaSyncResult {
46    /// All changes that occurred during the sync operation
47    pub changes: HashMap<PathBuf, Action>,
48}
49wirevalue::register_type!(CondaSyncResult);
50
51#[derive(Debug, Clone, Named, Serialize, Deserialize, Bind, Unbind)]
52pub struct CondaSyncMessage {
53    /// The connect message to create a duplex bytestream with the client.
54    pub connect: reference::PortRef<Connect>,
55    /// A port to send back the result or any errors.
56    pub result: reference::PortRef<Result<CondaSyncResult, String>>,
57    /// The location of the workspace to sync.
58    pub workspace: WorkspaceLocation,
59    /// Path prefixes to fixup/replace when copying.
60    pub path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
61}
62wirevalue::register_type!(CondaSyncMessage);
63
64#[derive(Debug, Named, Serialize, Deserialize)]
65pub struct CondaSyncParams {}
66wirevalue::register_type!(CondaSyncParams);
67
68#[derive(Debug, Default)]
69#[hyperactor::export(spawn = true, handlers = [CondaSyncMessage { cast = true }])]
70pub struct CondaSyncActor {}
71
72impl Actor for CondaSyncActor {}
73
74#[async_trait]
75impl Handler<CondaSyncMessage> for CondaSyncActor {
76    async fn handle(
77        &mut self,
78        cx: &hyperactor::Context<Self>,
79        CondaSyncMessage {
80            workspace,
81            path_prefix_replacements,
82            connect,
83            result,
84        }: CondaSyncMessage,
85    ) -> Result<(), anyhow::Error> {
86        let res = async {
87            let workspace = workspace.resolve()?;
88            let (connect_msg, completer) = Connect::allocate(cx.self_id().clone(), cx);
89            connect.send(cx, connect_msg)?;
90            let (mut read, mut write) = completer.complete().await?.into_split();
91            let path_prefix_replacements = path_prefix_replacements
92                .into_iter()
93                .map(|(l, r)| Ok((l, r.resolve()?)))
94                .collect::<Result<Vec<_>>>()?
95                .into_iter()
96                .collect::<HashMap<_, _>>();
97            let changes_result =
98                receiver(&workspace, &mut read, &mut write, path_prefix_replacements).await;
99
100            // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
101            // message spam.
102            write.shutdown().await?;
103            let mut buf = vec![];
104            read.read_to_end(&mut buf).await?;
105
106            anyhow::Ok(CondaSyncResult {
107                changes: changes_result?,
108            })
109        }
110        .await;
111        result.send(cx, res.map_err(|e| format!("{:#?}", e)))?;
112        Ok(())
113    }
114}
115
116pub async fn conda_sync_mesh(
117    instance: &Instance<()>,
118    actor_mesh: &ActorMeshRef<CondaSyncActor>,
119    local_workspace: PathBuf,
120    remote_workspace: WorkspaceLocation,
121    path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
122) -> Result<Vec<CondaSyncResult>> {
123    let (conns_tx, conns_rx) = instance.mailbox().open_port();
124
125    let (res1, res2) = futures::future::join(
126        conns_rx
127            .take(actor_mesh.region().slice().len())
128            .err_into::<anyhow::Error>()
129            .try_for_each_concurrent(None, |connect| async {
130                let (mut read, mut write) = accept(instance, instance.self_id().clone(), connect)
131                    .await?
132                    .into_split();
133                let res = sender(&local_workspace, &mut read, &mut write).await;
134
135                // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
136                // message spam.
137                write.shutdown().await?;
138                let mut buf = vec![];
139                read.read_to_end(&mut buf).await?;
140
141                res
142            })
143            .boxed(),
144        async move {
145            let (result_tx, result_rx) = instance
146                .mailbox()
147                .open_port::<Result<CondaSyncResult, String>>();
148            actor_mesh.cast(
149                instance,
150                CondaSyncMessage {
151                    connect: conns_tx.bind(),
152                    result: result_tx.bind(),
153                    workspace: remote_workspace,
154                    path_prefix_replacements,
155                },
156            )?;
157
158            // Wait for all actors to report result.
159            let results = result_rx
160                .take(actor_mesh.region().slice().len())
161                .try_collect::<Vec<_>>()
162                .await?;
163
164            // Combine all errors into one.
165            let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
166            match results
167                .into_iter()
168                .map(|res| res.map_err(anyhow::Error::msg))
169                .try_collect_or_stash::<Vec<_>>(&mut errs)
170            {
171                StashedResult::Ok(results) => anyhow::Ok(results),
172                StashedResult::Err(_) => Err(errs.into_result().unwrap_err().into()),
173            }
174        },
175    )
176    .await;
177
178    // Combine code sync handler and cast errors into one.
179    let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "code sync failed");
180    res1.or_stash(&mut errs);
181    if let StashedResult::Ok(results) = res2.or_stash(&mut errs) {
182        errs.into_result()?;
183        return Ok(results);
184    }
185    Err(errs.into_result().unwrap_err().into())
186}