1use bytes::Bytes;
57use bytes::BytesMut;
58use serde::Deserialize;
59use serde::Serialize;
60use serde::de::DeserializeOwned;
61use serde_multipart::Part;
62
63use crate::attrs::AttrValue;
64use crate::attrs::Attrs;
65use crate::attrs::Key;
66
67const HEADER_SIZE: usize = 2;
69
70const ENTRY_HEADER_SIZE: usize = 12;
72
73#[derive(Clone, Default)]
79pub struct Flattrs {
80 buffer: BytesMut,
84}
85
86impl Flattrs {
87 pub fn new() -> Self {
89 let mut buffer = BytesMut::with_capacity(HEADER_SIZE);
90 buffer.extend_from_slice(&0u16.to_le_bytes());
91 Self { buffer }
92 }
93
94 pub fn from_part(part: Part) -> Self {
96 Self {
97 buffer: BytesMut::from(part.into_bytes().as_ref()),
98 }
99 }
100
101 pub fn to_part(&self) -> Part {
105 Part::from(Bytes::copy_from_slice(&self.buffer))
106 }
107
108 pub fn set<T: Serialize>(&mut self, key: Key<T>, value: T) {
114 let key_hash = key.key_hash();
115 let serialized = bincode::serialize(&value).expect("serialization failed");
116
117 if let Some((offset, old_len)) = self.find_entry_location(key_hash) {
119 if serialized.len() == old_len {
120 let value_start = offset + ENTRY_HEADER_SIZE;
122 self.buffer[value_start..value_start + old_len].copy_from_slice(&serialized);
123 return;
124 }
125
126 let entry_size = ENTRY_HEADER_SIZE + old_len;
128 let end = offset + entry_size;
129
130 if end < self.buffer.len() {
131 self.buffer.copy_within(end.., offset);
132 }
133 self.buffer.truncate(self.buffer.len() - entry_size);
134
135 let count = self.len();
137 self.buffer[0..2].copy_from_slice(&((count - 1) as u16).to_le_bytes());
138 }
139
140 self.append_entry(key_hash, &serialized);
141 }
142
143 pub fn get<T: AttrValue + DeserializeOwned>(&self, key: Key<T>) -> Option<T> {
148 let key_hash = key.key_hash();
149 let value_bytes = self.find_value(key_hash)?;
150 bincode::deserialize(value_bytes).ok()
151 }
152
153 #[inline]
155 pub fn contains_key<T>(&self, key: Key<T>) -> bool {
156 self.find_value(key.key_hash()).is_some()
157 }
158
159 #[inline]
161 pub fn is_empty(&self) -> bool {
162 self.len() == 0
163 }
164
165 #[inline]
167 pub fn len(&self) -> usize {
168 if self.buffer.len() < HEADER_SIZE {
169 return 0;
170 }
171 u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize
172 }
173
174 pub fn from_attrs(attrs: &Attrs) -> Self {
176 let mut flattrs = Self::new();
177 for (name, value) in attrs.iter() {
178 let key_hash = crate::attrs::fnv1a_hash(name.as_bytes());
179 let serialized = value.serialize_bincode();
180 flattrs.append_entry(key_hash, &serialized);
181 }
182 flattrs
183 }
184
185 fn find_value(&self, key_hash: u64) -> Option<&[u8]> {
187 if self.buffer.len() < HEADER_SIZE {
188 return None;
189 }
190
191 let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
192 let mut offset = HEADER_SIZE;
193
194 for _ in 0..num_entries {
195 if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
196 return None;
197 }
198
199 let entry_key_hash =
200 u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
201 let entry_len = u32::from_le_bytes(
202 self.buffer[offset + 8..offset + 12]
203 .try_into()
204 .unwrap_or([0; 4]),
205 ) as usize;
206
207 let value_start = offset + ENTRY_HEADER_SIZE;
208 let value_end = value_start + entry_len;
209
210 if value_end > self.buffer.len() {
211 return None;
212 }
213
214 if entry_key_hash == key_hash {
215 return Some(&self.buffer[value_start..value_end]);
216 }
217
218 offset = value_end;
219 }
220
221 None
222 }
223
224 fn find_entry_location(&self, key_hash: u64) -> Option<(usize, usize)> {
226 if self.buffer.len() < HEADER_SIZE {
227 return None;
228 }
229
230 let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
231 let mut offset = HEADER_SIZE;
232
233 for _ in 0..num_entries {
234 if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
235 return None;
236 }
237
238 let entry_key_hash =
239 u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
240 let entry_len = u32::from_le_bytes(
241 self.buffer[offset + 8..offset + 12]
242 .try_into()
243 .unwrap_or([0; 4]),
244 ) as usize;
245
246 if entry_key_hash == key_hash {
247 return Some((offset, entry_len));
248 }
249
250 offset += ENTRY_HEADER_SIZE + entry_len;
251 }
252
253 None
254 }
255
256 fn append_entry(&mut self, key_hash: u64, value: &[u8]) {
258 let len = self.len();
259 self.buffer[0..2].copy_from_slice(&((len + 1) as u16).to_le_bytes());
260
261 self.buffer.extend_from_slice(&key_hash.to_le_bytes());
263 self.buffer
264 .extend_from_slice(&(value.len() as u32).to_le_bytes());
265 self.buffer.extend_from_slice(value);
266 }
267}
268
269impl From<Attrs> for Flattrs {
270 fn from(attrs: Attrs) -> Self {
271 Self::from_attrs(&attrs)
272 }
273}
274
275impl From<&Attrs> for Flattrs {
276 fn from(attrs: &Attrs) -> Self {
277 Self::from_attrs(attrs)
278 }
279}
280
281impl std::fmt::Debug for Flattrs {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("Flattrs").field("len", &self.len()).finish()
284 }
285}
286
287impl std::fmt::Display for Flattrs {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 use crate::attrs::lookup_key_info;
290
291 let mut offset = HEADER_SIZE;
292 let mut first = true;
293
294 for _ in 0..self.len() {
295 let key_hash = u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap());
296 let entry_len =
297 u32::from_le_bytes(self.buffer[offset + 8..offset + 12].try_into().unwrap())
298 as usize;
299 let value_bytes = &self.buffer[offset + ENTRY_HEADER_SIZE..][..entry_len];
300
301 if !first {
302 write!(f, ",")?;
303 }
304 first = false;
305
306 let info =
307 lookup_key_info(key_hash).expect("key should be registered via declare_attrs!");
308
309 let value = (info.deserialize_bincode)(value_bytes).expect("value should deserialize");
310 write!(f, "{}={}", info.name, (info.display)(value.as_ref()))?;
311
312 offset += ENTRY_HEADER_SIZE + entry_len;
313 }
314
315 Ok(())
316 }
317}
318
319impl Serialize for Flattrs {
320 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
321 where
322 S: serde::Serializer,
323 {
324 self.to_part().serialize(serializer)
325 }
326}
327
328impl<'de> Deserialize<'de> for Flattrs {
329 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330 where
331 D: serde::Deserializer<'de>,
332 {
333 let part: Part = Deserialize::deserialize(deserializer)?;
334 Ok(Self::from_part(part))
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::attrs::declare_attrs;
342
343 declare_attrs! {
344 attr TEST_U64: u64;
345 attr TEST_STRING: String;
346 attr TEST_BOOL: bool;
347 }
348
349 #[test]
350 fn test_basic_set_get() {
351 let mut attrs = Flattrs::new();
352
353 attrs.set(TEST_U64, 42u64);
354 attrs.set(TEST_STRING, "hello".to_string());
355 attrs.set(TEST_BOOL, true);
356
357 assert_eq!(attrs.get(TEST_U64), Some(42u64));
358 assert_eq!(attrs.get(TEST_STRING), Some("hello".to_string()));
359 assert_eq!(attrs.get(TEST_BOOL), Some(true));
360 }
361
362 #[test]
363 fn test_missing_key() {
364 let attrs = Flattrs::new();
365 assert_eq!(attrs.get::<u64>(TEST_U64), None);
366 }
367
368 #[test]
369 fn test_set_replaces_existing() {
370 let mut attrs = Flattrs::new();
371 attrs.set(TEST_U64, 42u64);
372 attrs.set(TEST_U64, 100u64);
373 assert_eq!(attrs.get(TEST_U64), Some(100u64));
374 assert_eq!(attrs.len(), 1);
375 }
376
377 #[test]
378 fn test_set_replaces_different_size() {
379 let mut attrs = Flattrs::new();
380 attrs.set(TEST_STRING, "short".to_string());
381 attrs.set(TEST_STRING, "a much longer string".to_string());
382 assert_eq!(
383 attrs.get(TEST_STRING),
384 Some("a much longer string".to_string())
385 );
386 assert_eq!(attrs.len(), 1);
387 }
388
389 #[test]
390 fn test_contains_key() {
391 let mut attrs = Flattrs::new();
392
393 assert!(!attrs.contains_key(TEST_U64));
394 attrs.set(TEST_U64, 42u64);
395 assert!(attrs.contains_key(TEST_U64));
396 }
397
398 #[test]
399 fn test_serde_roundtrip() {
400 let mut attrs = Flattrs::new();
401 attrs.set(TEST_U64, 42u64);
402 attrs.set(TEST_STRING, "hello".to_string());
403
404 let serialized = bincode::serialize(&attrs).expect("serialize");
405 let deserialized: Flattrs = bincode::deserialize(&serialized).expect("deserialize");
406
407 assert_eq!(deserialized.get(TEST_U64), Some(42u64));
408 assert_eq!(deserialized.get(TEST_STRING), Some("hello".to_string()));
409 assert_eq!(deserialized.len(), 2);
410 }
411
412 #[test]
413 fn test_wire_roundtrip() {
414 let mut attrs = Flattrs::new();
415 attrs.set(TEST_U64, 42u64);
416 attrs.set(TEST_STRING, "hello".to_string());
417
418 let wire = attrs.to_part();
419 let received = Flattrs::from_part(wire);
420
421 assert_eq!(received.get(TEST_U64), Some(42u64));
422 assert_eq!(received.get(TEST_STRING), Some("hello".to_string()));
423 assert_eq!(received.len(), 2);
424 }
425
426 #[test]
427 fn test_multiple_keys() {
428 let mut attrs = Flattrs::new();
429 attrs.set(TEST_U64, 1u64);
430 attrs.set(TEST_STRING, "two".to_string());
431 attrs.set(TEST_BOOL, true);
432
433 assert_eq!(attrs.get(TEST_U64), Some(1u64));
434 assert_eq!(attrs.get(TEST_STRING), Some("two".to_string()));
435 assert_eq!(attrs.get(TEST_BOOL), Some(true));
436 assert_eq!(attrs.len(), 3);
437 }
438
439 #[test]
440 fn test_is_empty() {
441 let attrs = Flattrs::new();
442 assert!(attrs.is_empty());
443
444 let mut attrs2 = Flattrs::new();
445 attrs2.set(TEST_U64, 42u64);
446 assert!(!attrs2.is_empty());
447 }
448
449 #[test]
450 fn test_display() {
451 use crate::attrs::Attrs;
452
453 let empty_flattrs = Flattrs::new();
455 let empty_attrs = Attrs::new();
456 assert_eq!(format!("{}", empty_flattrs), format!("{}", empty_attrs));
457 assert_eq!(format!("{}", empty_flattrs), "");
458
459 let mut single_flattrs = Flattrs::new();
461 single_flattrs.set(TEST_U64, 42u64);
462 let mut single_attrs = Attrs::new();
463 single_attrs.set(TEST_U64, 42u64);
464 assert_eq!(format!("{}", single_flattrs), format!("{}", single_attrs));
465 assert_eq!(
466 format!("{}", single_flattrs),
467 "hyperactor_config::flattrs::tests::test_u64=42"
468 );
469
470 let mut multi_flattrs = Flattrs::new();
473 multi_flattrs.set(TEST_U64, 1u64);
474 multi_flattrs.set(TEST_STRING, "hello".to_string());
475 assert_eq!(
476 format!("{}", multi_flattrs),
477 "hyperactor_config::flattrs::tests::test_u64=1,hyperactor_config::flattrs::tests::test_string=hello"
478 );
479 }
480}