iceberg/encryption/kms/
memory.rs1use std::collections::HashMap;
24use std::fmt;
25use std::sync::{Arc, RwLock};
26
27use async_trait::async_trait;
28
29use super::KeyManagementClient;
30use crate::encryption::{AesGcmCipher, AesKeySize, SecureKey, SensitiveBytes};
31use crate::error::lock_error;
32use crate::{Error, ErrorKind, Result};
33
34#[derive(Clone, Default)]
52pub struct MemoryKeyManagementClient {
53 master_keys: Arc<RwLock<HashMap<String, SensitiveBytes>>>,
54 master_key_size: AesKeySize,
55}
56
57impl fmt::Debug for MemoryKeyManagementClient {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("MemoryKeyManagementClient")
60 .field("master_key_size", &self.master_key_size)
61 .field("key_count", &self.key_count())
62 .finish()
63 }
64}
65
66impl MemoryKeyManagementClient {
67 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn with_master_key_size(master_key_size: AesKeySize) -> Self {
74 Self {
75 master_keys: Arc::new(RwLock::new(HashMap::new())),
76 master_key_size,
77 }
78 }
79
80 pub fn add_master_key(&self, key_id: impl Into<String>) -> Result<()> {
82 let key = SecureKey::generate(self.master_key_size);
83 self.insert_key(key_id.into(), SensitiveBytes::new(key.as_bytes()))
84 }
85
86 pub fn add_master_key_bytes(
92 &self,
93 key_id: impl Into<String>,
94 key_bytes: SensitiveBytes,
95 ) -> Result<()> {
96 Self::check_key_length(&key_bytes)?;
97 self.insert_key(key_id.into(), key_bytes)
98 }
99
100 fn check_key_length(key_bytes: &SensitiveBytes) -> Result<()> {
102 SecureKey::new(key_bytes.as_bytes())?;
103 Ok(())
104 }
105
106 fn insert_key(&self, key_id: String, key: SensitiveBytes) -> Result<()> {
107 let mut keys = self.master_keys.write().map_err(lock_error)?;
108
109 if keys.contains_key(&key_id) {
110 return Err(Error::new(
111 ErrorKind::DataInvalid,
112 format!("Master key already exists: {key_id}"),
113 ));
114 }
115
116 keys.insert(key_id, key);
117 Ok(())
118 }
119
120 fn get_master_key(&self, key_id: &str) -> Result<SensitiveBytes> {
121 let keys = self.master_keys.read().map_err(lock_error)?;
122
123 keys.get(key_id).cloned().ok_or_else(|| {
124 Error::new(
125 ErrorKind::DataInvalid,
126 format!("Master key not found: {key_id}"),
127 )
128 })
129 }
130
131 pub fn key_count(&self) -> usize {
133 self.master_keys.read().map(|keys| keys.len()).unwrap_or(0)
134 }
135
136 pub fn has_key(&self, key_id: &str) -> bool {
138 self.master_keys
139 .read()
140 .map(|keys| keys.contains_key(key_id))
141 .unwrap_or(false)
142 }
143}
144
145#[async_trait]
146impl KeyManagementClient for MemoryKeyManagementClient {
147 async fn wrap_key(&self, key: &[u8], wrapping_key_id: &str) -> Result<Vec<u8>> {
148 let master_key_bytes = self.get_master_key(wrapping_key_id)?;
149 let master_key = SecureKey::new(master_key_bytes.as_bytes())?;
150 let cipher = AesGcmCipher::new(master_key);
151
152 cipher.encrypt(key, None)
153 }
154
155 async fn unwrap_key(
156 &self,
157 wrapped_key: &[u8],
158 wrapping_key_id: &str,
159 ) -> Result<SensitiveBytes> {
160 let master_key_bytes = self.get_master_key(wrapping_key_id)?;
161 let master_key = SecureKey::new(master_key_bytes.as_bytes())?;
162 let cipher = AesGcmCipher::new(master_key);
163
164 Ok(SensitiveBytes::new(cipher.decrypt(wrapped_key, None)?))
165 }
166
167 fn supports_key_generation(&self) -> bool {
168 false
169 }
170
171 async fn generate_key(&self, _wrapping_key_id: &str) -> Result<super::GeneratedKey> {
172 Err(Error::new(
173 ErrorKind::FeatureUnsupported,
174 "MemoryKeyManagementClient does not support server-side key generation",
175 ))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[tokio::test]
184 async fn test_wrap_unwrap_roundtrip() {
185 let kms = MemoryKeyManagementClient::new();
186 kms.add_master_key("master-1").unwrap();
187 let dek = vec![0u8; 16];
188
189 let wrapped = kms.wrap_key(&dek, "master-1").await.unwrap();
190 let unwrapped = kms.unwrap_key(&wrapped, "master-1").await.unwrap();
191 assert_eq!(unwrapped.as_bytes(), dek.as_slice());
192 }
193
194 #[tokio::test]
195 async fn test_wrap_unknown_key_fails() {
196 let kms = MemoryKeyManagementClient::new();
197 let dek = vec![0u8; 16];
198
199 let result = kms.wrap_key(&dek, "nonexistent").await;
200 assert!(result.is_err());
201 }
202
203 #[tokio::test]
204 async fn test_wrong_master_key_fails_unwrap() {
205 let kms = MemoryKeyManagementClient::new();
206 kms.add_master_key("master-1").unwrap();
207 kms.add_master_key("master-2").unwrap();
208 let dek = vec![0u8; 16];
209
210 let wrapped = kms.wrap_key(&dek, "master-1").await.unwrap();
211
212 let result = kms.unwrap_key(&wrapped, "master-2").await;
213 assert!(result.is_err());
214 }
215
216 #[tokio::test]
217 async fn test_does_not_support_key_generation() {
218 let kms = MemoryKeyManagementClient::new();
219 assert!(!kms.supports_key_generation());
220
221 let result = kms.generate_key("master-1").await;
222 assert!(result.is_err());
223 }
224
225 #[tokio::test]
226 async fn test_multiple_master_keys() {
227 let kms = MemoryKeyManagementClient::new();
228 kms.add_master_key("master-1").unwrap();
229 kms.add_master_key("master-2").unwrap();
230 let dek1 = vec![1u8; 16];
231 let dek2 = vec![2u8; 16];
232
233 let wrapped1 = kms.wrap_key(&dek1, "master-1").await.unwrap();
234 let wrapped2 = kms.wrap_key(&dek2, "master-2").await.unwrap();
235
236 let unwrapped1 = kms.unwrap_key(&wrapped1, "master-1").await.unwrap();
237 let unwrapped2 = kms.unwrap_key(&wrapped2, "master-2").await.unwrap();
238
239 assert_eq!(unwrapped1.as_bytes(), dek1.as_slice());
240 assert_eq!(unwrapped2.as_bytes(), dek2.as_slice());
241 }
242
243 #[tokio::test]
244 async fn test_add_master_key() {
245 let kms = MemoryKeyManagementClient::new();
246
247 kms.add_master_key("my-key").unwrap();
248 assert!(kms.has_key("my-key"));
249 assert_eq!(kms.key_count(), 1);
250
251 let result = kms.add_master_key("my-key");
252 assert!(result.is_err());
253 }
254
255 #[tokio::test]
256 async fn test_add_master_key_bytes() {
257 let kms = MemoryKeyManagementClient::new();
258 let key_bytes = SensitiveBytes::new([42u8; 16]);
259
260 kms.add_master_key_bytes("my-key", key_bytes).unwrap();
261 assert!(kms.has_key("my-key"));
262
263 let dek = vec![7u8; 16];
264 let wrapped = kms.wrap_key(&dek, "my-key").await.unwrap();
265 let unwrapped = kms.unwrap_key(&wrapped, "my-key").await.unwrap();
266 assert_eq!(unwrapped.as_bytes(), dek.as_slice());
267 }
268
269 #[tokio::test]
270 async fn test_add_master_key_bytes_invalid_length() {
271 let kms = MemoryKeyManagementClient::new();
272
273 let result = kms.add_master_key_bytes("my-key", SensitiveBytes::new([0u8; 7]));
274 assert!(result.is_err());
275 }
276
277 #[tokio::test]
278 async fn test_with_master_key_size() {
279 let kms = MemoryKeyManagementClient::with_master_key_size(AesKeySize::Bits256);
280 kms.add_master_key("master-256").unwrap();
281
282 let dek = vec![0u8; 16];
283 let wrapped = kms.wrap_key(&dek, "master-256").await.unwrap();
284 let unwrapped = kms.unwrap_key(&wrapped, "master-256").await.unwrap();
285 assert_eq!(unwrapped.as_bytes(), dek.as_slice());
286 }
287
288 #[tokio::test]
289 async fn test_clone_shares_state() {
290 let kms1 = MemoryKeyManagementClient::new();
291 let kms2 = kms1.clone();
292
293 kms1.add_master_key("shared-key").unwrap();
294 assert!(kms2.has_key("shared-key"));
295 }
296}