serde_multipart/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Serde codec for multipart messages.
10//!
11//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] are extracted
12//! from the main payload and appended to a list of `parts`. Each part is backed by
13//! [`Vec<bytes::Bytes>`] for cheap, zero-copy sharing.
14//!
15//! On decode, the body and its parts are reassembled into the original value
16//! without copying.
17//!
18//! The on-the-wire form is a [`Message`] (body + parts). Your transport sends
19//! and receives [`Message`]s; the codec reconstructs the value, enabling
20//! efficient network I/O without compacting data into a single buffer.
21//!
22//! Implementation note: this crate uses Rust's min_specialization feature to enable
23//! the use of [`Part`]s with any Serde serializer or deserializer. This feature
24//! is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`]
25//! is not customizable. If customization is needed, you need to add specialization
26//! implementations for these codecs. See [`part::PartSerializer`] and [`part::PartDeserializer`]
27//! for details.
28
29#![feature(min_specialization)]
30#![feature(assert_matches)]
31
32use std::cell::UnsafeCell;
33use std::cmp::min;
34use std::collections::VecDeque;
35use std::io::IoSlice;
36use std::ptr::NonNull;
37
38use bincode::Options;
39use bytes::Buf;
40use bytes::BufMut;
41use bytes::buf::UninitSlice;
42
43mod de;
44mod part;
45mod ser;
46use bytes::Bytes;
47use bytes::BytesMut;
48pub use part::Part;
49use serde::Deserialize;
50use serde::Serialize;
51
52/// A multi-part message, comprising a message body and a list of parts.
53/// Messages only contain references to underlying byte buffers and are
54/// cheaply cloned.
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56pub struct Message {
57    body: Part,
58    parts: Vec<Part>,
59}
60
61impl Message {
62    /// Returns a new message with the given body and parts.
63    pub fn from_body_and_parts(body: Part, parts: Vec<Part>) -> Self {
64        Self { body, parts }
65    }
66
67    /// The body of the message.
68    pub fn body(&self) -> &Part {
69        &self.body
70    }
71
72    /// The list of parts of the message.
73    pub fn parts(&self) -> &[Part] {
74        &self.parts
75    }
76
77    /// Returns the total number of parts (excluding the body) in the message.
78    pub fn num_parts(&self) -> usize {
79        self.parts.len()
80    }
81
82    /// Returns the total size (in bytes) of the message.
83    pub fn len(&self) -> usize {
84        self.body.len() + self.parts.iter().map(|part| part.len()).sum::<usize>()
85    }
86
87    /// Returns whether the message is empty. It is always false, since the body
88    /// is always defined.
89    pub fn is_empty(&self) -> bool {
90        self.body.is_empty() && self.parts.iter().all(|part| part.is_empty())
91    }
92
93    /// Convert this message into its constituent components.
94    pub fn into_inner(self) -> (Part, Vec<Part>) {
95        (self.body, self.parts)
96    }
97
98    /// Returns the total size (in bytes) of the message when it is framed.
99    pub fn frame_len(&self) -> usize {
100        8 + self.body.len()
101            + (8 * self.parts.len())
102            + self.parts.iter().map(|p| p.len()).sum::<usize>()
103    }
104
105    /// Efficiently frames a message containing the body and all of its parts
106    /// using a simple frame-length encoding:
107    ///
108    /// ```text
109    /// +--------------------+-------------------+--------------------+-------------------+   ...   +
110    /// | body_len (u64 BE)  |   body bytes      | part1_len (u64 BE) |   part1 bytes     |         |
111    /// +--------------------+-------------------+--------------------+-------------------+         +
112    ///                                                                                      repeat
113    ///                                                                                        for
114    ///                                                                                      each part
115    /// ```
116    pub fn framed(self) -> Frame {
117        let (body, parts) = self.into_inner();
118
119        let mut buffers = Vec::with_capacity(
120            1 + body.num_fragments()
121                + parts.len()
122                + parts.iter().map(|part| part.num_fragments()).sum::<usize>(),
123        );
124
125        buffers.push(Bytes::from_owner(body.len().to_be_bytes()));
126        for fragment in body.into_inner() {
127            buffers.push(fragment);
128        }
129
130        for part in parts {
131            buffers.push(Bytes::from_owner(part.len().to_be_bytes()));
132            for fragment in part.into_inner() {
133                buffers.push(fragment);
134            }
135        }
136
137        Frame::from_buffers(buffers)
138    }
139
140    /// Reassembles a message from a framed encoding.
141    pub fn from_framed(mut buf: Bytes) -> Result<Self, std::io::Error> {
142        if buf.len() < 8 {
143            return Err(std::io::ErrorKind::UnexpectedEof.into());
144        }
145        let body_len = buf.get_u64();
146        let body = buf.split_to(body_len as usize);
147        let mut parts = Vec::new();
148        while !buf.is_empty() {
149            parts.push(Self::split_part(&mut buf)?.into());
150        }
151        Ok(Self {
152            body: body.into(),
153            parts,
154        })
155    }
156
157    fn split_part(buf: &mut Bytes) -> Result<Bytes, std::io::Error> {
158        if buf.len() < 8 {
159            return Err(std::io::ErrorKind::UnexpectedEof.into());
160        }
161        let at = buf.get_u64() as usize;
162        if buf.len() < at {
163            return Err(std::io::ErrorKind::UnexpectedEof.into());
164        }
165        Ok(buf.split_to(at))
166    }
167}
168
169/// An encoded [`Message`] frame. Implements [`bytes::Buf`],
170/// and supports vectored writes. Thus, `Frame` is like a reader
171/// of an encoded [`Message`].
172#[derive(Clone)]
173pub struct Frame {
174    buffers: VecDeque<Bytes>,
175}
176
177impl Frame {
178    /// Construct a new frame from the provided buffers. The frame is a
179    /// concatenation of these buffers.
180    fn from_buffers(buffers: Vec<Bytes>) -> Self {
181        let mut buffers: VecDeque<Bytes> = buffers.into();
182        buffers.retain(|buf| !buf.is_empty());
183        Self { buffers }
184    }
185
186    /// **DO NOT USE THIS**
187    pub fn illegal_unipart_frame(body: Bytes) -> Self {
188        Self {
189            buffers: vec![body].into(),
190        }
191    }
192}
193
194impl Buf for Frame {
195    fn remaining(&self) -> usize {
196        self.buffers.iter().map(|buf| buf.remaining()).sum()
197    }
198
199    fn chunk(&self) -> &[u8] {
200        match self.buffers.front() {
201            Some(buf) => buf.chunk(),
202            None => &[],
203        }
204    }
205
206    fn advance(&mut self, mut cnt: usize) {
207        while cnt > 0 {
208            let Some(buf) = self.buffers.front_mut() else {
209                panic!("advanced beyond the buffer size");
210            };
211
212            if cnt >= buf.remaining() {
213                cnt -= buf.remaining();
214                self.buffers.pop_front();
215                continue;
216            }
217
218            buf.advance(cnt);
219            cnt = 0;
220        }
221    }
222
223    // We implement our own chunks_vectored here, as the default implementation
224    // does not do any vectoring (returning only a single IoSlice at a time).
225    fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
226        let n = min(dst.len(), self.buffers.len());
227        for i in 0..n {
228            dst[i] = IoSlice::new(self.buffers[i].chunk());
229        }
230        n
231    }
232}
233
234/// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer
235/// for the serializer without exposing lifetime parameters (which cannot be)
236/// specialized.
237struct UnsafeBufCell {
238    buf: UnsafeCell<BytesMut>,
239}
240
241impl UnsafeBufCell {
242    /// Create a new cell from a [`BytesMut`].
243    fn from_bytes_mut(bytes: BytesMut) -> Self {
244        Self {
245            buf: UnsafeCell::new(bytes),
246        }
247    }
248
249    /// Convert this cell into its underlying [`BytesMut`].
250    fn into_inner(self) -> BytesMut {
251        self.buf.into_inner()
252    }
253
254    /// Borrow the cell, without lifetime checks. The caller must guarantee that
255    /// the returned cell cannot be used after the cell is dropped (usually through
256    /// [`UnsafeBufCell::into_inner`]).
257    unsafe fn borrow_unchecked(&self) -> UnsafeBufCellRef {
258        let ptr =
259            // SAFETY: the user is providing the necessary invariants
260            unsafe { NonNull::new_unchecked(self.buf.get()) };
261        UnsafeBufCellRef { ptr }
262    }
263}
264
265/// A borrowed reference to an [`UnsafeBufCell`].
266struct UnsafeBufCellRef {
267    ptr: NonNull<BytesMut>,
268}
269
270/// SAFETY: we're extending the implementation of the underlying [`BytesMut`];
271/// adding an additional layer of danger by disregarding lifetimes.
272unsafe impl BufMut for UnsafeBufCellRef {
273    fn remaining_mut(&self) -> usize {
274        // SAFETY: extending the implementation of the underlying [`BytesMut`]
275        unsafe { self.ptr.as_ref().remaining_mut() }
276    }
277
278    unsafe fn advance_mut(&mut self, cnt: usize) {
279        // SAFETY: extending the implementation of the underlying [`BytesMut`]
280        unsafe { self.ptr.as_mut().advance_mut(cnt) }
281    }
282
283    fn chunk_mut(&mut self) -> &mut UninitSlice {
284        // SAFETY: extending the implementation of the underlying [`BytesMut`]
285        unsafe { self.ptr.as_mut().chunk_mut() }
286    }
287}
288
289/// Serialize the provided value into a multipart message. The value is encoded using an
290/// extended version of [`bincode`] that skips serializing [`Part`]s, which are instead
291/// held directly by the returned message.
292///
293/// Serialize uses the same codec options as [`bincode::serialize`] / [`bincode::deserialize`].
294/// These are currently not customizable unless an explicit specialization is also provided.
295pub fn serialize_bincode<S: ?Sized + serde::Serialize>(
296    value: &S,
297) -> Result<Message, bincode::Error> {
298    let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new());
299    // SAFETY: we know here that, once the below "value.serialize()" is done, there are no more
300    // extant references to this buffer; we are thus safe to reclaim the buffer into the message
301    let buffer_borrow = unsafe { buffer.borrow_unchecked() };
302    let mut serializer: part::BincodeSerializer =
303        ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options()));
304    value.serialize(&mut serializer)?;
305    Ok(Message {
306        body: Part(vec![buffer.into_inner().freeze()]),
307        parts: serializer.into_parts(),
308    })
309}
310
311/// Deserialize a message serialized by `[serialize]`, stitching together the original
312/// message without copying the underlying buffers.
313pub fn deserialize_bincode<T>(message: Message) -> Result<T, bincode::Error>
314where
315    T: serde::de::DeserializeOwned,
316{
317    let (body, parts) = message.into_inner();
318    let mut deserializer = part::BincodeDeserializer::new(
319        bincode::Deserializer::with_reader(body.into_bytes().reader(), options()),
320        parts.into(),
321    );
322    let value = T::deserialize(&mut deserializer)?;
323    // Check that all parts were consumed:
324    deserializer.end()?;
325    Ok(value)
326}
327
328/// Construct the set of options used by the specialized serializer and deserializer.
329fn options() -> part::BincodeOptionsType {
330    bincode::DefaultOptions::new()
331        .with_fixint_encoding()
332        .allow_trailing_bytes()
333}
334
335#[cfg(test)]
336mod tests {
337    use std::assert_matches::assert_matches;
338    use std::collections::HashMap;
339    use std::net::SocketAddr;
340    use std::net::SocketAddrV6;
341
342    use proptest::prelude::*;
343    use proptest_derive::Arbitrary;
344    use serde::Deserialize;
345    use serde::Serialize;
346    use serde::de::DeserializeOwned;
347
348    use super::*;
349
350    fn test_roundtrip<T>(value: T, expected_parts: usize)
351    where
352        T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
353    {
354        // Test plain serialization roundtrip:
355        let message = serialize_bincode(&value).unwrap();
356        assert_eq!(message.num_parts(), expected_parts);
357        let deserialized_value = deserialize_bincode(message.clone()).unwrap();
358        assert_eq!(value, deserialized_value);
359
360        // Framing roundtrip:
361        let mut framed = message.clone().framed();
362        let framed = framed.copy_to_bytes(framed.remaining());
363        let unframed_message = Message::from_framed(framed).unwrap();
364        assert_eq!(message, unframed_message);
365
366        // Bincode passthrough:
367        let bincode_serialized = bincode::serialize(&value).unwrap();
368        let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap();
369        assert_eq!(value, bincode_deserialized);
370    }
371
372    #[test]
373    fn test_specialized_serializer_basic() {
374        test_roundtrip(Part::from("hello"), 1);
375    }
376
377    #[test]
378    fn test_specialized_serializer_compound() {
379        test_roundtrip(vec![Part::from("hello"), Part::from("world")], 2);
380        test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 2);
381        test_roundtrip(
382            {
383                #[derive(Serialize, Deserialize, Debug, PartialEq)]
384                struct U {
385                    parts: Vec<Part>,
386                }
387                #[derive(Serialize, Deserialize, Debug, PartialEq)]
388                enum E {
389                    First(Part),
390                    Second(String),
391                }
392
393                #[derive(Serialize, Deserialize, Debug, PartialEq)]
394                struct T {
395                    field2: String,
396                    field3: Part,
397                    field4: Part,
398                    field5: Vec<U>,
399                    field6: E,
400                }
401
402                T {
403                    field2: "hello".to_string(),
404                    field3: Part::from("hello"),
405                    field4: Part::from("world"),
406                    field5: vec![
407                        U {
408                            parts: vec![Part::from("hello"), Part::from("world")],
409                        },
410                        U {
411                            parts: vec![Part::from("five"), Part::from("six"), Part::from("seven")],
412                        },
413                    ],
414                    field6: E::First(Part::from("eight")),
415                }
416            },
417            8,
418        );
419        test_roundtrip(
420            {
421                #[derive(Serialize, Deserialize, Debug, PartialEq)]
422                struct T {
423                    field1: u64,
424                    field2: String,
425                    field3: Part,
426                    field4: Part,
427                    field5: u64,
428                }
429                T {
430                    field1: 1,
431                    field2: "hello".to_string(),
432                    field3: Part::from("hello"),
433                    field4: Part::from("world"),
434                    field5: 2,
435                }
436            },
437            2,
438        );
439    }
440
441    #[test]
442    fn test_recursive_message() {
443        let message = serialize_bincode(&[Part::from("hello"), Part::from("world")]).unwrap();
444        let message_message = serialize_bincode(&message).unwrap();
445
446        // message.body + message.parts (x2):
447        assert_eq!(message_message.num_parts(), 3);
448    }
449
450    #[test]
451    fn test_malformed_messages() {
452        let message = Message {
453            body: Part::from("hello"),
454            parts: vec![Part::from("world")],
455        };
456        let err = deserialize_bincode::<String>(message).unwrap_err();
457
458        // Normal bincode errors work:
459        assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof);
460
461        let mut message =
462            serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
463        message.parts.push(Part::from("foo"));
464        let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
465        assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding");
466
467        let mut message =
468            serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
469        let _dropped_message = message.parts.pop().unwrap();
470        let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
471        assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
472    }
473
474    #[test]
475    fn test_concat_buf() {
476        let buffers = vec![
477            Bytes::from("hello"),
478            Bytes::from("world"),
479            Bytes::from("1"),
480            Bytes::from(""),
481            Bytes::from("xyz"),
482            Bytes::from("xyzd"),
483        ];
484
485        let mut concat = Frame::from_buffers(buffers.clone());
486
487        assert_eq!(concat.remaining(), 18);
488        concat.advance(2);
489        assert_eq!(concat.remaining(), 16);
490        assert_eq!(concat.chunk(), &b"llo"[..]);
491        concat.advance(4);
492        assert_eq!(concat.chunk(), &b"orld"[..]);
493        concat.advance(5);
494        assert_eq!(concat.chunk(), &b"xyz"[..]);
495
496        let mut concat = Frame::from_buffers(buffers);
497        let bytes = concat.copy_to_bytes(concat.remaining());
498        assert_eq!(&*bytes, &b"helloworld1xyzxyzd"[..]);
499    }
500
501    #[test]
502    fn test_framing() {
503        let message = Message {
504            body: Part::from("hello"),
505            parts: vec![
506                Part::from("world"),
507                Part::from("1"),
508                Part::from(""),
509                Part::from("xyz"),
510                Part::from("xyzd"),
511            ],
512        };
513
514        let mut framed = message.clone().framed();
515        let framed = framed.copy_to_bytes(framed.remaining());
516        assert_eq!(Message::from_framed(framed).unwrap(), message);
517    }
518
519    #[test]
520    fn test_socket_addr() {
521        let socket_addr_v6: SocketAddrV6 =
522            "[2401:db00:225c:2d09:face:0:223:0]:48483".parse().unwrap();
523        {
524            let message = serialize_bincode(&socket_addr_v6).unwrap();
525            let deserialized: SocketAddrV6 = deserialize_bincode(message).unwrap();
526            assert_eq!(socket_addr_v6, deserialized);
527        }
528        let socket_addr = SocketAddr::V6(socket_addr_v6);
529        {
530            let message = serialize_bincode(&socket_addr).unwrap();
531            let deserialized: SocketAddr = deserialize_bincode(message).unwrap();
532            assert_eq!(socket_addr, deserialized);
533        }
534
535        let mut address_book: HashMap<usize, SocketAddr> = HashMap::new();
536        address_book.insert(1, socket_addr);
537        {
538            let message = serialize_bincode(&address_book).unwrap();
539            let deserialized: HashMap<usize, SocketAddr> = deserialize_bincode(message).unwrap();
540            assert_eq!(address_book, deserialized);
541        }
542    }
543
544    prop_compose! {
545        fn arb_bytes()(len in 0..1000000usize) -> Bytes {
546            Bytes::from(vec![42; len])
547        }
548    }
549
550    prop_compose! {
551        fn arb_part()(bytes in arb_bytes()) -> Part {
552            bytes.into()
553        }
554    }
555
556    #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
557    enum TupleEnum {
558        One,
559        Two(String),
560        Three(u32),
561    }
562
563    #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
564    enum StructEnum {
565        One {
566            a: i32,
567        },
568        Two {
569            s: String,
570        },
571        Three {
572            e: TupleEnum,
573            s: String,
574            u: u32,
575        },
576        Four {
577            #[proptest(strategy = "arb_part()")]
578            part: Part,
579        },
580    }
581
582    #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
583    struct S {
584        field: String,
585        tup: (StructEnum, i32, String, u32, f32),
586        tup2: Option<(String, String, String, i32)>,
587        e: StructEnum,
588        maybe_e: Option<StructEnum>,
589        many_e: Vec<(StructEnum, Option<TupleEnum>)>,
590        #[proptest(strategy = "arb_bytes()")]
591        some_bytes: Bytes,
592    }
593
594    #[derive(Arbitrary, Serialize, Deserialize, Debug, PartialEq)]
595    struct N(S);
596
597    proptest! {
598        #[test]
599        fn test_arbitrary_roundtrip(value in any::<N>()) {
600            let message = serialize_bincode(&value).unwrap();
601            let deserialized_value = deserialize_bincode(message.clone()).unwrap();
602            assert_eq!(value, deserialized_value);
603        }
604    }
605}