iceberg/encryption/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! AGS1 stream encryption/decryption for Iceberg.
19//!
20//! Implements the block-based AES-GCM stream format used by Iceberg for
21//! encrypting manifest lists and manifest files. The format is
22//! byte-compatible with Java's `AesGcmInputStream` / `AesGcmOutputStream`.
23//!
24//! # AGS1 File Format
25//!
26//! ```text
27//! ┌─────────────────────────────────────────────┐
28//! │ Header (8 bytes)                            │
29//! │   Magic: "AGS1" (4 bytes, ASCII)            │
30//! │   Plain block size: u32 LE (4 bytes)        │
31//! │     Default: 1,048,576 (1 MiB)              │
32//! ├─────────────────────────────────────────────┤
33//! │ Block 0                                     │
34//! │   Nonce (12 bytes)                          │
35//! │   Ciphertext (up to plain_block_size bytes) │
36//! │   GCM Tag (16 bytes)                        │
37//! ├─────────────────────────────────────────────┤
38//! │ Block 1..N (same structure)                 │
39//! ├─────────────────────────────────────────────┤
40//! │ Final block (may be shorter)                │
41//! └─────────────────────────────────────────────┘
42//! ```
43//!
44//! Each block's AAD is: `aad_prefix || block_index (4 bytes, LE)`.
45
46use std::ops::Range;
47use std::sync::Arc;
48
49use bytes::{Bytes, BytesMut};
50
51use super::AesGcmCipher;
52use crate::io::{FileRead, FileWrite};
53use crate::{Error, ErrorKind, Result};
54
55/// Default plaintext block size (1 MiB), matching Java's `Ciphers.PLAIN_BLOCK_SIZE`.
56pub const PLAIN_BLOCK_SIZE: u32 = 1024 * 1024;
57
58/// AES-GCM nonce length in bytes.
59pub const NONCE_LENGTH: u32 = 12;
60
61/// AES-GCM authentication tag length in bytes.
62pub const GCM_TAG_LENGTH: u32 = 16;
63
64/// Cipher block size = plaintext block size + nonce + GCM tag.
65pub const CIPHER_BLOCK_SIZE: u32 = PLAIN_BLOCK_SIZE + NONCE_LENGTH + GCM_TAG_LENGTH;
66
67/// AGS1 stream magic bytes.
68pub const GCM_STREAM_MAGIC: [u8; 4] = *b"AGS1";
69
70/// AGS1 stream header length (4-byte magic + 4-byte block size).
71pub const GCM_STREAM_HEADER_LENGTH: u32 = 8;
72
73/// Minimum valid AGS1 stream length (header + one empty block).
74#[cfg(test)]
75pub const MIN_STREAM_LENGTH: u32 = GCM_STREAM_HEADER_LENGTH + NONCE_LENGTH + GCM_TAG_LENGTH;
76
77/// Constructs the per-block AAD for AGS1 stream encryption.
78///
79/// Format: `aad_prefix || block_index (4 bytes, little-endian)`
80///
81/// This matches Java's `Ciphers.streamBlockAAD()`.
82pub(crate) fn stream_block_aad(aad_prefix: &[u8], block_index: u32) -> Vec<u8> {
83    let index_bytes = block_index.to_le_bytes();
84    if aad_prefix.is_empty() {
85        index_bytes.to_vec()
86    } else {
87        let mut aad = Vec::with_capacity(aad_prefix.len() + 4);
88        aad.extend_from_slice(aad_prefix);
89        aad.extend_from_slice(&index_bytes);
90        aad
91    }
92}
93
94/// Transparent decryption of AGS1 stream-encrypted files.
95///
96/// Implements the [`FileRead`] trait, providing random-access reads over
97/// encrypted data. Each `read()` call determines which encrypted blocks
98/// overlap the requested plaintext range, reads and decrypts them, then
99/// returns the requested plaintext bytes.
100///
101/// # Usage
102///
103/// ```ignore
104/// // (ignored: requires async runtime and concrete FileRead/FileWrite impls)
105/// let reader = AesGcmFileRead::new(
106///     inner_reader,       // Box<dyn FileRead> for the encrypted file
107///     cipher,             // Arc<AesGcmCipher> with the DEK
108///     aad_prefix.to_vec(),
109///     encrypted_file_length,
110/// )?;
111///
112/// // Read plaintext bytes transparently
113/// let plaintext = reader.read(0..1024).await?;
114/// ```
115pub struct AesGcmFileRead {
116    /// The underlying encrypted file reader.
117    inner: Box<dyn FileRead>,
118    /// The AES-GCM cipher holding the DEK.
119    cipher: Arc<AesGcmCipher>,
120    /// AAD prefix from the key metadata.
121    aad_prefix: Box<[u8]>,
122    /// Total plaintext stream size in bytes.
123    plain_stream_size: u64,
124    /// Total number of encrypted blocks.
125    num_blocks: u64,
126    /// Size of the last cipher block (may be smaller than `CIPHER_BLOCK_SIZE`).
127    last_cipher_block_size: u32,
128}
129
130impl AesGcmFileRead {
131    /// Creates a new `AesGcmFileRead` for decrypting an AGS1 stream.
132    ///
133    /// Computes the plaintext size and block layout from the encrypted file
134    /// length. No I/O is performed; header validation happens implicitly
135    /// when blocks are decrypted (GCM authentication will fail on corrupt data).
136    ///
137    /// # Arguments
138    ///
139    /// * `inner` - Reader for the underlying encrypted file
140    /// * `cipher` - AES-GCM cipher initialized with the file's DEK
141    /// * `aad_prefix` - AAD prefix from the file's `StandardKeyMetadata`
142    /// * `encrypted_file_length` - Total byte length of the encrypted file
143    pub fn new(
144        inner: Box<dyn FileRead>,
145        cipher: Arc<AesGcmCipher>,
146        aad_prefix: Box<[u8]>,
147        encrypted_file_length: u64,
148    ) -> Result<Self> {
149        let plain_stream_size = Self::calculate_plaintext_length(encrypted_file_length)?;
150        let stream_length = encrypted_file_length - GCM_STREAM_HEADER_LENGTH as u64;
151
152        if stream_length == 0 {
153            return Ok(Self {
154                inner,
155                cipher,
156                aad_prefix,
157                plain_stream_size: 0,
158                num_blocks: 0,
159                last_cipher_block_size: 0,
160            });
161        }
162
163        let num_full_blocks = stream_length / CIPHER_BLOCK_SIZE as u64;
164        let cipher_bytes_in_last_block = (stream_length % CIPHER_BLOCK_SIZE as u64) as u32;
165        let full_blocks_only = cipher_bytes_in_last_block == 0;
166
167        let num_blocks = if full_blocks_only {
168            num_full_blocks
169        } else {
170            num_full_blocks + 1
171        };
172
173        if num_blocks > u32::MAX as u64 {
174            return Err(Error::new(
175                ErrorKind::DataInvalid,
176                format!(
177                    "AGS1 format supports at most {} blocks (~4 TiB per file), but file requires {num_blocks} blocks",
178                    u32::MAX
179                ),
180            ));
181        }
182
183        let last_cipher_block_size = if full_blocks_only {
184            CIPHER_BLOCK_SIZE
185        } else {
186            cipher_bytes_in_last_block
187        };
188
189        Ok(Self {
190            inner,
191            cipher,
192            aad_prefix,
193            plain_stream_size,
194            num_blocks,
195            last_cipher_block_size,
196        })
197    }
198
199    /// Returns the plaintext stream size in bytes.
200    pub fn plaintext_length(&self) -> u64 {
201        self.plain_stream_size
202    }
203
204    /// Calculates the plaintext length from an encrypted file's total length.
205    ///
206    /// This is a static calculation matching Java's
207    /// `AesGcmInputStream.calculatePlaintextLength()`.
208    pub fn calculate_plaintext_length(encrypted_file_length: u64) -> Result<u64> {
209        if encrypted_file_length < GCM_STREAM_HEADER_LENGTH as u64 {
210            return Err(Error::new(
211                ErrorKind::DataInvalid,
212                format!(
213                    "Encrypted file too short: {encrypted_file_length} bytes (minimum {GCM_STREAM_HEADER_LENGTH})"
214                ),
215            ));
216        }
217
218        let stream_length = encrypted_file_length - GCM_STREAM_HEADER_LENGTH as u64;
219
220        if stream_length == 0 {
221            return Ok(0);
222        }
223
224        let num_full_blocks = stream_length / CIPHER_BLOCK_SIZE as u64;
225        let cipher_bytes_in_last_block = stream_length % CIPHER_BLOCK_SIZE as u64;
226        let full_blocks_only = cipher_bytes_in_last_block == 0;
227
228        let plain_bytes_in_last_block = if full_blocks_only {
229            0
230        } else {
231            if cipher_bytes_in_last_block < (NONCE_LENGTH + GCM_TAG_LENGTH) as u64 {
232                return Err(Error::new(
233                    ErrorKind::DataInvalid,
234                    format!(
235                        "Truncated encrypted file: last block is {} bytes (minimum {})",
236                        cipher_bytes_in_last_block,
237                        NONCE_LENGTH + GCM_TAG_LENGTH
238                    ),
239                ));
240            }
241            cipher_bytes_in_last_block - NONCE_LENGTH as u64 - GCM_TAG_LENGTH as u64
242        };
243
244        Ok(num_full_blocks * PLAIN_BLOCK_SIZE as u64 + plain_bytes_in_last_block)
245    }
246
247    /// Returns the encrypted byte offset for a given block index.
248    fn encrypted_block_offset(block_index: u64) -> u64 {
249        block_index * CIPHER_BLOCK_SIZE as u64 + GCM_STREAM_HEADER_LENGTH as u64
250    }
251
252    /// Returns the cipher block size for a given block index.
253    fn cipher_block_size(&self, block_index: u64) -> u32 {
254        if block_index == self.num_blocks - 1 {
255            self.last_cipher_block_size
256        } else {
257            CIPHER_BLOCK_SIZE
258        }
259    }
260}
261
262#[async_trait::async_trait]
263impl FileRead for AesGcmFileRead {
264    /// Reads and decrypts a plaintext byte range from the encrypted AGS1 stream.
265    ///
266    /// The caller specifies a range in **plaintext** coordinates (e.g. "bytes 0..1024
267    /// of the original file"). This method translates that into the encrypted file
268    /// layout and performs the following steps:
269    ///
270    /// 1. **Map to blocks** — divides the plaintext range by `PLAIN_BLOCK_SIZE` to
271    ///    find which encrypted blocks (`first_block..=last_block`) contain the
272    ///    requested data.
273    ///
274    /// 2. **Single I/O read** — calculates the contiguous byte range in the
275    ///    encrypted file that covers all needed blocks (including the 8-byte AGS1
276    ///    header offset, 12-byte nonces, and 16-byte GCM tags) and fetches them in
277    ///    one call to the inner `FileRead`.
278    ///
279    /// 3. **Decrypt per block** — iterates over each cipher block in the response,
280    ///    decrypts it with AES-GCM using the per-block AAD (`aad_prefix || block_index`),
281    ///    and slices out only the plaintext bytes that overlap the requested range.
282    ///
283    /// 4. **Assemble result** — concatenates the slices into a single `Bytes` buffer
284    ///    matching exactly `range.end - range.start` bytes.
285    ///
286    /// Because each block is independently encrypted with its own nonce and AAD,
287    /// arbitrary random-access reads are supported without decrypting the entire
288    /// file. GCM authentication is verified per-block, so any tampering is detected
289    /// at the granularity of individual blocks.
290    async fn read(&self, range: Range<u64>) -> Result<Bytes> {
291        if range.start == range.end {
292            return Ok(Bytes::new());
293        }
294
295        if range.start > range.end {
296            return Err(Error::new(
297                ErrorKind::DataInvalid,
298                format!(
299                    "Invalid read range: start ({}) is greater than end ({})",
300                    range.start, range.end
301                ),
302            ));
303        }
304
305        if range.end > self.plain_stream_size {
306            return Err(Error::new(
307                ErrorKind::DataInvalid,
308                format!(
309                    "Read range {}..{} exceeds plaintext size {}",
310                    range.start, range.end, self.plain_stream_size
311                ),
312            ));
313        }
314
315        if self.num_blocks == 0 {
316            return Ok(Bytes::new());
317        }
318
319        let first_block = range.start / PLAIN_BLOCK_SIZE as u64;
320        let last_block = (range.end - 1) / PLAIN_BLOCK_SIZE as u64;
321
322        // Read all needed encrypted blocks in a single I/O call
323        let encrypted_start = Self::encrypted_block_offset(first_block);
324        let encrypted_end =
325            Self::encrypted_block_offset(last_block) + self.cipher_block_size(last_block) as u64;
326
327        let all_encrypted = self.inner.read(encrypted_start..encrypted_end).await?;
328
329        // Decrypt each block and extract the requested plaintext range
330        let result_len = (range.end - range.start) as usize;
331        let mut result = BytesMut::with_capacity(result_len);
332        let mut encrypted_offset = 0usize;
333
334        for block_idx in first_block..=last_block {
335            let block_size = self.cipher_block_size(block_idx) as usize;
336            let cipher_block = &all_encrypted[encrypted_offset..encrypted_offset + block_size];
337            encrypted_offset += block_size;
338
339            let aad = stream_block_aad(&self.aad_prefix, block_idx as u32);
340            let decrypted = self.cipher.decrypt(cipher_block, Some(&aad))?;
341
342            // Calculate which slice of this decrypted block we need
343            let block_plain_start = block_idx * PLAIN_BLOCK_SIZE as u64;
344            let slice_start = if block_idx == first_block {
345                (range.start - block_plain_start) as usize
346            } else {
347                0
348            };
349            let slice_end = if block_idx == last_block {
350                (range.end - block_plain_start) as usize
351            } else {
352                decrypted.len()
353            };
354
355            result.extend_from_slice(&decrypted[slice_start..slice_end]);
356        }
357
358        Ok(result.freeze())
359    }
360}
361
362/// Transparent encryption of AGS1 stream-encrypted files.
363///
364/// Implements the [`FileWrite`] trait, buffering plaintext and emitting
365/// encrypted AGS1 blocks. This is the streaming write counterpart to
366/// [`AesGcmFileRead`].
367///
368/// # Usage
369///
370/// ```ignore
371/// // (ignored: requires async runtime and concrete FileRead/FileWrite impls)
372/// let writer = AesGcmFileWrite::new(
373///     inner_writer,       // Box<dyn FileWrite> for the output file
374///     cipher,             // Arc<AesGcmCipher> with the DEK
375///     aad_prefix.to_vec(),
376/// );
377///
378/// writer.write(plaintext_chunk).await?;
379/// writer.close().await?;
380/// ```
381pub struct AesGcmFileWrite {
382    /// The underlying output writer.
383    inner: Box<dyn FileWrite>,
384    /// The AES-GCM cipher holding the DEK.
385    cipher: Arc<AesGcmCipher>,
386    /// AAD prefix from the key metadata.
387    aad_prefix: Box<[u8]>,
388    /// Plaintext buffer accumulating data before block encryption.
389    buffer: Vec<u8>,
390    /// Current block index for AAD construction.
391    block_index: u32,
392    /// Whether the AGS1 header has been written.
393    header_written: bool,
394    /// Whether close() has been called.
395    closed: bool,
396    /// Whether the writer is in a poisoned state due to a failed inner write.
397    /// Once poisoned, all subsequent operations are rejected because the inner
398    /// writer may have received partial data.
399    poisoned: bool,
400}
401
402impl AesGcmFileWrite {
403    /// Creates a new `AesGcmFileWrite` for encrypting to AGS1 format.
404    ///
405    /// No I/O is performed until `write()` or `close()` is called.
406    pub fn new(
407        inner: Box<dyn FileWrite>,
408        cipher: Arc<AesGcmCipher>,
409        aad_prefix: impl Into<Box<[u8]>>,
410    ) -> Self {
411        Self {
412            inner,
413            cipher,
414            aad_prefix: aad_prefix.into(),
415            buffer: Vec::new(),
416            block_index: 0,
417            header_written: false,
418            closed: false,
419            poisoned: false,
420        }
421    }
422
423    /// Writes the AGS1 header (magic + plain block size) to the inner writer.
424    async fn write_header(&mut self) -> Result<()> {
425        let mut header = Vec::with_capacity(GCM_STREAM_HEADER_LENGTH as usize);
426        header.extend_from_slice(&GCM_STREAM_MAGIC);
427        header.extend_from_slice(&PLAIN_BLOCK_SIZE.to_le_bytes());
428        if let Err(e) = self.inner.write(Bytes::from(header)).await {
429            self.poisoned = true;
430            return Err(e);
431        }
432        self.header_written = true;
433        Ok(())
434    }
435
436    /// Encrypts a plaintext block and writes it to the inner writer.
437    async fn encrypt_and_write_block(&mut self, block_data: &[u8]) -> Result<()> {
438        let aad = stream_block_aad(&self.aad_prefix, self.block_index);
439        let encrypted = self.cipher.encrypt(block_data, Some(&aad))?;
440        if let Err(e) = self.inner.write(Bytes::from(encrypted)).await {
441            self.poisoned = true;
442            return Err(e);
443        }
444        self.block_index = self.block_index.checked_add(1).ok_or_else(|| {
445            Error::new(
446                ErrorKind::DataInvalid,
447                "AGS1 block index overflow: file exceeds the maximum supported size (~4 TiB)",
448            )
449        })?;
450        Ok(())
451    }
452
453    /// Encrypts the first `PLAIN_BLOCK_SIZE` bytes of the buffer in-place
454    /// and drains them, avoiding a 1 MiB temporary copy.
455    async fn encrypt_and_drain_block(&mut self) -> Result<()> {
456        let aad = stream_block_aad(&self.aad_prefix, self.block_index);
457        let encrypted = self
458            .cipher
459            .encrypt(&self.buffer[..PLAIN_BLOCK_SIZE as usize], Some(&aad))?;
460        if let Err(e) = self.inner.write(Bytes::from(encrypted)).await {
461            self.poisoned = true;
462            return Err(e);
463        }
464        self.block_index = self.block_index.checked_add(1).ok_or_else(|| {
465            Error::new(
466                ErrorKind::DataInvalid,
467                "AGS1 block index overflow: file exceeds the maximum supported size (~4 TiB)",
468            )
469        })?;
470        self.buffer.drain(..PLAIN_BLOCK_SIZE as usize);
471        Ok(())
472    }
473}
474
475#[async_trait::async_trait]
476impl FileWrite for AesGcmFileWrite {
477    async fn write(&mut self, bs: Bytes) -> Result<()> {
478        if self.closed {
479            return Err(Error::new(
480                ErrorKind::Unexpected,
481                "Cannot write to a closed AesGcmFileWrite",
482            ));
483        }
484        if self.poisoned {
485            return Err(Error::new(
486                ErrorKind::Unexpected,
487                "AesGcmFileWrite is in a poisoned state due to a previous write failure",
488            ));
489        }
490
491        if !self.header_written {
492            self.write_header().await?;
493        }
494
495        self.buffer.extend_from_slice(&bs);
496
497        // Flush full blocks
498        while self.buffer.len() >= PLAIN_BLOCK_SIZE as usize {
499            self.encrypt_and_drain_block().await?;
500        }
501
502        Ok(())
503    }
504
505    async fn close(&mut self) -> Result<()> {
506        if self.closed {
507            return Err(Error::new(
508                ErrorKind::Unexpected,
509                "AesGcmFileWrite already closed",
510            ));
511        }
512        if self.poisoned {
513            return Err(Error::new(
514                ErrorKind::Unexpected,
515                "AesGcmFileWrite is in a poisoned state due to a previous write failure",
516            ));
517        }
518
519        if !self.header_written {
520            self.write_header().await?;
521        }
522
523        // Write the final block if there's remaining data, or if this is an empty file
524        // (block_index == 0). Skip writing a spurious empty block when the plaintext was
525        // exactly block-aligned (buffer empty, blocks already written).
526        if !self.buffer.is_empty() || self.block_index == 0 {
527            let final_block = std::mem::take(&mut self.buffer);
528            self.encrypt_and_write_block(&final_block).await?;
529        }
530        self.closed = true;
531
532        self.inner.close().await
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    /// Encrypts plaintext into AGS1 format for testing.
541    ///
542    /// Mirrors Java's `AesGcmOutputStream` behavior:
543    /// - Always writes header + at least one block (even for empty input)
544    /// - Full blocks are `PLAIN_BLOCK_SIZE` bytes; last block may be shorter
545    fn encrypt_ags1(plaintext: &[u8], cipher: &AesGcmCipher, aad_prefix: &[u8]) -> Vec<u8> {
546        let mut result = Vec::new();
547
548        // Write header: "AGS1" + PLAIN_BLOCK_SIZE (LE)
549        result.extend_from_slice(&GCM_STREAM_MAGIC);
550        result.extend_from_slice(&PLAIN_BLOCK_SIZE.to_le_bytes());
551
552        // Write blocks
553        let mut offset = 0;
554        let mut block_index = 0u32;
555
556        loop {
557            let remaining = plaintext.len() - offset;
558            let block_size = std::cmp::min(remaining, PLAIN_BLOCK_SIZE as usize);
559
560            // Block 0 is always written (even if empty); subsequent empty blocks are skipped
561            if block_size == 0 && block_index > 0 {
562                break;
563            }
564
565            let block_data = &plaintext[offset..offset + block_size];
566            let aad = stream_block_aad(aad_prefix, block_index);
567            let encrypted = cipher.encrypt(block_data, Some(&aad)).unwrap();
568            result.extend_from_slice(&encrypted);
569
570            offset += block_size;
571            block_index += 1;
572
573            // A partial block is always the last
574            if block_size < PLAIN_BLOCK_SIZE as usize {
575                break;
576            }
577        }
578
579        result
580    }
581
582    /// Helper to create an AesGcmCipher from raw key bytes.
583    fn make_cipher(key: &[u8]) -> AesGcmCipher {
584        use super::super::SecureKey;
585        let secure_key = SecureKey::new(key).unwrap();
586        AesGcmCipher::new(secure_key)
587    }
588
589    /// Helper to create an in-memory FileRead from bytes.
590    fn memory_reader(data: Vec<u8>) -> Box<dyn FileRead> {
591        Box::new(MemoryFileRead(Bytes::from(data)))
592    }
593
594    /// Simple in-memory FileRead for tests.
595    struct MemoryFileRead(Bytes);
596
597    #[async_trait::async_trait]
598    impl FileRead for MemoryFileRead {
599        async fn read(&self, range: Range<u64>) -> Result<Bytes> {
600            let start = range.start as usize;
601            let end = range.end as usize;
602            if end > self.0.len() {
603                return Err(Error::new(
604                    ErrorKind::DataInvalid,
605                    format!(
606                        "Range {}..{} out of bounds for {} bytes",
607                        start,
608                        end,
609                        self.0.len()
610                    ),
611                ));
612            }
613            Ok(self.0.slice(start..end))
614        }
615    }
616
617    #[tokio::test]
618    async fn test_empty_file_roundtrip() {
619        let key = b"0123456789abcdef";
620        let aad_prefix = b"test-aad-prefix!";
621        let cipher = make_cipher(key);
622
623        let encrypted = encrypt_ags1(b"", &cipher, aad_prefix);
624
625        // Verify minimum length: header(8) + nonce(12) + tag(16) = 36
626        assert_eq!(encrypted.len(), MIN_STREAM_LENGTH as usize);
627
628        let reader = AesGcmFileRead::new(
629            memory_reader(encrypted.clone()),
630            Arc::new(make_cipher(key)),
631            aad_prefix.as_slice().into(),
632            encrypted.len() as u64,
633        )
634        .unwrap();
635
636        assert_eq!(reader.plaintext_length(), 0);
637
638        // Reading empty range should return empty bytes
639        let result = reader.read(0..0).await.unwrap();
640        assert!(result.is_empty());
641    }
642
643    #[tokio::test]
644    async fn test_small_file_roundtrip() {
645        let key = b"0123456789abcdef";
646        let aad_prefix = b"test-aad-prefix!";
647        let plaintext = b"Hello, Iceberg encryption!";
648        let cipher = make_cipher(key);
649
650        let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
651
652        let reader = AesGcmFileRead::new(
653            memory_reader(encrypted.clone()),
654            Arc::new(make_cipher(key)),
655            aad_prefix.as_slice().into(),
656            encrypted.len() as u64,
657        )
658        .unwrap();
659
660        assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
661
662        // Read entire file
663        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
664        assert_eq!(&result[..], plaintext);
665    }
666
667    #[tokio::test]
668    async fn test_partial_read() {
669        let key = b"0123456789abcdef";
670        let aad_prefix = b"aad-prefix-here!";
671        let plaintext = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
672        let cipher = make_cipher(key);
673
674        let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
675
676        let reader = AesGcmFileRead::new(
677            memory_reader(encrypted.clone()),
678            Arc::new(make_cipher(key)),
679            aad_prefix.as_slice().into(),
680            encrypted.len() as u64,
681        )
682        .unwrap();
683
684        // Read a slice from the middle
685        let result = reader.read(10..20).await.unwrap();
686        assert_eq!(&result[..], &plaintext[10..20]);
687
688        // Read first byte
689        let result = reader.read(0..1).await.unwrap();
690        assert_eq!(&result[..], &plaintext[0..1]);
691
692        // Read last byte
693        let last = plaintext.len() as u64;
694        let result = reader.read(last - 1..last).await.unwrap();
695        assert_eq!(&result[..], &plaintext[plaintext.len() - 1..]);
696    }
697
698    #[tokio::test]
699    async fn test_multi_block_roundtrip() {
700        let key = b"0123456789abcdef";
701        let aad_prefix = b"multi-block-aad!";
702
703        // 1.5 blocks of data
704        let size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
705        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
706        let cipher = make_cipher(key);
707
708        let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
709
710        let reader = AesGcmFileRead::new(
711            memory_reader(encrypted.clone()),
712            Arc::new(make_cipher(key)),
713            aad_prefix.as_slice().into(),
714            encrypted.len() as u64,
715        )
716        .unwrap();
717
718        assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
719
720        // Read entire file
721        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
722        assert_eq!(&result[..], &plaintext[..]);
723    }
724
725    #[tokio::test]
726    async fn test_cross_block_read() {
727        let key = b"0123456789abcdef";
728        let aad_prefix = b"cross-block-aad!";
729
730        // 2.5 blocks of data
731        let size = PLAIN_BLOCK_SIZE as usize * 2 + PLAIN_BLOCK_SIZE as usize / 2;
732        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
733        let cipher = make_cipher(key);
734
735        let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
736
737        let reader = AesGcmFileRead::new(
738            memory_reader(encrypted.clone()),
739            Arc::new(make_cipher(key)),
740            aad_prefix.as_slice().into(),
741            encrypted.len() as u64,
742        )
743        .unwrap();
744
745        // Read across block boundary (last 100 bytes of block 0 + first 100 bytes of block 1)
746        let boundary = PLAIN_BLOCK_SIZE as u64;
747        let result = reader.read(boundary - 100..boundary + 100).await.unwrap();
748        assert_eq!(
749            &result[..],
750            &plaintext[(boundary - 100) as usize..(boundary + 100) as usize]
751        );
752
753        // Read across two block boundaries (spans blocks 0, 1, and 2)
754        let result = reader.read(boundary - 50..boundary * 2 + 50).await.unwrap();
755        assert_eq!(
756            &result[..],
757            &plaintext[(boundary - 50) as usize..(boundary * 2 + 50) as usize]
758        );
759    }
760
761    #[tokio::test]
762    async fn test_exact_block_size() {
763        let key = b"0123456789abcdef";
764        let aad_prefix = b"exact-block-aad!";
765
766        // Exactly 1 block
767        let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
768            .map(|i| (i % 256) as u8)
769            .collect();
770        let cipher = make_cipher(key);
771
772        let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
773
774        let reader = AesGcmFileRead::new(
775            memory_reader(encrypted.clone()),
776            Arc::new(make_cipher(key)),
777            aad_prefix.as_slice().into(),
778            encrypted.len() as u64,
779        )
780        .unwrap();
781
782        assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
783
784        let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
785        assert_eq!(&result[..], &plaintext[..]);
786    }
787
788    #[tokio::test]
789    async fn test_block_size_plus_one() {
790        let key = b"0123456789abcdef";
791        let aad_prefix = b"block-plus-one!!";
792
793        // 1 block + 1 byte
794        let size = PLAIN_BLOCK_SIZE as usize + 1;
795        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
796        let cipher = make_cipher(key);
797
798        let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
799
800        let reader = AesGcmFileRead::new(
801            memory_reader(encrypted.clone()),
802            Arc::new(make_cipher(key)),
803            aad_prefix.as_slice().into(),
804            encrypted.len() as u64,
805        )
806        .unwrap();
807
808        assert_eq!(reader.plaintext_length(), size as u64);
809
810        // Read the last byte (in block 1)
811        let result = reader.read(size as u64 - 1..size as u64).await.unwrap();
812        assert_eq!(result[0], plaintext[size - 1]);
813
814        // Read all
815        let result = reader.read(0..size as u64).await.unwrap();
816        assert_eq!(&result[..], &plaintext[..]);
817    }
818
819    #[tokio::test]
820    async fn test_block_size_minus_one() {
821        let key = b"0123456789abcdef";
822        let aad_prefix = b"block-minus-one!";
823
824        // 1 block - 1 byte
825        let size = PLAIN_BLOCK_SIZE as usize - 1;
826        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
827        let cipher = make_cipher(key);
828
829        let encrypted = encrypt_ags1(&plaintext, &cipher, aad_prefix);
830
831        let reader = AesGcmFileRead::new(
832            memory_reader(encrypted.clone()),
833            Arc::new(make_cipher(key)),
834            aad_prefix.as_slice().into(),
835            encrypted.len() as u64,
836        )
837        .unwrap();
838
839        assert_eq!(reader.plaintext_length(), size as u64);
840
841        let result = reader.read(0..size as u64).await.unwrap();
842        assert_eq!(&result[..], &plaintext[..]);
843    }
844
845    #[tokio::test]
846    async fn test_wrong_aad_fails() {
847        let key = b"0123456789abcdef";
848        let aad_prefix = b"correct-aad-here";
849        let plaintext = b"sensitive data here";
850        let cipher = make_cipher(key);
851
852        let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
853
854        // Try to decrypt with wrong AAD
855        let mut bad_aad = aad_prefix.to_vec();
856        bad_aad[0] ^= 0xFF;
857
858        let reader = AesGcmFileRead::new(
859            memory_reader(encrypted.clone()),
860            Arc::new(make_cipher(key)),
861            bad_aad.as_slice().into(),
862            encrypted.len() as u64,
863        )
864        .unwrap();
865
866        let result = reader.read(0..plaintext.len() as u64).await;
867        assert!(result.is_err(), "Decryption with wrong AAD should fail");
868    }
869
870    #[tokio::test]
871    async fn test_wrong_key_fails() {
872        let key = b"0123456789abcdef";
873        let wrong_key = b"fedcba9876543210";
874        let aad_prefix = b"test-aad-prefix!";
875        let plaintext = b"sensitive data";
876        let cipher = make_cipher(key);
877
878        let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
879
880        let reader = AesGcmFileRead::new(
881            memory_reader(encrypted.clone()),
882            Arc::new(make_cipher(wrong_key)),
883            aad_prefix.as_slice().into(),
884            encrypted.len() as u64,
885        )
886        .unwrap();
887
888        let result = reader.read(0..plaintext.len() as u64).await;
889        assert!(result.is_err(), "Decryption with wrong key should fail");
890    }
891
892    #[tokio::test]
893    async fn test_out_of_bounds_read() {
894        let key = b"0123456789abcdef";
895        let aad_prefix = b"test-aad-prefix!";
896        let plaintext = b"short data";
897        let cipher = make_cipher(key);
898
899        let encrypted = encrypt_ags1(plaintext, &cipher, aad_prefix);
900
901        let reader = AesGcmFileRead::new(
902            memory_reader(encrypted.clone()),
903            Arc::new(make_cipher(key)),
904            aad_prefix.as_slice().into(),
905            encrypted.len() as u64,
906        )
907        .unwrap();
908
909        let result = reader.read(0..plaintext.len() as u64 + 1).await;
910        assert!(result.is_err(), "Reading past end should fail");
911    }
912
913    #[tokio::test]
914    async fn test_calculate_plaintext_length() {
915        // Empty file: header only (not valid per Java, but handled)
916        assert_eq!(
917            AesGcmFileRead::calculate_plaintext_length(GCM_STREAM_HEADER_LENGTH as u64).unwrap(),
918            0
919        );
920
921        // Empty file with one empty block: header(8) + nonce(12) + tag(16) = 36
922        assert_eq!(
923            AesGcmFileRead::calculate_plaintext_length(MIN_STREAM_LENGTH as u64).unwrap(),
924            0
925        );
926
927        // One full block: header(8) + cipher_block(1048604) = 1048612
928        let one_full = GCM_STREAM_HEADER_LENGTH as u64 + CIPHER_BLOCK_SIZE as u64;
929        assert_eq!(
930            AesGcmFileRead::calculate_plaintext_length(one_full).unwrap(),
931            PLAIN_BLOCK_SIZE as u64
932        );
933
934        // One full block + 1 byte: need partial second block
935        // Second block = nonce(12) + 1 byte ciphertext + tag(16) = 29
936        let one_full_plus_one = one_full + NONCE_LENGTH as u64 + 1 + GCM_TAG_LENGTH as u64;
937        assert_eq!(
938            AesGcmFileRead::calculate_plaintext_length(one_full_plus_one).unwrap(),
939            PLAIN_BLOCK_SIZE as u64 + 1
940        );
941    }
942
943    #[tokio::test]
944    async fn test_stream_block_aad() {
945        // With prefix
946        let aad = stream_block_aad(b"prefix", 0);
947        assert_eq!(&aad[..6], b"prefix");
948        assert_eq!(&aad[6..], &0u32.to_le_bytes());
949
950        let aad = stream_block_aad(b"prefix", 1);
951        assert_eq!(&aad[..6], b"prefix");
952        assert_eq!(&aad[6..], &1u32.to_le_bytes());
953
954        // Without prefix
955        let aad = stream_block_aad(b"", 42);
956        assert_eq!(&aad[..], &42u32.to_le_bytes());
957    }
958
959    #[tokio::test]
960    async fn test_encrypted_file_too_short() {
961        let result = AesGcmFileRead::new(
962            memory_reader(vec![0; 4]),
963            Arc::new(make_cipher(b"0123456789abcdef")),
964            [].into(),
965            4,
966        );
967        assert!(result.is_err());
968    }
969
970    // --- AesGcmFileWrite tests ---
971
972    /// Shared-buffer FileWrite for testing AesGcmFileWrite output.
973    struct SharedMemoryWrite {
974        buffer: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
975    }
976
977    /// FileWrite that fails after a configured number of successful writes.
978    struct FailingFileWrite {
979        writes_before_failure: usize,
980        write_count: usize,
981    }
982
983    #[async_trait::async_trait]
984    impl FileWrite for FailingFileWrite {
985        async fn write(&mut self, _bs: Bytes) -> Result<()> {
986            if self.write_count >= self.writes_before_failure {
987                return Err(Error::new(ErrorKind::Unexpected, "simulated write failure"));
988            }
989            self.write_count += 1;
990            Ok(())
991        }
992
993        async fn close(&mut self) -> Result<()> {
994            Ok(())
995        }
996    }
997
998    #[async_trait::async_trait]
999    impl FileWrite for SharedMemoryWrite {
1000        async fn write(&mut self, bs: Bytes) -> Result<()> {
1001            self.buffer.lock().unwrap().extend_from_slice(&bs);
1002            Ok(())
1003        }
1004
1005        async fn close(&mut self) -> Result<()> {
1006            Ok(())
1007        }
1008    }
1009
1010    /// Helper: one-shot encrypt through AesGcmFileWrite, return encrypted bytes.
1011    async fn write_through_ags1(plaintext: &[u8], key: &[u8], aad_prefix: &[u8]) -> Vec<u8> {
1012        let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1013        let inner: Box<dyn FileWrite> = Box::new(SharedMemoryWrite {
1014            buffer: buffer.clone(),
1015        });
1016        let cipher = Arc::new(make_cipher(key));
1017        let mut writer = AesGcmFileWrite::new(inner, cipher, aad_prefix.to_vec());
1018
1019        writer.write(Bytes::from(plaintext.to_vec())).await.unwrap();
1020        writer.close().await.unwrap();
1021
1022        buffer.lock().unwrap().clone()
1023    }
1024
1025    #[tokio::test]
1026    async fn test_write_empty_roundtrip() {
1027        let key = b"0123456789abcdef";
1028        let aad_prefix = b"test-aad-prefix!";
1029
1030        let encrypted = write_through_ags1(b"", key, aad_prefix).await;
1031
1032        // Should produce header + one empty encrypted block
1033        assert_eq!(encrypted.len(), MIN_STREAM_LENGTH as usize);
1034
1035        let reader = AesGcmFileRead::new(
1036            memory_reader(encrypted.clone()),
1037            Arc::new(make_cipher(key)),
1038            aad_prefix.as_slice().into(),
1039            encrypted.len() as u64,
1040        )
1041        .unwrap();
1042
1043        assert_eq!(reader.plaintext_length(), 0);
1044    }
1045
1046    #[tokio::test]
1047    async fn test_write_small_roundtrip() {
1048        let key = b"0123456789abcdef";
1049        let aad_prefix = b"test-aad-prefix!";
1050        let plaintext = b"Hello, Iceberg encryption!";
1051
1052        let encrypted = write_through_ags1(plaintext, key, aad_prefix).await;
1053
1054        let reader = AesGcmFileRead::new(
1055            memory_reader(encrypted.clone()),
1056            Arc::new(make_cipher(key)),
1057            aad_prefix.as_slice().into(),
1058            encrypted.len() as u64,
1059        )
1060        .unwrap();
1061
1062        assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1063        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1064        assert_eq!(&result[..], plaintext);
1065    }
1066
1067    #[tokio::test]
1068    async fn test_write_multi_block_roundtrip() {
1069        let key = b"0123456789abcdef";
1070        let aad_prefix = b"multi-block-aad!";
1071
1072        // 1.5 blocks of data
1073        let size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
1074        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
1075
1076        let encrypted = write_through_ags1(&plaintext, key, aad_prefix).await;
1077
1078        let reader = AesGcmFileRead::new(
1079            memory_reader(encrypted.clone()),
1080            Arc::new(make_cipher(key)),
1081            aad_prefix.as_slice().into(),
1082            encrypted.len() as u64,
1083        )
1084        .unwrap();
1085
1086        assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1087        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1088        assert_eq!(&result[..], &plaintext[..]);
1089    }
1090
1091    #[tokio::test]
1092    async fn test_write_cross_block_accumulation() {
1093        let key = b"0123456789abcdef";
1094        let aad_prefix = b"cross-block-aad!";
1095
1096        let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1097        let inner: Box<dyn FileWrite> = Box::new(SharedMemoryWrite {
1098            buffer: buffer.clone(),
1099        });
1100        let cipher = Arc::new(make_cipher(key));
1101        let mut writer = AesGcmFileWrite::new(inner, cipher, aad_prefix.to_vec());
1102
1103        // Write 1.5 blocks in 1000-byte chunks
1104        let total_size = PLAIN_BLOCK_SIZE as usize + PLAIN_BLOCK_SIZE as usize / 2;
1105        let plaintext: Vec<u8> = (0..total_size).map(|i| (i % 256) as u8).collect();
1106        let chunk_size = 1000;
1107        for chunk in plaintext.chunks(chunk_size) {
1108            writer.write(Bytes::from(chunk.to_vec())).await.unwrap();
1109        }
1110        writer.close().await.unwrap();
1111
1112        let encrypted = buffer.lock().unwrap().clone();
1113
1114        let reader = AesGcmFileRead::new(
1115            memory_reader(encrypted.clone()),
1116            Arc::new(make_cipher(key)),
1117            aad_prefix.as_slice().into(),
1118            encrypted.len() as u64,
1119        )
1120        .unwrap();
1121
1122        assert_eq!(reader.plaintext_length(), plaintext.len() as u64);
1123        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
1124        assert_eq!(&result[..], &plaintext[..]);
1125    }
1126
1127    #[tokio::test]
1128    async fn test_write_exact_block_size() {
1129        let key = b"0123456789abcdef";
1130        let aad_prefix = b"exact-block-aad!";
1131
1132        // Exactly 1 block
1133        let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
1134            .map(|i| (i % 256) as u8)
1135            .collect();
1136
1137        let encrypted = write_through_ags1(&plaintext, key, aad_prefix).await;
1138
1139        let reader = AesGcmFileRead::new(
1140            memory_reader(encrypted.clone()),
1141            Arc::new(make_cipher(key)),
1142            aad_prefix.as_slice().into(),
1143            encrypted.len() as u64,
1144        )
1145        .unwrap();
1146
1147        assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
1148        let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
1149        assert_eq!(&result[..], &plaintext[..]);
1150    }
1151
1152    #[tokio::test]
1153    async fn test_write_block_aligned_no_spurious_empty_block() {
1154        let key = b"0123456789abcdef";
1155        let aad_prefix = b"block-align-aad!";
1156
1157        // Write exactly one block of plaintext — close() should NOT add
1158        // a trailing empty encrypted block (28 bytes: 12-byte nonce + 16-byte tag).
1159        let plaintext: Vec<u8> = (0..PLAIN_BLOCK_SIZE as usize)
1160            .map(|i| (i % 256) as u8)
1161            .collect();
1162
1163        let encrypted_via_writer = write_through_ags1(&plaintext, key, aad_prefix).await;
1164        let encrypted_via_reference = encrypt_ags1(&plaintext, &make_cipher(key), aad_prefix);
1165
1166        // Both should be the same length — no extra 28-byte empty block
1167        assert_eq!(
1168            encrypted_via_writer.len(),
1169            encrypted_via_reference.len(),
1170            "Writer output should match reference encryption length (no spurious trailing block)"
1171        );
1172
1173        // Verify roundtrip
1174        let reader = AesGcmFileRead::new(
1175            memory_reader(encrypted_via_writer.clone()),
1176            Arc::new(make_cipher(key)),
1177            aad_prefix.as_slice().into(),
1178            encrypted_via_writer.len() as u64,
1179        )
1180        .unwrap();
1181
1182        assert_eq!(reader.plaintext_length(), PLAIN_BLOCK_SIZE as u64);
1183        let result = reader.read(0..PLAIN_BLOCK_SIZE as u64).await.unwrap();
1184        assert_eq!(&result[..], &plaintext[..]);
1185    }
1186
1187    #[tokio::test]
1188    async fn test_write_two_blocks_aligned_no_spurious_empty_block() {
1189        let key = b"0123456789abcdef";
1190        let aad_prefix = b"2blk-align-aad!!";
1191
1192        // Exactly 2 blocks
1193        let size = PLAIN_BLOCK_SIZE as usize * 2;
1194        let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
1195
1196        let encrypted_via_writer = write_through_ags1(&plaintext, key, aad_prefix).await;
1197        let encrypted_via_reference = encrypt_ags1(&plaintext, &make_cipher(key), aad_prefix);
1198
1199        assert_eq!(
1200            encrypted_via_writer.len(),
1201            encrypted_via_reference.len(),
1202            "Writer output should match reference encryption length (no spurious trailing block)"
1203        );
1204
1205        let reader = AesGcmFileRead::new(
1206            memory_reader(encrypted_via_writer.clone()),
1207            Arc::new(make_cipher(key)),
1208            aad_prefix.as_slice().into(),
1209            encrypted_via_writer.len() as u64,
1210        )
1211        .unwrap();
1212
1213        assert_eq!(reader.plaintext_length(), size as u64);
1214        let result = reader.read(0..size as u64).await.unwrap();
1215        assert_eq!(&result[..], &plaintext[..]);
1216    }
1217
1218    #[tokio::test]
1219    async fn test_write_poisoned_after_inner_write_failure() {
1220        let cipher = Arc::new(make_cipher(b"0123456789abcdef"));
1221        // Fail on the second write (first write is the header, second is block data)
1222        let inner: Box<dyn FileWrite> = Box::new(FailingFileWrite {
1223            writes_before_failure: 1,
1224            write_count: 0,
1225        });
1226        let mut writer = AesGcmFileWrite::new(inner, cipher, b"aad-prefix-here!".to_vec());
1227
1228        // First write triggers header (succeeds) + block encrypt+write (fails)
1229        let data = vec![0u8; PLAIN_BLOCK_SIZE as usize];
1230        let result = writer.write(Bytes::from(data)).await;
1231        assert!(result.is_err());
1232
1233        // Subsequent write should be rejected as poisoned
1234        let result = writer.write(Bytes::from(b"more data".to_vec())).await;
1235        assert!(result.is_err());
1236        assert!(
1237            result.unwrap_err().to_string().contains("poisoned"),
1238            "expected poisoned error"
1239        );
1240
1241        // Close should also be rejected
1242        let result = writer.close().await;
1243        assert!(result.is_err());
1244        assert!(
1245            result.unwrap_err().to_string().contains("poisoned"),
1246            "expected poisoned error on close"
1247        );
1248    }
1249}