1use std::collections::HashMap;
27use std::fmt;
28use std::sync::{Arc, RwLock};
29use std::time::Duration;
30
31use aes_gcm::aead::OsRng;
32use aes_gcm::aead::rand_core::RngCore;
33use chrono::Utc;
34use moka::future::Cache;
35use uuid::Uuid;
36
37const MILLIS_IN_DAY: i64 = 24 * 60 * 60 * 1000;
38
39use super::crypto::{AesGcmCipher, AesKeySize, SecureKey, SensitiveBytes};
40use super::io::EncryptedOutputFile;
41use super::key_metadata::StandardKeyMetadata;
42use super::kms::KeyManagementClient;
43use crate::io::OutputFile;
44use crate::spec::EncryptedKey;
45use crate::{Error, ErrorKind, Result};
46
47pub const KEK_CREATED_AT_PROPERTY: &str = "KEY_TIMESTAMP";
50
51const DEFAULT_KEK_LIFESPAN_DAYS: i64 = 730;
53
54const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(3600);
56
57const AAD_PREFIX_LENGTH: usize = 16;
60
61#[derive(typed_builder::TypedBuilder)]
65#[builder(mutators(
66 pub fn add_encryption_key(&mut self, key: EncryptedKey) {
68 self.encryption_keys
69 .write()
70 .expect("encryption_keys lock poisoned")
71 .insert(key.key_id().to_string(), key);
72 }
73 pub fn encryption_keys(&mut self, keys: HashMap<String, EncryptedKey>) {
75 self.encryption_keys = RwLock::new(keys);
76 }
77))]
78pub struct EncryptionManager {
79 kms_client: Arc<dyn KeyManagementClient>,
80 #[builder(
81 default = Cache::builder().time_to_live(DEFAULT_CACHE_TTL).build(),
82 setter(skip)
83 )]
84 kek_cache: Cache<String, SensitiveBytes>,
85 #[builder(default = AesKeySize::default())]
87 key_size: AesKeySize,
88 #[builder(setter(into))]
90 table_key_id: String,
91 #[builder(default = RwLock::new(HashMap::new()), via_mutators)]
95 encryption_keys: RwLock<HashMap<String, EncryptedKey>>,
96}
97
98impl fmt::Debug for EncryptionManager {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("EncryptionManager")
101 .field("key_size", &self.key_size)
102 .field("table_key_id", &self.table_key_id)
103 .finish_non_exhaustive()
104 }
105}
106
107impl EncryptionManager {
108 pub fn encrypt(&self, raw_output: OutputFile) -> EncryptedOutputFile {
113 let dek = SecureKey::generate(self.key_size);
114 let aad_prefix = Self::generate_aad_prefix();
115 let metadata = StandardKeyMetadata::new(dek.as_bytes()).with_aad_prefix(&aad_prefix);
116 EncryptedOutputFile::new(raw_output, metadata)
117 }
118
119 pub async fn encrypt_manifest_list_key_metadata(
128 &self,
129 key_metadata: &StandardKeyMetadata,
130 ) -> Result<String> {
131 let kek = match self.find_active_kek()? {
132 Some(existing) => existing,
133 None => self.create_kek().await?,
134 };
135
136 let kek_bytes = self.unwrap_key_encryption_key(&kek).await?;
137
138 let aad = Self::kek_timestamp_aad(&kek)?;
140 let serialized = key_metadata.encode()?;
141 let wrapped_metadata = self.wrap_dek_with_kek(&serialized, &kek_bytes, Some(aad))?;
142
143 let wrapped_key = EncryptedKey::builder()
144 .key_id(Uuid::new_v4().to_string())
145 .encrypted_key_metadata(wrapped_metadata)
146 .encrypted_by_id(kek.key_id())
147 .build();
148
149 let wrapped_key_id = wrapped_key.key_id().to_string();
150 self.insert_encryption_key(wrapped_key);
151 Ok(wrapped_key_id)
152 }
153
154 pub async fn decrypt_manifest_list_key_metadata(
160 &self,
161 encryption_key_id: &str,
162 ) -> Result<StandardKeyMetadata> {
163 let encrypted_key = self
164 .encryption_keys
165 .read()
166 .expect("encryption_keys lock poisoned")
167 .get(encryption_key_id)
168 .cloned()
169 .ok_or_else(|| {
170 Error::new(
171 ErrorKind::DataInvalid,
172 format!("Encryption key '{encryption_key_id}' not found"),
173 )
174 })?;
175
176 let kek_key_id = encrypted_key.encrypted_by_id().ok_or_else(|| {
177 Error::new(
178 ErrorKind::DataInvalid,
179 format!(
180 "EncryptedKey '{}' has no encrypted_by_id",
181 encrypted_key.key_id()
182 ),
183 )
184 })?;
185
186 let bytes = self
187 .decrypt_dek(kek_key_id, encrypted_key.encrypted_key_metadata())
188 .await?;
189
190 StandardKeyMetadata::decode(bytes.as_bytes())
191 }
192
193 pub fn with_encryption_keys<F, R>(&self, f: F) -> R
198 where F: FnOnce(&HashMap<String, EncryptedKey>) -> R {
199 let keys = self
200 .encryption_keys
201 .read()
202 .expect("encryption_keys lock poisoned");
203 f(&keys)
204 }
205
206 fn insert_encryption_key(&self, key: EncryptedKey) {
207 self.encryption_keys
208 .write()
209 .expect("encryption_keys lock poisoned")
210 .insert(key.key_id().to_string(), key);
211 }
212
213 async fn create_kek(&self) -> Result<EncryptedKey> {
216 let (plaintext_kek, wrapped_kek) = if self.kms_client.supports_key_generation() {
217 let result = self.kms_client.generate_key(&self.table_key_id).await?;
218 (result.key().clone(), result.wrapped_key().to_vec())
219 } else {
220 let plaintext_key = SecureKey::generate(self.key_size);
221 let wrapped = self
222 .kms_client
223 .wrap_key(plaintext_key.as_bytes(), &self.table_key_id)
224 .await?;
225
226 (SensitiveBytes::new(plaintext_key.as_bytes()), wrapped)
227 };
228
229 let key_id = Uuid::new_v4().to_string();
230 let now_ms = Utc::now().timestamp_millis();
231
232 let mut properties = HashMap::new();
233 properties.insert(KEK_CREATED_AT_PROPERTY.to_string(), now_ms.to_string());
234
235 self.kek_cache.insert(key_id.clone(), plaintext_kek).await;
236
237 let kek = EncryptedKey::builder()
238 .key_id(key_id)
239 .encrypted_key_metadata(wrapped_kek)
240 .encrypted_by_id(&self.table_key_id)
241 .properties(properties)
242 .build();
243
244 self.insert_encryption_key(kek.clone());
245 Ok(kek)
246 }
247
248 fn is_kek_expired(&self, kek: &EncryptedKey) -> bool {
250 let created_at_ms = match kek
251 .properties()
252 .get(KEK_CREATED_AT_PROPERTY)
253 .and_then(|ts| ts.parse::<i64>().ok())
254 {
255 Some(ts) => ts,
256 None => return true, };
258
259 let now_ms = Utc::now().timestamp_millis();
260 let lifespan_ms = DEFAULT_KEK_LIFESPAN_DAYS * MILLIS_IN_DAY;
261 (now_ms - created_at_ms) >= lifespan_ms
262 }
263
264 fn find_active_kek(&self) -> Result<Option<EncryptedKey>> {
266 let keys = self
267 .encryption_keys
268 .read()
269 .expect("encryption_keys lock poisoned");
270 Ok(keys
271 .values()
272 .filter(|kek| {
273 kek.encrypted_by_id()
274 .map(|id| id == self.table_key_id)
275 .unwrap_or(false)
276 && !self.is_kek_expired(kek)
277 })
278 .max_by_key(|kek| {
279 kek.properties()
280 .get(KEK_CREATED_AT_PROPERTY)
281 .and_then(|ts| ts.parse::<i64>().ok())
282 .unwrap_or(0)
283 })
284 .cloned())
285 }
286
287 async fn unwrap_key_encryption_key(&self, kek: &EncryptedKey) -> Result<SensitiveBytes> {
289 let cache_key = kek.key_id().to_string();
290
291 if let Some(cached) = self.kek_cache.get(&cache_key).await {
292 return Ok(cached);
293 }
294
295 let master_key_id = kek.encrypted_by_id().ok_or_else(|| {
296 Error::new(
297 ErrorKind::DataInvalid,
298 format!("KEK '{}' has no encrypted_by_id", kek.key_id()),
299 )
300 })?;
301
302 let plaintext = self
303 .kms_client
304 .unwrap_key(kek.encrypted_key_metadata(), master_key_id)
305 .await?;
306
307 self.kek_cache.insert(cache_key, plaintext.clone()).await;
308
309 Ok(plaintext)
310 }
311
312 async fn decrypt_dek(&self, kek_key_id: &str, wrapped_dek: &[u8]) -> Result<SensitiveBytes> {
315 let kek = self
316 .encryption_keys
317 .read()
318 .expect("encryption_keys lock poisoned")
319 .get(kek_key_id)
320 .cloned()
321 .ok_or_else(|| {
322 Error::new(
323 ErrorKind::DataInvalid,
324 format!("KEK not found in encryption keys: {kek_key_id}"),
325 )
326 })?;
327
328 let aad = Self::kek_timestamp_aad(&kek)?;
330
331 let kek_bytes = self.unwrap_key_encryption_key(&kek).await?;
332 self.unwrap_dek_with_kek(wrapped_dek, &kek_bytes, Some(aad))
333 .map_err(|e| {
334 Error::new(
335 e.kind(),
336 format!("Failed to unwrap key metadata with KEK '{kek_key_id}'"),
337 )
338 .with_source(e)
339 })
340 }
341
342 fn kek_timestamp_aad(kek: &EncryptedKey) -> Result<&[u8]> {
344 kek.properties()
345 .get(KEK_CREATED_AT_PROPERTY)
346 .map(|ts| ts.as_bytes())
347 .ok_or_else(|| {
348 Error::new(
349 ErrorKind::DataInvalid,
350 format!(
351 "KEK '{}' is missing required '{}' property",
352 kek.key_id(),
353 KEK_CREATED_AT_PROPERTY
354 ),
355 )
356 })
357 }
358
359 fn generate_aad_prefix() -> Box<[u8]> {
361 let mut prefix = vec![0u8; AAD_PREFIX_LENGTH];
362 OsRng.fill_bytes(&mut prefix);
363 prefix.into_boxed_slice()
364 }
365
366 fn wrap_dek_with_kek(
368 &self,
369 dek: &[u8],
370 kek: &SensitiveBytes,
371 aad: Option<&[u8]>,
372 ) -> Result<Vec<u8>> {
373 let key = SecureKey::try_from(kek.clone())?;
374 let cipher = AesGcmCipher::new(key);
375 cipher.encrypt(dek, aad)
376 }
377
378 fn unwrap_dek_with_kek(
380 &self,
381 wrapped_dek: &[u8],
382 kek: &SensitiveBytes,
383 aad: Option<&[u8]>,
384 ) -> Result<SensitiveBytes> {
385 let key = SecureKey::try_from(kek.clone())?;
386 let cipher = AesGcmCipher::new(key);
387 cipher.decrypt(wrapped_dek, aad).map(SensitiveBytes::new)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::encryption::EncryptedInputFile;
395 use crate::encryption::kms::MemoryKeyManagementClient;
396
397 fn create_test_kms() -> Arc<dyn KeyManagementClient> {
398 let kms = MemoryKeyManagementClient::new();
399 kms.add_master_key("master-1").unwrap();
400 Arc::new(kms)
401 }
402
403 fn create_test_manager() -> EncryptionManager {
404 EncryptionManager::builder()
405 .kms_client(create_test_kms())
406 .table_key_id("master-1")
407 .build()
408 }
409
410 #[tokio::test]
411 async fn test_create_kek() {
412 let mgr = create_test_manager();
413 let kek = mgr.create_kek().await.unwrap();
414
415 assert!(!kek.key_id().is_empty());
416 assert!(!kek.encrypted_key_metadata().is_empty());
417 assert_eq!(kek.encrypted_by_id(), Some("master-1"));
418 assert!(kek.properties().contains_key(KEK_CREATED_AT_PROPERTY));
419 }
420
421 fn sample_key_metadata() -> StandardKeyMetadata {
422 StandardKeyMetadata::new(b"0123456789abcdef").with_aad_prefix(b"test-aad-prefix!")
423 }
424
425 #[tokio::test]
426 async fn test_wrap_unwrap_key_metadata_roundtrip() {
427 let mgr = create_test_manager();
428 let plaintext = sample_key_metadata();
429
430 let key_id = mgr
431 .encrypt_manifest_list_key_metadata(&plaintext)
432 .await
433 .unwrap();
434
435 assert_eq!(mgr.with_encryption_keys(|k| k.len()), 2);
437
438 let decrypted = mgr
439 .decrypt_manifest_list_key_metadata(&key_id)
440 .await
441 .unwrap();
442 assert_eq!(decrypted, plaintext);
443 }
444
445 #[tokio::test]
446 async fn test_kek_reuse_when_not_expired() {
447 let mgr = create_test_manager();
448
449 let _id1 = mgr
451 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
452 .await
453 .unwrap();
454 let kek_id = mgr.with_encryption_keys(|keys| {
455 assert_eq!(keys.len(), 2);
456 keys.values()
457 .find(|k| k.encrypted_by_id() == Some("master-1"))
458 .unwrap()
459 .key_id()
460 .to_string()
461 });
462
463 let id2 = mgr
465 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
466 .await
467 .unwrap();
468 let entry2 = mgr.with_encryption_keys(|keys| {
469 assert_eq!(keys.len(), 3);
470 keys.get(&id2).cloned().unwrap()
471 });
472 assert_eq!(entry2.encrypted_by_id(), Some(kek_id.as_str()));
473 }
474
475 #[tokio::test]
476 async fn test_kek_rotation_when_expired() {
477 let kms = create_test_kms();
478
479 let three_years_ago_ms = Utc::now().timestamp_millis() - (3 * 365 * MILLIS_IN_DAY);
481 let mut properties = HashMap::new();
482 properties.insert(
483 KEK_CREATED_AT_PROPERTY.to_string(),
484 three_years_ago_ms.to_string(),
485 );
486
487 let kek_key = SecureKey::generate(AesKeySize::Bits128);
489 let wrapped = kms.wrap_key(kek_key.as_bytes(), "master-1").await.unwrap();
490
491 let old_kek = EncryptedKey::builder()
492 .key_id("expired-kek")
493 .encrypted_key_metadata(wrapped)
494 .encrypted_by_id("master-1")
495 .properties(properties)
496 .build();
497
498 let mgr = EncryptionManager::builder()
500 .kms_client(kms)
501 .table_key_id("master-1")
502 .add_encryption_key(old_kek.clone())
503 .build();
504
505 let new_entry_id = mgr
507 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
508 .await
509 .unwrap();
510 let entry = mgr
511 .with_encryption_keys(|keys| keys.get(&new_entry_id).cloned())
512 .unwrap();
513 let used_kek_id = entry.encrypted_by_id().unwrap();
514 assert_ne!(used_kek_id, old_kek.key_id());
515 }
516
517 #[tokio::test]
518 async fn test_is_kek_expired_no_timestamp() {
519 let mgr = create_test_manager();
520
521 let kek = EncryptedKey::builder()
523 .key_id("no-ts")
524 .encrypted_key_metadata(vec![0u8; 32])
525 .build();
526
527 assert!(mgr.is_kek_expired(&kek));
528 }
529
530 #[tokio::test]
531 async fn test_decrypt_with_unknown_key_id() {
532 let mgr = create_test_manager();
533 let result = mgr.decrypt_manifest_list_key_metadata("nonexistent").await;
534 assert!(result.is_err());
535 }
536
537 #[tokio::test]
538 async fn test_kek_cache_hit() {
539 let mgr = create_test_manager();
540
541 let key_id = mgr
543 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
544 .await
545 .unwrap();
546
547 let _ = mgr
549 .decrypt_manifest_list_key_metadata(&key_id)
550 .await
551 .unwrap();
552 }
553
554 #[tokio::test]
555 async fn test_unwrap_fails_when_kek_missing_timestamp() {
556 let mgr = create_test_manager();
557
558 let entry_id = mgr
560 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
561 .await
562 .unwrap();
563
564 let mut keys = mgr.with_encryption_keys(|k| k.clone());
567 let kek_id = keys
568 .get(&entry_id)
569 .unwrap()
570 .encrypted_by_id()
571 .unwrap()
572 .to_string();
573 let kek = keys.remove(&kek_id).unwrap();
574 let kek_no_ts = EncryptedKey::builder()
575 .key_id(kek.key_id())
576 .encrypted_key_metadata(kek.encrypted_key_metadata())
577 .encrypted_by_id(kek.encrypted_by_id().unwrap())
578 .build();
579 keys.insert(kek_no_ts.key_id().to_string(), kek_no_ts);
580
581 let mgr = EncryptionManager::builder()
582 .kms_client(create_test_kms())
583 .table_key_id("master-1")
584 .encryption_keys(keys)
585 .build();
586
587 let result = mgr.decrypt_manifest_list_key_metadata(&entry_id).await;
588 assert!(result.is_err());
589 let err = result.unwrap_err();
590 assert_eq!(err.kind(), ErrorKind::DataInvalid);
591 assert!(
592 err.to_string().contains(KEK_CREATED_AT_PROPERTY),
593 "error should mention the missing property: {err}"
594 );
595 }
596
597 #[tokio::test]
598 async fn test_unwrap_fails_when_kek_timestamp_tampered() {
599 let mgr = create_test_manager();
600
601 let entry_id = mgr
603 .encrypt_manifest_list_key_metadata(&sample_key_metadata())
604 .await
605 .unwrap();
606
607 let mut keys = mgr.with_encryption_keys(|k| k.clone());
609 let kek_id = keys
610 .get(&entry_id)
611 .unwrap()
612 .encrypted_by_id()
613 .unwrap()
614 .to_string();
615 let kek = keys.remove(&kek_id).unwrap();
616 let mut tampered_properties = kek.properties().clone();
617 tampered_properties.insert(KEK_CREATED_AT_PROPERTY.to_string(), "9999999".to_string());
618 let tampered_kek = EncryptedKey::builder()
619 .key_id(kek.key_id())
620 .encrypted_key_metadata(kek.encrypted_key_metadata())
621 .encrypted_by_id(kek.encrypted_by_id().unwrap())
622 .properties(tampered_properties)
623 .build();
624 keys.insert(tampered_kek.key_id().to_string(), tampered_kek);
625
626 let mgr = EncryptionManager::builder()
627 .kms_client(create_test_kms())
628 .table_key_id("master-1")
629 .encryption_keys(keys)
630 .build();
631
632 let result = mgr.decrypt_manifest_list_key_metadata(&entry_id).await;
634 assert!(
635 result.is_err(),
636 "tampered timestamp should cause decryption failure"
637 );
638 }
639
640 #[tokio::test]
641 async fn test_encrypt_decrypt_roundtrip() {
642 use crate::io::FileIO;
643
644 let io = FileIO::new_with_memory();
645 let path = "memory:///test/encrypt_roundtrip.bin";
646
647 let kms = MemoryKeyManagementClient::new();
648 kms.add_master_key("master-1").unwrap();
649 let mgr = EncryptionManager::builder()
650 .kms_client(Arc::new(kms) as Arc<dyn KeyManagementClient>)
651 .table_key_id("master-1")
652 .build();
653
654 let output = io.new_output(path).unwrap();
655 let encrypted_output = mgr.encrypt(output);
656
657 let plaintext = b"Hello, encrypted Iceberg round-trip!";
658 let serialized_metadata = encrypted_output.key_metadata().encode().unwrap();
659 encrypted_output
660 .write(bytes::Bytes::from(plaintext.to_vec()))
661 .await
662 .unwrap();
663
664 let input = io.new_input(path).unwrap();
665 let parsed_metadata = StandardKeyMetadata::decode(&serialized_metadata).unwrap();
666 let decrypted_file = EncryptedInputFile::new(input, parsed_metadata);
667
668 let content = decrypted_file.read().await.unwrap();
669 assert_eq!(&content[..], plaintext);
670 }
671}