iceberg/encryption/
file_encryptor.rs1use std::fmt;
21use std::sync::Arc;
22
23use super::crypto::{AesGcmCipher, SecureKey};
24use super::stream::AesGcmFileWrite;
25use crate::Result;
26use crate::io::FileWrite;
27
28pub struct AesGcmFileEncryptor {
35 cipher: Arc<AesGcmCipher>,
36 aad_prefix: Box<[u8]>,
37}
38
39impl AesGcmFileEncryptor {
40 pub fn new(dek: &[u8], aad_prefix: impl Into<Box<[u8]>>) -> Result<Self> {
42 let key = SecureKey::new(dek)?;
43 let cipher = Arc::new(AesGcmCipher::new(key));
44 Ok(Self {
45 cipher,
46 aad_prefix: aad_prefix.into(),
47 })
48 }
49
50 pub fn wrap_writer(&self, writer: Box<dyn FileWrite>) -> Box<dyn FileWrite> {
52 Box::new(AesGcmFileWrite::new(
53 writer,
54 Arc::clone(&self.cipher),
55 self.aad_prefix.clone(),
56 ))
57 }
58}
59
60impl fmt::Debug for AesGcmFileEncryptor {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("AesGcmFileEncryptor")
63 .field("aad_prefix_len", &self.aad_prefix.len())
64 .finish_non_exhaustive()
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use std::ops::Range;
71
72 use bytes::Bytes;
73
74 use super::*;
75 use crate::encryption::AesGcmFileDecryptor;
76 use crate::io::FileRead;
77
78 struct MemoryFileRead(Bytes);
79
80 #[async_trait::async_trait]
81 impl FileRead for MemoryFileRead {
82 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
83 Ok(self.0.slice(range.start as usize..range.end as usize))
84 }
85 }
86
87 struct MemoryFileWrite {
88 buffer: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
89 }
90
91 #[async_trait::async_trait]
92 impl FileWrite for MemoryFileWrite {
93 async fn write(&mut self, bs: Bytes) -> Result<()> {
94 self.buffer.lock().unwrap().extend_from_slice(&bs);
95 Ok(())
96 }
97
98 async fn close(&mut self) -> Result<()> {
99 Ok(())
100 }
101 }
102
103 #[tokio::test]
104 async fn test_wrap_writer_roundtrip() {
105 let key = b"0123456789abcdef";
106 let aad_prefix = b"test-aad-prefix!";
107 let plaintext = b"Hello from file encryptor!";
108
109 let encryptor = AesGcmFileEncryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
111 let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
112 let mut writer = encryptor.wrap_writer(Box::new(MemoryFileWrite {
113 buffer: buffer.clone(),
114 }));
115 writer.write(Bytes::from(plaintext.to_vec())).await.unwrap();
116 writer.close().await.unwrap();
117 let encrypted = buffer.lock().unwrap().clone();
118 let encrypted_len = encrypted.len() as u64;
119
120 let decryptor = AesGcmFileDecryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
122 let reader = decryptor
123 .wrap_reader(
124 Box::new(MemoryFileRead(Bytes::from(encrypted))),
125 encrypted_len,
126 )
127 .unwrap();
128
129 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
130 assert_eq!(&result[..], plaintext);
131 }
132
133 #[tokio::test]
134 async fn test_invalid_key_length() {
135 let result = AesGcmFileEncryptor::new(b"bad-key", b"aad".as_slice());
136 assert!(result.is_err());
137 }
138}