hyperactor_config/
flattrs.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Flat attribute storage for message headers.
10//!
11//! This module provides `Flattrs`, a type optimized for message passing scenarios
12//! where headers are often forwarded without inspection. It uses a single contiguous
13//! buffer with inline entry lengths for efficient zero-copy passthrough.
14//!
15//! # Wire Format
16//!
17//! ```text
18//! ┌─────────────┬──────────────────────────────────────────────────┐
19//! │ num_entries │ entries...                                       │
20//! │ (u16)       │ (key_hash: u64, len: u32, value: [u8])...          │
21//! └─────────────┴──────────────────────────────────────────────────┘
22//! ```
23//!
24//! Each entry is self-describing with its length inline, allowing linear scan
25//! without a separate index section.
26//!
27//! - Key IDs are FNV-1a hashes of key names (stable, computed at compile time)
28//! - Uses linear search (optimal for typical small header counts of 2-5 entries)
29//!
30//! # Design Benefits
31//!
32//! - **Zero-copy passthrough**: Forward the entire buffer without parsing
33//! - **Zero-copy serialization**: Uses `Part` for zero-copy through multipart codec
34//! - **Simple implementation**: No mode switching, just a single buffer
35//! - **Compact wire format**: u64 key IDs instead of string names
36//!
37//! # Example
38//!
39//! ```ignore
40//! use hyperactor_config::flattrs::Flattrs;
41//! use hyperactor_config::attrs::declare_attrs;
42//!
43//! declare_attrs! {
44//!     pub attr TIMESTAMP: u64;
45//!     pub attr REQUEST_ID: String;
46//! }
47//!
48//! let mut headers = Flattrs::new();
49//! headers.set(TIMESTAMP, 1234567890u64);
50//! headers.set(REQUEST_ID, "req-123".to_string());
51//!
52//! // Lazy deserialization on access
53//! let ts: Option<u64> = headers.get(TIMESTAMP);
54//! ```
55
56use bytes::Bytes;
57use bytes::BytesMut;
58use serde::Deserialize;
59use serde::Serialize;
60use serde::de::DeserializeOwned;
61use serde_multipart::Part;
62
63use crate::attrs::AttrValue;
64use crate::attrs::Attrs;
65use crate::attrs::Key;
66
67/// Header size: num_entries as u16
68const HEADER_SIZE: usize = 2;
69
70/// Entry header size: key_hash (u64) + len (u32) = 12 bytes
71const ENTRY_HEADER_SIZE: usize = 12;
72
73/// Flat attribute storage for message headers.
74///
75/// Uses a single contiguous buffer with inline entry lengths.
76/// Each entry is `[key_hash: u64][len: u32][value: bytes]`.
77/// Linear scan is used for lookup, which is optimal for small N.
78#[derive(Clone, Default)]
79pub struct Flattrs {
80    /// The buffer containing all entries.
81    /// Format: [num_entries: u16][entries...]
82    /// Each entry: [key_hash: u64][len: u32][value: bytes]
83    buffer: BytesMut,
84}
85
86impl Flattrs {
87    /// Create a new empty Flattrs.
88    pub fn new() -> Self {
89        let mut buffer = BytesMut::with_capacity(HEADER_SIZE);
90        buffer.extend_from_slice(&0u16.to_le_bytes());
91        Self { buffer }
92    }
93
94    /// Create from a `Part`.
95    pub fn from_part(part: Part) -> Self {
96        Self {
97            buffer: BytesMut::from(part.into_bytes().as_ref()),
98        }
99    }
100
101    /// Convert to wire format for transmission.
102    ///
103    /// Returns a [`Part`] for zero-copy serialization through the multipart codec.
104    pub fn to_part(&self) -> Part {
105        Part::from(Bytes::copy_from_slice(&self.buffer))
106    }
107
108    /// Serialize a value and store it.
109    ///
110    /// If the key already exists:
111    /// - Same size value: overwrite in place (no shifting)
112    /// - Different size: remove old entry and append new one
113    pub fn set<T: Serialize>(&mut self, key: Key<T>, value: T) {
114        let key_hash = key.key_hash();
115        let serialized = bincode::serialize(&value).expect("serialization failed");
116
117        // If key exists, either overwrite in place or compact + append
118        if let Some((offset, old_len)) = self.find_entry_location(key_hash) {
119            if serialized.len() == old_len {
120                // Same size - overwrite value in place
121                let value_start = offset + ENTRY_HEADER_SIZE;
122                self.buffer[value_start..value_start + old_len].copy_from_slice(&serialized);
123                return;
124            }
125
126            // Different size - remove old entry by shifting
127            let entry_size = ENTRY_HEADER_SIZE + old_len;
128            let end = offset + entry_size;
129
130            if end < self.buffer.len() {
131                self.buffer.copy_within(end.., offset);
132            }
133            self.buffer.truncate(self.buffer.len() - entry_size);
134
135            // Decrement entry count since `self.append_entry` will increment it
136            let count = self.len();
137            self.buffer[0..2].copy_from_slice(&((count - 1) as u16).to_le_bytes());
138        }
139
140        self.append_entry(key_hash, &serialized);
141    }
142
143    /// Get a value, deserializing from the buffer.
144    ///
145    /// Uses linear search which is optimal for the typical small
146    /// number of headers (2-5 entries).
147    pub fn get<T: AttrValue + DeserializeOwned>(&self, key: Key<T>) -> Option<T> {
148        let key_hash = key.key_hash();
149        let value_bytes = self.find_value(key_hash)?;
150        bincode::deserialize(value_bytes).ok()
151    }
152
153    /// Check if a key exists.
154    #[inline]
155    pub fn contains_key<T>(&self, key: Key<T>) -> bool {
156        self.find_value(key.key_hash()).is_some()
157    }
158
159    /// Returns true if empty.
160    #[inline]
161    pub fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164
165    /// Returns the number of entries.
166    #[inline]
167    pub fn len(&self) -> usize {
168        if self.buffer.len() < HEADER_SIZE {
169            return 0;
170        }
171        u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize
172    }
173
174    /// Convert from an existing Attrs by serializing all values.
175    pub fn from_attrs(attrs: &Attrs) -> Self {
176        let mut flattrs = Self::new();
177        for (name, value) in attrs.iter() {
178            let key_hash = crate::attrs::fnv1a_hash(name.as_bytes());
179            let serialized = value.serialize_bincode();
180            flattrs.append_entry(key_hash, &serialized);
181        }
182        flattrs
183    }
184
185    /// Find the value bytes for a given key_hash by scanning entries.
186    fn find_value(&self, key_hash: u64) -> Option<&[u8]> {
187        if self.buffer.len() < HEADER_SIZE {
188            return None;
189        }
190
191        let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
192        let mut offset = HEADER_SIZE;
193
194        for _ in 0..num_entries {
195            if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
196                return None;
197            }
198
199            let entry_key_hash =
200                u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
201            let entry_len = u32::from_le_bytes(
202                self.buffer[offset + 8..offset + 12]
203                    .try_into()
204                    .unwrap_or([0; 4]),
205            ) as usize;
206
207            let value_start = offset + ENTRY_HEADER_SIZE;
208            let value_end = value_start + entry_len;
209
210            if value_end > self.buffer.len() {
211                return None;
212            }
213
214            if entry_key_hash == key_hash {
215                return Some(&self.buffer[value_start..value_end]);
216            }
217
218            offset = value_end;
219        }
220
221        None
222    }
223
224    /// Find the location (offset, value_len) of an entry by key_hash.
225    fn find_entry_location(&self, key_hash: u64) -> Option<(usize, usize)> {
226        if self.buffer.len() < HEADER_SIZE {
227            return None;
228        }
229
230        let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
231        let mut offset = HEADER_SIZE;
232
233        for _ in 0..num_entries {
234            if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
235                return None;
236            }
237
238            let entry_key_hash =
239                u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
240            let entry_len = u32::from_le_bytes(
241                self.buffer[offset + 8..offset + 12]
242                    .try_into()
243                    .unwrap_or([0; 4]),
244            ) as usize;
245
246            if entry_key_hash == key_hash {
247                return Some((offset, entry_len));
248            }
249
250            offset += ENTRY_HEADER_SIZE + entry_len;
251        }
252
253        None
254    }
255
256    /// Append a new entry to the buffer.
257    fn append_entry(&mut self, key_hash: u64, value: &[u8]) {
258        let len = self.len();
259        self.buffer[0..2].copy_from_slice(&((len + 1) as u16).to_le_bytes());
260
261        // Append entry: key_hash + len + value
262        self.buffer.extend_from_slice(&key_hash.to_le_bytes());
263        self.buffer
264            .extend_from_slice(&(value.len() as u32).to_le_bytes());
265        self.buffer.extend_from_slice(value);
266    }
267}
268
269impl From<Attrs> for Flattrs {
270    fn from(attrs: Attrs) -> Self {
271        Self::from_attrs(&attrs)
272    }
273}
274
275impl From<&Attrs> for Flattrs {
276    fn from(attrs: &Attrs) -> Self {
277        Self::from_attrs(attrs)
278    }
279}
280
281impl std::fmt::Debug for Flattrs {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("Flattrs").field("len", &self.len()).finish()
284    }
285}
286
287impl std::fmt::Display for Flattrs {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        use crate::attrs::lookup_key_info;
290
291        let mut offset = HEADER_SIZE;
292        let mut first = true;
293
294        for _ in 0..self.len() {
295            let key_hash = u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap());
296            let entry_len =
297                u32::from_le_bytes(self.buffer[offset + 8..offset + 12].try_into().unwrap())
298                    as usize;
299            let value_bytes = &self.buffer[offset + ENTRY_HEADER_SIZE..][..entry_len];
300
301            if !first {
302                write!(f, ",")?;
303            }
304            first = false;
305
306            let info =
307                lookup_key_info(key_hash).expect("key should be registered via declare_attrs!");
308
309            let value = (info.deserialize_bincode)(value_bytes).expect("value should deserialize");
310            write!(f, "{}={}", info.name, (info.display)(value.as_ref()))?;
311
312            offset += ENTRY_HEADER_SIZE + entry_len;
313        }
314
315        Ok(())
316    }
317}
318
319impl Serialize for Flattrs {
320    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
321    where
322        S: serde::Serializer,
323    {
324        self.to_part().serialize(serializer)
325    }
326}
327
328impl<'de> Deserialize<'de> for Flattrs {
329    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330    where
331        D: serde::Deserializer<'de>,
332    {
333        let part: Part = Deserialize::deserialize(deserializer)?;
334        Ok(Self::from_part(part))
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::attrs::declare_attrs;
342
343    declare_attrs! {
344        attr TEST_U64: u64;
345        attr TEST_STRING: String;
346        attr TEST_BOOL: bool;
347    }
348
349    #[test]
350    fn test_basic_set_get() {
351        let mut attrs = Flattrs::new();
352
353        attrs.set(TEST_U64, 42u64);
354        attrs.set(TEST_STRING, "hello".to_string());
355        attrs.set(TEST_BOOL, true);
356
357        assert_eq!(attrs.get(TEST_U64), Some(42u64));
358        assert_eq!(attrs.get(TEST_STRING), Some("hello".to_string()));
359        assert_eq!(attrs.get(TEST_BOOL), Some(true));
360    }
361
362    #[test]
363    fn test_missing_key() {
364        let attrs = Flattrs::new();
365        assert_eq!(attrs.get::<u64>(TEST_U64), None);
366    }
367
368    #[test]
369    fn test_set_replaces_existing() {
370        let mut attrs = Flattrs::new();
371        attrs.set(TEST_U64, 42u64);
372        attrs.set(TEST_U64, 100u64);
373        assert_eq!(attrs.get(TEST_U64), Some(100u64));
374        assert_eq!(attrs.len(), 1);
375    }
376
377    #[test]
378    fn test_set_replaces_different_size() {
379        let mut attrs = Flattrs::new();
380        attrs.set(TEST_STRING, "short".to_string());
381        attrs.set(TEST_STRING, "a much longer string".to_string());
382        assert_eq!(
383            attrs.get(TEST_STRING),
384            Some("a much longer string".to_string())
385        );
386        assert_eq!(attrs.len(), 1);
387    }
388
389    #[test]
390    fn test_contains_key() {
391        let mut attrs = Flattrs::new();
392
393        assert!(!attrs.contains_key(TEST_U64));
394        attrs.set(TEST_U64, 42u64);
395        assert!(attrs.contains_key(TEST_U64));
396    }
397
398    #[test]
399    fn test_serde_roundtrip() {
400        let mut attrs = Flattrs::new();
401        attrs.set(TEST_U64, 42u64);
402        attrs.set(TEST_STRING, "hello".to_string());
403
404        let serialized = bincode::serialize(&attrs).expect("serialize");
405        let deserialized: Flattrs = bincode::deserialize(&serialized).expect("deserialize");
406
407        assert_eq!(deserialized.get(TEST_U64), Some(42u64));
408        assert_eq!(deserialized.get(TEST_STRING), Some("hello".to_string()));
409        assert_eq!(deserialized.len(), 2);
410    }
411
412    #[test]
413    fn test_wire_roundtrip() {
414        let mut attrs = Flattrs::new();
415        attrs.set(TEST_U64, 42u64);
416        attrs.set(TEST_STRING, "hello".to_string());
417
418        let wire = attrs.to_part();
419        let received = Flattrs::from_part(wire);
420
421        assert_eq!(received.get(TEST_U64), Some(42u64));
422        assert_eq!(received.get(TEST_STRING), Some("hello".to_string()));
423        assert_eq!(received.len(), 2);
424    }
425
426    #[test]
427    fn test_multiple_keys() {
428        let mut attrs = Flattrs::new();
429        attrs.set(TEST_U64, 1u64);
430        attrs.set(TEST_STRING, "two".to_string());
431        attrs.set(TEST_BOOL, true);
432
433        assert_eq!(attrs.get(TEST_U64), Some(1u64));
434        assert_eq!(attrs.get(TEST_STRING), Some("two".to_string()));
435        assert_eq!(attrs.get(TEST_BOOL), Some(true));
436        assert_eq!(attrs.len(), 3);
437    }
438
439    #[test]
440    fn test_is_empty() {
441        let attrs = Flattrs::new();
442        assert!(attrs.is_empty());
443
444        let mut attrs2 = Flattrs::new();
445        attrs2.set(TEST_U64, 42u64);
446        assert!(!attrs2.is_empty());
447    }
448
449    #[test]
450    fn test_display() {
451        use crate::attrs::Attrs;
452
453        // Empty displays as empty string
454        let empty_flattrs = Flattrs::new();
455        let empty_attrs = Attrs::new();
456        assert_eq!(format!("{}", empty_flattrs), format!("{}", empty_attrs));
457        assert_eq!(format!("{}", empty_flattrs), "");
458
459        // Single entry - Flattrs and Attrs should display the same
460        let mut single_flattrs = Flattrs::new();
461        single_flattrs.set(TEST_U64, 42u64);
462        let mut single_attrs = Attrs::new();
463        single_attrs.set(TEST_U64, 42u64);
464        assert_eq!(format!("{}", single_flattrs), format!("{}", single_attrs));
465        assert_eq!(
466            format!("{}", single_flattrs),
467            "hyperactor_config::flattrs::tests::test_u64=42"
468        );
469
470        // Multiple entries - Flattrs maintains insertion order, Attrs uses HashMap order
471        // So we only compare to the expected string for Flattrs
472        let mut multi_flattrs = Flattrs::new();
473        multi_flattrs.set(TEST_U64, 1u64);
474        multi_flattrs.set(TEST_STRING, "hello".to_string());
475        assert_eq!(
476            format!("{}", multi_flattrs),
477            "hyperactor_config::flattrs::tests::test_u64=1,hyperactor_config::flattrs::tests::test_string=hello"
478        );
479    }
480}