1use std::collections::HashMap;
26use std::ops::Range;
27use std::sync::{Arc, RwLock};
28
29use async_trait::async_trait;
30use bytes::Bytes;
31use serde::{Deserialize, Serialize};
32
33use crate::io::{
34 FileMetadata, FileRead, FileWrite, InputFile, OutputFile, Storage, StorageConfig,
35 StorageFactory,
36};
37use crate::{Error, ErrorKind, Result};
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct MemoryStorage {
62 #[serde(skip, default = "default_memory_data")]
63 data: Arc<RwLock<HashMap<String, Bytes>>>,
64}
65
66fn default_memory_data() -> Arc<RwLock<HashMap<String, Bytes>>> {
67 Arc::new(RwLock::new(HashMap::new()))
68}
69
70impl MemoryStorage {
71 pub fn new() -> Self {
73 Self {
74 data: Arc::new(RwLock::new(HashMap::new())),
75 }
76 }
77
78 pub(crate) fn normalize_path(path: &str) -> String {
86 let path = path.strip_prefix("memory://").unwrap_or(path);
88 let path = path.strip_prefix("memory:/").unwrap_or(path);
90 path.trim_start_matches('/').to_string()
92 }
93}
94
95#[async_trait]
96#[typetag::serde]
97impl Storage for MemoryStorage {
98 async fn exists(&self, path: &str) -> Result<bool> {
99 let normalized = Self::normalize_path(path);
100 let data = self.data.read().map_err(|e| {
101 Error::new(
102 ErrorKind::Unexpected,
103 format!("Failed to acquire read lock: {e}"),
104 )
105 })?;
106 Ok(data.contains_key(&normalized))
107 }
108
109 async fn metadata(&self, path: &str) -> Result<FileMetadata> {
110 let normalized = Self::normalize_path(path);
111 let data = self.data.read().map_err(|e| {
112 Error::new(
113 ErrorKind::Unexpected,
114 format!("Failed to acquire read lock: {e}"),
115 )
116 })?;
117 match data.get(&normalized) {
118 Some(bytes) => Ok(FileMetadata {
119 size: bytes.len() as u64,
120 }),
121 None => Err(Error::new(
122 ErrorKind::DataInvalid,
123 format!("File not found: {path}"),
124 )),
125 }
126 }
127
128 async fn read(&self, path: &str) -> Result<Bytes> {
129 let normalized = Self::normalize_path(path);
130 let data = self.data.read().map_err(|e| {
131 Error::new(
132 ErrorKind::Unexpected,
133 format!("Failed to acquire read lock: {e}"),
134 )
135 })?;
136 match data.get(&normalized) {
137 Some(bytes) => Ok(bytes.clone()),
138 None => Err(Error::new(
139 ErrorKind::DataInvalid,
140 format!("File not found: {path}"),
141 )),
142 }
143 }
144
145 async fn reader(&self, path: &str) -> Result<Box<dyn FileRead>> {
146 let normalized = Self::normalize_path(path);
147 let data = self.data.read().map_err(|e| {
148 Error::new(
149 ErrorKind::Unexpected,
150 format!("Failed to acquire read lock: {e}"),
151 )
152 })?;
153 match data.get(&normalized) {
154 Some(bytes) => Ok(Box::new(MemoryFileRead::new(bytes.clone()))),
155 None => Err(Error::new(
156 ErrorKind::DataInvalid,
157 format!("File not found: {path}"),
158 )),
159 }
160 }
161
162 async fn write(&self, path: &str, bs: Bytes) -> Result<()> {
163 let normalized = Self::normalize_path(path);
164 let mut data = self.data.write().map_err(|e| {
165 Error::new(
166 ErrorKind::Unexpected,
167 format!("Failed to acquire write lock: {e}"),
168 )
169 })?;
170 data.insert(normalized, bs);
171 Ok(())
172 }
173
174 async fn writer(&self, path: &str) -> Result<Box<dyn FileWrite>> {
175 let normalized = Self::normalize_path(path);
176 Ok(Box::new(MemoryFileWrite::new(
177 self.data.clone(),
178 normalized,
179 )))
180 }
181
182 async fn delete(&self, path: &str) -> Result<()> {
183 let normalized = Self::normalize_path(path);
184 let mut data = self.data.write().map_err(|e| {
185 Error::new(
186 ErrorKind::Unexpected,
187 format!("Failed to acquire write lock: {e}"),
188 )
189 })?;
190 data.remove(&normalized);
191 Ok(())
192 }
193
194 async fn delete_prefix(&self, path: &str) -> Result<()> {
195 let normalized = Self::normalize_path(path);
196 let prefix = if normalized.ends_with('/') {
197 normalized
198 } else {
199 format!("{normalized}/")
200 };
201
202 let mut data = self.data.write().map_err(|e| {
203 Error::new(
204 ErrorKind::Unexpected,
205 format!("Failed to acquire write lock: {e}"),
206 )
207 })?;
208
209 let keys_to_remove: Vec<String> = data
211 .keys()
212 .filter(|k| k.starts_with(&prefix))
213 .cloned()
214 .collect();
215
216 for key in keys_to_remove {
217 data.remove(&key);
218 }
219
220 Ok(())
221 }
222
223 fn new_input(&self, path: &str) -> Result<InputFile> {
224 Ok(InputFile::new(Arc::new(self.clone()), path.to_string()))
225 }
226
227 fn new_output(&self, path: &str) -> Result<OutputFile> {
228 Ok(OutputFile::new(Arc::new(self.clone()), path.to_string()))
229 }
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
238pub struct MemoryStorageFactory;
239
240#[typetag::serde]
241impl StorageFactory for MemoryStorageFactory {
242 fn build(&self, _config: &StorageConfig) -> Result<Arc<dyn Storage>> {
243 Ok(Arc::new(MemoryStorage::new()))
244 }
245}
246
247#[derive(Debug)]
249pub struct MemoryFileRead {
250 data: Bytes,
251}
252
253impl MemoryFileRead {
254 pub fn new(data: Bytes) -> Self {
256 Self { data }
257 }
258}
259
260#[async_trait]
261impl FileRead for MemoryFileRead {
262 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
263 let start = range.start as usize;
264 let end = range.end as usize;
265
266 if start > self.data.len() || end > self.data.len() {
267 return Err(Error::new(
268 ErrorKind::DataInvalid,
269 format!(
270 "Range {}..{} is out of bounds for data of length {}",
271 start,
272 end,
273 self.data.len()
274 ),
275 ));
276 }
277
278 Ok(self.data.slice(start..end))
279 }
280}
281
282#[derive(Debug)]
288pub struct MemoryFileWrite {
289 data: Arc<RwLock<HashMap<String, Bytes>>>,
290 path: String,
291 buffer: Vec<u8>,
292 closed: bool,
293}
294
295impl MemoryFileWrite {
296 pub fn new(data: Arc<RwLock<HashMap<String, Bytes>>>, path: String) -> Self {
298 Self {
299 data,
300 path,
301 buffer: Vec::new(),
302 closed: false,
303 }
304 }
305}
306
307#[async_trait]
308impl FileWrite for MemoryFileWrite {
309 async fn write(&mut self, bs: Bytes) -> Result<()> {
310 if self.closed {
311 return Err(Error::new(
312 ErrorKind::DataInvalid,
313 "Cannot write to closed file",
314 ));
315 }
316 self.buffer.extend_from_slice(&bs);
317 Ok(())
318 }
319
320 async fn close(&mut self) -> Result<()> {
321 if self.closed {
322 return Err(Error::new(ErrorKind::DataInvalid, "File already closed"));
323 }
324
325 let mut data = self.data.write().map_err(|e| {
326 Error::new(
327 ErrorKind::Unexpected,
328 format!("Failed to acquire write lock: {e}"),
329 )
330 })?;
331
332 data.insert(
333 self.path.clone(),
334 Bytes::from(std::mem::take(&mut self.buffer)),
335 );
336 self.closed = true;
337 Ok(())
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_normalize_path() {
347 assert_eq!(
349 MemoryStorage::normalize_path("memory://path/to/file"),
350 "path/to/file"
351 );
352
353 assert_eq!(
355 MemoryStorage::normalize_path("memory:/path/to/file"),
356 "path/to/file"
357 );
358
359 assert_eq!(
361 MemoryStorage::normalize_path("/path/to/file"),
362 "path/to/file"
363 );
364
365 assert_eq!(
367 MemoryStorage::normalize_path("path/to/file"),
368 "path/to/file"
369 );
370
371 assert_eq!(
373 MemoryStorage::normalize_path("///path/to/file"),
374 "path/to/file"
375 );
376
377 assert_eq!(
379 MemoryStorage::normalize_path("memory:///path/to/file"),
380 "path/to/file"
381 );
382 }
383
384 #[tokio::test]
385 async fn test_memory_storage_write_read() {
386 let storage = MemoryStorage::new();
387 let path = "memory://test/file.txt";
388 let content = Bytes::from("Hello, World!");
389
390 storage.write(path, content.clone()).await.unwrap();
392
393 let read_content = storage.read(path).await.unwrap();
395 assert_eq!(read_content, content);
396 }
397
398 #[tokio::test]
399 async fn test_memory_storage_exists() {
400 let storage = MemoryStorage::new();
401 let path = "memory://test/file.txt";
402
403 assert!(!storage.exists(path).await.unwrap());
405
406 storage.write(path, Bytes::from("test")).await.unwrap();
408
409 assert!(storage.exists(path).await.unwrap());
411 }
412
413 #[tokio::test]
414 async fn test_memory_storage_metadata() {
415 let storage = MemoryStorage::new();
416 let path = "memory://test/file.txt";
417 let content = Bytes::from("Hello, World!");
418
419 storage.write(path, content.clone()).await.unwrap();
420
421 let metadata = storage.metadata(path).await.unwrap();
422 assert_eq!(metadata.size, content.len() as u64);
423 }
424
425 #[tokio::test]
426 async fn test_memory_storage_delete() {
427 let storage = MemoryStorage::new();
428 let path = "memory://test/file.txt";
429
430 storage.write(path, Bytes::from("test")).await.unwrap();
431 assert!(storage.exists(path).await.unwrap());
432
433 storage.delete(path).await.unwrap();
434 assert!(!storage.exists(path).await.unwrap());
435 }
436
437 #[tokio::test]
438 async fn test_memory_storage_delete_prefix() {
439 let storage = MemoryStorage::new();
440
441 storage
443 .write("memory://dir/file1.txt", Bytes::from("1"))
444 .await
445 .unwrap();
446 storage
447 .write("memory://dir/file2.txt", Bytes::from("2"))
448 .await
449 .unwrap();
450 storage
451 .write("memory://other/file.txt", Bytes::from("3"))
452 .await
453 .unwrap();
454
455 storage.delete_prefix("memory://dir").await.unwrap();
457
458 assert!(!storage.exists("memory://dir/file1.txt").await.unwrap());
460 assert!(!storage.exists("memory://dir/file2.txt").await.unwrap());
461
462 assert!(storage.exists("memory://other/file.txt").await.unwrap());
464 }
465
466 #[tokio::test]
467 async fn test_memory_storage_reader() {
468 let storage = MemoryStorage::new();
469 let path = "memory://test/file.txt";
470 let content = Bytes::from("Hello, World!");
471
472 storage.write(path, content.clone()).await.unwrap();
473
474 let reader = storage.reader(path).await.unwrap();
475 let read_content = reader.read(0..content.len() as u64).await.unwrap();
476 assert_eq!(read_content, content);
477
478 let partial = reader.read(0..5).await.unwrap();
480 assert_eq!(partial, Bytes::from("Hello"));
481 }
482
483 #[tokio::test]
484 async fn test_memory_storage_writer() {
485 let storage = MemoryStorage::new();
486 let path = "memory://test/file.txt";
487
488 let mut writer = storage.writer(path).await.unwrap();
489 writer.write(Bytes::from("Hello, ")).await.unwrap();
490 writer.write(Bytes::from("World!")).await.unwrap();
491 writer.close().await.unwrap();
492
493 let content = storage.read(path).await.unwrap();
494 assert_eq!(content, Bytes::from("Hello, World!"));
495 }
496
497 #[tokio::test]
498 async fn test_memory_file_write_double_close() {
499 let storage = MemoryStorage::new();
500 let path = "memory://test/file.txt";
501
502 let mut writer = storage.writer(path).await.unwrap();
503 writer.write(Bytes::from("test")).await.unwrap();
504 writer.close().await.unwrap();
505
506 let result = writer.close().await;
508 assert!(result.is_err());
509 }
510
511 #[tokio::test]
512 async fn test_memory_file_write_after_close() {
513 let storage = MemoryStorage::new();
514 let path = "memory://test/file.txt";
515
516 let mut writer = storage.writer(path).await.unwrap();
517 writer.close().await.unwrap();
518
519 let result = writer.write(Bytes::from("test")).await;
521 assert!(result.is_err());
522 }
523
524 #[tokio::test]
525 async fn test_memory_file_read_out_of_bounds() {
526 let storage = MemoryStorage::new();
527 let path = "memory://test/file.txt";
528 let content = Bytes::from("Hello");
529
530 storage.write(path, content).await.unwrap();
531
532 let reader = storage.reader(path).await.unwrap();
533 let result = reader.read(0..100).await;
534 assert!(result.is_err());
535 }
536
537 #[test]
538 fn test_memory_storage_serialization() {
539 let storage = MemoryStorage::new();
540
541 let serialized = serde_json::to_string(&storage).unwrap();
543
544 let deserialized: MemoryStorage = serde_json::from_str(&serialized).unwrap();
546
547 assert!(deserialized.data.read().unwrap().is_empty());
549 }
550
551 #[test]
552 fn test_memory_storage_factory() {
553 let factory = MemoryStorageFactory;
554 let config = StorageConfig::new();
555 let storage = factory.build(&config).unwrap();
556
557 assert!(format!("{storage:?}").contains("MemoryStorage"));
559 }
560
561 #[test]
562 fn test_memory_storage_factory_serialization() {
563 let factory = MemoryStorageFactory;
564
565 let serialized = serde_json::to_string(&factory).unwrap();
567
568 let deserialized: MemoryStorageFactory = serde_json::from_str(&serialized).unwrap();
570
571 let config = StorageConfig::new();
573 let storage = deserialized.build(&config).unwrap();
574 assert!(format!("{storage:?}").contains("MemoryStorage"));
575 }
576
577 #[tokio::test]
578 async fn test_path_normalization_consistency() {
579 let storage = MemoryStorage::new();
580 let content = Bytes::from("test content");
581
582 storage
584 .write("memory://path/to/file", content.clone())
585 .await
586 .unwrap();
587
588 assert_eq!(
590 storage.read("memory://path/to/file").await.unwrap(),
591 content
592 );
593 assert_eq!(storage.read("memory:/path/to/file").await.unwrap(), content);
594 assert_eq!(storage.read("/path/to/file").await.unwrap(), content);
595 assert_eq!(storage.read("path/to/file").await.unwrap(), content);
596 }
597}