1use std::fmt;
21use std::str::FromStr;
22
23use aes_gcm::aead::generic_array::typenum::U12;
24use aes_gcm::aead::rand_core::RngCore;
25use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng, Payload};
26use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, Nonce};
27use zeroize::Zeroizing;
28
29type Aes192Gcm = AesGcm<aes_gcm::aes::Aes192, U12>;
32
33use crate::{Error, ErrorKind, Result};
34
35#[derive(Clone, PartialEq, Eq)]
46pub struct SensitiveBytes(Zeroizing<Box<[u8]>>);
47
48impl SensitiveBytes {
49 pub fn new(bytes: impl Into<Box<[u8]>>) -> Self {
51 Self(Zeroizing::new(bytes.into()))
52 }
53
54 pub fn as_bytes(&self) -> &[u8] {
56 &self.0
57 }
58
59 pub fn len(&self) -> usize {
61 self.0.len()
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.0.is_empty()
67 }
68}
69
70impl fmt::Debug for SensitiveBytes {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "[{} bytes REDACTED]", self.0.len())
73 }
74}
75
76impl fmt::Display for SensitiveBytes {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "[{} bytes REDACTED]", self.0.len())
79 }
80}
81
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
87pub enum AesKeySize {
88 #[default]
90 Bits128 = 128,
91 Bits192 = 192,
93 Bits256 = 256,
95}
96
97impl AesKeySize {
98 pub fn key_length(&self) -> usize {
100 match self {
101 Self::Bits128 => 16,
102 Self::Bits192 => 24,
103 Self::Bits256 => 32,
104 }
105 }
106
107 pub fn from_key_length(len: usize) -> Result<Self> {
112 match len {
113 16 => Ok(Self::Bits128),
114 24 => Ok(Self::Bits192),
115 32 => Ok(Self::Bits256),
116 _ => Err(Error::new(
117 ErrorKind::FeatureUnsupported,
118 format!("Unsupported data key length: {len} (must be 16, 24, or 32)"),
119 )),
120 }
121 }
122}
123
124impl FromStr for AesKeySize {
125 type Err = Error;
126
127 fn from_str(s: &str) -> Result<Self> {
128 match s {
129 "128" | "AES_GCM_128" | "AES128_GCM" => Ok(Self::Bits128),
130 "192" | "AES_GCM_192" | "AES192_GCM" => Ok(Self::Bits192),
131 "256" | "AES_GCM_256" | "AES256_GCM" => Ok(Self::Bits256),
132 _ => Err(Error::new(
133 ErrorKind::FeatureUnsupported,
134 format!("Unsupported AES key size: {s}"),
135 )),
136 }
137 }
138}
139
140pub struct SecureKey {
142 key: SensitiveBytes,
143 key_size: AesKeySize,
144}
145
146impl SecureKey {
147 pub fn new(key: &[u8]) -> Result<Self> {
152 let key_size = AesKeySize::from_key_length(key.len())?;
153 Ok(Self {
154 key: SensitiveBytes::new(key),
155 key_size,
156 })
157 }
158
159 pub fn generate(key_size: AesKeySize) -> Self {
161 let mut key = vec![0u8; key_size.key_length()];
162 OsRng.fill_bytes(&mut key);
163 Self {
164 key: SensitiveBytes::new(key),
165 key_size,
166 }
167 }
168
169 pub fn key_size(&self) -> AesKeySize {
171 self.key_size
172 }
173
174 pub fn as_bytes(&self) -> &[u8] {
176 self.key.as_bytes()
177 }
178}
179
180impl TryFrom<SensitiveBytes> for SecureKey {
181 type Error = Error;
182
183 fn try_from(key: SensitiveBytes) -> Result<Self> {
184 let key_size = AesKeySize::from_key_length(key.len())?;
185 Ok(Self { key, key_size })
186 }
187}
188
189pub struct AesGcmCipher {
191 key: SensitiveBytes,
192 key_size: AesKeySize,
193}
194
195impl AesGcmCipher {
196 pub const NONCE_LEN: usize = 12;
198 pub const TAG_LEN: usize = 16;
200
201 pub fn new(key: SecureKey) -> Self {
203 Self {
204 key: SensitiveBytes::new(key.as_bytes()),
205 key_size: key.key_size(),
206 }
207 }
208
209 pub fn encrypt(&self, plaintext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>> {
219 match self.key_size {
220 AesKeySize::Bits128 => {
221 encrypt_aes_gcm::<Aes128Gcm>(self.key.as_bytes(), plaintext, aad)
222 }
223 AesKeySize::Bits192 => {
224 encrypt_aes_gcm::<Aes192Gcm>(self.key.as_bytes(), plaintext, aad)
225 }
226 AesKeySize::Bits256 => {
227 encrypt_aes_gcm::<Aes256Gcm>(self.key.as_bytes(), plaintext, aad)
228 }
229 }
230 }
231
232 pub fn decrypt(&self, ciphertext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>> {
241 if ciphertext.len() < Self::NONCE_LEN + Self::TAG_LEN {
242 return Err(Error::new(
243 ErrorKind::DataInvalid,
244 format!(
245 "Ciphertext too short: expected at least {} bytes, got {}",
246 Self::NONCE_LEN + Self::TAG_LEN,
247 ciphertext.len()
248 ),
249 ));
250 }
251
252 match self.key_size {
253 AesKeySize::Bits128 => {
254 decrypt_aes_gcm::<Aes128Gcm>(self.key.as_bytes(), ciphertext, aad)
255 }
256 AesKeySize::Bits192 => {
257 decrypt_aes_gcm::<Aes192Gcm>(self.key.as_bytes(), ciphertext, aad)
258 }
259 AesKeySize::Bits256 => {
260 decrypt_aes_gcm::<Aes256Gcm>(self.key.as_bytes(), ciphertext, aad)
261 }
262 }
263 }
264}
265
266fn encrypt_aes_gcm<C>(key_bytes: &[u8], plaintext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>>
267where C: Aead + AeadCore + KeyInit {
268 let cipher = C::new_from_slice(key_bytes).map_err(|e| {
269 Error::new(ErrorKind::DataInvalid, "Invalid AES key").with_source(anyhow::anyhow!(e))
270 })?;
271 let nonce = C::generate_nonce(&mut OsRng);
272
273 let ciphertext = if let Some(aad) = aad {
274 cipher.encrypt(&nonce, Payload {
275 msg: plaintext,
276 aad,
277 })
278 } else {
279 cipher.encrypt(&nonce, plaintext.as_ref())
280 }
281 .map_err(|e| {
282 Error::new(ErrorKind::Unexpected, "AES-GCM encryption failed")
283 .with_source(anyhow::anyhow!(e))
284 })?;
285
286 let mut result = Vec::with_capacity(nonce.len() + ciphertext.len());
288 result.extend_from_slice(&nonce);
289 result.extend_from_slice(&ciphertext);
290 Ok(result)
291}
292
293fn decrypt_aes_gcm<C>(key_bytes: &[u8], ciphertext: &[u8], aad: Option<&[u8]>) -> Result<Vec<u8>>
294where C: Aead + AeadCore + KeyInit {
295 let cipher = C::new_from_slice(key_bytes).map_err(|e| {
296 Error::new(ErrorKind::DataInvalid, "Invalid AES key").with_source(anyhow::anyhow!(e))
297 })?;
298
299 let nonce = Nonce::from_slice(&ciphertext[..AesGcmCipher::NONCE_LEN]);
300 let encrypted_data = &ciphertext[AesGcmCipher::NONCE_LEN..];
301
302 let plaintext = if let Some(aad) = aad {
303 cipher.decrypt(nonce, Payload {
304 msg: encrypted_data,
305 aad,
306 })
307 } else {
308 cipher.decrypt(nonce, encrypted_data)
309 }
310 .map_err(|e| {
311 Error::new(ErrorKind::Unexpected, "AES-GCM decryption failed")
312 .with_source(anyhow::anyhow!(e))
313 })?;
314
315 Ok(plaintext)
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_aes_key_size() {
324 assert_eq!(AesKeySize::Bits128.key_length(), 16);
325 assert_eq!(AesKeySize::Bits192.key_length(), 24);
326 assert_eq!(AesKeySize::Bits256.key_length(), 32);
327
328 assert_eq!(
329 AesKeySize::from_key_length(16).unwrap(),
330 AesKeySize::Bits128
331 );
332 assert_eq!(
333 AesKeySize::from_key_length(24).unwrap(),
334 AesKeySize::Bits192
335 );
336 assert_eq!(
337 AesKeySize::from_key_length(32).unwrap(),
338 AesKeySize::Bits256
339 );
340 assert!(AesKeySize::from_key_length(8).is_err());
341
342 assert_eq!(AesKeySize::from_str("128").unwrap(), AesKeySize::Bits128);
343 assert_eq!(
344 AesKeySize::from_str("AES_GCM_128").unwrap(),
345 AesKeySize::Bits128
346 );
347 assert_eq!(
348 AesKeySize::from_str("AES_GCM_256").unwrap(),
349 AesKeySize::Bits256
350 );
351 assert!(AesKeySize::from_str("INVALID").is_err());
352 }
353
354 #[test]
355 fn test_secure_key() {
356 let key1 = SecureKey::generate(AesKeySize::Bits128);
358 assert_eq!(key1.as_bytes().len(), 16);
359 assert_eq!(key1.key_size(), AesKeySize::Bits128);
360
361 let valid_key = [0u8; 16];
363 assert!(SecureKey::new(valid_key.as_slice()).is_ok());
364
365 let invalid_key = [0u8; 33];
366 assert!(SecureKey::new(invalid_key.as_slice()).is_err());
367 }
368
369 #[test]
370 fn test_aes128_gcm_encryption_roundtrip() {
371 let key = SecureKey::generate(AesKeySize::Bits128);
372 let cipher = AesGcmCipher::new(key);
373
374 let plaintext = b"Hello, Iceberg encryption!";
375 let aad = b"additional authenticated data";
376
377 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
379 assert!(ciphertext.len() > plaintext.len() + 12); assert_ne!(&ciphertext[12..], plaintext); let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
383 assert_eq!(decrypted, plaintext);
384
385 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
387 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
388 assert_eq!(decrypted, plaintext);
389
390 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
392 }
393
394 #[test]
395 fn test_aes192_gcm_encryption_roundtrip() {
396 let key = SecureKey::generate(AesKeySize::Bits192);
397 let cipher = AesGcmCipher::new(key);
398
399 let plaintext = b"Hello, Iceberg encryption!";
400 let aad = b"additional authenticated data";
401
402 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
404 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
405 assert_eq!(decrypted, plaintext);
406
407 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
409 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
410 assert_eq!(decrypted, plaintext);
411
412 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
414 }
415
416 #[test]
417 fn test_aes256_gcm_encryption_roundtrip() {
418 let key = SecureKey::generate(AesKeySize::Bits256);
419 let cipher = AesGcmCipher::new(key);
420
421 let plaintext = b"Hello, Iceberg encryption!";
422 let aad = b"additional authenticated data";
423
424 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
426 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
427 assert_eq!(decrypted, plaintext);
428
429 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
431 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
432 assert_eq!(decrypted, plaintext);
433
434 assert!(cipher.decrypt(&ciphertext, Some(b"wrong aad")).is_err());
436 }
437
438 #[test]
439 fn test_cross_key_size_incompatibility() {
440 let plaintext = b"Cross-key test";
441
442 let key128 = SecureKey::generate(AesKeySize::Bits128);
443 let key256 = SecureKey::generate(AesKeySize::Bits256);
444
445 let cipher128 = AesGcmCipher::new(key128);
446 let cipher256 = AesGcmCipher::new(key256);
447
448 let ciphertext = cipher128.encrypt(plaintext, None).unwrap();
450 assert!(cipher256.decrypt(&ciphertext, None).is_err());
451 }
452
453 #[test]
454 fn test_encryption_with_empty_plaintext() {
455 let key = SecureKey::generate(AesKeySize::Bits128);
456 let cipher = AesGcmCipher::new(key);
457
458 let plaintext = b"";
459 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
460
461 assert_eq!(ciphertext.len(), 12 + 16); let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
465 assert_eq!(decrypted, plaintext);
466 }
467
468 #[test]
469 fn test_decryption_with_tampered_ciphertext() {
470 let key = SecureKey::generate(AesKeySize::Bits128);
471 let cipher = AesGcmCipher::new(key);
472
473 let plaintext = b"Sensitive data";
474 let mut ciphertext = cipher.encrypt(plaintext, None).unwrap();
475
476 if ciphertext.len() > 12 {
478 ciphertext[12] ^= 0xFF;
479 }
480
481 assert!(cipher.decrypt(&ciphertext, None).is_err());
483 }
484
485 #[test]
486 fn test_different_keys_produce_different_ciphertexts() {
487 let key1 = SecureKey::generate(AesKeySize::Bits128);
488 let key2 = SecureKey::generate(AesKeySize::Bits128);
489
490 let cipher1 = AesGcmCipher::new(key1);
491 let cipher2 = AesGcmCipher::new(key2);
492
493 let plaintext = b"Same plaintext";
494
495 let ciphertext1 = cipher1.encrypt(plaintext, None).unwrap();
496 let ciphertext2 = cipher2.encrypt(plaintext, None).unwrap();
497
498 assert_ne!(&ciphertext1[12..], &ciphertext2[12..]);
501 }
502
503 #[test]
504 fn test_ciphertext_format_java_compatible() {
505 let key = SecureKey::generate(AesKeySize::Bits128);
507 let cipher = AesGcmCipher::new(key);
508
509 let plaintext = b"Test data";
510 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
511
512 assert_eq!(
514 ciphertext.len(),
515 12 + plaintext.len() + 16,
516 "Ciphertext should be nonce + plaintext + tag length"
517 );
518
519 let nonce = &ciphertext[..12];
521 assert_eq!(nonce.len(), 12, "Nonce should be 12 bytes");
522
523 let encrypted_with_tag = &ciphertext[12..];
525 assert_eq!(
526 encrypted_with_tag.len(),
527 plaintext.len() + 16,
528 "Encrypted portion should be plaintext length + 16-byte tag"
529 );
530 }
531}