serde_multipart/
part.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::ops::Deref;
10
11use bytes::Bytes;
12use bytes::BytesMut;
13use bytes::buf::Reader as BufReader;
14use bytes::buf::Writer as BufWriter;
15use serde::Deserialize;
16use serde::Serialize;
17
18use crate::UnsafeBufCellRef;
19use crate::de;
20use crate::ser;
21
22/// Part represents a single part of a multipart message. Its type is simple:
23/// it is just a newtype of the byte buffer [`Bytes`], which permits zero copy
24/// shared ownership of the underlying buffers. Part itself provides a customized
25/// serialization implementation that is specialized for the multipart codecs in
26/// this crate, skipping copying the bytes whenever possible.
27#[derive(Clone, Debug, PartialEq, Eq, Default)]
28pub struct Part(pub(crate) Vec<Bytes>);
29
30impl Part {
31    /// Consumes the part, returning its underlying byte buffers.
32    pub fn into_inner(self) -> Vec<Bytes> {
33        self.0
34    }
35
36    /// Consumes the part, concatenating fragments if necessary into a single byte buffer.
37    pub fn into_bytes(self) -> Bytes {
38        match self.0.len() {
39            0 => Bytes::new(),
40            1 => self.0.into_iter().next().unwrap(),
41            _ => {
42                let total_len: usize = self.0.iter().map(|p| p.len()).sum();
43                let mut result = BytesMut::with_capacity(total_len);
44                for fragment in self.0 {
45                    result.extend_from_slice(&fragment);
46                }
47                result.freeze()
48            }
49        }
50    }
51
52    /// Get bytes as a reference, concatenating fragments if necessary.
53    pub fn to_bytes(&self) -> Bytes {
54        match self.0.len() {
55            0 => Bytes::new(),
56            1 => self.0.first().unwrap().clone(),
57            _ => {
58                let total_len: usize = self.0.iter().map(|p| p.len()).sum();
59                let mut result = BytesMut::with_capacity(total_len);
60                for fragment in &self.0 {
61                    result.extend_from_slice(fragment);
62                }
63                result.freeze()
64            }
65        }
66    }
67
68    /// Returns the total length in bytes.
69    pub fn len(&self) -> usize {
70        self.0.iter().map(|b| b.len()).sum()
71    }
72
73    /// Returns the number of fragments
74    pub fn num_fragments(&self) -> usize {
75        self.0.len()
76    }
77
78    /// Returns whether the part is empty.
79    pub fn is_empty(&self) -> bool {
80        self.0.iter().all(|b| b.is_empty())
81    }
82
83    pub fn from_fragments(fragments: Vec<Bytes>) -> Self {
84        Self(fragments)
85    }
86}
87
88impl<T: Into<Bytes>> From<T> for Part {
89    fn from(bytes: T) -> Self {
90        Self(vec![bytes.into()])
91    }
92}
93
94impl Deref for Part {
95    type Target = Vec<Bytes>;
96
97    fn deref(&self) -> &Self::Target {
98        &self.0
99    }
100}
101
102impl Serialize for Part {
103    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
104        <Part as PartSerializer<S>>::serialize(self, s)
105    }
106}
107
108impl<'de> Deserialize<'de> for Part {
109    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
110        <Part as PartDeserializer<'de, D>>::deserialize(d)
111    }
112}
113
114/// PartSerializer is the trait that selects serialization strategy based on the
115/// the serializer's type.
116pub trait PartSerializer<S: serde::Serializer> {
117    fn serialize(this: &Part, s: S) -> Result<S::Ok, S::Error>;
118}
119
120/// By default, we use the underlying byte serializer, which copies the underlying bytes
121/// into the serialization buffer.
122impl<S: serde::Serializer> PartSerializer<S> for Part {
123    default fn serialize(this: &Part, s: S) -> Result<S::Ok, S::Error> {
124        // Normal serializer: concatenate into contiguous byte chunk (requires copy).
125        this.to_bytes().serialize(s)
126    }
127}
128
129/// The options type used by the underlying bincode codec. We capture this here to make sure
130/// we consistently use the type, which is required to correctly specialize the multipart codec.
131pub(crate) type BincodeOptionsType = bincode::config::WithOtherTrailing<
132    bincode::config::WithOtherIntEncoding<bincode::DefaultOptions, bincode::config::FixintEncoding>,
133    bincode::config::AllowTrailing,
134>;
135
136/// The serializer type used by the underlying bincode codec. We capture this here to make sure
137/// we consistently use the type, which is required to correctly specialize the multipart codec.
138pub(crate) type BincodeSerializer =
139    ser::bincode::Serializer<BufWriter<UnsafeBufCellRef>, BincodeOptionsType>;
140
141/// Specialized implementaiton for our multipart serializer.
142impl<'a> PartSerializer<&'a mut BincodeSerializer> for Part {
143    fn serialize(this: &Part, s: &'a mut BincodeSerializer) -> Result<(), bincode::Error> {
144        s.serialize_part(this);
145        Ok(())
146    }
147}
148
149/// PartDeserializer is the trait that selects serialization strategy based on the
150/// the deserializer's type.
151trait PartDeserializer<'de, S: serde::Deserializer<'de>>: Sized {
152    fn deserialize(this: S) -> Result<Self, S::Error>;
153}
154
155/// By default, we use the underlying byte deserializer, which copies the serialized bytes
156/// into the value directly.
157impl<'de, D: serde::Deserializer<'de>> PartDeserializer<'de, D> for Part {
158    default fn deserialize(deserializer: D) -> Result<Self, D::Error> {
159        Ok(Part(vec![Bytes::deserialize(deserializer)?]))
160    }
161}
162
163/// The deserializer type used by the underlying bincode codec. We capture this here to make sure
164/// we consistently use the type, which is required to correctly specialize the multipart codec.
165pub(crate) type BincodeDeserializer =
166    de::bincode::Deserializer<bincode::de::read::IoReader<BufReader<Bytes>>, BincodeOptionsType>;
167
168/// Specialized implementation for our multipart deserializer.
169impl<'de, 'a> PartDeserializer<'de, &'a mut BincodeDeserializer> for Part {
170    fn deserialize(deserializer: &'a mut BincodeDeserializer) -> Result<Self, bincode::Error> {
171        deserializer.deserialize_part()
172    }
173}