monarch_hyperactor/code_sync/
workspace.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::path::PathBuf;
10
11use anyhow::Context;
12use anyhow::Result;
13use serde::Deserialize;
14use serde::Serialize;
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
17pub enum WorkspaceLocation {
18    //// Workspace directory specified by the given path.
19    Constant(PathBuf),
20
21    /// Workspace directory specified by dereferencing the value of the environment variable
22    /// and appending the relative path to it.
23    ///
24    /// Example: `WorkspaceLocation::FromEnvVar{ env:"WORKSPACE_DIR", relpath: PathBuf::from("github/torchtitan) }`
25    /// points to `$WORKSPACE_DIR/github/torchtitan`.
26    FromEnvVar {
27        env: String,
28        relpath: PathBuf,
29    },
30}
31
32impl WorkspaceLocation {
33    pub fn resolve(&self) -> Result<PathBuf> {
34        Ok(match self {
35            WorkspaceLocation::Constant(p) => p.clone(),
36            WorkspaceLocation::FromEnvVar { env, relpath } => PathBuf::from(
37                std::env::var_os(env)
38                    .with_context(|| format!("workspace env var not set: {}", env))?,
39            )
40            .join(relpath),
41        })
42    }
43}
44#[cfg(test)]
45mod tests {
46    use std::env;
47
48    use tempfile::tempdir;
49
50    use super::*;
51
52    #[test]
53    fn test_constant_workspace_location_constant() {
54        let dir = tempdir().unwrap();
55        let path = dir.path().to_path_buf();
56        let loc = WorkspaceLocation::Constant(path.clone());
57        let resolved = loc.resolve().unwrap();
58        assert_eq!(resolved, path);
59    }
60
61    #[test]
62    fn test_from_env_var_workspace_location() {
63        let tmpdir = tempdir().unwrap();
64
65        // SAFETY: ok for single threaded test case
66        unsafe { env::set_var("WORKSPACE_DIR", tmpdir.path()) }
67
68        assert_eq!(
69            tmpdir.path().join("github/torchtitan"),
70            WorkspaceLocation::FromEnvVar {
71                env: "WORKSPACE_DIR".to_string(),
72                relpath: PathBuf::from("github/torchtitan")
73            }
74            .resolve()
75            .unwrap(),
76        );
77
78        assert_eq!(
79            tmpdir.path().to_path_buf(),
80            WorkspaceLocation::FromEnvVar {
81                env: "WORKSPACE_DIR".to_string(),
82                relpath: PathBuf::new()
83            }
84            .resolve()
85            .unwrap(),
86        );
87
88        // SAFETY: ok for single threaded test case
89        unsafe { env::remove_var("WORKSPACE_DIR") }
90    }
91
92    #[test]
93    fn test_from_env_var_missing_env() {
94        let loc = WorkspaceLocation::FromEnvVar {
95            env: "__NON_EXISTENT__".to_string(),
96            relpath: PathBuf::from("foo"),
97        };
98        let err = loc.resolve().unwrap_err();
99        assert!(format!("{:?}", err).contains("__NON_EXISTENT__"));
100    }
101}