hyperactor/
config.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Configuration for Hyperactor.
10//!
11//! This module provides a centralized way to manage configuration settings for Hyperactor.
12//! It uses the attrs system for type-safe, flexible configuration management that supports
13//! environment variables, YAML files, and temporary modifications for tests.
14
15use std::env;
16use std::fs::File;
17use std::io::Read;
18use std::path::Path;
19use std::sync::Arc;
20use std::sync::LazyLock;
21use std::sync::RwLock;
22use std::time::Duration;
23
24use crate::attrs::Attrs;
25use crate::attrs::declare_attrs;
26
27// Declare configuration keys using the new attrs system with defaults
28declare_attrs! {
29    /// Maximum frame length for codec
30    pub attr CODEC_MAX_FRAME_LENGTH: usize = 1024 * 1024 * 1024; // 1GB
31
32    /// Message delivery timeout
33    pub attr MESSAGE_DELIVERY_TIMEOUT: Duration = Duration::from_secs(30);
34
35    /// Timeout used by allocator for stopping a proc.
36    pub attr PROCESS_EXIT_TIMEOUT: Duration = Duration::from_secs(10);
37
38    /// Message acknowledgment interval
39    pub attr MESSAGE_ACK_TIME_INTERVAL: Duration = Duration::from_millis(500);
40
41    /// Number of messages after which to send an acknowledgment
42    pub attr MESSAGE_ACK_EVERY_N_MESSAGES: u64 = 1000;
43
44    /// Maximum buffer size for split port messages
45    pub attr SPLIT_MAX_BUFFER_SIZE: usize = 5;
46
47    /// Timeout used by proc mesh for stopping an actor.
48    pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(1);
49
50    /// Heartbeat interval for remote allocator
51    pub attr REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
52}
53
54/// Load configuration from environment variables
55pub fn from_env() -> Attrs {
56    let mut config = Attrs::new();
57
58    // Load codec max frame length
59    if let Ok(val) = env::var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH") {
60        if let Ok(parsed) = val.parse::<usize>() {
61            config[CODEC_MAX_FRAME_LENGTH] = parsed;
62        }
63    }
64
65    // Load message delivery timeout
66    if let Ok(val) = env::var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS") {
67        if let Ok(parsed) = val.parse::<u64>() {
68            config[MESSAGE_DELIVERY_TIMEOUT] = Duration::from_secs(parsed);
69        }
70    }
71
72    // Load message ack time interval
73    if let Ok(val) = env::var("HYPERACTOR_MESSAGE_ACK_TIME_INTERVAL_MS") {
74        if let Ok(parsed) = val.parse::<u64>() {
75            config[MESSAGE_ACK_TIME_INTERVAL] = Duration::from_millis(parsed);
76        }
77    }
78
79    // Load message ack every n messages
80    if let Ok(val) = env::var("HYPERACTOR_MESSAGE_ACK_EVERY_N_MESSAGES") {
81        if let Ok(parsed) = val.parse::<u64>() {
82            config[MESSAGE_ACK_EVERY_N_MESSAGES] = parsed;
83        }
84    }
85
86    // Load split max buffer size
87    if let Ok(val) = env::var("HYPERACTOR_SPLIT_MAX_BUFFER_SIZE") {
88        if let Ok(parsed) = val.parse::<usize>() {
89            config[SPLIT_MAX_BUFFER_SIZE] = parsed;
90        }
91    }
92
93    // Load remote allocator heartbeat interval
94    if let Ok(val) = env::var("HYPERACTOR_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS") {
95        if let Ok(parsed) = val.parse::<u64>() {
96            config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = Duration::from_secs(parsed);
97        }
98    }
99
100    config
101}
102
103/// Load configuration from a YAML file
104pub fn from_yaml<P: AsRef<Path>>(path: P) -> Result<Attrs, anyhow::Error> {
105    let mut file = File::open(path)?;
106    let mut contents = String::new();
107    file.read_to_string(&mut contents)?;
108    Ok(serde_yaml::from_str(&contents)?)
109}
110
111/// Save configuration to a YAML file
112pub fn to_yaml<P: AsRef<Path>>(attrs: &Attrs, path: P) -> Result<(), anyhow::Error> {
113    let yaml = serde_yaml::to_string(attrs)?;
114    std::fs::write(path, yaml)?;
115    Ok(())
116}
117
118/// Merge with another configuration, with the other taking precedence
119pub fn merge(config: &mut Attrs, other: &Attrs) {
120    if other.contains_key(CODEC_MAX_FRAME_LENGTH) {
121        config[CODEC_MAX_FRAME_LENGTH] = other[CODEC_MAX_FRAME_LENGTH];
122    }
123    if other.contains_key(MESSAGE_DELIVERY_TIMEOUT) {
124        config[MESSAGE_DELIVERY_TIMEOUT] = other[MESSAGE_DELIVERY_TIMEOUT];
125    }
126    if other.contains_key(MESSAGE_ACK_TIME_INTERVAL) {
127        config[MESSAGE_ACK_TIME_INTERVAL] = other[MESSAGE_ACK_TIME_INTERVAL];
128    }
129    if other.contains_key(MESSAGE_ACK_EVERY_N_MESSAGES) {
130        config[MESSAGE_ACK_EVERY_N_MESSAGES] = other[MESSAGE_ACK_EVERY_N_MESSAGES];
131    }
132    if other.contains_key(SPLIT_MAX_BUFFER_SIZE) {
133        config[SPLIT_MAX_BUFFER_SIZE] = other[SPLIT_MAX_BUFFER_SIZE];
134    }
135    if other.contains_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) {
136        config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = other[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL];
137    }
138}
139
140/// Global configuration functions
141///
142/// This module provides global configuration access and testing utilities.
143///
144/// # Testing with Global Configuration
145///
146/// Tests can override global configuration using [`global::lock`]. This ensures that
147/// such tests are serialized (and cannot clobber each other's overrides).
148///
149/// ```ignore rust
150/// #[test]
151/// fn test_my_feature() {
152///     let config = hyperactor::config::global::lock();
153///     let _guard = config.override_key(SOME_CONFIG_KEY, test_value);
154///     // ... test logic here ...
155/// }
156/// ```
157pub mod global {
158    use std::marker::PhantomData;
159
160    use super::*;
161    use crate::attrs::Key;
162
163    /// Global configuration instance, initialized from environment variables.
164    static CONFIG: LazyLock<Arc<RwLock<Attrs>>> =
165        LazyLock::new(|| Arc::new(RwLock::new(from_env())));
166
167    /// Acquire the global configuration lock for testing.
168    ///
169    /// This function returns a ConfigLock that acts as both a write lock guard (preventing
170    /// other tests from modifying global config concurrently) and as the only way to
171    /// create configuration overrides.
172    ///
173    /// Example usage:
174    /// ```ignore rust
175    /// let config = hyperactor::config::global::lock();
176    /// let _guard = config.override_key(CONFIG_KEY, "value");
177    /// // ... test code using the overridden config ...
178    /// ```
179    pub fn lock() -> ConfigLock {
180        static MUTEX: LazyLock<std::sync::Mutex<()>> = LazyLock::new(|| std::sync::Mutex::new(()));
181        ConfigLock {
182            _guard: MUTEX.lock().unwrap(),
183        }
184    }
185
186    /// Initialize the global configuration from environment variables
187    pub fn init_from_env() {
188        let config = from_env();
189        let mut global_config = CONFIG.write().unwrap();
190        *global_config = config;
191    }
192
193    /// Initialize the global configuration from a YAML file
194    pub fn init_from_yaml<P: AsRef<Path>>(path: P) -> Result<(), anyhow::Error> {
195        let config = from_yaml(path)?;
196        let mut global_config = CONFIG.write().unwrap();
197        *global_config = config;
198        Ok(())
199    }
200
201    /// Get a key from the global configuration. Currently only available for Copy types.
202    /// `get` assumes that the key has a default value.
203    pub fn get<
204        T: Send
205            + Sync
206            + Copy
207            + serde::Serialize
208            + serde::de::DeserializeOwned
209            + crate::data::Named
210            + 'static,
211    >(
212        key: Key<T>,
213    ) -> T {
214        *CONFIG.read().unwrap().get(key).unwrap()
215    }
216
217    /// Get the global attrs
218    pub fn attrs() -> Attrs {
219        CONFIG.read().unwrap().clone()
220    }
221
222    /// Reset the global configuration to defaults (for testing only)
223    ///
224    /// Note: This should be called from within with_test_lock() to ensure thread safety.
225    /// Available in all builds to support tests in other crates.
226    pub fn reset_to_defaults() {
227        let mut config = CONFIG.write().unwrap();
228        *config = Attrs::new();
229    }
230
231    /// A guard that holds the global configuration lock and provides override functionality.
232    ///
233    /// This struct acts as both a lock guard (preventing other tests from modifying global config)
234    /// and as the only way to create configuration overrides. Override guards cannot outlive
235    /// this ConfigLock, ensuring proper synchronization.
236    pub struct ConfigLock {
237        _guard: std::sync::MutexGuard<'static, ()>,
238    }
239
240    impl ConfigLock {
241        /// Create a configuration override that will be restored when the guard is dropped.
242        ///
243        /// The returned guard must not outlive this ConfigLock.
244        pub fn override_key<
245            'a,
246            T: Send
247                + Sync
248                + serde::Serialize
249                + serde::de::DeserializeOwned
250                + crate::data::Named
251                + Clone
252                + 'static,
253        >(
254            &'a self,
255            key: crate::attrs::Key<T>,
256            value: T,
257        ) -> ConfigValueGuard<'a, T> {
258            let orig = {
259                let mut config = CONFIG.write().unwrap();
260                let orig = config.take_value(key);
261                config.set(key, value);
262                orig
263            };
264
265            ConfigValueGuard {
266                key,
267                orig,
268                _phantom: PhantomData,
269            }
270        }
271    }
272
273    /// A guard that restores a single configuration value when dropped
274    pub struct ConfigValueGuard<'a, T: 'static> {
275        key: crate::attrs::Key<T>,
276        orig: Option<Box<dyn crate::attrs::SerializableValue>>,
277        // This is here so we can hold onto a 'a lifetime.
278        _phantom: PhantomData<&'a ()>,
279    }
280
281    impl<T: 'static> Drop for ConfigValueGuard<'_, T> {
282        fn drop(&mut self) {
283            let mut config = CONFIG.write().unwrap();
284            if let Some(orig) = self.orig.take() {
285                config.restore_value(self.key, orig);
286            } else {
287                config.remove_value(self.key);
288            }
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_default_config() {
299        let config = Attrs::new();
300        assert_eq!(config[CODEC_MAX_FRAME_LENGTH], 1024 * 1024 * 1024);
301        assert_eq!(config[MESSAGE_DELIVERY_TIMEOUT], Duration::from_secs(30));
302        assert_eq!(
303            config[MESSAGE_ACK_TIME_INTERVAL],
304            Duration::from_millis(500)
305        );
306        assert_eq!(config[MESSAGE_ACK_EVERY_N_MESSAGES], 1000);
307        assert_eq!(config[SPLIT_MAX_BUFFER_SIZE], 5);
308        assert_eq!(
309            config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL],
310            Duration::from_secs(5)
311        );
312    }
313
314    #[test]
315    fn test_from_env() {
316        // Set environment variables
317        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
318        unsafe { std::env::set_var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH", "1024") };
319        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
320        unsafe { std::env::set_var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS", "60") };
321
322        let config = from_env();
323
324        assert_eq!(config[CODEC_MAX_FRAME_LENGTH], 1024);
325        assert_eq!(config[MESSAGE_DELIVERY_TIMEOUT], Duration::from_secs(60));
326        assert_eq!(
327            config[MESSAGE_ACK_TIME_INTERVAL],
328            Duration::from_millis(500)
329        ); // Default value
330
331        // Clean up
332        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
333        unsafe { std::env::remove_var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH") };
334        // SAFETY: TODO: Audit that the environment access only happens in single-threaded code.
335        unsafe { std::env::remove_var("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS") };
336    }
337
338    #[test]
339    fn test_merge() {
340        let mut config1 = Attrs::new();
341        let mut config2 = Attrs::new();
342        config2[CODEC_MAX_FRAME_LENGTH] = 1024;
343        config2[MESSAGE_DELIVERY_TIMEOUT] = Duration::from_secs(60);
344
345        merge(&mut config1, &config2);
346
347        assert_eq!(config1[CODEC_MAX_FRAME_LENGTH], 1024);
348        assert_eq!(config1[MESSAGE_DELIVERY_TIMEOUT], Duration::from_secs(60));
349    }
350
351    #[test]
352    fn test_global_config() {
353        let config = global::lock();
354
355        // Reset global config to defaults to avoid interference from other tests
356        global::reset_to_defaults();
357
358        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
359        {
360            let _guard = config.override_key(CODEC_MAX_FRAME_LENGTH, 1024);
361            assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024);
362        }
363        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
364
365        {
366            let _guard = config.override_key(CODEC_MAX_FRAME_LENGTH, 1024);
367            assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024);
368
369            // The configuration will be automatically restored when _guard goes out of scope
370        }
371
372        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
373    }
374
375    #[test]
376    fn test_defaults() {
377        // Test that empty config now returns defaults via get_or_default
378        let config = Attrs::new();
379
380        // Verify that the config is empty (no values explicitly set)
381        assert!(config.is_empty());
382
383        // But getters should still return the defaults from the keys
384        assert_eq!(config[CODEC_MAX_FRAME_LENGTH], 1024 * 1024 * 1024);
385        assert_eq!(config[MESSAGE_DELIVERY_TIMEOUT], Duration::from_secs(30));
386        assert_eq!(
387            config[MESSAGE_ACK_TIME_INTERVAL],
388            Duration::from_millis(500)
389        );
390        assert_eq!(config[MESSAGE_ACK_EVERY_N_MESSAGES], 1000);
391        assert_eq!(config[SPLIT_MAX_BUFFER_SIZE], 5);
392
393        // Verify the keys have defaults
394        assert!(CODEC_MAX_FRAME_LENGTH.has_default());
395        assert!(MESSAGE_DELIVERY_TIMEOUT.has_default());
396        assert!(MESSAGE_ACK_TIME_INTERVAL.has_default());
397        assert!(MESSAGE_ACK_EVERY_N_MESSAGES.has_default());
398        assert!(SPLIT_MAX_BUFFER_SIZE.has_default());
399
400        // Verify we can get defaults directly from keys
401        assert_eq!(
402            CODEC_MAX_FRAME_LENGTH.default(),
403            Some(&(1024 * 1024 * 1024))
404        );
405        assert_eq!(
406            MESSAGE_DELIVERY_TIMEOUT.default(),
407            Some(&Duration::from_secs(30))
408        );
409        assert_eq!(
410            MESSAGE_ACK_TIME_INTERVAL.default(),
411            Some(&Duration::from_millis(500))
412        );
413        assert_eq!(MESSAGE_ACK_EVERY_N_MESSAGES.default(), Some(&1000));
414        assert_eq!(SPLIT_MAX_BUFFER_SIZE.default(), Some(&5));
415    }
416
417    #[test]
418    fn test_serialization_only_includes_set_values() {
419        let mut config = Attrs::new();
420
421        // Initially empty, serialization should be empty
422        let serialized = serde_json::to_string(&config).unwrap();
423        assert_eq!(serialized, "{}");
424
425        config[CODEC_MAX_FRAME_LENGTH] = 1024;
426
427        let serialized = serde_json::to_string(&config).unwrap();
428        assert!(serialized.contains("codec_max_frame_length"));
429        assert!(!serialized.contains("message_delivery_timeout")); // Default not serialized
430
431        // Deserialize back
432        let restored_config: Attrs = serde_json::from_str(&serialized).unwrap();
433
434        // Custom value should be preserved
435        assert_eq!(restored_config[CODEC_MAX_FRAME_LENGTH], 1024);
436
437        // Defaults should still work for other values
438        assert_eq!(
439            restored_config[MESSAGE_DELIVERY_TIMEOUT],
440            Duration::from_secs(30)
441        );
442    }
443
444    #[test]
445    fn test_overrides() {
446        let config = global::lock();
447
448        // Reset global config to defaults to avoid interference from other tests
449        global::reset_to_defaults();
450
451        // Test the new lock/override API for individual config values
452        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
453        assert_eq!(
454            global::get(MESSAGE_DELIVERY_TIMEOUT),
455            Duration::from_secs(30)
456        );
457
458        // Test single value override
459        {
460            let _guard = config.override_key(CODEC_MAX_FRAME_LENGTH, 2048);
461            assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 2048);
462            assert_eq!(
463                global::get(MESSAGE_DELIVERY_TIMEOUT),
464                Duration::from_secs(30)
465            ); // Unchanged
466        }
467
468        // Values should be restored after guard is dropped
469        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
470
471        // Test multiple overrides
472        {
473            let _guard1 = config.override_key(CODEC_MAX_FRAME_LENGTH, 4096);
474            let _guard2 = config.override_key(MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(60));
475
476            assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 4096);
477            assert_eq!(
478                global::get(MESSAGE_DELIVERY_TIMEOUT),
479                Duration::from_secs(60)
480            );
481        }
482
483        // All values should be restored
484        assert_eq!(global::get(CODEC_MAX_FRAME_LENGTH), 1024 * 1024 * 1024);
485        assert_eq!(
486            global::get(MESSAGE_DELIVERY_TIMEOUT),
487            Duration::from_secs(30)
488        );
489    }
490}