1use std::io::{Read, Write};
21
22use flate2::Compression;
23use flate2::read::GzDecoder;
24use flate2::write::GzEncoder;
25use serde::{Deserialize, Serialize};
26
27use crate::{Error, ErrorKind, Result};
28
29#[derive(Debug, PartialEq, Eq, Clone, Copy, Default, Serialize, Deserialize)]
31#[serde(rename_all = "lowercase")]
32pub enum CompressionCodec {
33 #[default]
34 None,
36 Lz4,
38 Zstd,
40 Gzip,
42}
43
44impl CompressionCodec {
45 pub(crate) fn decompress(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
46 match self {
47 CompressionCodec::None => Ok(bytes),
48 CompressionCodec::Lz4 => Err(Error::new(
49 ErrorKind::FeatureUnsupported,
50 "LZ4 decompression is not supported currently",
51 )),
52 CompressionCodec::Zstd => Ok(zstd::stream::decode_all(&bytes[..])?),
53 CompressionCodec::Gzip => {
54 let mut decoder = GzDecoder::new(&bytes[..]);
55 let mut decompressed = Vec::new();
56 decoder.read_to_end(&mut decompressed)?;
57 Ok(decompressed)
58 }
59 }
60 }
61
62 pub(crate) fn compress(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
63 match self {
64 CompressionCodec::None => Ok(bytes),
65 CompressionCodec::Lz4 => Err(Error::new(
66 ErrorKind::FeatureUnsupported,
67 "LZ4 compression is not supported currently",
68 )),
69 CompressionCodec::Zstd => {
70 let writer = Vec::<u8>::new();
71 let mut encoder = zstd::stream::Encoder::new(writer, 3)?;
72 encoder.include_checksum(true)?;
73 encoder.set_pledged_src_size(Some(bytes.len().try_into()?))?;
74 std::io::copy(&mut &bytes[..], &mut encoder)?;
75 Ok(encoder.finish()?)
76 }
77 CompressionCodec::Gzip => {
78 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
79 encoder.write_all(&bytes)?;
80 Ok(encoder.finish()?)
81 }
82 }
83 }
84
85 pub(crate) fn is_none(&self) -> bool {
86 matches!(self, CompressionCodec::None)
87 }
88
89 pub fn suffix(&self) -> Result<&'static str> {
96 match self {
97 CompressionCodec::None => Ok(""),
98 CompressionCodec::Gzip => Ok(".gz"),
99 codec @ (CompressionCodec::Lz4 | CompressionCodec::Zstd) => Err(Error::new(
100 ErrorKind::FeatureUnsupported,
101 format!("suffix not defined for {codec:?}"),
102 )),
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::CompressionCodec;
110
111 #[tokio::test]
112 async fn test_compression_codec_none() {
113 let bytes_vec = [0_u8; 100].to_vec();
114
115 let codec = CompressionCodec::None;
116 let compressed = codec.compress(bytes_vec.clone()).unwrap();
117 assert_eq!(bytes_vec, compressed);
118 let decompressed = codec.decompress(compressed).unwrap();
119 assert_eq!(bytes_vec, decompressed);
120 }
121
122 #[tokio::test]
123 async fn test_compression_codec_compress() {
124 let bytes_vec = [0_u8; 100].to_vec();
125
126 let compression_codecs = [CompressionCodec::Zstd, CompressionCodec::Gzip];
127
128 for codec in compression_codecs {
129 let compressed = codec.compress(bytes_vec.clone()).unwrap();
130 assert!(compressed.len() < bytes_vec.len());
131 let decompressed = codec.decompress(compressed).unwrap();
132 assert_eq!(decompressed, bytes_vec);
133 }
134 }
135
136 #[tokio::test]
137 async fn test_compression_codec_unsupported() {
138 let unsupported_codecs = [(CompressionCodec::Lz4, "LZ4")];
139 let bytes_vec = [0_u8; 100].to_vec();
140
141 for (codec, name) in unsupported_codecs {
142 assert_eq!(
143 codec.compress(bytes_vec.clone()).unwrap_err().to_string(),
144 format!("FeatureUnsupported => {name} compression is not supported currently"),
145 );
146
147 assert_eq!(
148 codec.decompress(bytes_vec.clone()).unwrap_err().to_string(),
149 format!("FeatureUnsupported => {name} decompression is not supported currently"),
150 );
151 }
152 }
153
154 #[test]
155 fn test_suffix() {
156 assert_eq!(CompressionCodec::None.suffix().unwrap(), "");
158 assert_eq!(CompressionCodec::Gzip.suffix().unwrap(), ".gz");
159
160 assert!(CompressionCodec::Lz4.suffix().is_err());
162 assert!(CompressionCodec::Zstd.suffix().is_err());
163
164 let lz4_err = CompressionCodec::Lz4.suffix().unwrap_err();
165 assert!(lz4_err.to_string().contains("suffix not defined for Lz4"));
166
167 let zstd_err = CompressionCodec::Zstd.suffix().unwrap_err();
168 assert!(zstd_err.to_string().contains("suffix not defined for Zstd"));
169 }
170}