hyperactor_config/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Core configuration and attribute infrastructure for Hyperactor.
10//!
11//! This crate provides the core infrastructure for type-safe configuration
12//! management including:
13//! - `ConfigAttr`: Metadata for configuration keys
14//! - Helper functions to load/save `Attrs` (from env via `from_env`,
15//!   from YAML via `from_yaml`, and `to_yaml`)
16//! - Global layered configuration store under [`crate::global`]
17//!
18//! Individual crates should declare their own config keys using `declare_attrs!`
19//! and import `ConfigAttr`, `CONFIG`, and other infrastructure from this crate.
20
21use std::env;
22use std::fs::File;
23use std::io::Read;
24use std::path::Path;
25
26use serde::Deserialize;
27use serde::Serialize;
28use shell_quote::QuoteRefExt;
29use typeuri::Named;
30
31pub mod attrs;
32pub mod global;
33
34// Re-export commonly used items
35pub use attrs::AttrKeyInfo;
36pub use attrs::AttrValue;
37pub use attrs::Attrs;
38pub use attrs::Key;
39pub use attrs::SerializableValue;
40// Re-export AttrValue derive macro
41pub use hyperactor_config_macros::AttrValue;
42// Re-export macros needed by declare_attrs!
43pub use inventory::submit;
44pub use paste::paste;
45// Re-export typeuri for macro usage
46#[doc(hidden)]
47pub use typeuri;
48
49// declare_attrs is already exported via #[macro_export] in attrs.rs
50
51/// Metadata describing how a configuration key is exposed across
52/// environments.
53///
54/// Each `ConfigAttr` entry defines how a Rust configuration key maps
55/// to external representations:
56///  - `env_name`: the environment variable consulted by
57///    [`global::init_from_env()`] when loading configuration.
58///  - `py_name`: the Python keyword argument accepted by
59///    `monarch.configure(...)` and returned by `get_configuration()`.
60///
61/// All configuration keys should carry this meta-attribute via
62/// `@meta(CONFIG = ConfigAttr { ... })`.
63#[derive(Clone, Debug, Serialize, Deserialize)]
64pub struct ConfigAttr {
65    /// Environment variable consulted by `global::init_from_env()`.
66    pub env_name: Option<String>,
67
68    /// Python kwarg name used by `monarch.configure(...)` and
69    /// `get_configuration()`.
70    pub py_name: Option<String>,
71}
72
73impl Named for ConfigAttr {
74    fn typename() -> &'static str {
75        "hyperactor_config::ConfigAttr"
76    }
77}
78
79impl AttrValue for ConfigAttr {
80    fn display(&self) -> String {
81        serde_json::to_string(self).unwrap_or_else(|_| "<invalid ConfigAttr>".into())
82    }
83    fn parse(s: &str) -> Result<Self, anyhow::Error> {
84        Ok(serde_json::from_str(s)?)
85    }
86}
87
88// Declare the CONFIG meta-attribute
89declare_attrs! {
90    /// This is a meta-attribute marking a configuration key.
91    ///
92    /// It carries metadata used to bridge Rust, environment
93    /// variables, and Python:
94    ///  - `env_name`: environment variable name consulted by
95    ///    `global::init_from_env()`.
96    ///  - `py_name`: keyword argument name recognized by
97    ///    `monarch.configure(...)`.
98    ///
99    /// All configuration keys should be annotated with this
100    /// attribute.
101    pub attr CONFIG: ConfigAttr;
102}
103
104/// Load configuration from environment variables
105pub fn from_env() -> Attrs {
106    let mut config = Attrs::new();
107    let mut output = String::new();
108
109    fn export(env_var: &str, value: Option<&dyn SerializableValue>) -> String {
110        let env_var: String = env_var.quoted(shell_quote::Bash);
111        let value: String = value
112            .map_or("".to_string(), SerializableValue::display)
113            .quoted(shell_quote::Bash);
114        format!("export {}={}\n", env_var, value)
115    }
116
117    for key in inventory::iter::<AttrKeyInfo>() {
118        // Skip keys that are not marked as CONFIG or that do not
119        // declare an environment variable mapping. Only CONFIG-marked
120        // keys with an `env_name` participate in environment
121        // initialization.
122        let Some(cfg_meta) = key.meta.get(CONFIG) else {
123            continue;
124        };
125        let Some(env_var) = cfg_meta.env_name.as_deref() else {
126            continue;
127        };
128
129        let Ok(val) = env::var(env_var) else {
130            // Default value
131            output.push_str("# ");
132            output.push_str(&export(env_var, key.default));
133            continue;
134        };
135
136        match (key.parse)(&val) {
137            Err(e) => {
138                tracing::error!(
139                    "failed to override config key {} from value \"{}\" in ${}: {})",
140                    key.name,
141                    val,
142                    env_var,
143                    e
144                );
145                output.push_str("# ");
146                output.push_str(&export(env_var, key.default));
147            }
148            Ok(parsed) => {
149                output.push_str("# ");
150                output.push_str(&export(env_var, key.default));
151                output.push_str(&export(env_var, Some(parsed.as_ref())));
152                config.insert_value_by_name_unchecked(key.name, parsed);
153            }
154        }
155    }
156
157    tracing::info!(
158        "loaded configuration from environment:\n{}",
159        output.trim_end()
160    );
161
162    config
163}
164
165/// Load configuration from a YAML file
166pub fn from_yaml<P: AsRef<Path>>(path: P) -> Result<Attrs, anyhow::Error> {
167    let mut file = File::open(path)?;
168    let mut contents = String::new();
169    file.read_to_string(&mut contents)?;
170    Ok(serde_yaml::from_str(&contents)?)
171}
172
173/// Save configuration to a YAML file
174pub fn to_yaml<P: AsRef<Path>>(attrs: &Attrs, path: P) -> Result<(), anyhow::Error> {
175    let yaml = serde_yaml::to_string(attrs)?;
176    std::fs::write(path, yaml)?;
177    Ok(())
178}
179
180#[cfg(test)]
181mod tests {
182    use std::collections::HashSet;
183    use std::net::Ipv4Addr;
184
185    use indoc::indoc;
186
187    use crate::CONFIG;
188    use crate::ConfigAttr;
189    use crate::attrs::declare_attrs;
190    use crate::from_env;
191    use crate::from_yaml;
192    use crate::to_yaml;
193
194    #[derive(
195        Debug,
196        Clone,
197        Copy,
198        PartialEq,
199        Eq,
200        serde::Serialize,
201        serde::Deserialize
202    )]
203    pub(crate) enum TestMode {
204        Development,
205        Staging,
206        Production,
207    }
208
209    impl std::fmt::Display for TestMode {
210        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211            match self {
212                TestMode::Development => write!(f, "dev"),
213                TestMode::Staging => write!(f, "staging"),
214                TestMode::Production => write!(f, "prod"),
215            }
216        }
217    }
218
219    impl std::str::FromStr for TestMode {
220        type Err = anyhow::Error;
221
222        fn from_str(s: &str) -> Result<Self, Self::Err> {
223            match s {
224                "dev" => Ok(TestMode::Development),
225                "staging" => Ok(TestMode::Staging),
226                "prod" => Ok(TestMode::Production),
227                _ => Err(anyhow::anyhow!("unknown mode: {}", s)),
228            }
229        }
230    }
231
232    impl typeuri::Named for TestMode {
233        fn typename() -> &'static str {
234            "hyperactor_config::tests::TestMode"
235        }
236    }
237
238    impl crate::attrs::AttrValue for TestMode {
239        fn display(&self) -> String {
240            self.to_string()
241        }
242
243        fn parse(s: &str) -> Result<Self, anyhow::Error> {
244            s.parse()
245        }
246    }
247
248    declare_attrs! {
249        @meta(CONFIG = ConfigAttr {
250            env_name: Some("TEST_USIZE_KEY".to_string()),
251            py_name: None,
252        })
253        pub attr USIZE_KEY: usize = 10;
254
255        @meta(CONFIG = ConfigAttr {
256            env_name: Some("TEST_STRING_KEY".to_string()),
257            py_name: None,
258        })
259        pub attr STRING_KEY: String = String::new();
260
261        @meta(CONFIG = ConfigAttr {
262            env_name: Some("TEST_BOOL_KEY".to_string()),
263            py_name: None,
264        })
265        pub attr BOOL_KEY: bool = false;
266
267        @meta(CONFIG = ConfigAttr {
268            env_name: Some("TEST_I64_KEY".to_string()),
269            py_name: None,
270        })
271        pub attr I64_KEY: i64 = -42;
272
273        @meta(CONFIG = ConfigAttr {
274            env_name: Some("TEST_F64_KEY".to_string()),
275            py_name: None,
276        })
277        pub attr F64_KEY: f64 = 3.14;
278
279        @meta(CONFIG = ConfigAttr {
280            env_name: Some("TEST_U32_KEY".to_string()),
281            py_name: Some("test_u32_key".to_string()),
282        })
283        pub attr U32_KEY: u32 = 100;
284
285        @meta(CONFIG = ConfigAttr {
286            env_name: Some("TEST_DURATION_KEY".to_string()),
287            py_name: None,
288        })
289        pub attr DURATION_KEY: std::time::Duration = std::time::Duration::from_mins(1);
290
291        @meta(CONFIG = ConfigAttr {
292            env_name: Some("TEST_MODE_KEY".to_string()),
293            py_name: None,
294        })
295        pub attr MODE_KEY: TestMode = TestMode::Development;
296
297        @meta(CONFIG = ConfigAttr {
298            env_name: Some("TEST_IP_KEY".to_string()),
299            py_name: None,
300        })
301        pub attr IP_KEY: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
302
303        @meta(CONFIG = ConfigAttr {
304            env_name: Some("TEST_SYSTEMTIME_KEY".to_string()),
305            py_name: None,
306        })
307        pub attr SYSTEMTIME_KEY: std::time::SystemTime = std::time::UNIX_EPOCH;
308
309        @meta(CONFIG = ConfigAttr {
310            env_name: None,
311            py_name: Some("test_no_env_key".to_string()),
312        })
313        pub attr NO_ENV_KEY: usize = 999;
314    }
315
316    #[tracing_test::traced_test]
317    #[test]
318    // TODO: OSS: The logs_assert function returned an error: missing log lines: {"# export HYPERACTOR_DEFAULT_ENCODING=serde_multipart", ...}
319    #[cfg_attr(not(fbcode_build), ignore)]
320    fn test_from_env() {
321        // Set environment variables
322        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
323        unsafe { std::env::set_var("TEST_USIZE_KEY", "1024") };
324        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
325        unsafe { std::env::set_var("TEST_STRING_KEY", "world") };
326        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
327        unsafe { std::env::set_var("TEST_BOOL_KEY", "true") };
328        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
329        unsafe { std::env::set_var("TEST_I64_KEY", "-999") };
330        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
331        unsafe { std::env::set_var("TEST_F64_KEY", "2.718") };
332        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
333        unsafe { std::env::set_var("TEST_U32_KEY", "500") };
334        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
335        unsafe { std::env::set_var("TEST_DURATION_KEY", "5s") };
336        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
337        unsafe { std::env::set_var("TEST_MODE_KEY", "prod") };
338        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
339        unsafe { std::env::set_var("TEST_IP_KEY", "192.168.1.1") };
340        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
341        unsafe { std::env::set_var("TEST_SYSTEMTIME_KEY", "2024-01-15T10:30:00Z") };
342
343        let config = from_env();
344
345        // Verify values loaded from environment
346        assert_eq!(config[USIZE_KEY], 1024);
347        assert_eq!(config[STRING_KEY], "world");
348        assert!(config[BOOL_KEY]);
349        assert_eq!(config[I64_KEY], -999);
350        assert_eq!(config[F64_KEY], 2.718);
351        assert_eq!(config[U32_KEY], 500);
352        assert_eq!(config[DURATION_KEY], std::time::Duration::from_secs(5));
353        assert_eq!(config[MODE_KEY], TestMode::Production);
354        assert_eq!(config[IP_KEY], Ipv4Addr::new(192, 168, 1, 1));
355        assert_eq!(
356            config[SYSTEMTIME_KEY],
357            std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1705314600) // 2024-01-15T10:30:00Z
358        );
359
360        // Verify key without env_name uses default
361        assert_eq!(config[NO_ENV_KEY], 999);
362
363        let expected_lines: HashSet<&str> = indoc! {"
364            # export TEST_USIZE_KEY=10
365            export TEST_USIZE_KEY=1024
366            # export TEST_STRING_KEY=''
367            export TEST_STRING_KEY=world
368            # export TEST_BOOL_KEY=0
369            export TEST_BOOL_KEY=1
370            # export TEST_I64_KEY=-42
371            export TEST_I64_KEY=-999
372            # export TEST_F64_KEY=3.14
373            export TEST_F64_KEY=2.718
374            # export TEST_U32_KEY=100
375            export TEST_U32_KEY=500
376            # export TEST_DURATION_KEY=1m
377            export TEST_DURATION_KEY=5s
378            # export TEST_MODE_KEY=dev
379            export TEST_MODE_KEY=prod
380            # export TEST_IP_KEY=127.0.0.1
381            export TEST_IP_KEY=192.168.1.1
382        "}
383        .trim_end()
384        .lines()
385        .collect();
386
387        // For some reason, logs_contain fails to find these lines individually
388        // (possibly to do with the fact that we have newlines in our log entries);
389        // instead, we test it manually.
390        logs_assert(|logged_lines: &[&str]| {
391            let mut expected_lines = expected_lines.clone(); // this is an `Fn` closure
392            for logged in logged_lines {
393                expected_lines.remove(logged);
394            }
395
396            if expected_lines.is_empty() {
397                Ok(())
398            } else {
399                Err(format!("missing log lines: {:?}", expected_lines))
400            }
401        });
402
403        // Clean up
404        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
405        unsafe { std::env::remove_var("TEST_USIZE_KEY") };
406        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
407        unsafe { std::env::remove_var("TEST_STRING_KEY") };
408        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
409        unsafe { std::env::remove_var("TEST_BOOL_KEY") };
410        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
411        unsafe { std::env::remove_var("TEST_I64_KEY") };
412        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
413        unsafe { std::env::remove_var("TEST_F64_KEY") };
414        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
415        unsafe { std::env::remove_var("TEST_U32_KEY") };
416        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
417        unsafe { std::env::remove_var("TEST_DURATION_KEY") };
418        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
419        unsafe { std::env::remove_var("TEST_MODE_KEY") };
420        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
421        unsafe { std::env::remove_var("TEST_IP_KEY") };
422        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
423        unsafe { std::env::remove_var("TEST_SYSTEMTIME_KEY") };
424    }
425
426    #[test]
427    fn test_yaml_round_trip() {
428        let temp_path = std::env::temp_dir().join("test_config.yaml");
429
430        let mut config = crate::Attrs::new();
431        config.set(USIZE_KEY, 2048);
432        config.set(STRING_KEY, "hello_yaml".to_string());
433        config.set(BOOL_KEY, true);
434        config.set(I64_KEY, -123);
435        config.set(F64_KEY, 1.414);
436        config.set(U32_KEY, 777);
437        config.set(DURATION_KEY, std::time::Duration::from_mins(2));
438        config.set(MODE_KEY, TestMode::Staging);
439        config.set(IP_KEY, Ipv4Addr::new(10, 0, 0, 1));
440        config.set(
441            SYSTEMTIME_KEY,
442            std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1609459200),
443        );
444
445        to_yaml(&config, &temp_path).unwrap();
446
447        let yaml_content = std::fs::read_to_string(&temp_path).unwrap();
448
449        eprintln!("YAML content:\n{}", yaml_content);
450
451        assert!(yaml_content.contains("2048"));
452        assert!(yaml_content.contains("hello_yaml"));
453        assert!(yaml_content.contains("Staging"));
454
455        let loaded_config = from_yaml(&temp_path).unwrap();
456
457        assert_eq!(loaded_config[USIZE_KEY], 2048);
458        assert_eq!(loaded_config[STRING_KEY], "hello_yaml");
459        assert!(loaded_config[BOOL_KEY]);
460        assert_eq!(loaded_config[I64_KEY], -123);
461        assert_eq!(loaded_config[F64_KEY], 1.414);
462        assert_eq!(loaded_config[U32_KEY], 777);
463        assert_eq!(
464            loaded_config[DURATION_KEY],
465            std::time::Duration::from_mins(2)
466        );
467        assert_eq!(loaded_config[MODE_KEY], TestMode::Staging);
468        assert_eq!(loaded_config[IP_KEY], Ipv4Addr::new(10, 0, 0, 1));
469        assert_eq!(
470            loaded_config[SYSTEMTIME_KEY],
471            std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1609459200)
472        );
473
474        let _ = std::fs::remove_file(&temp_path);
475    }
476}