1use std::fmt::{Debug, Formatter};
19
20use arrow_array::RecordBatch;
21
22use crate::io::{FileIO, OutputFile};
23use crate::spec::{DataFileBuilder, PartitionKey, TableProperties};
24use crate::writer::CurrentFileStatus;
25use crate::writer::file_writer::location_generator::{FileNameGenerator, LocationGenerator};
26use crate::writer::file_writer::{FileWriter, FileWriterBuilder};
27use crate::{Error, ErrorKind, Result};
28
29#[derive(Clone, Debug)]
31pub struct RollingFileWriterBuilder<
32 B: FileWriterBuilder,
33 L: LocationGenerator,
34 F: FileNameGenerator,
35> {
36 inner_builder: B,
37 target_file_size: usize,
38 file_io: FileIO,
39 location_generator: L,
40 file_name_generator: F,
41}
42
43impl<B, L, F> RollingFileWriterBuilder<B, L, F>
44where
45 B: FileWriterBuilder,
46 L: LocationGenerator,
47 F: FileNameGenerator,
48{
49 pub fn new(
63 inner_builder: B,
64 target_file_size: usize,
65 file_io: FileIO,
66 location_generator: L,
67 file_name_generator: F,
68 ) -> Self {
69 Self {
70 inner_builder,
71 target_file_size,
72 file_io,
73 location_generator,
74 file_name_generator,
75 }
76 }
77
78 pub fn new_with_default_file_size(
91 inner_builder: B,
92 file_io: FileIO,
93 location_generator: L,
94 file_name_generator: F,
95 ) -> Self {
96 Self {
97 inner_builder,
98 target_file_size: TableProperties::PROPERTY_WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
99 file_io,
100 location_generator,
101 file_name_generator,
102 }
103 }
104
105 pub fn build(&self) -> RollingFileWriter<B, L, F> {
107 RollingFileWriter {
108 inner: None,
109 inner_builder: self.inner_builder.clone(),
110 target_file_size: self.target_file_size,
111 data_file_builders: vec![],
112 file_io: self.file_io.clone(),
113 location_generator: self.location_generator.clone(),
114 file_name_generator: self.file_name_generator.clone(),
115 }
116 }
117}
118
119pub struct RollingFileWriter<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
126 inner: Option<B::R>,
127 inner_builder: B,
128 target_file_size: usize,
129 data_file_builders: Vec<DataFileBuilder>,
130 file_io: FileIO,
131 location_generator: L,
132 file_name_generator: F,
133}
134
135impl<B, L, F> Debug for RollingFileWriter<B, L, F>
136where
137 B: FileWriterBuilder,
138 L: LocationGenerator,
139 F: FileNameGenerator,
140{
141 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
142 f.debug_struct("RollingFileWriter")
143 .field("target_file_size", &self.target_file_size)
144 .field("file_io", &self.file_io)
145 .finish()
146 }
147}
148
149impl<B, L, F> RollingFileWriter<B, L, F>
150where
151 B: FileWriterBuilder,
152 L: LocationGenerator,
153 F: FileNameGenerator,
154{
155 fn should_roll(&self) -> bool {
161 self.current_written_size() > self.target_file_size
162 }
163
164 fn new_output_file(&self, partition_key: &Option<PartitionKey>) -> Result<OutputFile> {
165 self.file_io
166 .new_output(self.location_generator.generate_location(
167 partition_key.as_ref(),
168 &self.file_name_generator.generate_file_name(),
169 ))
170 }
171
172 pub async fn write(
187 &mut self,
188 partition_key: &Option<PartitionKey>,
189 input: &RecordBatch,
190 ) -> Result<()> {
191 if self.inner.is_none() {
192 self.inner = Some(
194 self.inner_builder
195 .build(self.new_output_file(partition_key)?)
196 .await?,
197 );
198 }
199
200 if self.should_roll() {
201 if let Some(inner) = self.inner.take() {
202 self.data_file_builders.extend(inner.close().await?);
204
205 self.inner = Some(
207 self.inner_builder
208 .build(self.new_output_file(partition_key)?)
209 .await?,
210 );
211 }
212 }
213
214 if let Some(writer) = self.inner.as_mut() {
216 Ok(writer.write(input).await?)
217 } else {
218 Err(Error::new(
219 ErrorKind::Unexpected,
220 "Writer is not initialized!",
221 ))
222 }
223 }
224
225 pub async fn close(mut self) -> Result<Vec<DataFileBuilder>> {
232 if let Some(current_writer) = self.inner {
234 self.data_file_builders
235 .extend(current_writer.close().await?);
236 }
237
238 Ok(self.data_file_builders)
239 }
240}
241
242impl<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> CurrentFileStatus
243 for RollingFileWriter<B, L, F>
244{
245 fn current_file_path(&self) -> String {
246 self.inner.as_ref().unwrap().current_file_path()
247 }
248
249 fn current_row_num(&self) -> usize {
250 self.inner.as_ref().unwrap().current_row_num()
251 }
252
253 fn current_written_size(&self) -> usize {
254 self.inner.as_ref().unwrap().current_written_size()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use std::collections::HashMap;
261 use std::sync::Arc;
262
263 use arrow_array::{ArrayRef, Int32Array, StringArray};
264 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
265 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
266 use parquet::file::properties::WriterProperties;
267 use rand::prelude::IteratorRandom;
268 use tempfile::TempDir;
269
270 use super::*;
271 use crate::io::FileIOBuilder;
272 use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Schema, Type};
273 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
274 use crate::writer::file_writer::ParquetWriterBuilder;
275 use crate::writer::file_writer::location_generator::{
276 DefaultFileNameGenerator, DefaultLocationGenerator,
277 };
278 use crate::writer::tests::check_parquet_data_file;
279 use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
280
281 fn make_test_schema() -> Result<Schema> {
282 Schema::builder()
283 .with_schema_id(1)
284 .with_fields(vec![
285 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
286 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
287 ])
288 .build()
289 }
290
291 fn make_test_arrow_schema() -> ArrowSchema {
292 ArrowSchema::new(vec![
293 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
294 PARQUET_FIELD_ID_META_KEY.to_string(),
295 1.to_string(),
296 )])),
297 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
298 PARQUET_FIELD_ID_META_KEY.to_string(),
299 2.to_string(),
300 )])),
301 ])
302 }
303
304 #[tokio::test]
305 async fn test_rolling_writer_basic() -> Result<()> {
306 let temp_dir = TempDir::new()?;
307 let file_io = FileIOBuilder::new_fs_io().build()?;
308 let location_gen = DefaultLocationGenerator::with_data_location(
309 temp_dir.path().to_str().unwrap().to_string(),
310 );
311 let file_name_gen =
312 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
313
314 let schema = make_test_schema()?;
316
317 let parquet_writer_builder =
319 ParquetWriterBuilder::new(WriterProperties::builder().build(), Arc::new(schema));
320
321 let rolling_file_writer_builder = RollingFileWriterBuilder::new(
323 parquet_writer_builder,
324 1024 * 1024,
325 file_io.clone(),
326 location_gen,
327 file_name_gen,
328 );
329
330 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_file_writer_builder);
331
332 let mut writer = data_file_writer_builder.build(None).await?;
334
335 let arrow_schema = make_test_arrow_schema();
337
338 let batch = RecordBatch::try_new(Arc::new(arrow_schema), vec![
339 Arc::new(Int32Array::from(vec![1, 2, 3])),
340 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
341 ])?;
342
343 writer.write(batch.clone()).await?;
345
346 let data_files = writer.close().await?;
348
349 assert_eq!(
351 data_files.len(),
352 1,
353 "Expected only one data file to be created"
354 );
355
356 check_parquet_data_file(&file_io, &data_files[0], &batch).await;
358
359 Ok(())
360 }
361
362 #[tokio::test]
363 async fn test_rolling_writer_with_rolling() -> Result<()> {
364 let temp_dir = TempDir::new()?;
365 let file_io = FileIOBuilder::new_fs_io().build()?;
366 let location_gen = DefaultLocationGenerator::with_data_location(
367 temp_dir.path().to_str().unwrap().to_string(),
368 );
369 let file_name_gen =
370 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
371
372 let schema = make_test_schema()?;
374
375 let parquet_writer_builder =
377 ParquetWriterBuilder::new(WriterProperties::builder().build(), Arc::new(schema));
378
379 let rolling_writer_builder = RollingFileWriterBuilder::new(
381 parquet_writer_builder,
382 1024,
383 file_io,
384 location_gen,
385 file_name_gen,
386 );
387
388 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
389
390 let mut writer = data_file_writer_builder.build(None).await?;
392
393 let arrow_schema = make_test_arrow_schema();
395 let arrow_schema_ref = Arc::new(arrow_schema.clone());
396
397 let names = vec![
398 "Alice", "Bob", "Charlie", "Dave", "Eve", "Frank", "Grace", "Heidi", "Ivan", "Judy",
399 "Kelly", "Larry", "Mallory", "Shawn",
400 ];
401
402 let mut rng = rand::thread_rng();
403 let batch_num = 10;
404 let batch_rows = 100;
405 let expected_rows = batch_num * batch_rows;
406
407 for i in 0..batch_num {
408 let int_values: Vec<i32> = (0..batch_rows).map(|row| i * batch_rows + row).collect();
409 let str_values: Vec<&str> = (0..batch_rows)
410 .map(|_| *names.iter().choose(&mut rng).unwrap())
411 .collect();
412
413 let int_array = Arc::new(Int32Array::from(int_values)) as ArrayRef;
414 let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
415
416 let batch =
417 RecordBatch::try_new(Arc::clone(&arrow_schema_ref), vec![int_array, str_array])
418 .expect("Failed to create RecordBatch");
419
420 writer.write(batch).await?;
421 }
422
423 let data_files = writer.close().await?;
425
426 assert!(
428 data_files.len() > 4,
429 "Expected at least 4 data files to be created, but got {}",
430 data_files.len()
431 );
432
433 let total_records: u64 = data_files.iter().map(|file| file.record_count).sum();
435 assert_eq!(
436 total_records, expected_rows as u64,
437 "Expected {expected_rows} total records across all files"
438 );
439
440 Ok(())
441 }
442}