1use std::fmt;
21use std::str::FromStr;
22
23use aes_gcm::aead::generic_array::typenum::U12;
24use aes_gcm::aead::rand_core::RngCore;
25use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng, Payload};
26use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, Nonce};
27use zeroize::Zeroizing;
28
29type Aes192Gcm = AesGcm<aes_gcm::aes::Aes192, U12>;
32
33use crate::{Error, ErrorKind, Result};
34
35#[derive(Clone, PartialEq, Eq)]
46struct SensitiveBytes(Zeroizing<Box<[u8]>>);
47
48impl SensitiveBytes {
49 pub fn new(bytes: impl Into<Box<[u8]>>) -> Self {
51 Self(Zeroizing::new(bytes.into()))
52 }
53
54 pub fn as_bytes(&self) -> &[u8] {
56 &self.0
57 }
58
59 #[allow(dead_code)] pub fn len(&self) -> usize {
62 self.0.len()
63 }
64
65 #[allow(dead_code)] pub fn is_empty(&self) -> bool {
68 self.0.is_empty()
69 }
70}
71
72impl fmt::Debug for SensitiveBytes {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "[{} bytes REDACTED]", self.0.len())
75 }
76}
77
78impl fmt::Display for SensitiveBytes {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "[{} bytes REDACTED]", self.0.len())
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum AesKeySize {
90 Bits128 = 128,
92 Bits192 = 192,
94 Bits256 = 256,
96}
97
98impl AesKeySize {
99 pub fn key_length(&self) -> usize {
101 match self {
102 Self::Bits128 => 16,
103 Self::Bits192 => 24,
104 Self::Bits256 => 32,
105 }
106 }
107
108 pub fn from_key_length(len: usize) -> Result<Self> {
113 match len {
114 16 => Ok(Self::Bits128),
115 24 => Ok(Self::Bits192),
116 32 => Ok(Self::Bits256),
117 _ => Err(Error::new(
118 ErrorKind::FeatureUnsupported,
119 format!("Unsupported data key length: {len} (must be 16, 24, or 32)"),
120 )),
121 }
122 }
123}
124
125impl FromStr for AesKeySize {
126 type Err = Error;
127
128 fn from_str(s: &str) -> Result<Self> {
129 match s {
130 "128" | "AES_GCM_128" | "AES128_GCM" => Ok(Self::Bits128),
131 "192" | "AES_GCM_192" | "AES192_GCM" => Ok(Self::Bits192),
132 "256" | "AES_GCM_256" | "AES256_GCM" => Ok(Self::Bits256),
133 _ => Err(Error::new(
134 ErrorKind::FeatureUnsupported,
135 format!("Unsupported AES key size: {s}"),
136 )),
137 }
138 }
139}
140
141pub struct SecureKey {
143 key: SensitiveBytes,
144 key_size: AesKeySize,
145}
146
147impl SecureKey {
148 pub fn new(key: &[u8]) -> Result<Self> {
153 let key_size = AesKeySize::from_key_length(key.len())?;
154 Ok(Self {
155 key: SensitiveBytes::new(key),
156 key_size,
157 })
158 }
159
160 pub fn generate(key_size: AesKeySize) -> Self {
162 let mut key = vec![0u8; key_size.key_length()];
163 OsRng.fill_bytes(&mut key);
164 Self {
165 key: SensitiveBytes::new(key),
166 key_size,
167 }
168 }
169
170 pub fn key_size(&self) -> AesKeySize {
172 self.key_size
173 }
174
175 pub fn as_bytes(&self) -> &[u8] {
177 self.key.as_bytes()
178 }
179}
180
181pub struct AesGcmCipher {
183 key: SensitiveBytes,
184 key_size: AesKeySize,
185}
186
187impl AesGcmCipher {
188 pub const NONCE_LEN: usize = 12;
190 pub const TAG_LEN: usize = 16;
192
193 pub fn new(key: SecureKey) -> Self {
195 Self {
196 key: SensitiveBytes::new(key.as_bytes()),
197 key_size: key.key_size(),
198 }
199 }
200
201 pub fn encrypt(&self, plaintext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>> {
211 match self.key_size {
212 AesKeySize::Bits128 => {
213 encrypt_aes_gcm::<Aes128Gcm>(self.key.as_bytes(), plaintext, aad)
214 }
215 AesKeySize::Bits192 => {
216 encrypt_aes_gcm::<Aes192Gcm>(self.key.as_bytes(), plaintext, aad)
217 }
218 AesKeySize::Bits256 => {
219 encrypt_aes_gcm::<Aes256Gcm>(self.key.as_bytes(), plaintext, aad)
220 }
221 }
222 }
223
224 pub fn decrypt(&self, ciphertext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>> {
233 if ciphertext.len() < Self::NONCE_LEN + Self::TAG_LEN {
234 return Err(Error::new(
235 ErrorKind::DataInvalid,
236 format!(
237 "Ciphertext too short: expected at least {} bytes, got {}",
238 Self::NONCE_LEN + Self::TAG_LEN,
239 ciphertext.len()
240 ),
241 ));
242 }
243
244 match self.key_size {
245 AesKeySize::Bits128 => {
246 decrypt_aes_gcm::<Aes128Gcm>(self.key.as_bytes(), ciphertext, aad)
247 }
248 AesKeySize::Bits192 => {
249 decrypt_aes_gcm::<Aes192Gcm>(self.key.as_bytes(), ciphertext, aad)
250 }
251 AesKeySize::Bits256 => {
252 decrypt_aes_gcm::<Aes256Gcm>(self.key.as_bytes(), ciphertext, aad)
253 }
254 }
255 }
256}
257
258fn encrypt_aes_gcm<C>(key_bytes: &[u8], plaintext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>>
259where C: Aead + AeadCore + KeyInit {
260 let cipher = C::new_from_slice(key_bytes).map_err(|e| {
261 Error::new(ErrorKind::DataInvalid, "Invalid AES key").with_source(anyhow::anyhow!(e))
262 })?;
263 let nonce = C::generate_nonce(&mut OsRng);
264
265 let ciphertext = if let Some(aad) = aad {
266 cipher.encrypt(&nonce, Payload {
267 msg: plaintext,
268 aad,
269 })
270 } else {
271 cipher.encrypt(&nonce, plaintext.as_ref())
272 }
273 .map_err(|e| {
274 Error::new(ErrorKind::Unexpected, "AES-GCM encryption failed")
275 .with_source(anyhow::anyhow!(e))
276 })?;
277
278 let mut result = Vec::with_capacity(nonce.len() + ciphertext.len());
280 result.extend_from_slice(&nonce);
281 result.extend_from_slice(&ciphertext);
282 Ok(result)
283}
284
285fn decrypt_aes_gcm<C>(key_bytes: &[u8], ciphertext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>>
286where C: Aead + AeadCore + KeyInit {
287 let cipher = C::new_from_slice(key_bytes).map_err(|e| {
288 Error::new(ErrorKind::DataInvalid, "Invalid AES key").with_source(anyhow::anyhow!(e))
289 })?;
290
291 let nonce = Nonce::from_slice(&ciphertext[..AesGcmCipher::NONCE_LEN]);
292 let encrypted_data = &ciphertext[AesGcmCipher::NONCE_LEN..];
293
294 let plaintext = if let Some(aad) = aad {
295 cipher.decrypt(nonce, Payload {
296 msg: encrypted_data,
297 aad,
298 })
299 } else {
300 cipher.decrypt(nonce, encrypted_data)
301 }
302 .map_err(|e| {
303 Error::new(ErrorKind::Unexpected, "AES-GCM decryption failed")
304 .with_source(anyhow::anyhow!(e))
305 })?;
306
307 Ok(plaintext)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_aes_key_size() {
316 assert_eq!(AesKeySize::Bits128.key_length(), 16);
317 assert_eq!(AesKeySize::Bits192.key_length(), 24);
318 assert_eq!(AesKeySize::Bits256.key_length(), 32);
319
320 assert_eq!(
321 AesKeySize::from_key_length(16).unwrap(),
322 AesKeySize::Bits128
323 );
324 assert_eq!(
325 AesKeySize::from_key_length(24).unwrap(),
326 AesKeySize::Bits192
327 );
328 assert_eq!(
329 AesKeySize::from_key_length(32).unwrap(),
330 AesKeySize::Bits256
331 );
332 assert!(AesKeySize::from_key_length(8).is_err());
333
334 assert_eq!(AesKeySize::from_str("128").unwrap(), AesKeySize::Bits128);
335 assert_eq!(
336 AesKeySize::from_str("AES_GCM_128").unwrap(),
337 AesKeySize::Bits128
338 );
339 assert_eq!(
340 AesKeySize::from_str("AES_GCM_256").unwrap(),
341 AesKeySize::Bits256
342 );
343 assert!(AesKeySize::from_str("INVALID").is_err());
344 }
345
346 #[test]
347 fn test_secure_key() {
348 let key1 = SecureKey::generate(AesKeySize::Bits128);
350 assert_eq!(key1.as_bytes().len(), 16);
351 assert_eq!(key1.key_size(), AesKeySize::Bits128);
352
353 let valid_key = [0u8; 16];
355 assert!(SecureKey::new(valid_key.as_slice()).is_ok());
356
357 let invalid_key = [0u8; 33];
358 assert!(SecureKey::new(invalid_key.as_slice()).is_err());
359 }
360
361 #[test]
362 fn test_aes128_gcm_encryption_roundtrip() {
363 let key = SecureKey::generate(AesKeySize::Bits128);
364 let cipher = AesGcmCipher::new(key);
365
366 let plaintext = b"Hello, Iceberg encryption!";
367 let aad = b"additional authenticated data";
368
369 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
371 assert!(ciphertext.len() > plaintext.len() + 12); assert_ne!(&ciphertext[12..], plaintext); let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
375 assert_eq!(decrypted, plaintext);
376
377 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
379 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
380 assert_eq!(decrypted, plaintext);
381
382 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
384 }
385
386 #[test]
387 fn test_aes192_gcm_encryption_roundtrip() {
388 let key = SecureKey::generate(AesKeySize::Bits192);
389 let cipher = AesGcmCipher::new(key);
390
391 let plaintext = b"Hello, Iceberg encryption!";
392 let aad = b"additional authenticated data";
393
394 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
396 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
397 assert_eq!(decrypted, plaintext);
398
399 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
401 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
402 assert_eq!(decrypted, plaintext);
403
404 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
406 }
407
408 #[test]
409 fn test_aes256_gcm_encryption_roundtrip() {
410 let key = SecureKey::generate(AesKeySize::Bits256);
411 let cipher = AesGcmCipher::new(key);
412
413 let plaintext = b"Hello, Iceberg encryption!";
414 let aad = b"additional authenticated data";
415
416 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
418 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
419 assert_eq!(decrypted, plaintext);
420
421 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
423 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
424 assert_eq!(decrypted, plaintext);
425
426 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
428 }
429
430 #[test]
431 fn test_cross_key_size_incompatibility() {
432 let plaintext = b"Cross-key test";
433
434 let key128 = SecureKey::generate(AesKeySize::Bits128);
435 let key256 = SecureKey::generate(AesKeySize::Bits256);
436
437 let cipher128 = AesGcmCipher::new(key128);
438 let cipher256 = AesGcmCipher::new(key256);
439
440 let ciphertext = cipher128.encrypt(plaintext, None).unwrap();
442 assert!(cipher256.decrypt(&ciphertext, None).is_err());
443 }
444
445 #[test]
446 fn test_encryption_with_empty_plaintext() {
447 let key = SecureKey::generate(AesKeySize::Bits128);
448 let cipher = AesGcmCipher::new(key);
449
450 let plaintext = b"";
451 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
452
453 assert_eq!(ciphertext.len(), 12 + 16); let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
457 assert_eq!(decrypted, plaintext);
458 }
459
460 #[test]
461 fn test_decryption_with_tampered_ciphertext() {
462 let key = SecureKey::generate(AesKeySize::Bits128);
463 let cipher = AesGcmCipher::new(key);
464
465 let plaintext = b"Sensitive data";
466 let mut ciphertext = cipher.encrypt(plaintext, None).unwrap();
467
468 if ciphertext.len() > 12 {
470 ciphertext[12] ^= 0xFF;
471 }
472
473 assert!(cipher.decrypt(&ciphertext, None).is_err());
475 }
476
477 #[test]
478 fn test_different_keys_produce_different_ciphertexts() {
479 let key1 = SecureKey::generate(AesKeySize::Bits128);
480 let key2 = SecureKey::generate(AesKeySize::Bits128);
481
482 let cipher1 = AesGcmCipher::new(key1);
483 let cipher2 = AesGcmCipher::new(key2);
484
485 let plaintext = b"Same plaintext";
486
487 let ciphertext1 = cipher1.encrypt(plaintext, None).unwrap();
488 let ciphertext2 = cipher2.encrypt(plaintext, None).unwrap();
489
490 assert_ne!(&ciphertext1[12..], &ciphertext2[12..]);
493 }
494
495 #[test]
496 fn test_ciphertext_format_java_compatible() {
497 let key = SecureKey::generate(AesKeySize::Bits128);
499 let cipher = AesGcmCipher::new(key);
500
501 let plaintext = b"Test data";
502 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
503
504 assert_eq!(
506 ciphertext.len(),
507 12 + plaintext.len() + 16,
508 "Ciphertext should be nonce + plaintext + tag length"
509 );
510
511 let nonce = &ciphertext[..12];
513 assert_eq!(nonce.len(), 12, "Nonce should be 12 bytes");
514
515 let encrypted_with_tag = &ciphertext[12..];
517 assert_eq!(
518 encrypted_with_tag.len(),
519 plaintext.len() + 16,
520 "Encrypted portion should be plaintext length + 16-byte tag"
521 );
522 }
523}