iceberg/encryption/
file_decryptor.rs1use std::fmt;
21use std::sync::Arc;
22
23use super::crypto::{AesGcmCipher, SecureKey};
24use super::stream::AesGcmFileRead;
25use crate::Result;
26use crate::io::FileRead;
27
28pub struct AesGcmFileDecryptor {
33 cipher: Arc<AesGcmCipher>,
34 aad_prefix: Box<[u8]>,
35}
36
37impl AesGcmFileDecryptor {
38 pub fn new(dek: &[u8], aad_prefix: impl Into<Box<[u8]>>) -> Result<Self> {
40 let key = SecureKey::new(dek)?;
41 let cipher = Arc::new(AesGcmCipher::new(key));
42 Ok(Self {
43 cipher,
44 aad_prefix: aad_prefix.into(),
45 })
46 }
47
48 pub fn wrap_reader(
50 &self,
51 reader: Box<dyn FileRead>,
52 encrypted_file_length: u64,
53 ) -> Result<Box<dyn FileRead>> {
54 let decrypting = AesGcmFileRead::new(
55 reader,
56 Arc::clone(&self.cipher),
57 self.aad_prefix.clone(),
58 encrypted_file_length,
59 )?;
60 Ok(Box::new(decrypting))
61 }
62
63 pub fn plaintext_length(&self, encrypted_file_length: u64) -> Result<u64> {
65 AesGcmFileRead::calculate_plaintext_length(encrypted_file_length)
66 }
67}
68
69impl fmt::Debug for AesGcmFileDecryptor {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("AesGcmFileDecryptor")
72 .field("aad_prefix_len", &self.aad_prefix.len())
73 .finish_non_exhaustive()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::ops::Range;
80
81 use bytes::Bytes;
82
83 use super::*;
84 use crate::encryption::AesGcmFileEncryptor;
85 use crate::io::FileWrite;
86
87 struct MemoryFileRead(Bytes);
88
89 #[async_trait::async_trait]
90 impl FileRead for MemoryFileRead {
91 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
92 Ok(self.0.slice(range.start as usize..range.end as usize))
93 }
94 }
95
96 struct MemoryFileWrite {
97 buffer: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
98 }
99
100 #[async_trait::async_trait]
101 impl FileWrite for MemoryFileWrite {
102 async fn write(&mut self, bs: Bytes) -> Result<()> {
103 self.buffer.lock().unwrap().extend_from_slice(&bs);
104 Ok(())
105 }
106
107 async fn close(&mut self) -> Result<()> {
108 Ok(())
109 }
110 }
111
112 #[tokio::test]
113 async fn test_wrap_reader_roundtrip() {
114 let key = b"0123456789abcdef";
115 let aad_prefix = b"test-aad-prefix!";
116 let plaintext = b"Hello from file decryptor!";
117
118 let encryptor = AesGcmFileEncryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
120 let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
121 let mut writer = encryptor.wrap_writer(Box::new(MemoryFileWrite {
122 buffer: buffer.clone(),
123 }));
124 writer.write(Bytes::from(plaintext.to_vec())).await.unwrap();
125 writer.close().await.unwrap();
126 let encrypted = buffer.lock().unwrap().clone();
127 let encrypted_len = encrypted.len() as u64;
128
129 let decryptor = AesGcmFileDecryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
131 let reader = decryptor
132 .wrap_reader(
133 Box::new(MemoryFileRead(Bytes::from(encrypted))),
134 encrypted_len,
135 )
136 .unwrap();
137
138 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
139 assert_eq!(&result[..], plaintext);
140 }
141
142 #[tokio::test]
143 async fn test_invalid_key_length() {
144 let result = AesGcmFileDecryptor::new(b"too-short", b"aad".as_slice());
145 assert!(result.is_err());
146 }
147
148 #[tokio::test]
149 async fn test_plaintext_length() {
150 let decryptor = AesGcmFileDecryptor::new(b"0123456789abcdef", b"aad".as_slice()).unwrap();
151 let encrypted_len = 8 + 12 + 10 + 16;
153 let plain_len = decryptor.plaintext_length(encrypted_len).unwrap();
154 assert_eq!(plain_len, 10);
155 }
156}