1use std::any::Any;
70use std::collections::HashMap;
71use std::marker::PhantomData;
72use std::ops::Index;
73use std::ops::IndexMut;
74
75use erased_serde::Deserializer as ErasedDeserializer;
76use erased_serde::Serialize as ErasedSerialize;
77use serde::Deserialize;
78use serde::Deserializer;
79use serde::Serialize;
80use serde::Serializer;
81use serde::de::DeserializeOwned;
82use serde::de::MapAccess;
83use serde::de::Visitor;
84use serde::ser::SerializeMap;
85
86use crate::data::Named;
87
88#[doc(hidden)]
92pub struct AttrKeyInfo {
93 pub name: &'static str,
95 pub typehash: fn() -> u64,
97 pub deserialize_erased:
99 fn(&mut dyn ErasedDeserializer) -> Result<Box<dyn SerializableValue>, erased_serde::Error>,
100}
101
102inventory::collect!(AttrKeyInfo);
103
104pub struct Key<T: 'static> {
110 name: &'static str,
111 default_value: Option<&'static T>,
112 _phantom: PhantomData<T>,
113}
114
115impl<T: Named + 'static> Key<T> {
116 pub const fn new(name: &'static str) -> Self {
118 Self {
119 name,
120 default_value: None,
121 _phantom: PhantomData,
122 }
123 }
124
125 pub const fn with_default(name: &'static str, default_value: &'static T) -> Self {
127 Self {
128 name,
129 default_value: Some(default_value),
130 _phantom: PhantomData,
131 }
132 }
133
134 pub fn name(&self) -> &'static str {
136 self.name
137 }
138
139 pub fn default(&self) -> Option<&'static T> {
141 self.default_value
142 }
143
144 pub fn has_default(&self) -> bool {
146 self.default_value.is_some()
147 }
148
149 pub fn typehash(&self) -> u64 {
151 T::typehash()
152 }
153}
154
155impl<T: 'static> Clone for Key<T> {
156 fn clone(&self) -> Self {
157 *self
159 }
160}
161
162impl<T: 'static> Copy for Key<T> {}
163
164impl<T: Send + Sync + Serialize + DeserializeOwned + Named + 'static> Index<Key<T>> for Attrs {
166 type Output = T;
167
168 fn index(&self, key: Key<T>) -> &Self::Output {
169 self.get(key).unwrap()
170 }
171}
172
173impl<T: Send + Sync + Serialize + DeserializeOwned + Named + Clone + 'static> IndexMut<Key<T>>
176 for Attrs
177{
178 fn index_mut(&mut self, key: Key<T>) -> &mut Self::Output {
179 self.get_mut(key).unwrap()
180 }
181}
182
183#[doc(hidden)]
185pub trait SerializableValue: Send + Sync {
186 fn as_any(&self) -> &dyn Any;
188 fn as_any_mut(&mut self) -> &mut dyn Any;
190 fn as_erased_serialize(&self) -> &dyn ErasedSerialize;
192 fn cloned(&self) -> Box<dyn SerializableValue>;
194}
195
196impl<T: Serialize + Send + Sync + Clone + 'static> SerializableValue for T {
197 fn as_any(&self) -> &dyn Any {
198 self
199 }
200
201 fn as_any_mut(&mut self) -> &mut dyn Any {
202 self
203 }
204
205 fn as_erased_serialize(&self) -> &dyn ErasedSerialize {
206 self
207 }
208
209 fn cloned(&self) -> Box<dyn SerializableValue> {
210 Box::new(self.clone())
211 }
212}
213
214pub struct Attrs {
232 values: HashMap<&'static str, Box<dyn SerializableValue>>,
233}
234
235impl Attrs {
236 pub fn new() -> Self {
238 Self {
239 values: HashMap::new(),
240 }
241 }
242
243 pub fn set<T: Send + Sync + Serialize + DeserializeOwned + Named + Clone + 'static>(
245 &mut self,
246 key: Key<T>,
247 value: T,
248 ) {
249 self.values.insert(key.name, Box::new(value));
250 }
251
252 fn maybe_set_from_default<
253 T: Send + Sync + Serialize + DeserializeOwned + Named + Clone + 'static,
254 >(
255 &mut self,
256 key: Key<T>,
257 ) {
258 if self.contains_key(key) {
259 return;
260 }
261 let Some(default) = key.default() else { return };
262 self.set(key, default.clone());
263 }
264
265 pub fn get<T: Send + Sync + Serialize + DeserializeOwned + Named + 'static>(
268 &self,
269 key: Key<T>,
270 ) -> Option<&T> {
271 self.values
272 .get(key.name)
273 .and_then(|value| value.as_any().downcast_ref::<T>())
274 .or_else(|| key.default())
275 }
276
277 pub fn get_mut<T: Send + Sync + Serialize + DeserializeOwned + Named + Clone + 'static>(
280 &mut self,
281 key: Key<T>,
282 ) -> Option<&mut T> {
283 self.maybe_set_from_default(key);
284 self.values
285 .get_mut(key.name)
286 .and_then(|value| value.as_any_mut().downcast_mut::<T>())
287 }
288
289 pub fn remove<T: Send + Sync + Serialize + DeserializeOwned + Named + 'static>(
291 &mut self,
292 key: Key<T>,
293 ) -> bool {
294 self.values.remove(key.name).is_some()
296 }
297
298 pub fn contains_key<T: Send + Sync + Serialize + DeserializeOwned + Named + 'static>(
300 &self,
301 key: Key<T>,
302 ) -> bool {
303 self.values.contains_key(key.name)
304 }
305
306 pub fn len(&self) -> usize {
308 self.values.len()
309 }
310
311 pub fn is_empty(&self) -> bool {
313 self.values.is_empty()
314 }
315
316 pub fn clear(&mut self) {
318 self.values.clear();
319 }
320
321 pub(crate) fn take_value<T: 'static>(
324 &mut self,
325 key: Key<T>,
326 ) -> Option<Box<dyn SerializableValue>> {
327 self.values.remove(key.name)
328 }
329
330 pub(crate) fn restore_value<T: 'static>(
332 &mut self,
333 key: Key<T>,
334 value: Box<dyn SerializableValue>,
335 ) {
336 self.values.insert(key.name, value);
337 }
338
339 pub(crate) fn remove_value<T: 'static>(&mut self, key: Key<T>) -> bool {
341 self.values.remove(key.name).is_some()
342 }
343}
344
345impl Clone for Attrs {
346 fn clone(&self) -> Self {
347 let mut values = HashMap::new();
348 for (key, value) in &self.values {
349 values.insert(*key, value.cloned());
350 }
351 Self { values }
352 }
353}
354
355impl std::fmt::Debug for Attrs {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 let mut debug_map = std::collections::BTreeMap::new();
359 for (key, value) in &self.values {
360 match serde_json::to_string(value.as_erased_serialize()) {
361 Ok(json) => {
362 debug_map.insert(*key, json);
363 }
364 Err(_) => {
365 debug_map.insert(*key, "<serialization error>".to_string());
366 }
367 }
368 }
369
370 f.debug_struct("Attrs").field("values", &debug_map).finish()
371 }
372}
373
374impl Default for Attrs {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380impl Serialize for Attrs {
381 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
382 where
383 S: Serializer,
384 {
385 let mut map = serializer.serialize_map(Some(self.values.len()))?;
386
387 for (key_name, value) in &self.values {
388 map.serialize_entry(key_name, value.as_erased_serialize())?;
389 }
390
391 map.end()
392 }
393}
394
395struct AttrsVisitor;
396
397impl<'de> Visitor<'de> for AttrsVisitor {
398 type Value = Attrs;
399
400 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
401 formatter.write_str("a map of attribute keys to their serialized values")
402 }
403
404 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
405 where
406 M: MapAccess<'de>,
407 {
408 static KEYS_BY_NAME: std::sync::LazyLock<HashMap<&'static str, &'static AttrKeyInfo>> =
409 std::sync::LazyLock::new(|| {
410 inventory::iter::<AttrKeyInfo>()
411 .map(|info| (info.name, info))
412 .collect()
413 });
414 let keys_by_name = &*KEYS_BY_NAME;
415
416 let mut attrs = Attrs::new();
417 while let Some(key_name) = access.next_key::<String>()? {
418 let Some(&key) = keys_by_name.get(key_name.as_str()) else {
419 access.next_value::<serde::de::IgnoredAny>()?;
421 continue;
422 };
423
424 let seed = ValueDeserializeSeed {
426 deserialize_erased: key.deserialize_erased,
427 };
428 match access.next_value_seed(seed) {
429 Ok(value) => {
430 attrs.values.insert(key.name, value);
431 }
432 Err(err) => {
433 return Err(serde::de::Error::custom(format!(
434 "failed to deserialize value for key {}: {}",
435 key_name, err
436 )));
437 }
438 }
439 }
440
441 Ok(attrs)
442 }
443}
444
445struct ValueDeserializeSeed {
447 deserialize_erased:
448 fn(&mut dyn ErasedDeserializer) -> Result<Box<dyn SerializableValue>, erased_serde::Error>,
449}
450
451impl<'de> serde::de::DeserializeSeed<'de> for ValueDeserializeSeed {
452 type Value = Box<dyn SerializableValue>;
453
454 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
455 where
456 D: serde::de::Deserializer<'de>,
457 {
458 let mut erased = <dyn erased_serde::Deserializer>::erase(deserializer);
459 (self.deserialize_erased)(&mut erased).map_err(serde::de::Error::custom)
460 }
461}
462
463impl<'de> Deserialize<'de> for Attrs {
464 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
465 where
466 D: Deserializer<'de>,
467 {
468 deserializer.deserialize_map(AttrsVisitor)
469 }
470}
471
472#[doc(hidden)]
475pub const fn ascii_to_lowercase_const<const N: usize>(input: &str) -> [u8; N] {
476 let bytes = input.as_bytes();
477 let mut result = [0u8; N];
478 let mut i = 0;
479
480 while i < bytes.len() && i < N {
481 let byte = bytes[i];
482 if byte >= b'A' && byte <= b'Z' {
483 result[i] = byte + 32; } else {
485 result[i] = byte;
486 }
487 i += 1;
488 }
489
490 result
491}
492
493#[doc(hidden)]
495#[macro_export]
496macro_rules! const_ascii_lowercase {
497 ($s:expr) => {{
498 const INPUT: &str = $s;
499 const LEN: usize = INPUT.len();
500 const BYTES: [u8; LEN] = $crate::attrs::ascii_to_lowercase_const::<LEN>(INPUT);
501 unsafe { std::str::from_utf8_unchecked(&BYTES) }
503 }};
504}
505
506#[macro_export]
549macro_rules! declare_attrs {
550 ($(
552 $(#[$attr:meta])*
553 $vis:vis attr $name:ident: $type:ty $(= $default:expr)?;
554 )*) => {
555 $(
556 $crate::declare_attrs! { @single $(#[$attr])* ; $vis attr $name: $type $(= $default)?; }
557 )*
558 };
559
560 (@single $(#[$attr:meta])* ; $vis:vis attr $name:ident: $type:ty = $default:expr;) => {
562 $crate::paste! {
564 static [<$name _DEFAULT>]: $type = $default;
565 }
566
567 $(#[$attr])*
568 $vis static $name: $crate::attrs::Key<$type> = {
569 const FULL_NAME: &str = concat!(std::module_path!(), "::", stringify!($name));
570 const LOWER_NAME: &str = $crate::const_ascii_lowercase!(FULL_NAME);
571 $crate::paste! {
572 $crate::attrs::Key::with_default(
573 LOWER_NAME,
574 &[<$name _DEFAULT>]
575 )
576 }
577 };
578
579 $crate::submit! {
581 $crate::attrs::AttrKeyInfo {
582 name: {
583 const FULL_NAME: &str = concat!(std::module_path!(), "::", stringify!($name));
584 $crate::const_ascii_lowercase!(FULL_NAME)
585 },
586 typehash: <$type as $crate::data::Named>::typehash,
587 deserialize_erased: |deserializer| {
588 let value: $type = erased_serde::deserialize(deserializer)?;
589 Ok(Box::new(value) as Box<dyn $crate::attrs::SerializableValue>)
590 },
591 }
592 }
593 };
594
595 (@single $(#[$attr:meta])* ; $vis:vis attr $name:ident: $type:ty;) => {
597 $(#[$attr])*
598 $vis static $name: $crate::attrs::Key<$type> = {
599 const FULL_NAME: &str = concat!(std::module_path!(), "::", stringify!($name));
600 const LOWER_NAME: &str = $crate::const_ascii_lowercase!(FULL_NAME);
601 $crate::attrs::Key::new(LOWER_NAME)
602 };
603
604 $crate::submit! {
606 $crate::attrs::AttrKeyInfo {
607 name: {
608 const FULL_NAME: &str = concat!(std::module_path!(), "::", stringify!($name));
609 $crate::const_ascii_lowercase!(FULL_NAME)
610 },
611 typehash: <$type as $crate::data::Named>::typehash,
612 deserialize_erased: |deserializer| {
613 let value: $type = erased_serde::deserialize(deserializer)?;
614 Ok(Box::new(value) as Box<dyn $crate::attrs::SerializableValue>)
615 },
616 }
617 }
618 };
619}
620
621pub use declare_attrs;
622
623#[cfg(test)]
624mod tests {
625 use std::time::Duration;
626
627 use super::*;
628
629 declare_attrs! {
630 attr TEST_TIMEOUT: Duration;
631 attr TEST_COUNT: u32;
632 attr TEST_NAME: String;
633 }
634
635 #[test]
636 fn test_basic_operations() {
637 let mut attrs = Attrs::new();
638
639 attrs.set(TEST_TIMEOUT, Duration::from_secs(5));
641 attrs.set(TEST_COUNT, 42u32);
642 attrs.set(TEST_NAME, "test".to_string());
643
644 assert_eq!(attrs.get(TEST_TIMEOUT), Some(&Duration::from_secs(5)));
645 assert_eq!(attrs.get(TEST_COUNT), Some(&42u32));
646 assert_eq!(attrs.get(TEST_NAME), Some(&"test".to_string()));
647
648 assert!(attrs.contains_key(TEST_TIMEOUT));
650 assert!(attrs.contains_key(TEST_COUNT));
651 assert!(attrs.contains_key(TEST_NAME));
652
653 assert_eq!(attrs.len(), 3);
655 assert!(!attrs.is_empty());
656 }
657
658 #[test]
659 fn test_get_mut() {
660 let mut attrs = Attrs::new();
661 attrs.set(TEST_COUNT, 10u32);
662
663 if let Some(count) = attrs.get_mut(TEST_COUNT) {
664 *count += 5;
665 }
666
667 assert_eq!(attrs.get(TEST_COUNT), Some(&15u32));
668 }
669
670 #[test]
671 fn test_remove() {
672 let mut attrs = Attrs::new();
673 attrs.set(TEST_COUNT, 42u32);
674
675 let removed = attrs.remove(TEST_COUNT);
676 assert!(removed);
677 assert_eq!(attrs.get(TEST_COUNT), None);
678 assert!(!attrs.contains_key(TEST_COUNT));
679 }
680
681 #[test]
682 fn test_clear() {
683 let mut attrs = Attrs::new();
684 attrs.set(TEST_TIMEOUT, Duration::from_secs(1));
685 attrs.set(TEST_COUNT, 42u32);
686
687 attrs.clear();
688 assert!(attrs.is_empty());
689 assert_eq!(attrs.len(), 0);
690 }
691
692 #[test]
693 fn test_key_properties() {
694 assert_eq!(
695 TEST_TIMEOUT.name(),
696 "hyperactor::attrs::tests::test_timeout"
697 );
698 }
699
700 #[test]
701 fn test_serialization() {
702 let mut attrs = Attrs::new();
703 attrs.set(TEST_TIMEOUT, Duration::from_secs(5));
704 attrs.set(TEST_COUNT, 42u32);
705 attrs.set(TEST_NAME, "test".to_string());
706
707 let serialized = serde_json::to_string(&attrs).expect("Failed to serialize");
709
710 assert!(serialized.contains("hyperactor::attrs::tests::test_timeout"));
712 assert!(serialized.contains("hyperactor::attrs::tests::test_count"));
713 assert!(serialized.contains("hyperactor::attrs::tests::test_name"));
714 }
715
716 #[test]
717 fn test_deserialization() {
718 let mut original_attrs = Attrs::new();
720 original_attrs.set(TEST_TIMEOUT, Duration::from_secs(5));
721 original_attrs.set(TEST_COUNT, 42u32);
722 original_attrs.set(TEST_NAME, "test".to_string());
723
724 let serialized = serde_json::to_string(&original_attrs).expect("Failed to serialize");
726
727 let deserialized_attrs: Attrs =
729 serde_json::from_str(&serialized).expect("Failed to deserialize");
730
731 assert_eq!(
733 deserialized_attrs.get(TEST_TIMEOUT),
734 Some(&Duration::from_secs(5))
735 );
736 assert_eq!(deserialized_attrs.get(TEST_COUNT), Some(&42u32));
737 assert_eq!(deserialized_attrs.get(TEST_NAME), Some(&"test".to_string()));
738 }
739
740 #[test]
741 fn test_roundtrip_serialization() {
742 let mut original = Attrs::new();
744 original.set(TEST_TIMEOUT, Duration::from_secs(10));
745 original.set(TEST_COUNT, 5u32);
746 original.set(TEST_NAME, "test-service".to_string());
747
748 let serialized = serde_json::to_string(&original).unwrap();
750
751 let deserialized: Attrs = serde_json::from_str(&serialized).unwrap();
753
754 assert_eq!(
756 deserialized.get(TEST_TIMEOUT),
757 Some(&Duration::from_secs(10))
758 );
759 assert_eq!(deserialized.get(TEST_COUNT), Some(&5u32));
760 assert_eq!(
761 deserialized.get(TEST_NAME),
762 Some(&"test-service".to_string())
763 );
764 }
765
766 #[test]
767 fn test_empty_attrs_serialization() {
768 let attrs = Attrs::new();
769 let serialized = serde_json::to_string(&attrs).unwrap();
770
771 assert_eq!(serialized, "{}");
773
774 let deserialized: Attrs = serde_json::from_str(&serialized).unwrap();
775
776 assert!(deserialized.is_empty());
777 }
778
779 #[test]
780 fn test_format_independence() {
781 let mut attrs = Attrs::new();
783 attrs.set(TEST_COUNT, 42u32);
784 attrs.set(TEST_NAME, "test".to_string());
785
786 let json_output = serde_json::to_string(&attrs).unwrap();
788 let yaml_output = serde_yaml::to_string(&attrs).unwrap();
789
790 assert!(json_output.contains(":"));
792 assert!(json_output.contains("\""));
793
794 assert!(json_output.contains("42"));
796 assert!(!json_output.contains("\"42\""));
797
798 assert!(yaml_output.contains(":"));
800 assert!(yaml_output.contains("42"));
801
802 assert!(!yaml_output.contains("\"42\""));
804
805 assert_ne!(json_output, yaml_output);
807
808 let from_json: Attrs = serde_json::from_str(&json_output).unwrap();
810 let from_yaml: Attrs = serde_yaml::from_str(&yaml_output).unwrap();
811
812 assert_eq!(from_json.get(TEST_COUNT), Some(&42u32));
813 assert_eq!(from_yaml.get(TEST_COUNT), Some(&42u32));
814 assert_eq!(from_json.get(TEST_NAME), Some(&"test".to_string()));
815 assert_eq!(from_yaml.get(TEST_NAME), Some(&"test".to_string()));
816 }
817
818 #[test]
819 fn test_clone() {
820 let mut original = Attrs::new();
822 original.set(TEST_COUNT, 42u32);
823 original.set(TEST_NAME, "test".to_string());
824 original.set(TEST_TIMEOUT, std::time::Duration::from_secs(10));
825
826 let cloned = original.clone();
828
829 assert_eq!(cloned.get(TEST_COUNT), Some(&42u32));
831 assert_eq!(cloned.get(TEST_NAME), Some(&"test".to_string()));
832 assert_eq!(
833 cloned.get(TEST_TIMEOUT),
834 Some(&std::time::Duration::from_secs(10))
835 );
836
837 original.set(TEST_COUNT, 100u32);
839 assert_eq!(original.get(TEST_COUNT), Some(&100u32));
840 assert_eq!(cloned.get(TEST_COUNT), Some(&42u32)); let mut cloned_mut = cloned.clone();
844 cloned_mut.set(TEST_NAME, "modified".to_string());
845 assert_eq!(cloned_mut.get(TEST_NAME), Some(&"modified".to_string()));
846 assert_eq!(original.get(TEST_NAME), Some(&"test".to_string())); }
848
849 #[test]
850 fn test_debug_with_json() {
851 let mut attrs = Attrs::new();
852 attrs.set(TEST_COUNT, 42u32);
853 attrs.set(TEST_NAME, "test".to_string());
854
855 let debug_output = format!("{:?}", attrs);
857
858 assert!(debug_output.contains("Attrs"));
860
861 assert!(debug_output.contains("42"));
863
864 assert!(debug_output.contains("hyperactor::attrs::tests::test_count"));
866 assert!(debug_output.contains("hyperactor::attrs::tests::test_name"));
867
868 assert!(debug_output.contains("test"));
871 }
872
873 declare_attrs! {
874 attr TIMEOUT_WITH_DEFAULT: Duration = Duration::from_secs(10);
876
877 pub(crate) attr CRATE_LOCAL_ATTR: String;
879 }
880
881 #[test]
882 fn test_defaults() {
883 assert!(TIMEOUT_WITH_DEFAULT.has_default());
884 assert!(!CRATE_LOCAL_ATTR.has_default());
885
886 assert_eq!(
887 Attrs::new().get(TIMEOUT_WITH_DEFAULT),
888 Some(&Duration::from_secs(10))
889 );
890 }
891
892 #[test]
893 fn test_indexing() {
894 let mut attrs = Attrs::new();
895
896 assert_eq!(attrs[TIMEOUT_WITH_DEFAULT], Duration::from_secs(10));
897 attrs[TIMEOUT_WITH_DEFAULT] = Duration::from_secs(100);
898 assert_eq!(attrs[TIMEOUT_WITH_DEFAULT], Duration::from_secs(100));
899
900 attrs.set(CRATE_LOCAL_ATTR, "test".to_string());
901 assert_eq!(attrs[CRATE_LOCAL_ATTR], "test".to_string());
902 }
903}