Skip to main content

monarch_hyperactor/code_sync/
conda_sync.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::path::PathBuf;
11
12use anyhow::Result;
13use async_trait::async_trait;
14use futures::FutureExt;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use hyperactor as reference;
18use hyperactor::Actor;
19use hyperactor::Bind;
20use hyperactor::Endpoint as _;
21use hyperactor::Handler;
22use hyperactor::Instance;
23use hyperactor::Unbind;
24use hyperactor::context::Mailbox;
25use hyperactor_mesh::ActorMeshRef;
26use hyperactor_mesh::connect::Connect;
27use hyperactor_mesh::connect::accept;
28use lazy_errors::ErrorStash;
29use lazy_errors::OrStash;
30use lazy_errors::StashedResult;
31use lazy_errors::TryCollectOrStash;
32use monarch_conda::sync::Action;
33use monarch_conda::sync::receiver;
34use monarch_conda::sync::sender;
35use ndslice::view::Ranked;
36use serde::Deserialize;
37use serde::Serialize;
38use tokio::io::AsyncReadExt;
39use tokio::io::AsyncWriteExt;
40use typeuri::Named;
41
42use crate::code_sync::WorkspaceLocation;
43
44/// Represents the result of an conda sync operation with details about what was transferred
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
46pub struct CondaSyncResult {
47    /// All changes that occurred during the sync operation
48    pub changes: HashMap<PathBuf, Action>,
49}
50wirevalue::register_type!(CondaSyncResult);
51
52#[derive(Debug, Clone, Named, Serialize, Deserialize, Bind, Unbind)]
53pub struct CondaSyncMessage {
54    /// The connect message to create a duplex bytestream with the client.
55    pub connect: reference::PortRef<Connect>,
56    /// A port to send back the result or any errors.
57    pub result: reference::PortRef<Result<CondaSyncResult, String>>,
58    /// The location of the workspace to sync.
59    pub workspace: WorkspaceLocation,
60    /// Path prefixes to fixup/replace when copying.
61    pub path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
62}
63wirevalue::register_type!(CondaSyncMessage);
64
65#[derive(Debug, Named, Serialize, Deserialize)]
66pub struct CondaSyncParams {}
67wirevalue::register_type!(CondaSyncParams);
68
69#[derive(Debug, Default)]
70#[hyperactor::export(handlers = [CondaSyncMessage { cast = true }])]
71#[hyperactor::spawnable]
72pub struct CondaSyncActor {}
73
74impl Actor for CondaSyncActor {}
75
76#[async_trait]
77impl Handler<CondaSyncMessage> for CondaSyncActor {
78    async fn handle(
79        &mut self,
80        cx: &hyperactor::Context<Self>,
81        CondaSyncMessage {
82            workspace,
83            path_prefix_replacements,
84            connect,
85            result,
86        }: CondaSyncMessage,
87    ) -> Result<(), anyhow::Error> {
88        let res = async {
89            let workspace = workspace.resolve()?;
90            let (connect_msg, completer) = Connect::allocate(cx.self_addr().clone(), cx);
91            connect.post(cx, connect_msg);
92            let (mut read, mut write) = completer.complete().await?.into_split();
93            let path_prefix_replacements = path_prefix_replacements
94                .into_iter()
95                .map(|(l, r)| Ok((l, r.resolve()?)))
96                .collect::<Result<Vec<_>>>()?
97                .into_iter()
98                .collect::<HashMap<_, _>>();
99            let changes_result =
100                receiver(&workspace, &mut read, &mut write, path_prefix_replacements).await;
101
102            // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
103            // message spam.
104            write.shutdown().await?;
105            let mut buf = vec![];
106            read.read_to_end(&mut buf).await?;
107
108            anyhow::Ok(CondaSyncResult {
109                changes: changes_result?,
110            })
111        }
112        .await;
113        result.post(cx, res.map_err(|e| format!("{:#?}", e)));
114        Ok(())
115    }
116}
117
118pub async fn conda_sync_mesh(
119    instance: &Instance<()>,
120    actor_mesh: &ActorMeshRef<CondaSyncActor>,
121    local_workspace: PathBuf,
122    remote_workspace: WorkspaceLocation,
123    path_prefix_replacements: HashMap<PathBuf, WorkspaceLocation>,
124) -> Result<Vec<CondaSyncResult>> {
125    let (conns_tx, conns_rx) = instance.mailbox().open_port();
126
127    let (res1, res2) = futures::future::join(
128        conns_rx
129            .take(actor_mesh.region().slice().len())
130            .err_into::<anyhow::Error>()
131            .try_for_each_concurrent(None, |connect| async {
132                let (mut read, mut write) = accept(instance, instance.self_addr().clone(), connect)
133                    .await?
134                    .into_split();
135                let res = sender(&local_workspace, &mut read, &mut write).await;
136
137                // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable
138                // message spam.
139                write.shutdown().await?;
140                let mut buf = vec![];
141                read.read_to_end(&mut buf).await?;
142
143                res
144            })
145            .boxed(),
146        async move {
147            let (result_tx, result_rx) = instance
148                .mailbox()
149                .open_port::<Result<CondaSyncResult, String>>();
150            actor_mesh.cast(
151                instance,
152                CondaSyncMessage {
153                    connect: conns_tx.bind(),
154                    result: result_tx.bind(),
155                    workspace: remote_workspace,
156                    path_prefix_replacements,
157                },
158            )?;
159
160            // Wait for all actors to report result.
161            let results = result_rx
162                .take(actor_mesh.region().slice().len())
163                .try_collect::<Vec<_>>()
164                .await?;
165
166            // Combine all errors into one.
167            let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures");
168            match results
169                .into_iter()
170                .map(|res| res.map_err(anyhow::Error::msg))
171                .try_collect_or_stash::<Vec<_>>(&mut errs)
172            {
173                StashedResult::Ok(results) => anyhow::Ok(results),
174                StashedResult::Err(_) => Err(errs.into_result().unwrap_err().into()),
175            }
176        },
177    )
178    .await;
179
180    // Combine code sync handler and cast errors into one.
181    let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "code sync failed");
182    res1.or_stash(&mut errs);
183    if let StashedResult::Ok(results) = res2.or_stash(&mut errs) {
184        errs.into_result()?;
185        return Ok(results);
186    }
187    Err(errs.into_result().unwrap_err().into())
188}