iceberg/encryption/
file_decryptor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! File-level decryption helper for AGS1 stream-encrypted files.
19
20use std::fmt;
21use std::sync::Arc;
22
23use super::crypto::{AesGcmCipher, SecureKey};
24use super::stream::AesGcmFileRead;
25use crate::Result;
26use crate::io::FileRead;
27
28/// Holds the decryption material for a single encrypted file.
29///
30/// Created from a plaintext DEK and AAD prefix, then used to wrap
31/// an encrypted file reader for transparent decryption on read.
32pub struct AesGcmFileDecryptor {
33    cipher: Arc<AesGcmCipher>,
34    aad_prefix: Box<[u8]>,
35}
36
37impl AesGcmFileDecryptor {
38    /// Creates a new `AesGcmFileDecryptor` from a plaintext DEK and AAD prefix.
39    pub fn new(dek: &[u8], aad_prefix: impl Into<Box<[u8]>>) -> Result<Self> {
40        let key = SecureKey::new(dek)?;
41        let cipher = Arc::new(AesGcmCipher::new(key));
42        Ok(Self {
43            cipher,
44            aad_prefix: aad_prefix.into(),
45        })
46    }
47
48    /// Wraps a raw encrypted-file reader in a decrypting [`AesGcmFileRead`].
49    pub fn wrap_reader(
50        &self,
51        reader: Box<dyn FileRead>,
52        encrypted_file_length: u64,
53    ) -> Result<Box<dyn FileRead>> {
54        let decrypting = AesGcmFileRead::new(
55            reader,
56            Arc::clone(&self.cipher),
57            self.aad_prefix.clone(),
58            encrypted_file_length,
59        )?;
60        Ok(Box::new(decrypting))
61    }
62
63    /// Calculates the plaintext length from an encrypted file's total length.
64    pub fn plaintext_length(&self, encrypted_file_length: u64) -> Result<u64> {
65        AesGcmFileRead::calculate_plaintext_length(encrypted_file_length)
66    }
67}
68
69impl fmt::Debug for AesGcmFileDecryptor {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("AesGcmFileDecryptor")
72            .field("aad_prefix_len", &self.aad_prefix.len())
73            .finish_non_exhaustive()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::ops::Range;
80
81    use bytes::Bytes;
82
83    use super::*;
84    use crate::encryption::AesGcmFileEncryptor;
85    use crate::io::FileWrite;
86
87    struct MemoryFileRead(Bytes);
88
89    #[async_trait::async_trait]
90    impl FileRead for MemoryFileRead {
91        async fn read(&self, range: Range<u64>) -> Result<Bytes> {
92            Ok(self.0.slice(range.start as usize..range.end as usize))
93        }
94    }
95
96    struct MemoryFileWrite {
97        buffer: std::sync::Arc<std::sync::Mutex<Vec<u8>>>,
98    }
99
100    #[async_trait::async_trait]
101    impl FileWrite for MemoryFileWrite {
102        async fn write(&mut self, bs: Bytes) -> Result<()> {
103            self.buffer.lock().unwrap().extend_from_slice(&bs);
104            Ok(())
105        }
106
107        async fn close(&mut self) -> Result<()> {
108            Ok(())
109        }
110    }
111
112    #[tokio::test]
113    async fn test_wrap_reader_roundtrip() {
114        let key = b"0123456789abcdef";
115        let aad_prefix = b"test-aad-prefix!";
116        let plaintext = b"Hello from file decryptor!";
117
118        // Encrypt via the encryptor wrapper
119        let encryptor = AesGcmFileEncryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
120        let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
121        let mut writer = encryptor.wrap_writer(Box::new(MemoryFileWrite {
122            buffer: buffer.clone(),
123        }));
124        writer.write(Bytes::from(plaintext.to_vec())).await.unwrap();
125        writer.close().await.unwrap();
126        let encrypted = buffer.lock().unwrap().clone();
127        let encrypted_len = encrypted.len() as u64;
128
129        // Decrypt via the decryptor wrapper
130        let decryptor = AesGcmFileDecryptor::new(key.as_slice(), aad_prefix.as_slice()).unwrap();
131        let reader = decryptor
132            .wrap_reader(
133                Box::new(MemoryFileRead(Bytes::from(encrypted))),
134                encrypted_len,
135            )
136            .unwrap();
137
138        let result = reader.read(0..plaintext.len() as u64).await.unwrap();
139        assert_eq!(&result[..], plaintext);
140    }
141
142    #[tokio::test]
143    async fn test_invalid_key_length() {
144        let result = AesGcmFileDecryptor::new(b"too-short", b"aad".as_slice());
145        assert!(result.is_err());
146    }
147
148    #[tokio::test]
149    async fn test_plaintext_length() {
150        let decryptor = AesGcmFileDecryptor::new(b"0123456789abcdef", b"aad".as_slice()).unwrap();
151        // header(8) + nonce(12) + 10 bytes ciphertext + tag(16) = 46
152        let encrypted_len = 8 + 12 + 10 + 16;
153        let plain_len = decryptor.plaintext_length(encrypted_len).unwrap();
154        assert_eq!(plain_len, 10);
155    }
156}