1use std::env;
22use std::fs::File;
23use std::io::Read;
24use std::path::Path;
25
26use serde::Deserialize;
27use serde::Serialize;
28use shell_quote::QuoteRefExt;
29use typeuri::Named;
30
31pub mod attrs;
32pub mod global;
33
34pub use attrs::AttrKeyInfo;
36pub use attrs::AttrValue;
37pub use attrs::Attrs;
38pub use attrs::Key;
39pub use attrs::SerializableValue;
40pub use hyperactor_config_macros::AttrValue;
42pub use inventory::submit;
44pub use paste::paste;
45#[doc(hidden)]
47pub use typeuri;
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
64pub struct ConfigAttr {
65 pub env_name: Option<String>,
67
68 pub py_name: Option<String>,
71}
72
73impl Named for ConfigAttr {
74 fn typename() -> &'static str {
75 "hyperactor_config::ConfigAttr"
76 }
77}
78
79impl AttrValue for ConfigAttr {
80 fn display(&self) -> String {
81 serde_json::to_string(self).unwrap_or_else(|_| "<invalid ConfigAttr>".into())
82 }
83 fn parse(s: &str) -> Result<Self, anyhow::Error> {
84 Ok(serde_json::from_str(s)?)
85 }
86}
87
88declare_attrs! {
90 pub attr CONFIG: ConfigAttr;
102}
103
104pub fn from_env() -> Attrs {
106 let mut config = Attrs::new();
107 let mut output = String::new();
108
109 fn export(env_var: &str, value: Option<&dyn SerializableValue>) -> String {
110 let env_var: String = env_var.quoted(shell_quote::Bash);
111 let value: String = value
112 .map_or("".to_string(), SerializableValue::display)
113 .quoted(shell_quote::Bash);
114 format!("export {}={}\n", env_var, value)
115 }
116
117 for key in inventory::iter::<AttrKeyInfo>() {
118 let Some(cfg_meta) = key.meta.get(CONFIG) else {
123 continue;
124 };
125 let Some(env_var) = cfg_meta.env_name.as_deref() else {
126 continue;
127 };
128
129 let Ok(val) = env::var(env_var) else {
130 output.push_str("# ");
132 output.push_str(&export(env_var, key.default));
133 continue;
134 };
135
136 match (key.parse)(&val) {
137 Err(e) => {
138 tracing::error!(
139 "failed to override config key {} from value \"{}\" in ${}: {})",
140 key.name,
141 val,
142 env_var,
143 e
144 );
145 output.push_str("# ");
146 output.push_str(&export(env_var, key.default));
147 }
148 Ok(parsed) => {
149 output.push_str("# ");
150 output.push_str(&export(env_var, key.default));
151 output.push_str(&export(env_var, Some(parsed.as_ref())));
152 config.insert_value_by_name_unchecked(key.name, parsed);
153 }
154 }
155 }
156
157 tracing::info!(
158 "loaded configuration from environment:\n{}",
159 output.trim_end()
160 );
161
162 config
163}
164
165pub fn from_yaml<P: AsRef<Path>>(path: P) -> Result<Attrs, anyhow::Error> {
167 let mut file = File::open(path)?;
168 let mut contents = String::new();
169 file.read_to_string(&mut contents)?;
170 Ok(serde_yaml::from_str(&contents)?)
171}
172
173pub fn to_yaml<P: AsRef<Path>>(attrs: &Attrs, path: P) -> Result<(), anyhow::Error> {
175 let yaml = serde_yaml::to_string(attrs)?;
176 std::fs::write(path, yaml)?;
177 Ok(())
178}
179
180#[cfg(test)]
181mod tests {
182 use std::collections::HashSet;
183 use std::net::Ipv4Addr;
184
185 use indoc::indoc;
186
187 use crate::CONFIG;
188 use crate::ConfigAttr;
189 use crate::attrs::declare_attrs;
190 use crate::from_env;
191 use crate::from_yaml;
192 use crate::to_yaml;
193
194 #[derive(
195 Debug,
196 Clone,
197 Copy,
198 PartialEq,
199 Eq,
200 serde::Serialize,
201 serde::Deserialize
202 )]
203 pub(crate) enum TestMode {
204 Development,
205 Staging,
206 Production,
207 }
208
209 impl std::fmt::Display for TestMode {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 match self {
212 TestMode::Development => write!(f, "dev"),
213 TestMode::Staging => write!(f, "staging"),
214 TestMode::Production => write!(f, "prod"),
215 }
216 }
217 }
218
219 impl std::str::FromStr for TestMode {
220 type Err = anyhow::Error;
221
222 fn from_str(s: &str) -> Result<Self, Self::Err> {
223 match s {
224 "dev" => Ok(TestMode::Development),
225 "staging" => Ok(TestMode::Staging),
226 "prod" => Ok(TestMode::Production),
227 _ => Err(anyhow::anyhow!("unknown mode: {}", s)),
228 }
229 }
230 }
231
232 impl typeuri::Named for TestMode {
233 fn typename() -> &'static str {
234 "hyperactor_config::tests::TestMode"
235 }
236 }
237
238 impl crate::attrs::AttrValue for TestMode {
239 fn display(&self) -> String {
240 self.to_string()
241 }
242
243 fn parse(s: &str) -> Result<Self, anyhow::Error> {
244 s.parse()
245 }
246 }
247
248 declare_attrs! {
249 @meta(CONFIG = ConfigAttr {
250 env_name: Some("TEST_USIZE_KEY".to_string()),
251 py_name: None,
252 })
253 pub attr USIZE_KEY: usize = 10;
254
255 @meta(CONFIG = ConfigAttr {
256 env_name: Some("TEST_STRING_KEY".to_string()),
257 py_name: None,
258 })
259 pub attr STRING_KEY: String = String::new();
260
261 @meta(CONFIG = ConfigAttr {
262 env_name: Some("TEST_BOOL_KEY".to_string()),
263 py_name: None,
264 })
265 pub attr BOOL_KEY: bool = false;
266
267 @meta(CONFIG = ConfigAttr {
268 env_name: Some("TEST_I64_KEY".to_string()),
269 py_name: None,
270 })
271 pub attr I64_KEY: i64 = -42;
272
273 @meta(CONFIG = ConfigAttr {
274 env_name: Some("TEST_F64_KEY".to_string()),
275 py_name: None,
276 })
277 pub attr F64_KEY: f64 = 3.14;
278
279 @meta(CONFIG = ConfigAttr {
280 env_name: Some("TEST_U32_KEY".to_string()),
281 py_name: Some("test_u32_key".to_string()),
282 })
283 pub attr U32_KEY: u32 = 100;
284
285 @meta(CONFIG = ConfigAttr {
286 env_name: Some("TEST_DURATION_KEY".to_string()),
287 py_name: None,
288 })
289 pub attr DURATION_KEY: std::time::Duration = std::time::Duration::from_mins(1);
290
291 @meta(CONFIG = ConfigAttr {
292 env_name: Some("TEST_MODE_KEY".to_string()),
293 py_name: None,
294 })
295 pub attr MODE_KEY: TestMode = TestMode::Development;
296
297 @meta(CONFIG = ConfigAttr {
298 env_name: Some("TEST_IP_KEY".to_string()),
299 py_name: None,
300 })
301 pub attr IP_KEY: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
302
303 @meta(CONFIG = ConfigAttr {
304 env_name: Some("TEST_SYSTEMTIME_KEY".to_string()),
305 py_name: None,
306 })
307 pub attr SYSTEMTIME_KEY: std::time::SystemTime = std::time::UNIX_EPOCH;
308
309 @meta(CONFIG = ConfigAttr {
310 env_name: None,
311 py_name: Some("test_no_env_key".to_string()),
312 })
313 pub attr NO_ENV_KEY: usize = 999;
314 }
315
316 #[tracing_test::traced_test]
317 #[test]
318 #[cfg_attr(not(fbcode_build), ignore)]
320 fn test_from_env() {
321 unsafe { std::env::set_var("TEST_USIZE_KEY", "1024") };
324 unsafe { std::env::set_var("TEST_STRING_KEY", "world") };
326 unsafe { std::env::set_var("TEST_BOOL_KEY", "true") };
328 unsafe { std::env::set_var("TEST_I64_KEY", "-999") };
330 unsafe { std::env::set_var("TEST_F64_KEY", "2.718") };
332 unsafe { std::env::set_var("TEST_U32_KEY", "500") };
334 unsafe { std::env::set_var("TEST_DURATION_KEY", "5s") };
336 unsafe { std::env::set_var("TEST_MODE_KEY", "prod") };
338 unsafe { std::env::set_var("TEST_IP_KEY", "192.168.1.1") };
340 unsafe { std::env::set_var("TEST_SYSTEMTIME_KEY", "2024-01-15T10:30:00Z") };
342
343 let config = from_env();
344
345 assert_eq!(config[USIZE_KEY], 1024);
347 assert_eq!(config[STRING_KEY], "world");
348 assert!(config[BOOL_KEY]);
349 assert_eq!(config[I64_KEY], -999);
350 assert_eq!(config[F64_KEY], 2.718);
351 assert_eq!(config[U32_KEY], 500);
352 assert_eq!(config[DURATION_KEY], std::time::Duration::from_secs(5));
353 assert_eq!(config[MODE_KEY], TestMode::Production);
354 assert_eq!(config[IP_KEY], Ipv4Addr::new(192, 168, 1, 1));
355 assert_eq!(
356 config[SYSTEMTIME_KEY],
357 std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1705314600) );
359
360 assert_eq!(config[NO_ENV_KEY], 999);
362
363 let expected_lines: HashSet<&str> = indoc! {"
364 # export TEST_USIZE_KEY=10
365 export TEST_USIZE_KEY=1024
366 # export TEST_STRING_KEY=''
367 export TEST_STRING_KEY=world
368 # export TEST_BOOL_KEY=0
369 export TEST_BOOL_KEY=1
370 # export TEST_I64_KEY=-42
371 export TEST_I64_KEY=-999
372 # export TEST_F64_KEY=3.14
373 export TEST_F64_KEY=2.718
374 # export TEST_U32_KEY=100
375 export TEST_U32_KEY=500
376 # export TEST_DURATION_KEY=1m
377 export TEST_DURATION_KEY=5s
378 # export TEST_MODE_KEY=dev
379 export TEST_MODE_KEY=prod
380 # export TEST_IP_KEY=127.0.0.1
381 export TEST_IP_KEY=192.168.1.1
382 "}
383 .trim_end()
384 .lines()
385 .collect();
386
387 logs_assert(|logged_lines: &[&str]| {
391 let mut expected_lines = expected_lines.clone(); for logged in logged_lines {
393 expected_lines.remove(logged);
394 }
395
396 if expected_lines.is_empty() {
397 Ok(())
398 } else {
399 Err(format!("missing log lines: {:?}", expected_lines))
400 }
401 });
402
403 unsafe { std::env::remove_var("TEST_USIZE_KEY") };
406 unsafe { std::env::remove_var("TEST_STRING_KEY") };
408 unsafe { std::env::remove_var("TEST_BOOL_KEY") };
410 unsafe { std::env::remove_var("TEST_I64_KEY") };
412 unsafe { std::env::remove_var("TEST_F64_KEY") };
414 unsafe { std::env::remove_var("TEST_U32_KEY") };
416 unsafe { std::env::remove_var("TEST_DURATION_KEY") };
418 unsafe { std::env::remove_var("TEST_MODE_KEY") };
420 unsafe { std::env::remove_var("TEST_IP_KEY") };
422 unsafe { std::env::remove_var("TEST_SYSTEMTIME_KEY") };
424 }
425
426 #[test]
427 fn test_yaml_round_trip() {
428 let temp_path = std::env::temp_dir().join("test_config.yaml");
429
430 let mut config = crate::Attrs::new();
431 config.set(USIZE_KEY, 2048);
432 config.set(STRING_KEY, "hello_yaml".to_string());
433 config.set(BOOL_KEY, true);
434 config.set(I64_KEY, -123);
435 config.set(F64_KEY, 1.414);
436 config.set(U32_KEY, 777);
437 config.set(DURATION_KEY, std::time::Duration::from_mins(2));
438 config.set(MODE_KEY, TestMode::Staging);
439 config.set(IP_KEY, Ipv4Addr::new(10, 0, 0, 1));
440 config.set(
441 SYSTEMTIME_KEY,
442 std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1609459200),
443 );
444
445 to_yaml(&config, &temp_path).unwrap();
446
447 let yaml_content = std::fs::read_to_string(&temp_path).unwrap();
448
449 eprintln!("YAML content:\n{}", yaml_content);
450
451 assert!(yaml_content.contains("2048"));
452 assert!(yaml_content.contains("hello_yaml"));
453 assert!(yaml_content.contains("Staging"));
454
455 let loaded_config = from_yaml(&temp_path).unwrap();
456
457 assert_eq!(loaded_config[USIZE_KEY], 2048);
458 assert_eq!(loaded_config[STRING_KEY], "hello_yaml");
459 assert!(loaded_config[BOOL_KEY]);
460 assert_eq!(loaded_config[I64_KEY], -123);
461 assert_eq!(loaded_config[F64_KEY], 1.414);
462 assert_eq!(loaded_config[U32_KEY], 777);
463 assert_eq!(
464 loaded_config[DURATION_KEY],
465 std::time::Duration::from_mins(2)
466 );
467 assert_eq!(loaded_config[MODE_KEY], TestMode::Staging);
468 assert_eq!(loaded_config[IP_KEY], Ipv4Addr::new(10, 0, 0, 1));
469 assert_eq!(
470 loaded_config[SYSTEMTIME_KEY],
471 std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1609459200)
472 );
473
474 let _ = std::fs::remove_file(&temp_path);
475 }
476}