1use std::collections::HashMap;
26use std::ops::Range;
27use std::sync::{Arc, RwLock};
28
29use async_trait::async_trait;
30use bytes::Bytes;
31use futures::StreamExt;
32use futures::stream::BoxStream;
33use serde::{Deserialize, Serialize};
34
35use crate::io::{
36 FileMetadata, FileRead, FileWrite, InputFile, OutputFile, Storage, StorageConfig,
37 StorageFactory,
38};
39use crate::{Error, ErrorKind, Result};
40
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63pub struct MemoryStorage {
64 #[serde(skip, default = "default_memory_data")]
65 data: Arc<RwLock<HashMap<String, Bytes>>>,
66}
67
68fn default_memory_data() -> Arc<RwLock<HashMap<String, Bytes>>> {
69 Arc::new(RwLock::new(HashMap::new()))
70}
71
72impl MemoryStorage {
73 pub fn new() -> Self {
75 Self {
76 data: Arc::new(RwLock::new(HashMap::new())),
77 }
78 }
79
80 pub(crate) fn normalize_path(path: &str) -> String {
88 let path = path.strip_prefix("memory://").unwrap_or(path);
90 let path = path.strip_prefix("memory:/").unwrap_or(path);
92 path.trim_start_matches('/').to_string()
94 }
95}
96
97#[async_trait]
98#[typetag::serde]
99impl Storage for MemoryStorage {
100 async fn exists(&self, path: &str) -> Result<bool> {
101 let normalized = Self::normalize_path(path);
102 let data = self.data.read().map_err(|e| {
103 Error::new(
104 ErrorKind::Unexpected,
105 format!("Failed to acquire read lock: {e}"),
106 )
107 })?;
108 Ok(data.contains_key(&normalized))
109 }
110
111 async fn metadata(&self, path: &str) -> Result<FileMetadata> {
112 let normalized = Self::normalize_path(path);
113 let data = self.data.read().map_err(|e| {
114 Error::new(
115 ErrorKind::Unexpected,
116 format!("Failed to acquire read lock: {e}"),
117 )
118 })?;
119 match data.get(&normalized) {
120 Some(bytes) => Ok(FileMetadata {
121 size: bytes.len() as u64,
122 }),
123 None => Err(Error::new(
124 ErrorKind::DataInvalid,
125 format!("File not found: {path}"),
126 )),
127 }
128 }
129
130 async fn read(&self, path: &str) -> Result<Bytes> {
131 let normalized = Self::normalize_path(path);
132 let data = self.data.read().map_err(|e| {
133 Error::new(
134 ErrorKind::Unexpected,
135 format!("Failed to acquire read lock: {e}"),
136 )
137 })?;
138 match data.get(&normalized) {
139 Some(bytes) => Ok(bytes.clone()),
140 None => Err(Error::new(
141 ErrorKind::DataInvalid,
142 format!("File not found: {path}"),
143 )),
144 }
145 }
146
147 async fn reader(&self, path: &str) -> Result<Box<dyn FileRead>> {
148 let normalized = Self::normalize_path(path);
149 let data = self.data.read().map_err(|e| {
150 Error::new(
151 ErrorKind::Unexpected,
152 format!("Failed to acquire read lock: {e}"),
153 )
154 })?;
155 match data.get(&normalized) {
156 Some(bytes) => Ok(Box::new(MemoryFileRead::new(bytes.clone()))),
157 None => Err(Error::new(
158 ErrorKind::DataInvalid,
159 format!("File not found: {path}"),
160 )),
161 }
162 }
163
164 async fn write(&self, path: &str, bs: Bytes) -> Result<()> {
165 let normalized = Self::normalize_path(path);
166 let mut data = self.data.write().map_err(|e| {
167 Error::new(
168 ErrorKind::Unexpected,
169 format!("Failed to acquire write lock: {e}"),
170 )
171 })?;
172 data.insert(normalized, bs);
173 Ok(())
174 }
175
176 async fn writer(&self, path: &str) -> Result<Box<dyn FileWrite>> {
177 let normalized = Self::normalize_path(path);
178 Ok(Box::new(MemoryFileWrite::new(
179 self.data.clone(),
180 normalized,
181 )))
182 }
183
184 async fn delete(&self, path: &str) -> Result<()> {
185 let normalized = Self::normalize_path(path);
186 let mut data = self.data.write().map_err(|e| {
187 Error::new(
188 ErrorKind::Unexpected,
189 format!("Failed to acquire write lock: {e}"),
190 )
191 })?;
192 data.remove(&normalized);
193 Ok(())
194 }
195
196 async fn delete_prefix(&self, path: &str) -> Result<()> {
197 let normalized = Self::normalize_path(path);
198 let prefix = if normalized.ends_with('/') {
199 normalized
200 } else {
201 format!("{normalized}/")
202 };
203
204 let mut data = self.data.write().map_err(|e| {
205 Error::new(
206 ErrorKind::Unexpected,
207 format!("Failed to acquire write lock: {e}"),
208 )
209 })?;
210
211 let keys_to_remove: Vec<String> = data
213 .keys()
214 .filter(|k| k.starts_with(&prefix))
215 .cloned()
216 .collect();
217
218 for key in keys_to_remove {
219 data.remove(&key);
220 }
221
222 Ok(())
223 }
224
225 async fn delete_stream(&self, mut paths: BoxStream<'static, String>) -> Result<()> {
226 while let Some(path) = paths.next().await {
227 self.delete(&path).await?;
228 }
229 Ok(())
230 }
231
232 fn new_input(&self, path: &str) -> Result<InputFile> {
233 Ok(InputFile::new(Arc::new(self.clone()), path.to_string()))
234 }
235
236 fn new_output(&self, path: &str) -> Result<OutputFile> {
237 Ok(OutputFile::new(Arc::new(self.clone()), path.to_string()))
238 }
239}
240
241#[derive(Clone, Debug, Default, Serialize, Deserialize)]
247pub struct MemoryStorageFactory;
248
249#[typetag::serde]
250impl StorageFactory for MemoryStorageFactory {
251 fn build(&self, _config: &StorageConfig) -> Result<Arc<dyn Storage>> {
252 Ok(Arc::new(MemoryStorage::new()))
253 }
254}
255
256#[derive(Debug)]
258pub struct MemoryFileRead {
259 data: Bytes,
260}
261
262impl MemoryFileRead {
263 pub fn new(data: Bytes) -> Self {
265 Self { data }
266 }
267}
268
269#[async_trait]
270impl FileRead for MemoryFileRead {
271 async fn read(&self, range: Range<u64>) -> Result<Bytes> {
272 let start = range.start as usize;
273 let end = range.end as usize;
274
275 if start > self.data.len() || end > self.data.len() {
276 return Err(Error::new(
277 ErrorKind::DataInvalid,
278 format!(
279 "Range {}..{} is out of bounds for data of length {}",
280 start,
281 end,
282 self.data.len()
283 ),
284 ));
285 }
286
287 Ok(self.data.slice(start..end))
288 }
289}
290
291#[derive(Debug)]
297pub struct MemoryFileWrite {
298 data: Arc<RwLock<HashMap<String, Bytes>>>,
299 path: String,
300 buffer: Vec<u8>,
301 closed: bool,
302}
303
304impl MemoryFileWrite {
305 pub fn new(data: Arc<RwLock<HashMap<String, Bytes>>>, path: String) -> Self {
307 Self {
308 data,
309 path,
310 buffer: Vec::new(),
311 closed: false,
312 }
313 }
314}
315
316#[async_trait]
317impl FileWrite for MemoryFileWrite {
318 async fn write(&mut self, bs: Bytes) -> Result<()> {
319 if self.closed {
320 return Err(Error::new(
321 ErrorKind::DataInvalid,
322 "Cannot write to closed file",
323 ));
324 }
325 self.buffer.extend_from_slice(&bs);
326 Ok(())
327 }
328
329 async fn close(&mut self) -> Result<()> {
330 if self.closed {
331 return Err(Error::new(ErrorKind::DataInvalid, "File already closed"));
332 }
333
334 let mut data = self.data.write().map_err(|e| {
335 Error::new(
336 ErrorKind::Unexpected,
337 format!("Failed to acquire write lock: {e}"),
338 )
339 })?;
340
341 data.insert(
342 self.path.clone(),
343 Bytes::from(std::mem::take(&mut self.buffer)),
344 );
345 self.closed = true;
346 Ok(())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_normalize_path() {
356 assert_eq!(
358 MemoryStorage::normalize_path("memory://path/to/file"),
359 "path/to/file"
360 );
361
362 assert_eq!(
364 MemoryStorage::normalize_path("memory:/path/to/file"),
365 "path/to/file"
366 );
367
368 assert_eq!(
370 MemoryStorage::normalize_path("/path/to/file"),
371 "path/to/file"
372 );
373
374 assert_eq!(
376 MemoryStorage::normalize_path("path/to/file"),
377 "path/to/file"
378 );
379
380 assert_eq!(
382 MemoryStorage::normalize_path("///path/to/file"),
383 "path/to/file"
384 );
385
386 assert_eq!(
388 MemoryStorage::normalize_path("memory:///path/to/file"),
389 "path/to/file"
390 );
391 }
392
393 #[tokio::test]
394 async fn test_memory_storage_write_read() {
395 let storage = MemoryStorage::new();
396 let path = "memory://test/file.txt";
397 let content = Bytes::from("Hello, World!");
398
399 storage.write(path, content.clone()).await.unwrap();
401
402 let read_content = storage.read(path).await.unwrap();
404 assert_eq!(read_content, content);
405 }
406
407 #[tokio::test]
408 async fn test_memory_storage_exists() {
409 let storage = MemoryStorage::new();
410 let path = "memory://test/file.txt";
411
412 assert!(!storage.exists(path).await.unwrap());
414
415 storage.write(path, Bytes::from("test")).await.unwrap();
417
418 assert!(storage.exists(path).await.unwrap());
420 }
421
422 #[tokio::test]
423 async fn test_memory_storage_metadata() {
424 let storage = MemoryStorage::new();
425 let path = "memory://test/file.txt";
426 let content = Bytes::from("Hello, World!");
427
428 storage.write(path, content.clone()).await.unwrap();
429
430 let metadata = storage.metadata(path).await.unwrap();
431 assert_eq!(metadata.size, content.len() as u64);
432 }
433
434 #[tokio::test]
435 async fn test_memory_storage_delete() {
436 let storage = MemoryStorage::new();
437 let path = "memory://test/file.txt";
438
439 storage.write(path, Bytes::from("test")).await.unwrap();
440 assert!(storage.exists(path).await.unwrap());
441
442 storage.delete(path).await.unwrap();
443 assert!(!storage.exists(path).await.unwrap());
444 }
445
446 #[tokio::test]
447 async fn test_memory_storage_delete_prefix() {
448 let storage = MemoryStorage::new();
449
450 storage
452 .write("memory://dir/file1.txt", Bytes::from("1"))
453 .await
454 .unwrap();
455 storage
456 .write("memory://dir/file2.txt", Bytes::from("2"))
457 .await
458 .unwrap();
459 storage
460 .write("memory://other/file.txt", Bytes::from("3"))
461 .await
462 .unwrap();
463
464 storage.delete_prefix("memory://dir").await.unwrap();
466
467 assert!(!storage.exists("memory://dir/file1.txt").await.unwrap());
469 assert!(!storage.exists("memory://dir/file2.txt").await.unwrap());
470
471 assert!(storage.exists("memory://other/file.txt").await.unwrap());
473 }
474
475 #[tokio::test]
476 async fn test_memory_storage_reader() {
477 let storage = MemoryStorage::new();
478 let path = "memory://test/file.txt";
479 let content = Bytes::from("Hello, World!");
480
481 storage.write(path, content.clone()).await.unwrap();
482
483 let reader = storage.reader(path).await.unwrap();
484 let read_content = reader.read(0..content.len() as u64).await.unwrap();
485 assert_eq!(read_content, content);
486
487 let partial = reader.read(0..5).await.unwrap();
489 assert_eq!(partial, Bytes::from("Hello"));
490 }
491
492 #[tokio::test]
493 async fn test_memory_storage_writer() {
494 let storage = MemoryStorage::new();
495 let path = "memory://test/file.txt";
496
497 let mut writer = storage.writer(path).await.unwrap();
498 writer.write(Bytes::from("Hello, ")).await.unwrap();
499 writer.write(Bytes::from("World!")).await.unwrap();
500 writer.close().await.unwrap();
501
502 let content = storage.read(path).await.unwrap();
503 assert_eq!(content, Bytes::from("Hello, World!"));
504 }
505
506 #[tokio::test]
507 async fn test_memory_file_write_double_close() {
508 let storage = MemoryStorage::new();
509 let path = "memory://test/file.txt";
510
511 let mut writer = storage.writer(path).await.unwrap();
512 writer.write(Bytes::from("test")).await.unwrap();
513 writer.close().await.unwrap();
514
515 let result = writer.close().await;
517 assert!(result.is_err());
518 }
519
520 #[tokio::test]
521 async fn test_memory_file_write_after_close() {
522 let storage = MemoryStorage::new();
523 let path = "memory://test/file.txt";
524
525 let mut writer = storage.writer(path).await.unwrap();
526 writer.close().await.unwrap();
527
528 let result = writer.write(Bytes::from("test")).await;
530 assert!(result.is_err());
531 }
532
533 #[tokio::test]
534 async fn test_memory_file_read_out_of_bounds() {
535 let storage = MemoryStorage::new();
536 let path = "memory://test/file.txt";
537 let content = Bytes::from("Hello");
538
539 storage.write(path, content).await.unwrap();
540
541 let reader = storage.reader(path).await.unwrap();
542 let result = reader.read(0..100).await;
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_memory_storage_serialization() {
548 let storage = MemoryStorage::new();
549
550 let serialized = serde_json::to_string(&storage).unwrap();
552
553 let deserialized: MemoryStorage = serde_json::from_str(&serialized).unwrap();
555
556 assert!(deserialized.data.read().unwrap().is_empty());
558 }
559
560 #[test]
561 fn test_memory_storage_factory() {
562 let factory = MemoryStorageFactory;
563 let config = StorageConfig::new();
564 let storage = factory.build(&config).unwrap();
565
566 assert!(format!("{storage:?}").contains("MemoryStorage"));
568 }
569
570 #[test]
571 fn test_memory_storage_factory_serialization() {
572 let factory = MemoryStorageFactory;
573
574 let serialized = serde_json::to_string(&factory).unwrap();
576
577 let deserialized: MemoryStorageFactory = serde_json::from_str(&serialized).unwrap();
579
580 let config = StorageConfig::new();
582 let storage = deserialized.build(&config).unwrap();
583 assert!(format!("{storage:?}").contains("MemoryStorage"));
584 }
585
586 #[tokio::test]
587 async fn test_path_normalization_consistency() {
588 let storage = MemoryStorage::new();
589 let content = Bytes::from("test content");
590
591 storage
593 .write("memory://path/to/file", content.clone())
594 .await
595 .unwrap();
596
597 assert_eq!(
599 storage.read("memory://path/to/file").await.unwrap(),
600 content
601 );
602 assert_eq!(storage.read("memory:/path/to/file").await.unwrap(), content);
603 assert_eq!(storage.read("/path/to/file").await.unwrap(), content);
604 assert_eq!(storage.read("path/to/file").await.unwrap(), content);
605 }
606
607 #[tokio::test]
608 async fn test_memory_storage_delete_stream() {
609 use futures::stream;
610
611 let storage = MemoryStorage::new();
612
613 storage
615 .write("memory://file1.txt", Bytes::from("1"))
616 .await
617 .unwrap();
618 storage
619 .write("memory://file2.txt", Bytes::from("2"))
620 .await
621 .unwrap();
622 storage
623 .write("memory://file3.txt", Bytes::from("3"))
624 .await
625 .unwrap();
626
627 assert!(storage.exists("memory://file1.txt").await.unwrap());
629 assert!(storage.exists("memory://file2.txt").await.unwrap());
630 assert!(storage.exists("memory://file3.txt").await.unwrap());
631
632 let paths = vec![
634 "memory://file1.txt".to_string(),
635 "memory://file2.txt".to_string(),
636 ];
637 let path_stream = stream::iter(paths).boxed();
638 storage.delete_stream(path_stream).await.unwrap();
639
640 assert!(!storage.exists("memory://file1.txt").await.unwrap());
642 assert!(!storage.exists("memory://file2.txt").await.unwrap());
643
644 assert!(storage.exists("memory://file3.txt").await.unwrap());
646 }
647
648 #[tokio::test]
649 async fn test_memory_storage_delete_stream_empty() {
650 use futures::stream;
651
652 let storage = MemoryStorage::new();
653
654 let path_stream = stream::iter(Vec::<String>::new()).boxed();
656 storage.delete_stream(path_stream).await.unwrap();
657 }
658}