1#![feature(min_specialization)]
30#![feature(assert_matches)]
31
32use std::cell::UnsafeCell;
33use std::cmp::min;
34use std::collections::VecDeque;
35use std::io::IoSlice;
36use std::ptr::NonNull;
37
38use bincode::Options;
39use bytes::Buf;
40use bytes::BufMut;
41use bytes::buf::UninitSlice;
42
43mod de;
44mod part;
45mod ser;
46use bytes::Bytes;
47use bytes::BytesMut;
48pub use part::Part;
49use serde::Deserialize;
50use serde::Serialize;
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56pub struct Message {
57 body: Part,
58 parts: Vec<Part>,
59}
60
61impl Message {
62 pub fn from_body_and_parts(body: Part, parts: Vec<Part>) -> Self {
64 Self { body, parts }
65 }
66
67 pub fn body(&self) -> &Part {
69 &self.body
70 }
71
72 pub fn parts(&self) -> &[Part] {
74 &self.parts
75 }
76
77 pub fn num_parts(&self) -> usize {
79 self.parts.len()
80 }
81
82 pub fn len(&self) -> usize {
84 self.body.len() + self.parts.iter().map(|part| part.len()).sum::<usize>()
85 }
86
87 pub fn is_empty(&self) -> bool {
90 self.body.is_empty() && self.parts.iter().all(|part| part.is_empty())
91 }
92
93 pub fn into_inner(self) -> (Part, Vec<Part>) {
95 (self.body, self.parts)
96 }
97
98 pub fn frame_len(&self) -> usize {
100 8 + self.body.len()
101 + (8 * self.parts.len())
102 + self.parts.iter().map(|p| p.len()).sum::<usize>()
103 }
104
105 pub fn framed(self) -> Frame {
117 let (body, parts) = self.into_inner();
118
119 let mut buffers = Vec::with_capacity(
120 1 + body.num_fragments()
121 + parts.len()
122 + parts.iter().map(|part| part.num_fragments()).sum::<usize>(),
123 );
124
125 buffers.push(Bytes::from_owner(body.len().to_be_bytes()));
126 for fragment in body.into_inner() {
127 buffers.push(fragment);
128 }
129
130 for part in parts {
131 buffers.push(Bytes::from_owner(part.len().to_be_bytes()));
132 for fragment in part.into_inner() {
133 buffers.push(fragment);
134 }
135 }
136
137 Frame::from_buffers(buffers)
138 }
139
140 pub fn from_framed(mut buf: Bytes) -> Result<Self, std::io::Error> {
142 if buf.len() < 8 {
143 return Err(std::io::ErrorKind::UnexpectedEof.into());
144 }
145 let body_len = buf.get_u64();
146 let body = buf.split_to(body_len as usize);
147 let mut parts = Vec::new();
148 while !buf.is_empty() {
149 parts.push(Self::split_part(&mut buf)?.into());
150 }
151 Ok(Self {
152 body: body.into(),
153 parts,
154 })
155 }
156
157 fn split_part(buf: &mut Bytes) -> Result<Bytes, std::io::Error> {
158 if buf.len() < 8 {
159 return Err(std::io::ErrorKind::UnexpectedEof.into());
160 }
161 let at = buf.get_u64() as usize;
162 if buf.len() < at {
163 return Err(std::io::ErrorKind::UnexpectedEof.into());
164 }
165 Ok(buf.split_to(at))
166 }
167}
168
169#[derive(Clone)]
173pub struct Frame {
174 buffers: VecDeque<Bytes>,
175}
176
177impl Frame {
178 fn from_buffers(buffers: Vec<Bytes>) -> Self {
181 let mut buffers: VecDeque<Bytes> = buffers.into();
182 buffers.retain(|buf| !buf.is_empty());
183 Self { buffers }
184 }
185
186 pub fn illegal_unipart_frame(body: Bytes) -> Self {
188 Self {
189 buffers: vec![body].into(),
190 }
191 }
192}
193
194impl Buf for Frame {
195 fn remaining(&self) -> usize {
196 self.buffers.iter().map(|buf| buf.remaining()).sum()
197 }
198
199 fn chunk(&self) -> &[u8] {
200 match self.buffers.front() {
201 Some(buf) => buf.chunk(),
202 None => &[],
203 }
204 }
205
206 fn advance(&mut self, mut cnt: usize) {
207 while cnt > 0 {
208 let Some(buf) = self.buffers.front_mut() else {
209 panic!("advanced beyond the buffer size");
210 };
211
212 if cnt >= buf.remaining() {
213 cnt -= buf.remaining();
214 self.buffers.pop_front();
215 continue;
216 }
217
218 buf.advance(cnt);
219 cnt = 0;
220 }
221 }
222
223 fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
226 let n = min(dst.len(), self.buffers.len());
227 for i in 0..n {
228 dst[i] = IoSlice::new(self.buffers[i].chunk());
229 }
230 n
231 }
232}
233
234struct UnsafeBufCell {
238 buf: UnsafeCell<BytesMut>,
239}
240
241impl UnsafeBufCell {
242 fn from_bytes_mut(bytes: BytesMut) -> Self {
244 Self {
245 buf: UnsafeCell::new(bytes),
246 }
247 }
248
249 fn into_inner(self) -> BytesMut {
251 self.buf.into_inner()
252 }
253
254 unsafe fn borrow_unchecked(&self) -> UnsafeBufCellRef {
258 let ptr =
259 unsafe { NonNull::new_unchecked(self.buf.get()) };
261 UnsafeBufCellRef { ptr }
262 }
263}
264
265struct UnsafeBufCellRef {
267 ptr: NonNull<BytesMut>,
268}
269
270unsafe impl BufMut for UnsafeBufCellRef {
273 fn remaining_mut(&self) -> usize {
274 unsafe { self.ptr.as_ref().remaining_mut() }
276 }
277
278 unsafe fn advance_mut(&mut self, cnt: usize) {
279 unsafe { self.ptr.as_mut().advance_mut(cnt) }
281 }
282
283 fn chunk_mut(&mut self) -> &mut UninitSlice {
284 unsafe { self.ptr.as_mut().chunk_mut() }
286 }
287}
288
289pub fn serialize_bincode<S: ?Sized + serde::Serialize>(
296 value: &S,
297) -> Result<Message, bincode::Error> {
298 let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new());
299 let buffer_borrow = unsafe { buffer.borrow_unchecked() };
302 let mut serializer: part::BincodeSerializer =
303 ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options()));
304 value.serialize(&mut serializer)?;
305 Ok(Message {
306 body: Part(vec![buffer.into_inner().freeze()]),
307 parts: serializer.into_parts(),
308 })
309}
310
311pub fn deserialize_bincode<T>(message: Message) -> Result<T, bincode::Error>
314where
315 T: serde::de::DeserializeOwned,
316{
317 let (body, parts) = message.into_inner();
318 let mut deserializer = part::BincodeDeserializer::new(
319 bincode::Deserializer::with_reader(body.into_bytes().reader(), options()),
320 parts.into(),
321 );
322 let value = T::deserialize(&mut deserializer)?;
323 deserializer.end()?;
325 Ok(value)
326}
327
328fn options() -> part::BincodeOptionsType {
330 bincode::DefaultOptions::new()
331 .with_fixint_encoding()
332 .allow_trailing_bytes()
333}
334
335#[cfg(test)]
336mod tests {
337 use std::assert_matches::assert_matches;
338 use std::collections::HashMap;
339 use std::net::SocketAddr;
340 use std::net::SocketAddrV6;
341
342 use proptest::prelude::*;
343 use proptest_derive::Arbitrary;
344 use serde::Deserialize;
345 use serde::Serialize;
346 use serde::de::DeserializeOwned;
347
348 use super::*;
349
350 fn test_roundtrip<T>(value: T, expected_parts: usize)
351 where
352 T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
353 {
354 let message = serialize_bincode(&value).unwrap();
356 assert_eq!(message.num_parts(), expected_parts);
357 let deserialized_value = deserialize_bincode(message.clone()).unwrap();
358 assert_eq!(value, deserialized_value);
359
360 let mut framed = message.clone().framed();
362 let framed = framed.copy_to_bytes(framed.remaining());
363 let unframed_message = Message::from_framed(framed).unwrap();
364 assert_eq!(message, unframed_message);
365
366 let bincode_serialized = bincode::serialize(&value).unwrap();
368 let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap();
369 assert_eq!(value, bincode_deserialized);
370 }
371
372 #[test]
373 fn test_specialized_serializer_basic() {
374 test_roundtrip(Part::from("hello"), 1);
375 }
376
377 #[test]
378 fn test_specialized_serializer_compound() {
379 test_roundtrip(vec![Part::from("hello"), Part::from("world")], 2);
380 test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 2);
381 test_roundtrip(
382 {
383 #[derive(Serialize, Deserialize, Debug, PartialEq)]
384 struct U {
385 parts: Vec<Part>,
386 }
387 #[derive(Serialize, Deserialize, Debug, PartialEq)]
388 enum E {
389 First(Part),
390 Second(String),
391 }
392
393 #[derive(Serialize, Deserialize, Debug, PartialEq)]
394 struct T {
395 field2: String,
396 field3: Part,
397 field4: Part,
398 field5: Vec<U>,
399 field6: E,
400 }
401
402 T {
403 field2: "hello".to_string(),
404 field3: Part::from("hello"),
405 field4: Part::from("world"),
406 field5: vec![
407 U {
408 parts: vec![Part::from("hello"), Part::from("world")],
409 },
410 U {
411 parts: vec![Part::from("five"), Part::from("six"), Part::from("seven")],
412 },
413 ],
414 field6: E::First(Part::from("eight")),
415 }
416 },
417 8,
418 );
419 test_roundtrip(
420 {
421 #[derive(Serialize, Deserialize, Debug, PartialEq)]
422 struct T {
423 field1: u64,
424 field2: String,
425 field3: Part,
426 field4: Part,
427 field5: u64,
428 }
429 T {
430 field1: 1,
431 field2: "hello".to_string(),
432 field3: Part::from("hello"),
433 field4: Part::from("world"),
434 field5: 2,
435 }
436 },
437 2,
438 );
439 }
440
441 #[test]
442 fn test_recursive_message() {
443 let message = serialize_bincode(&[Part::from("hello"), Part::from("world")]).unwrap();
444 let message_message = serialize_bincode(&message).unwrap();
445
446 assert_eq!(message_message.num_parts(), 3);
448 }
449
450 #[test]
451 fn test_malformed_messages() {
452 let message = Message {
453 body: Part::from("hello"),
454 parts: vec![Part::from("world")],
455 };
456 let err = deserialize_bincode::<String>(message).unwrap_err();
457
458 assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof);
460
461 let mut message =
462 serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
463 message.parts.push(Part::from("foo"));
464 let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
465 assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding");
466
467 let mut message =
468 serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
469 let _dropped_message = message.parts.pop().unwrap();
470 let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
471 assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
472 }
473
474 #[test]
475 fn test_concat_buf() {
476 let buffers = vec![
477 Bytes::from("hello"),
478 Bytes::from("world"),
479 Bytes::from("1"),
480 Bytes::from(""),
481 Bytes::from("xyz"),
482 Bytes::from("xyzd"),
483 ];
484
485 let mut concat = Frame::from_buffers(buffers.clone());
486
487 assert_eq!(concat.remaining(), 18);
488 concat.advance(2);
489 assert_eq!(concat.remaining(), 16);
490 assert_eq!(concat.chunk(), &b"llo"[..]);
491 concat.advance(4);
492 assert_eq!(concat.chunk(), &b"orld"[..]);
493 concat.advance(5);
494 assert_eq!(concat.chunk(), &b"xyz"[..]);
495
496 let mut concat = Frame::from_buffers(buffers);
497 let bytes = concat.copy_to_bytes(concat.remaining());
498 assert_eq!(&*bytes, &b"helloworld1xyzxyzd"[..]);
499 }
500
501 #[test]
502 fn test_framing() {
503 let message = Message {
504 body: Part::from("hello"),
505 parts: vec![
506 Part::from("world"),
507 Part::from("1"),
508 Part::from(""),
509 Part::from("xyz"),
510 Part::from("xyzd"),
511 ],
512 };
513
514 let mut framed = message.clone().framed();
515 let framed = framed.copy_to_bytes(framed.remaining());
516 assert_eq!(Message::from_framed(framed).unwrap(), message);
517 }
518
519 #[test]
520 fn test_socket_addr() {
521 let socket_addr_v6: SocketAddrV6 =
522 "[2401:db00:225c:2d09:face:0:223:0]:48483".parse().unwrap();
523 {
524 let message = serialize_bincode(&socket_addr_v6).unwrap();
525 let deserialized: SocketAddrV6 = deserialize_bincode(message).unwrap();
526 assert_eq!(socket_addr_v6, deserialized);
527 }
528 let socket_addr = SocketAddr::V6(socket_addr_v6);
529 {
530 let message = serialize_bincode(&socket_addr).unwrap();
531 let deserialized: SocketAddr = deserialize_bincode(message).unwrap();
532 assert_eq!(socket_addr, deserialized);
533 }
534
535 let mut address_book: HashMap<usize, SocketAddr> = HashMap::new();
536 address_book.insert(1, socket_addr);
537 {
538 let message = serialize_bincode(&address_book).unwrap();
539 let deserialized: HashMap<usize, SocketAddr> = deserialize_bincode(message).unwrap();
540 assert_eq!(address_book, deserialized);
541 }
542 }
543
544 prop_compose! {
545 fn arb_bytes()(len in 0..1000000usize) -> Bytes {
546 Bytes::from(vec![42; len])
547 }
548 }
549
550 prop_compose! {
551 fn arb_part()(bytes in arb_bytes()) -> Part {
552 bytes.into()
553 }
554 }
555
556 #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
557 enum TupleEnum {
558 One,
559 Two(String),
560 Three(u32),
561 }
562
563 #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
564 enum StructEnum {
565 One {
566 a: i32,
567 },
568 Two {
569 s: String,
570 },
571 Three {
572 e: TupleEnum,
573 s: String,
574 u: u32,
575 },
576 Four {
577 #[proptest(strategy = "arb_part()")]
578 part: Part,
579 },
580 }
581
582 #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
583 struct S {
584 field: String,
585 tup: (StructEnum, i32, String, u32, f32),
586 tup2: Option<(String, String, String, i32)>,
587 e: StructEnum,
588 maybe_e: Option<StructEnum>,
589 many_e: Vec<(StructEnum, Option<TupleEnum>)>,
590 #[proptest(strategy = "arb_bytes()")]
591 some_bytes: Bytes,
592 }
593
594 #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
595 struct N(S);
596
597 proptest! {
598 #[test]
599 fn test_arbitrary_roundtrip(value in any::<N>()) {
600 let message = serialize_bincode(&value).unwrap();
601 let deserialized_value = deserialize_bincode(message.clone()).unwrap();
602 assert_eq!(value, deserialized_value);
603 }
604 }
605}