1use std::fmt;
21use std::io::{Read, Write};
22
23use flate2::Compression;
24use flate2::read::GzDecoder;
25use flate2::write::GzEncoder;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27
28use crate::{Error, ErrorKind, Result};
29
30const ZSTD_DEFAULT_LEVEL: u8 = 3;
32const GZIP_DEFAULT_LEVEL: u8 = 6;
34const GZIP_MAX_LEVEL: u8 = 9;
36
37#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
39pub enum CompressionCodec {
40 #[default]
41 None,
43 Lz4,
45 Zstd(u8),
49 Gzip(u8),
52 Snappy,
54}
55
56impl CompressionCodec {
57 pub const fn zstd_default() -> Self {
59 CompressionCodec::Zstd(ZSTD_DEFAULT_LEVEL)
60 }
61
62 pub const fn gzip_default() -> Self {
64 CompressionCodec::Gzip(GZIP_DEFAULT_LEVEL)
65 }
66
67 pub fn name(&self) -> &'static str {
69 match self {
70 CompressionCodec::None => "none",
71 CompressionCodec::Lz4 => "lz4",
72 CompressionCodec::Zstd(_) => "zstd",
73 CompressionCodec::Gzip(_) => "gzip",
74 CompressionCodec::Snappy => "snappy",
75 }
76 }
77}
78
79impl Serialize for CompressionCodec {
84 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
85 serializer.serialize_str(self.name())
86 }
87}
88
89impl<'de> Deserialize<'de> for CompressionCodec {
90 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
91 let s = String::deserialize(deserializer)?;
92 match s.to_lowercase().as_str() {
93 "none" => Ok(CompressionCodec::None),
94 "lz4" => Ok(CompressionCodec::Lz4),
95 "zstd" => Ok(CompressionCodec::zstd_default()),
96 "gzip" => Ok(CompressionCodec::gzip_default()),
97 "snappy" => Ok(CompressionCodec::Snappy),
98 other => Err(serde::de::Error::unknown_variant(other, &[
99 "none", "lz4", "zstd", "gzip", "snappy",
100 ])),
101 }
102 }
103}
104
105impl fmt::Display for CompressionCodec {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match self {
108 CompressionCodec::None => write!(f, "None"),
109 CompressionCodec::Lz4 => write!(f, "Lz4"),
110 CompressionCodec::Zstd(level) => write!(f, "Zstd(level={level})"),
111 CompressionCodec::Gzip(level) => write!(f, "Gzip(level={level})"),
112 CompressionCodec::Snappy => write!(f, "Snappy"),
113 }
114 }
115}
116
117impl CompressionCodec {
118 pub(crate) fn decompress(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
119 match self {
120 CompressionCodec::None => Ok(bytes),
121 CompressionCodec::Lz4 => Err(Error::new(
122 ErrorKind::FeatureUnsupported,
123 "LZ4 decompression is not supported currently",
124 )),
125 CompressionCodec::Zstd(_) => Ok(zstd::stream::decode_all(&bytes[..])?),
126 CompressionCodec::Gzip(_) => {
127 let mut decoder = GzDecoder::new(&bytes[..]);
128 let mut decompressed = Vec::new();
129 decoder.read_to_end(&mut decompressed)?;
130 Ok(decompressed)
131 }
132 CompressionCodec::Snappy => Err(Error::new(
133 ErrorKind::FeatureUnsupported,
134 "Snappy decompression is not supported currently",
135 )),
136 }
137 }
138
139 pub(crate) fn compress(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
140 match self {
141 CompressionCodec::None => Ok(bytes),
142 CompressionCodec::Lz4 => Err(Error::new(
143 ErrorKind::FeatureUnsupported,
144 "LZ4 compression is not supported currently",
145 )),
146 CompressionCodec::Zstd(level) => {
147 let writer = Vec::<u8>::new();
148 let mut encoder = zstd::stream::Encoder::new(writer, *level as i32)?;
149 encoder.include_checksum(true)?;
150 encoder.set_pledged_src_size(Some(bytes.len().try_into()?))?;
151 std::io::copy(&mut &bytes[..], &mut encoder)?;
152 Ok(encoder.finish()?)
153 }
154 CompressionCodec::Gzip(level) => {
155 let compression = Compression::new((*level).min(GZIP_MAX_LEVEL) as u32);
156 let mut encoder = GzEncoder::new(Vec::new(), compression);
157 encoder.write_all(&bytes)?;
158 Ok(encoder.finish()?)
159 }
160 CompressionCodec::Snappy => Err(Error::new(
161 ErrorKind::FeatureUnsupported,
162 "Snappy compression is not supported currently",
163 )),
164 }
165 }
166
167 pub(crate) fn is_none(&self) -> bool {
168 matches!(self, CompressionCodec::None)
169 }
170
171 pub fn suffix(&self) -> Result<&'static str> {
178 match self {
179 CompressionCodec::None => Ok(""),
180 CompressionCodec::Gzip(_) => Ok(".gz"),
181 codec @ (CompressionCodec::Lz4
182 | CompressionCodec::Zstd(_)
183 | CompressionCodec::Snappy) => Err(Error::new(
184 ErrorKind::FeatureUnsupported,
185 format!("suffix not defined for {codec:?}"),
186 )),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::CompressionCodec;
194
195 #[tokio::test]
196 async fn test_compression_codec_none() {
197 let bytes_vec = [0_u8; 100].to_vec();
198
199 let codec = CompressionCodec::None;
200 let compressed = codec.compress(bytes_vec.clone()).unwrap();
201 assert_eq!(bytes_vec, compressed);
202 let decompressed = codec.decompress(compressed).unwrap();
203 assert_eq!(bytes_vec, decompressed);
204 }
205
206 #[tokio::test]
207 async fn test_compression_codec_compress() {
208 let bytes_vec = [0_u8; 100].to_vec();
209
210 let compression_codecs = [
211 CompressionCodec::zstd_default(),
212 CompressionCodec::gzip_default(),
213 ];
214
215 for codec in compression_codecs {
216 let compressed = codec.compress(bytes_vec.clone()).unwrap();
217 assert!(compressed.len() < bytes_vec.len());
218 let decompressed = codec.decompress(compressed).unwrap();
219 assert_eq!(decompressed, bytes_vec);
220 }
221 }
222
223 #[tokio::test]
224 async fn test_compression_codec_unsupported() {
225 let unsupported_codecs = [
226 (CompressionCodec::Lz4, "LZ4"),
227 (CompressionCodec::Snappy, "Snappy"),
228 ];
229 let bytes_vec = [0_u8; 100].to_vec();
230
231 for (codec, name) in unsupported_codecs {
232 assert_eq!(
233 codec.compress(bytes_vec.clone()).unwrap_err().to_string(),
234 format!("FeatureUnsupported => {name} compression is not supported currently"),
235 );
236
237 assert_eq!(
238 codec.decompress(bytes_vec.clone()).unwrap_err().to_string(),
239 format!("FeatureUnsupported => {name} decompression is not supported currently"),
240 );
241 }
242 }
243
244 #[test]
245 fn test_suffix() {
246 assert_eq!(CompressionCodec::None.suffix().unwrap(), "");
247 assert_eq!(CompressionCodec::gzip_default().suffix().unwrap(), ".gz");
248
249 assert!(CompressionCodec::Lz4.suffix().is_err());
250 assert!(CompressionCodec::zstd_default().suffix().is_err());
251 assert!(CompressionCodec::Snappy.suffix().is_err());
252
253 let lz4_err = CompressionCodec::Lz4.suffix().unwrap_err();
254 assert!(lz4_err.to_string().contains("suffix not defined for Lz4"));
255
256 let zstd_err = CompressionCodec::zstd_default().suffix().unwrap_err();
257 assert!(zstd_err.to_string().contains("suffix not defined for Zstd"));
258 }
259
260 #[test]
261 fn test_display() {
262 assert_eq!(CompressionCodec::None.to_string(), "None");
263 assert_eq!(CompressionCodec::Lz4.to_string(), "Lz4");
264 assert_eq!(
265 CompressionCodec::zstd_default().to_string(),
266 "Zstd(level=3)"
267 );
268 assert_eq!(CompressionCodec::Zstd(5).to_string(), "Zstd(level=5)");
269 assert_eq!(
270 CompressionCodec::gzip_default().to_string(),
271 "Gzip(level=6)"
272 );
273 assert_eq!(CompressionCodec::Gzip(9).to_string(), "Gzip(level=9)");
274 assert_eq!(CompressionCodec::Snappy.to_string(), "Snappy");
275 }
276}