1use bytes::Bytes;
57use bytes::BytesMut;
58use serde::Deserialize;
59use serde::Serialize;
60use serde::de::DeserializeOwned;
61use serde_multipart::Part;
62
63use crate::attrs::AttrValue;
64use crate::attrs::Attrs;
65use crate::attrs::Key;
66
67const HEADER_SIZE: usize = 2;
69
70const ENTRY_HEADER_SIZE: usize = 12;
72
73#[derive(Clone, Default)]
79pub struct Flattrs {
80 buffer: BytesMut,
84}
85
86impl Flattrs {
87 pub fn new() -> Self {
89 let mut buffer = BytesMut::with_capacity(HEADER_SIZE);
90 buffer.extend_from_slice(&0u16.to_le_bytes());
91 Self { buffer }
92 }
93
94 pub fn from_part(part: Part) -> Self {
96 Self {
97 buffer: BytesMut::from(part.into_bytes().as_ref()),
98 }
99 }
100
101 pub fn to_part(&self) -> Part {
105 Part::from(Bytes::copy_from_slice(&self.buffer))
106 }
107
108 pub fn set<T: Serialize>(&mut self, key: Key<T>, value: T) {
114 let key_hash = key.key_hash();
115 let serialized = bincode::serde::encode_to_vec(&value, bincode::config::legacy())
116 .expect("serialization failed");
117 self.set_serialized(key_hash, &serialized);
118 }
119
120 pub fn set_serialized(&mut self, key_hash: u64, serialized: &[u8]) {
135 if let Some((offset, old_len)) = self.find_entry_location(key_hash) {
137 if serialized.len() == old_len {
138 let value_start = offset + ENTRY_HEADER_SIZE;
140 self.buffer[value_start..value_start + old_len].copy_from_slice(serialized);
141 return;
142 }
143
144 let entry_size = ENTRY_HEADER_SIZE + old_len;
146 let end = offset + entry_size;
147
148 if end < self.buffer.len() {
149 self.buffer.copy_within(end.., offset);
150 }
151 self.buffer.truncate(self.buffer.len() - entry_size);
152
153 let count = self.len();
155 self.buffer[0..2].copy_from_slice(&((count - 1) as u16).to_le_bytes());
156 }
157
158 self.append_entry(key_hash, serialized);
159 }
160
161 pub fn get<T: AttrValue + DeserializeOwned>(&self, key: Key<T>) -> Option<T> {
166 let key_hash = key.key_hash();
167 let value_bytes = self.find_value(key_hash)?;
168 bincode::serde::decode_from_slice(value_bytes, bincode::config::legacy())
169 .map(|(v, _)| v)
170 .ok()
171 }
172
173 #[inline]
175 pub fn contains_key<T>(&self, key: Key<T>) -> bool {
176 self.find_value(key.key_hash()).is_some()
177 }
178
179 #[inline]
181 pub fn is_empty(&self) -> bool {
182 self.len() == 0
183 }
184
185 #[inline]
187 pub fn len(&self) -> usize {
188 if self.buffer.len() < HEADER_SIZE {
189 return 0;
190 }
191 u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize
192 }
193
194 pub fn from_attrs(attrs: &Attrs) -> Self {
196 let mut flattrs = Self::new();
197 for (name, value) in attrs.iter() {
198 let key_hash = crate::attrs::fnv1a_hash(name.as_bytes());
199 let serialized = value.serialize_bincode();
200 flattrs.append_entry(key_hash, &serialized);
201 }
202 flattrs
203 }
204
205 pub fn iter(&self) -> FlattrsIter<'_> {
214 FlattrsIter {
215 buffer: &self.buffer,
216 remaining: self.len(),
217 offset: HEADER_SIZE,
218 }
219 }
220
221 fn find_value(&self, key_hash: u64) -> Option<&[u8]> {
223 if self.buffer.len() < HEADER_SIZE {
224 return None;
225 }
226
227 let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
228 let mut offset = HEADER_SIZE;
229
230 for _ in 0..num_entries {
231 if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
232 return None;
233 }
234
235 let entry_key_hash =
236 u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
237 let entry_len = u32::from_le_bytes(
238 self.buffer[offset + 8..offset + 12]
239 .try_into()
240 .unwrap_or([0; 4]),
241 ) as usize;
242
243 let value_start = offset + ENTRY_HEADER_SIZE;
244 let value_end = value_start + entry_len;
245
246 if value_end > self.buffer.len() {
247 return None;
248 }
249
250 if entry_key_hash == key_hash {
251 return Some(&self.buffer[value_start..value_end]);
252 }
253
254 offset = value_end;
255 }
256
257 None
258 }
259
260 fn find_entry_location(&self, key_hash: u64) -> Option<(usize, usize)> {
262 if self.buffer.len() < HEADER_SIZE {
263 return None;
264 }
265
266 let num_entries = u16::from_le_bytes([self.buffer[0], self.buffer[1]]) as usize;
267 let mut offset = HEADER_SIZE;
268
269 for _ in 0..num_entries {
270 if offset + ENTRY_HEADER_SIZE > self.buffer.len() {
271 return None;
272 }
273
274 let entry_key_hash =
275 u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap_or([0; 8]));
276 let entry_len = u32::from_le_bytes(
277 self.buffer[offset + 8..offset + 12]
278 .try_into()
279 .unwrap_or([0; 4]),
280 ) as usize;
281
282 if entry_key_hash == key_hash {
283 return Some((offset, entry_len));
284 }
285
286 offset += ENTRY_HEADER_SIZE + entry_len;
287 }
288
289 None
290 }
291
292 fn append_entry(&mut self, key_hash: u64, value: &[u8]) {
294 let len = self.len();
295 self.buffer[0..2].copy_from_slice(&((len + 1) as u16).to_le_bytes());
296
297 self.buffer.extend_from_slice(&key_hash.to_le_bytes());
299 self.buffer
300 .extend_from_slice(&(value.len() as u32).to_le_bytes());
301 self.buffer.extend_from_slice(value);
302 }
303}
304
305pub struct FlattrsIter<'a> {
309 buffer: &'a [u8],
310 remaining: usize,
311 offset: usize,
312}
313
314impl<'a> Iterator for FlattrsIter<'a> {
315 type Item = (u64, &'a [u8]);
316
317 fn next(&mut self) -> Option<Self::Item> {
318 if self.remaining == 0 {
319 return None;
320 }
321 if self.offset + ENTRY_HEADER_SIZE > self.buffer.len() {
322 return None;
323 }
324 let key_hash = u64::from_le_bytes(
325 self.buffer[self.offset..self.offset + 8]
326 .try_into()
327 .unwrap_or([0; 8]),
328 );
329 let entry_len = u32::from_le_bytes(
330 self.buffer[self.offset + 8..self.offset + 12]
331 .try_into()
332 .unwrap_or([0; 4]),
333 ) as usize;
334 let value_start = self.offset + ENTRY_HEADER_SIZE;
335 let value_end = value_start + entry_len;
336 if value_end > self.buffer.len() {
337 return None;
338 }
339 let value = &self.buffer[value_start..value_end];
340 self.offset = value_end;
341 self.remaining -= 1;
342 Some((key_hash, value))
343 }
344}
345
346impl From<Attrs> for Flattrs {
347 fn from(attrs: Attrs) -> Self {
348 Self::from_attrs(&attrs)
349 }
350}
351
352impl From<&Attrs> for Flattrs {
353 fn from(attrs: &Attrs) -> Self {
354 Self::from_attrs(attrs)
355 }
356}
357
358impl std::fmt::Debug for Flattrs {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("Flattrs").field("len", &self.len()).finish()
361 }
362}
363
364impl std::fmt::Display for Flattrs {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 use crate::attrs::lookup_key_info;
367
368 let mut offset = HEADER_SIZE;
369 let mut first = true;
370
371 for _ in 0..self.len() {
372 let key_hash = u64::from_le_bytes(self.buffer[offset..offset + 8].try_into().unwrap());
373 let entry_len =
374 u32::from_le_bytes(self.buffer[offset + 8..offset + 12].try_into().unwrap())
375 as usize;
376 let value_bytes = &self.buffer[offset + ENTRY_HEADER_SIZE..][..entry_len];
377
378 if !first {
379 write!(f, ",")?;
380 }
381 first = false;
382
383 let info =
384 lookup_key_info(key_hash).expect("key should be registered via declare_attrs!");
385
386 let value = (info.deserialize_bincode)(value_bytes).expect("value should deserialize");
387 write!(f, "{}={}", info.name, (info.display)(value.as_ref()))?;
388
389 offset += ENTRY_HEADER_SIZE + entry_len;
390 }
391
392 Ok(())
393 }
394}
395
396impl Serialize for Flattrs {
397 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
398 where
399 S: serde::Serializer,
400 {
401 self.to_part().serialize(serializer)
402 }
403}
404
405impl<'de> Deserialize<'de> for Flattrs {
406 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
407 where
408 D: serde::Deserializer<'de>,
409 {
410 let part: Part = Deserialize::deserialize(deserializer)?;
411 Ok(Self::from_part(part))
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::attrs::declare_attrs;
419
420 declare_attrs! {
421 attr TEST_U64: u64;
422 attr TEST_STRING: String;
423 attr TEST_BOOL: bool;
424 }
425
426 #[test]
427 fn test_basic_set_get() {
428 let mut attrs = Flattrs::new();
429
430 attrs.set(TEST_U64, 42u64);
431 attrs.set(TEST_STRING, "hello".to_string());
432 attrs.set(TEST_BOOL, true);
433
434 assert_eq!(attrs.get(TEST_U64), Some(42u64));
435 assert_eq!(attrs.get(TEST_STRING), Some("hello".to_string()));
436 assert_eq!(attrs.get(TEST_BOOL), Some(true));
437 }
438
439 #[test]
440 fn test_missing_key() {
441 let attrs = Flattrs::new();
442 assert_eq!(attrs.get::<u64>(TEST_U64), None);
443 }
444
445 #[test]
446 fn test_set_replaces_existing() {
447 let mut attrs = Flattrs::new();
448 attrs.set(TEST_U64, 42u64);
449 attrs.set(TEST_U64, 100u64);
450 assert_eq!(attrs.get(TEST_U64), Some(100u64));
451 assert_eq!(attrs.len(), 1);
452 }
453
454 #[test]
455 fn test_set_replaces_different_size() {
456 let mut attrs = Flattrs::new();
457 attrs.set(TEST_STRING, "short".to_string());
458 attrs.set(TEST_STRING, "a much longer string".to_string());
459 assert_eq!(
460 attrs.get(TEST_STRING),
461 Some("a much longer string".to_string())
462 );
463 assert_eq!(attrs.len(), 1);
464 }
465
466 #[test]
469 fn test_set_serialized_overwrites_existing_same_size() {
470 let mut attrs = Flattrs::new();
471 let key_hash = TEST_U64.key_hash();
472 let first = bincode::serde::encode_to_vec(42u64, bincode::config::legacy()).unwrap();
473 attrs.set_serialized(key_hash, &first);
474 assert_eq!(attrs.get(TEST_U64), Some(42u64));
475 assert_eq!(attrs.len(), 1);
476
477 let second = bincode::serde::encode_to_vec(100u64, bincode::config::legacy()).unwrap();
478 assert_eq!(first.len(), second.len(), "same-size precondition");
479 attrs.set_serialized(key_hash, &second);
480
481 assert_eq!(attrs.get(TEST_U64), Some(100u64));
482 assert_eq!(attrs.len(), 1);
483 }
484
485 #[test]
488 fn test_set_serialized_overwrites_existing_different_size() {
489 let mut attrs = Flattrs::new();
490 let key_hash = TEST_STRING.key_hash();
491 let short =
492 bincode::serde::encode_to_vec("short".to_string(), bincode::config::legacy()).unwrap();
493 attrs.set_serialized(key_hash, &short);
494 assert_eq!(attrs.get(TEST_STRING), Some("short".to_string()));
495 assert_eq!(attrs.len(), 1);
496
497 let long = bincode::serde::encode_to_vec(
498 "a much longer string".to_string(),
499 bincode::config::legacy(),
500 )
501 .unwrap();
502 assert_ne!(short.len(), long.len(), "different-size precondition");
503 attrs.set_serialized(key_hash, &long);
504
505 assert_eq!(
506 attrs.get(TEST_STRING),
507 Some("a much longer string".to_string())
508 );
509 assert_eq!(attrs.len(), 1);
510 }
511
512 #[test]
515 fn test_set_serialized_interops_with_typed_set() {
516 let mut attrs = Flattrs::new();
517 attrs.set(TEST_U64, 1u64);
518 attrs.set_serialized(
519 TEST_U64.key_hash(),
520 &bincode::serde::encode_to_vec(2u64, bincode::config::legacy()).unwrap(),
521 );
522 assert_eq!(attrs.get(TEST_U64), Some(2u64));
523 assert_eq!(attrs.len(), 1);
524
525 let mut attrs = Flattrs::new();
526 attrs.set_serialized(
527 TEST_U64.key_hash(),
528 &bincode::serde::encode_to_vec(1u64, bincode::config::legacy()).unwrap(),
529 );
530 attrs.set(TEST_U64, 2u64);
531 assert_eq!(attrs.get(TEST_U64), Some(2u64));
532 assert_eq!(attrs.len(), 1);
533 }
534
535 #[test]
537 fn test_set_serialized_new_key_appends() {
538 let mut attrs = Flattrs::new();
539 attrs.set_serialized(
540 TEST_U64.key_hash(),
541 &bincode::serde::encode_to_vec(7u64, bincode::config::legacy()).unwrap(),
542 );
543 attrs.set_serialized(
544 TEST_STRING.key_hash(),
545 &bincode::serde::encode_to_vec("x".to_string(), bincode::config::legacy()).unwrap(),
546 );
547 assert_eq!(attrs.get(TEST_U64), Some(7u64));
548 assert_eq!(attrs.get(TEST_STRING), Some("x".to_string()));
549 assert_eq!(attrs.len(), 2);
550 }
551
552 #[test]
553 fn test_contains_key() {
554 let mut attrs = Flattrs::new();
555
556 assert!(!attrs.contains_key(TEST_U64));
557 attrs.set(TEST_U64, 42u64);
558 assert!(attrs.contains_key(TEST_U64));
559 }
560
561 #[test]
562 fn test_serde_roundtrip() {
563 let mut attrs = Flattrs::new();
564 attrs.set(TEST_U64, 42u64);
565 attrs.set(TEST_STRING, "hello".to_string());
566
567 let serialized =
568 bincode::serde::encode_to_vec(&attrs, bincode::config::legacy()).expect("serialize");
569 let deserialized: Flattrs =
570 bincode::serde::decode_from_slice(&serialized, bincode::config::legacy())
571 .map(|(v, _)| v)
572 .expect("deserialize");
573
574 assert_eq!(deserialized.get(TEST_U64), Some(42u64));
575 assert_eq!(deserialized.get(TEST_STRING), Some("hello".to_string()));
576 assert_eq!(deserialized.len(), 2);
577 }
578
579 #[test]
580 fn test_wire_roundtrip() {
581 let mut attrs = Flattrs::new();
582 attrs.set(TEST_U64, 42u64);
583 attrs.set(TEST_STRING, "hello".to_string());
584
585 let wire = attrs.to_part();
586 let received = Flattrs::from_part(wire);
587
588 assert_eq!(received.get(TEST_U64), Some(42u64));
589 assert_eq!(received.get(TEST_STRING), Some("hello".to_string()));
590 assert_eq!(received.len(), 2);
591 }
592
593 #[test]
594 fn test_multiple_keys() {
595 let mut attrs = Flattrs::new();
596 attrs.set(TEST_U64, 1u64);
597 attrs.set(TEST_STRING, "two".to_string());
598 attrs.set(TEST_BOOL, true);
599
600 assert_eq!(attrs.get(TEST_U64), Some(1u64));
601 assert_eq!(attrs.get(TEST_STRING), Some("two".to_string()));
602 assert_eq!(attrs.get(TEST_BOOL), Some(true));
603 assert_eq!(attrs.len(), 3);
604 }
605
606 #[test]
607 fn test_is_empty() {
608 let attrs = Flattrs::new();
609 assert!(attrs.is_empty());
610
611 let mut attrs2 = Flattrs::new();
612 attrs2.set(TEST_U64, 42u64);
613 assert!(!attrs2.is_empty());
614 }
615
616 #[test]
617 fn test_display() {
618 use crate::attrs::Attrs;
619
620 let empty_flattrs = Flattrs::new();
622 let empty_attrs = Attrs::new();
623 assert_eq!(format!("{}", empty_flattrs), format!("{}", empty_attrs));
624 assert_eq!(format!("{}", empty_flattrs), "");
625
626 let mut single_flattrs = Flattrs::new();
628 single_flattrs.set(TEST_U64, 42u64);
629 let mut single_attrs = Attrs::new();
630 single_attrs.set(TEST_U64, 42u64);
631 assert_eq!(format!("{}", single_flattrs), format!("{}", single_attrs));
632 assert_eq!(
633 format!("{}", single_flattrs),
634 "hyperactor_config::flattrs::tests::test_u64=42"
635 );
636
637 let mut multi_flattrs = Flattrs::new();
640 multi_flattrs.set(TEST_U64, 1u64);
641 multi_flattrs.set(TEST_STRING, "hello".to_string());
642 assert_eq!(
643 format!("{}", multi_flattrs),
644 "hyperactor_config::flattrs::tests::test_u64=1,hyperactor_config::flattrs::tests::test_string=hello"
645 );
646 }
647}