1use std::ops::Range;
47use std::sync::Arc;
48
49use bytes::{Bytes, BytesMut};
50
51use super::AesGcmCipher;
52use crate::io::{FileRead, FileWrite};
53use crate::{Error, ErrorKind, Result};
54
55pub const PLAIN_BLOCK_SIZE: u32 = 1024 * 1024;
57
58pub const NONCE_LENGTH: u32 = 12;
60
61pub const GCM_TAG_LENGTH: u32 = 16;
63
64pub const CIPHER_BLOCK_SIZE: u32 = PLAIN_BLOCK_SIZE + NONCE_LENGTH + GCM_TAG_LENGTH;
66
67pub const GCM_STREAM_MAGIC: [u8; 4] = *b"AGS1";
69
70pub const GCM_STREAM_HEADER_LENGTH: u32 = 8;
72
73#[cfg(test)]
75pub const MIN_STREAM_LENGTH: u32 = GCM_STREAM_HEADER_LENGTH + NONCE_LENGTH + GCM_TAG_LENGTH;
76
77pub(crate) fn stream_block_aad(aad_prefix: &[u8], block_index: u32) -> Vec<u8> {
83 let index_bytes = block_index.to_le_bytes();
84 if aad_prefix.is_empty() {
85 index_bytes.to_vec()
86 } else {
87 let mut aad = Vec::with_capacity(aad_prefix.len() + 4);
88 aad.extend_from_slice(aad_prefix);
89 aad.extend_from_slice(&index_bytes);
90 aad
91 }
92}
93
94pub struct AesGcmFileRead {
116 inner: Box<dyn FileRead>,
118 cipher: Arc<AesGcmCipher>,
120 aad_prefix: Box<[u8]>,
122 plain_stream_size: u64,
124 num_blocks: u64,
126 last_cipher_block_size: u32,
128}
129
130impl AesGcmFileRead {
131 pub fn new(
144 inner: Box<dyn FileRead>,
145 cipher: Arc<AesGcmCipher>,
146 aad_prefix: Box<[u8]>,
147 encrypted_file_length: u64,
148 ) -> Result<Self> {
149 let plain_stream_size = Self::calculate_plaintext_length(encrypted_file_length)?;
150 let stream_length = encrypted_file_length - GCM_STREAM_HEADER_LENGTH as u64;
151
152 if stream_length == 0 {
153 return Ok(Self {
154 inner,
155 cipher,
156 aad_prefix,
157 plain_stream_size: 0,
158 num_blocks: 0,
159 last_cipher_block_size: 0,
160 });
161 }
162
163 let num_full_blocks = stream_length / CIPHER_BLOCK_SIZE as u64;
164 let cipher_bytes_in_last_block = (stream_length % CIPHER_BLOCK_SIZE as u64) as u32;
165 let full_blocks_only = cipher_bytes_in_last_block == 0;
166
167 let num_blocks = if full_blocks_only {
168 num_full_blocks
169 } else {
170 num_full_blocks + 1
171 };
172
173 if num_blocks > u32::MAX as u64 {
174 return Err(Error::new(
175 ErrorKind::DataInvalid,
176 format!(
177 "AGS1 format supports at most {} blocks (~4 TiB per file), but file requires {num_blocks} blocks",
178 u32::MAX
179 ),
180 ));
181 }
182
183 let last_cipher_block_size = if full_blocks_only {
184 CIPHER_BLOCK_SIZE
185 } else {
186 cipher_bytes_in_last_block
187 };
188
189 Ok(Self {
190 inner,
191 cipher,
192 aad_prefix,
193 plain_stream_size,
194 num_blocks,
195 last_cipher_block_size,
196 })
197 }
198
199 pub fn plaintext_length(&self) -> u64 {
201 self.plain_stream_size
202 }
203
204 pub fn calculate_plaintext_length(encrypted_file_length: u64) -> Result<u64> {
209 if encrypted_file_length < GCM_STREAM_HEADER_LENGTH as u64 {
210 return Err(Error::new(
211 ErrorKind::DataInvalid,
212 format!(
213 "Encrypted file too short: {encrypted_file_length} bytes (minimum {GCM_STREAM_HEADER_LENGTH})"
214 ),
215 ));
216 }
217
218 let stream_length = encrypted_file_length - GCM_STREAM_HEADER_LENGTH as u64;
219
220 if stream_length == 0 {
221 return Ok(0);
222 }
223
224 let num_full_blocks = stream_length / CIPHER_BLOCK_SIZE as u64;
225 let cipher_bytes_in_last_block = stream_length % CIPHER_BLOCK_SIZE as u64;
226 let full_blocks_only = cipher_bytes_in_last_block == 0;
227
228 let plain_bytes_in_last_block = if full_blocks_only {
229 0
230 } else {
231 if cipher_bytes_in_last_block < (NONCE_LENGTH + GCM_TAG_LENGTH) as u64 {
232 return Err(Error::new(
233 ErrorKind::DataInvalid,
234 format!(
235 "Truncated encrypted file: last block is {} bytes (minimum {})",
236 cipher_bytes_in_last_block,
237 NONCE_LENGTH + GCM_TAG_LENGTH
238 ),
239 ));
240 }
241 cipher_bytes_in_last_block - NONCE_LENGTH as u64 - GCM_TAG_LENGTH as u64
242 };
243
244 Ok(num_full_blocks * PLAIN_BLOCK_SIZE as u64 + plain_bytes_in_last_block)
245 }
246
247 fn encrypted_block_offset(block_index: u64) -> u64 {
249 block_index * CIPHER_BLOCK_SIZE as u64 + GCM_STREAM_HEADER_LENGTH as u64
250 }
251
252 fn cipher_block_size(&self, block_index: u64) -> u32 {
254 if block_index == self.num_blocks - 1 {
255 self.last_cipher_block_size
256 } else {
257 CIPHER_BLOCK_SIZE
258 }
259 }
260}
261
262#[async_trait::async_trait]
263impl FileRead for AesGcmFileRead {
264 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
291 if range.start == range.end {
292 return Ok(Bytes::new());
293 }
294
295 if range.start > range.end {
296 return Err(Error::new(
297 ErrorKind::DataInvalid,
298 format!(
299 "Invalid read range: start ({}) is greater than end ({})",
300 range.start, range.end
301 ),
302 ));
303 }
304
305 if range.end > self.plain_stream_size {
306 return Err(Error::new(
307 ErrorKind::DataInvalid,
308 format!(
309 "Read range {}..{} exceeds plaintext size {}",
310 range.start, range.end, self.plain_stream_size
311 ),
312 ));
313 }
314
315 if self.num_blocks == 0 {
316 return Ok(Bytes::new());
317 }
318
319 let first_block = range.start / PLAIN_BLOCK_SIZE as u64;
320 let last_block = (range.end - 1) / PLAIN_BLOCK_SIZE as u64;
321
322 let encrypted_start = Self::encrypted_block_offset(first_block);
324 let encrypted_end =
325 Self::encrypted_block_offset(last_block) + self.cipher_block_size(last_block) as u64;
326
327 let all_encrypted = self.inner.read(encrypted_start..encrypted_end).await?;
328
329 let result_len = (range.end - range.start) as usize;
331 let mut result = BytesMut::with_capacity(result_len);
332 let mut encrypted_offset = 0usize;
333
334 for block_idx in first_block..=last_block {
335 let block_size = self.cipher_block_size(block_idx) as usize;
336 let cipher_block = &all_encrypted[encrypted_offset..encrypted_offset + block_size];
337 encrypted_offset += block_size;
338
339 let aad = stream_block_aad(&self.aad_prefix, block_idx as u32);
340 let decrypted = self.cipher.decrypt(cipher_block, Some(&aad))?;
341
342 let block_plain_start = block_idx * PLAIN_BLOCK_SIZE as u64;
344 let slice_start = if block_idx == first_block {
345 (range.start - block_plain_start) as usize
346 } else {
347 0
348 };
349 let slice_end = if block_idx == last_block {
350 (range.end - block_plain_start) as usize
351 } else {
352 decrypted.len()
353 };
354
355 result.extend_from_slice(&decrypted[slice_start..slice_end]);
356 }
357
358 Ok(result.freeze())
359 }
360}
361
362pub struct AesGcmFileWrite {
382 inner: Box<dyn FileWrite>,
384 cipher: Arc<AesGcmCipher>,
386 aad_prefix: Box<[u8]>,
388 buffer: Vec<u8>,
390 block_index: u32,
392 header_written: bool,
394 closed: bool,
396 poisoned: bool,
400}
401
402impl AesGcmFileWrite {
403 pub fn new(
407 inner: Box<dyn FileWrite>,
408 cipher: Arc<AesGcmCipher>,
409 aad_prefix: impl Into<Box<[u8]>>,
410 ) -> Self {
411 Self {
412 inner,
413 cipher,
414 aad_prefix: aad_prefix.into(),
415 buffer: Vec::new(),
416 block_index: 0,
417 header_written: false,
418 closed: false,
419 poisoned: false,
420 }
421 }
422
423 async fn write_header(&mut self) -> Result<()> {
425 let mut header = Vec::with_capacity(GCM_STREAM_HEADER_LENGTH as usize);
426 header.extend_from_slice(&GCM_STREAM_MAGIC);
427 header.extend_from_slice(&PLAIN_BLOCK_SIZE.to_le_bytes());
428 if let Err(e) = self.inner.write(Bytes::from(header)).await {
429 self.poisoned = true;
430 return Err(e);
431 }
432 self.header_written = true;
433 Ok(())
434 }
435
436 async fn encrypt_and_write_block(&mut self, block_data: &[u8]) -> Result<()> {
438 let aad = stream_block_aad(&self.aad_prefix, self.block_index);
439 let encrypted = self.cipher.encrypt(block_data, Some(&aad))?;
440 if let Err(e) = self.inner.write(Bytes::from(encrypted)).await {
441 self.poisoned = true;
442 return Err(e);
443 }
444 self.block_index = self.block_index.checked_add(1).ok_or_else(|| {
445 Error::new(
446 ErrorKind::DataInvalid,
447 "AGS1 block index overflow: file exceeds the maximum supported size (~4 TiB)",
448 )
449 })?;
450 Ok(())
451 }
452
453 async fn encrypt_and_drain_block(&mut self) -> Result<()> {
456 let aad = stream_block_aad(&self.aad_prefix, self.block_index);
457 let encrypted = self
458 .cipher
459 .encrypt(&self.buffer[..PLAIN_BLOCK_SIZE as usize], Some(&aad))?;
460 if let Err(e) = self.inner.write(Bytes::from(encrypted)).await {
461 self.poisoned = true;
462 return Err(e);
463 }
464 self.block_index = self.block_index.checked_add(1).ok_or_else(|| {
465 Error::new(
466 ErrorKind::DataInvalid,
467 "AGS1 block index overflow: file exceeds the maximum supported size (~4 TiB)",
468 )
469 })?;
470 self.buffer.drain(..PLAIN_BLOCK_SIZE as usize);
471 Ok(())
472 }
473}
474
475#[async_trait::async_trait]
476impl FileWrite for AesGcmFileWrite {
477 async fn write(&mut self, bs: Bytes) -> Result<()> {
478 if self.closed {
479 return Err(Error::new(
480 ErrorKind::Unexpected,
481 "Cannot write to a closed AesGcmFileWrite",
482 ));
483 }
484 if self.poisoned {
485 return Err(Error::new(
486 ErrorKind::Unexpected,
487 "AesGcmFileWrite is in a poisoned state due to a previous write failure",
488 ));
489 }
490
491 if !self.header_written {
492 self.write_header().await?;
493 }
494
495 self.buffer.extend_from_slice(&bs);
496
497 while self.buffer.len() >= PLAIN_BLOCK_SIZE as usize {
499 self.encrypt_and_drain_block().await?;
500 }
501
502 Ok(())
503 }
504
505 async fn close(&mut self) -> Result<()> {
506 if self.closed {
507 return Err(Error::new(
508 ErrorKind::Unexpected,
509 "AesGcmFileWrite already closed",
510 ));
511 }
512 if self.poisoned {
513 return Err(Error::new(
514 ErrorKind::Unexpected,
515 "AesGcmFileWrite is in a poisoned state due to a previous write failure",
516 ));
517 }
518
519 if !self.header_written {
520 self.write_header().await?;
521 }
522
523 if !self.buffer.is_empty() || self.block_index == 0 {
527 let final_block = std::mem::take(&mut self.buffer);
528 self.encrypt_and_write_block(&final_block).await?;
529 }
530 self.closed = true;
531
532 self.inner.close().await
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn encrypt_ags1(plaintext: &[u8], cipher: &AesGcmCipher, aad_prefix: &[u8]) -> Vec<u8> {
546 let mut result = Vec::new();
547
548 result.extend_from_slice(&GCM_STREAM_MAGIC);
550 result.extend_from_slice(&PLAIN_BLOCK_SIZE.to_le_bytes());
551
552 let mut offset = 0;
554 let mut block_index = 0u32;
555
556 loop {
557 let remaining = plaintext.len() - offset;
558 let block_size = std::cmp::min(remaining, PLAIN_BLOCK_SIZE as usize);
559
560 if block_size == 0 && block_index > 0 {
562 break;
563 }
564
565 let block_data = &plaintext[offset..offset + block_size];
566 let aad = stream_block_aad(aad_prefix, block_index);
567 let encrypted = cipher.encrypt(block_data, Some(&aad)).unwrap();
568 result.extend_from_slice(&encrypted);
569
570 offset += block_size;
571 block_index += 1;
572
573 if block_size < PLAIN_BLOCK_SIZE as usize {
575 break;
576 }
577 }
578
579 result
580 }
581
582 fn make_cipher(key: &[u8]) -> AesGcmCipher {
584 use super::super::SecureKey;
585 let secure_key = SecureKey::new(key).unwrap();
586 AesGcmCipher::new(secure_key)
587 }
588
589 fn memory_reader(data: Vec<u8>) -> Box<dyn FileRead> {
591 Box::new(MemoryFileRead(Bytes::from(data)))
592 }
593
594 struct MemoryFileRead(Bytes);
596
597 #[async_trait::async_trait]
598 impl FileRead for MemoryFileRead {
599 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
600 let start = range.start as usize;
601 let end = range.end as usize;
602 if end > self.0.len() {
603 return Err(Error::new(
604 ErrorKind::DataInvalid,
605 format!(
606 "Range {}..{} out of bounds for {} bytes",
607 start,
608 end,
609 self.0.len()
610 ),
611 ));
612 }
613 Ok(self.0.slice(start..end))
614 }
615 }
616
617 #[tokio::test]
618 async fn test_empty_file_roundtrip() {
619 let key = b"0123456789abcdef";
620 let aad_prefix = b"test-aad-prefix!";
621 let cipher = make_cipher(key);
622
623 let encrypted = encrypt_ags1(b"", &cipher, aad_prefix);
624
625 assert_eq!(encrypted.len(), MIN_STREAM_LENGTH as usize);
627
628 let reader = AesGcmFileRead::new(
629 memory_reader(encrypted.clone()),
630 Arc::new(make_cipher(key)),
631 aad_prefix.as_slice().into(),
632 encrypted.len() as u64,
633 )
634 .unwrap();
635
636 assert_eq!(reader.plaintext_length(), 0);
637
638 let result = reader.read(0..0).await.unwrap();
640 assert!(result.is_empty());
641 }
642
643 #[tokio::test]
644 async fn test_small_file_roundtrip() {
645 let key = b"0123456789abcdef";
646 let aad_prefix = b"test-aad-prefix!";
647 let plaintext = b"Hello, Iceberg encryption!";
648 let cipher = make_cipher(key);
649
650 let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
651
652 let reader = AesGcmFileRead::new(
653 memory_reader(encrypted.clone()),
654 Arc::new(make_cipher(key)),
655 aad_prefix.as_slice().into(),
656 encrypted.len() as u64,
657 )
658 .unwrap();
659
660 assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
661
662 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
664 assert_eq!(&result[..], plaintext);
665 }
666
667 #[tokio::test]
668 async fn test_partial_read() {
669 let key = b"0123456789abcdef";
670 let aad_prefix = b"aad-prefix-here!";
671 let plaintext = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
672 let cipher = make_cipher(key);
673
674 let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
675
676 let reader = AesGcmFileRead::new(
677 memory_reader(encrypted.clone()),
678 Arc::new(make_cipher(key)),
679 aad_prefix.as_slice().into(),
680 encrypted.len() as u64,
681 )
682 .unwrap();
683
684 let result = reader.read(10..20).await.unwrap();
686 assert_eq!(&result[..], &plaintext[10..20]);
687
688 let result = reader.read(0..1).await.unwrap();
690 assert_eq!(&result[..], &plaintext[0..1]);
691
692 let last = plaintext.len() as u64;
694 let result = reader.read(last - 1..last).await.unwrap();
695 assert_eq!(&result[..], &plaintext[plaintext.len() - 1..]);
696 }
697
698 #[tokio::test]
699 async fn test_multi_block_roundtrip() {
700 let key = b"0123456789abcdef";
701 let aad_prefix = b"multi-block-aad!";
702
703 let size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
705 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
706 let cipher = make_cipher(key);
707
708 let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
709
710 let reader = AesGcmFileRead::new(
711 memory_reader(encrypted.clone()),
712 Arc::new(make_cipher(key)),
713 aad_prefix.as_slice().into(),
714 encrypted.len() as u64,
715 )
716 .unwrap();
717
718 assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
719
720 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
722 assert_eq!(&result[..], &plaintext[..]);
723 }
724
725 #[tokio::test]
726 async fn test_cross_block_read() {
727 let key = b"0123456789abcdef";
728 let aad_prefix = b"cross-block-aad!";
729
730 let size = PLAIN_BLOCK_SIZE as usize * 2 + PLAIN_BLOCK_SIZE as usize / 2;
732 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
733 let cipher = make_cipher(key);
734
735 let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
736
737 let reader = AesGcmFileRead::new(
738 memory_reader(encrypted.clone()),
739 Arc::new(make_cipher(key)),
740 aad_prefix.as_slice().into(),
741 encrypted.len() as u64,
742 )
743 .unwrap();
744
745 let boundary = PLAIN_BLOCK_SIZE as u64;
747 let result = reader.read(boundary - 100..boundary + 100).await.unwrap();
748 assert_eq!(
749 &result[..],
750 &plaintext[(boundary - 100) as usize..(boundary + 100) as usize]
751 );
752
753 let result = reader.read(boundary - 50..boundary * 2 + 50).await.unwrap();
755 assert_eq!(
756 &result[..],
757 &plaintext[(boundary - 50) as usize..(boundary * 2 + 50) as usize]
758 );
759 }
760
761 #[tokio::test]
762 async fn test_exact_block_size() {
763 let key = b"0123456789abcdef";
764 let aad_prefix = b"exact-block-aad!";
765
766 let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
768 .map(|i| (i % 256) as u8)
769 .collect();
770 let cipher = make_cipher(key);
771
772 let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
773
774 let reader = AesGcmFileRead::new(
775 memory_reader(encrypted.clone()),
776 Arc::new(make_cipher(key)),
777 aad_prefix.as_slice().into(),
778 encrypted.len() as u64,
779 )
780 .unwrap();
781
782 assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
783
784 let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
785 assert_eq!(&result[..], &plaintext[..]);
786 }
787
788 #[tokio::test]
789 async fn test_block_size_plus_one() {
790 let key = b"0123456789abcdef";
791 let aad_prefix = b"block-plus-one!!";
792
793 let size = PLAIN_BLOCK_SIZE as usize + 1;
795 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
796 let cipher = make_cipher(key);
797
798 let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
799
800 let reader = AesGcmFileRead::new(
801 memory_reader(encrypted.clone()),
802 Arc::new(make_cipher(key)),
803 aad_prefix.as_slice().into(),
804 encrypted.len() as u64,
805 )
806 .unwrap();
807
808 assert_eq!(reader.plaintext_length(), size as u64);
809
810 let result = reader.read(size as u64 - 1..size as u64).await.unwrap();
812 assert_eq!(result[0], plaintext[size - 1]);
813
814 let result = reader.read(0..size as u64).await.unwrap();
816 assert_eq!(&result[..], &plaintext[..]);
817 }
818
819 #[tokio::test]
820 async fn test_block_size_minus_one() {
821 let key = b"0123456789abcdef";
822 let aad_prefix = b"block-minus-one!";
823
824 let size = PLAIN_BLOCK_SIZE as usize - 1;
826 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
827 let cipher = make_cipher(key);
828
829 let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
830
831 let reader = AesGcmFileRead::new(
832 memory_reader(encrypted.clone()),
833 Arc::new(make_cipher(key)),
834 aad_prefix.as_slice().into(),
835 encrypted.len() as u64,
836 )
837 .unwrap();
838
839 assert_eq!(reader.plaintext_length(), size as u64);
840
841 let result = reader.read(0..size as u64).await.unwrap();
842 assert_eq!(&result[..], &plaintext[..]);
843 }
844
845 #[tokio::test]
846 async fn test_wrong_aad_fails() {
847 let key = b"0123456789abcdef";
848 let aad_prefix = b"correct-aad-here";
849 let plaintext = b"sensitive data here";
850 let cipher = make_cipher(key);
851
852 let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
853
854 let mut bad_aad = aad_prefix.to_vec();
856 bad_aad[0] ^= 0xFF;
857
858 let reader = AesGcmFileRead::new(
859 memory_reader(encrypted.clone()),
860 Arc::new(make_cipher(key)),
861 bad_aad.as_slice().into(),
862 encrypted.len() as u64,
863 )
864 .unwrap();
865
866 let result = reader.read(0..plaintext.len() as u64).await;
867 assert!(result.is_err(), "Decryption with wrong AAD should fail");
868 }
869
870 #[tokio::test]
871 async fn test_wrong_key_fails() {
872 let key = b"0123456789abcdef";
873 let wrong_key = b"fedcba9876543210";
874 let aad_prefix = b"test-aad-prefix!";
875 let plaintext = b"sensitive data";
876 let cipher = make_cipher(key);
877
878 let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
879
880 let reader = AesGcmFileRead::new(
881 memory_reader(encrypted.clone()),
882 Arc::new(make_cipher(wrong_key)),
883 aad_prefix.as_slice().into(),
884 encrypted.len() as u64,
885 )
886 .unwrap();
887
888 let result = reader.read(0..plaintext.len() as u64).await;
889 assert!(result.is_err(), "Decryption with wrong key should fail");
890 }
891
892 #[tokio::test]
893 async fn test_out_of_bounds_read() {
894 let key = b"0123456789abcdef";
895 let aad_prefix = b"test-aad-prefix!";
896 let plaintext = b"short data";
897 let cipher = make_cipher(key);
898
899 let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
900
901 let reader = AesGcmFileRead::new(
902 memory_reader(encrypted.clone()),
903 Arc::new(make_cipher(key)),
904 aad_prefix.as_slice().into(),
905 encrypted.len() as u64,
906 )
907 .unwrap();
908
909 let result = reader.read(0..plaintext.len() as u64 + 1).await;
910 assert!(result.is_err(), "Reading past end should fail");
911 }
912
913 #[tokio::test]
914 async fn test_calculate_plaintext_length() {
915 assert_eq!(
917 AesGcmFileRead::calculate_plaintext_length(GCM_STREAM_HEADER_LENGTH as u64).unwrap(),
918 0
919 );
920
921 assert_eq!(
923 AesGcmFileRead::calculate_plaintext_length(MIN_STREAM_LENGTH as u64).unwrap(),
924 0
925 );
926
927 let one_full = GCM_STREAM_HEADER_LENGTH as u64 + CIPHER_BLOCK_SIZE as u64;
929 assert_eq!(
930 AesGcmFileRead::calculate_plaintext_length(one_full).unwrap(),
931 PLAIN_BLOCK_SIZE as u64
932 );
933
934 let one_full_plus_one = one_full + NONCE_LENGTH as u64 + 1 + GCM_TAG_LENGTH as u64;
937 assert_eq!(
938 AesGcmFileRead::calculate_plaintext_length(one_full_plus_one).unwrap(),
939 PLAIN_BLOCK_SIZE as u64 + 1
940 );
941 }
942
943 #[tokio::test]
944 async fn test_stream_block_aad() {
945 let aad = stream_block_aad(b"prefix", 0);
947 assert_eq!(&aad[..6], b"prefix");
948 assert_eq!(&aad[6..], &0u32.to_le_bytes());
949
950 let aad = stream_block_aad(b"prefix", 1);
951 assert_eq!(&aad[..6], b"prefix");
952 assert_eq!(&aad[6..], &1u32.to_le_bytes());
953
954 let aad = stream_block_aad(b"", 42);
956 assert_eq!(&aad[..], &42u32.to_le_bytes());
957 }
958
959 #[tokio::test]
960 async fn test_encrypted_file_too_short() {
961 let result = AesGcmFileRead::new(
962 memory_reader(vec![0; 4]),
963 Arc::new(make_cipher(b"0123456789abcdef")),
964 [].into(),
965 4,
966 );
967 assert!(result.is_err());
968 }
969
970 struct SharedMemoryWrite {
974 buffer: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
975 }
976
977 struct FailingFileWrite {
979 writes_before_failure: usize,
980 write_count: usize,
981 }
982
983 #[async_trait::async_trait]
984 impl FileWrite for FailingFileWrite {
985 async fn write(&mut self, _bs: Bytes) -> Result<()> {
986 if self.write_count >= self.writes_before_failure {
987 return Err(Error::new(ErrorKind::Unexpected, "simulated write failure"));
988 }
989 self.write_count += 1;
990 Ok(())
991 }
992
993 async fn close(&mut self) -> Result<()> {
994 Ok(())
995 }
996 }
997
998 #[async_trait::async_trait]
999 impl FileWrite for SharedMemoryWrite {
1000 async fn write(&mut self, bs: Bytes) -> Result<()> {
1001 self.buffer.lock().unwrap().extend_from_slice(&bs);
1002 Ok(())
1003 }
1004
1005 async fn close(&mut self) -> Result<()> {
1006 Ok(())
1007 }
1008 }
1009
1010 async fn write_through_ags1(plaintext: &[u8], key: &[u8], aad_prefix: &[u8]) -> Vec<u8> {
1012 let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1013 let inner: Box<dyn FileWrite> = Box::new(SharedMemoryWrite {
1014 buffer: buffer.clone(),
1015 });
1016 let cipher = Arc::new(make_cipher(key));
1017 let mut writer = AesGcmFileWrite::new(inner, cipher, aad_prefix.to_vec());
1018
1019 writer.write(Bytes::from(plaintext.to_vec())).await.unwrap();
1020 writer.close().await.unwrap();
1021
1022 buffer.lock().unwrap().clone()
1023 }
1024
1025 #[tokio::test]
1026 async fn test_write_empty_roundtrip() {
1027 let key = b"0123456789abcdef";
1028 let aad_prefix = b"test-aad-prefix!";
1029
1030 let encrypted = write_through_ags1(b"", key, aad_prefix).await;
1031
1032 assert_eq!(encrypted.len(), MIN_STREAM_LENGTH as usize);
1034
1035 let reader = AesGcmFileRead::new(
1036 memory_reader(encrypted.clone()),
1037 Arc::new(make_cipher(key)),
1038 aad_prefix.as_slice().into(),
1039 encrypted.len() as u64,
1040 )
1041 .unwrap();
1042
1043 assert_eq!(reader.plaintext_length(), 0);
1044 }
1045
1046 #[tokio::test]
1047 async fn test_write_small_roundtrip() {
1048 let key = b"0123456789abcdef";
1049 let aad_prefix = b"test-aad-prefix!";
1050 let plaintext = b"Hello, Iceberg encryption!";
1051
1052 let encrypted = write_through_ags1(plaintext, key, aad_prefix).await;
1053
1054 let reader = AesGcmFileRead::new(
1055 memory_reader(encrypted.clone()),
1056 Arc::new(make_cipher(key)),
1057 aad_prefix.as_slice().into(),
1058 encrypted.len() as u64,
1059 )
1060 .unwrap();
1061
1062 assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1063 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1064 assert_eq!(&result[..], plaintext);
1065 }
1066
1067 #[tokio::test]
1068 async fn test_write_multi_block_roundtrip() {
1069 let key = b"0123456789abcdef";
1070 let aad_prefix = b"multi-block-aad!";
1071
1072 let size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
1074 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
1075
1076 let encrypted = write_through_ags1(&plaintext, key, aad_prefix).await;
1077
1078 let reader = AesGcmFileRead::new(
1079 memory_reader(encrypted.clone()),
1080 Arc::new(make_cipher(key)),
1081 aad_prefix.as_slice().into(),
1082 encrypted.len() as u64,
1083 )
1084 .unwrap();
1085
1086 assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1087 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1088 assert_eq!(&result[..], &plaintext[..]);
1089 }
1090
1091 #[tokio::test]
1092 async fn test_write_cross_block_accumulation() {
1093 let key = b"0123456789abcdef";
1094 let aad_prefix = b"cross-block-aad!";
1095
1096 let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1097 let inner: Box<dyn FileWrite> = Box::new(SharedMemoryWrite {
1098 buffer: buffer.clone(),
1099 });
1100 let cipher = Arc::new(make_cipher(key));
1101 let mut writer = AesGcmFileWrite::new(inner, cipher, aad_prefix.to_vec());
1102
1103 let total_size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
1105 let plaintext: Vec<u8> = (0..total_size).map(|i| (i % 256) as u8).collect();
1106 let chunk_size = 1000;
1107 for chunk in plaintext.chunks(chunk_size) {
1108 writer.write(Bytes::from(chunk.to_vec())).await.unwrap();
1109 }
1110 writer.close().await.unwrap();
1111
1112 let encrypted = buffer.lock().unwrap().clone();
1113
1114 let reader = AesGcmFileRead::new(
1115 memory_reader(encrypted.clone()),
1116 Arc::new(make_cipher(key)),
1117 aad_prefix.as_slice().into(),
1118 encrypted.len() as u64,
1119 )
1120 .unwrap();
1121
1122 assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1123 let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1124 assert_eq!(&result[..], &plaintext[..]);
1125 }
1126
1127 #[tokio::test]
1128 async fn test_write_exact_block_size() {
1129 let key = b"0123456789abcdef";
1130 let aad_prefix = b"exact-block-aad!";
1131
1132 let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
1134 .map(|i| (i % 256) as u8)
1135 .collect();
1136
1137 let encrypted = write_through_ags1(&plaintext, key, aad_prefix).await;
1138
1139 let reader = AesGcmFileRead::new(
1140 memory_reader(encrypted.clone()),
1141 Arc::new(make_cipher(key)),
1142 aad_prefix.as_slice().into(),
1143 encrypted.len() as u64,
1144 )
1145 .unwrap();
1146
1147 assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
1148 let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
1149 assert_eq!(&result[..], &plaintext[..]);
1150 }
1151
1152 #[tokio::test]
1153 async fn test_write_block_aligned_no_spurious_empty_block() {
1154 let key = b"0123456789abcdef";
1155 let aad_prefix = b"block-align-aad!";
1156
1157 let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
1160 .map(|i| (i % 256) as u8)
1161 .collect();
1162
1163 let encrypted_via_writer = write_through_ags1(&plaintext, key, aad_prefix).await;
1164 let encrypted_via_reference = encrypt_ags1(&plaintext, &make_cipher(key), aad_prefix);
1165
1166 assert_eq!(
1168 encrypted_via_writer.len(),
1169 encrypted_via_reference.len(),
1170 "Writer output should match reference encryption length (no spurious trailing block)"
1171 );
1172
1173 let reader = AesGcmFileRead::new(
1175 memory_reader(encrypted_via_writer.clone()),
1176 Arc::new(make_cipher(key)),
1177 aad_prefix.as_slice().into(),
1178 encrypted_via_writer.len() as u64,
1179 )
1180 .unwrap();
1181
1182 assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
1183 let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
1184 assert_eq!(&result[..], &plaintext[..]);
1185 }
1186
1187 #[tokio::test]
1188 async fn test_write_two_blocks_aligned_no_spurious_empty_block() {
1189 let key = b"0123456789abcdef";
1190 let aad_prefix = b"2blk-align-aad!!";
1191
1192 let size = PLAIN_BLOCK_SIZE as usize * 2;
1194 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
1195
1196 let encrypted_via_writer = write_through_ags1(&plaintext, key, aad_prefix).await;
1197 let encrypted_via_reference = encrypt_ags1(&plaintext, &make_cipher(key), aad_prefix);
1198
1199 assert_eq!(
1200 encrypted_via_writer.len(),
1201 encrypted_via_reference.len(),
1202 "Writer output should match reference encryption length (no spurious trailing block)"
1203 );
1204
1205 let reader = AesGcmFileRead::new(
1206 memory_reader(encrypted_via_writer.clone()),
1207 Arc::new(make_cipher(key)),
1208 aad_prefix.as_slice().into(),
1209 encrypted_via_writer.len() as u64,
1210 )
1211 .unwrap();
1212
1213 assert_eq!(reader.plaintext_length(), size as u64);
1214 let result = reader.read(0..size as u64).await.unwrap();
1215 assert_eq!(&result[..], &plaintext[..]);
1216 }
1217
1218 #[tokio::test]
1219 async fn test_write_poisoned_after_inner_write_failure() {
1220 let cipher = Arc::new(make_cipher(b"0123456789abcdef"));
1221 let inner: Box<dyn FileWrite> = Box::new(FailingFileWrite {
1223 writes_before_failure: 1,
1224 write_count: 0,
1225 });
1226 let mut writer = AesGcmFileWrite::new(inner, cipher, b"aad-prefix-here!".to_vec());
1227
1228 let data = vec![0u8; PLAIN_BLOCK_SIZE as usize];
1230 let result = writer.write(Bytes::from(data)).await;
1231 assert!(result.is_err());
1232
1233 let result = writer.write(Bytes::from(b"more data".to_vec())).await;
1235 assert!(result.is_err());
1236 assert!(
1237 result.unwrap_err().to_string().contains("poisoned"),
1238 "expected poisoned error"
1239 );
1240
1241 let result = writer.close().await;
1243 assert!(result.is_err());
1244 assert!(
1245 result.unwrap_err().to_string().contains("poisoned"),
1246 "expected poisoned error on close"
1247 );
1248 }
1249}