1use arrow_array::RecordBatch;
21
22use crate::spec::{DataContentType, DataFile, PartitionKey};
23use crate::writer::file_writer::FileWriterBuilder;
24use crate::writer::file_writer::location_generator::{FileNameGenerator, LocationGenerator};
25use crate::writer::file_writer::rolling_writer::{RollingFileWriter, RollingFileWriterBuilder};
26use crate::writer::{CurrentFileStatus, IcebergWriter, IcebergWriterBuilder};
27use crate::{Error, ErrorKind, Result};
28
29#[derive(Debug)]
31pub struct DataFileWriterBuilder<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
32 inner: RollingFileWriterBuilder<B, L, F>,
33}
34
35impl<B, L, F> DataFileWriterBuilder<B, L, F>
36where
37 B: FileWriterBuilder,
38 L: LocationGenerator,
39 F: FileNameGenerator,
40{
41 pub fn new(inner: RollingFileWriterBuilder<B, L, F>) -> Self {
43 Self { inner }
44 }
45}
46
47#[async_trait::async_trait]
48impl<B, L, F> IcebergWriterBuilder for DataFileWriterBuilder<B, L, F>
49where
50 B: FileWriterBuilder,
51 L: LocationGenerator,
52 F: FileNameGenerator,
53{
54 type R = DataFileWriter<B, L, F>;
55
56 async fn build(&self, partition_key: Option<PartitionKey>) -> Result<Self::R> {
57 Ok(DataFileWriter {
58 inner: Some(self.inner.build()),
59 partition_key,
60 })
61 }
62}
63
64#[derive(Debug)]
66pub struct DataFileWriter<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
67 inner: Option<RollingFileWriter<B, L, F>>,
68 partition_key: Option<PartitionKey>,
69}
70
71#[async_trait::async_trait]
72impl<B, L, F> IcebergWriter for DataFileWriter<B, L, F>
73where
74 B: FileWriterBuilder,
75 L: LocationGenerator,
76 F: FileNameGenerator,
77{
78 async fn write(&mut self, batch: RecordBatch) -> Result<()> {
79 if let Some(writer) = self.inner.as_mut() {
80 writer.write(&self.partition_key, &batch).await
81 } else {
82 Err(Error::new(
83 ErrorKind::Unexpected,
84 "Writer is not initialized!",
85 ))
86 }
87 }
88
89 async fn close(&mut self) -> Result<Vec<DataFile>> {
90 if let Some(writer) = self.inner.take() {
91 writer
92 .close()
93 .await?
94 .into_iter()
95 .map(|mut res| {
96 res.content(DataContentType::Data);
97 if let Some(pk) = self.partition_key.as_ref() {
98 res.partition(pk.data().clone());
99 res.partition_spec_id(pk.spec().spec_id());
100 }
101 res.build().map_err(|e| {
102 Error::new(
103 ErrorKind::DataInvalid,
104 format!("Failed to build data file: {e}"),
105 )
106 })
107 })
108 .collect()
109 } else {
110 Err(Error::new(
111 ErrorKind::Unexpected,
112 "Data file writer has been closed.",
113 ))
114 }
115 }
116}
117
118impl<B, L, F> CurrentFileStatus for DataFileWriter<B, L, F>
119where
120 B: FileWriterBuilder,
121 L: LocationGenerator,
122 F: FileNameGenerator,
123{
124 fn current_file_path(&self) -> String {
125 self.inner.as_ref().unwrap().current_file_path()
126 }
127
128 fn current_row_num(&self) -> usize {
129 self.inner.as_ref().unwrap().current_row_num()
130 }
131
132 fn current_written_size(&self) -> usize {
133 self.inner.as_ref().unwrap().current_written_size()
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use std::collections::HashMap;
140 use std::sync::Arc;
141
142 use arrow_array::{Int32Array, StringArray};
143 use arrow_schema::{DataType, Field};
144 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
145 use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions};
146 use parquet::file::properties::WriterProperties;
147 use tempfile::TempDir;
148
149 use crate::Result;
150 use crate::io::FileIOBuilder;
151 use crate::spec::{
152 DataContentType, DataFileFormat, Literal, NestedField, PartitionKey, PartitionSpec,
153 PrimitiveType, Schema, Struct, Type,
154 };
155 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
156 use crate::writer::file_writer::ParquetWriterBuilder;
157 use crate::writer::file_writer::location_generator::{
158 DefaultFileNameGenerator, DefaultLocationGenerator,
159 };
160 use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
161 use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
162
163 #[tokio::test]
164 async fn test_parquet_writer() -> Result<()> {
165 let temp_dir = TempDir::new().unwrap();
166 let file_io = FileIOBuilder::new_fs_io().build().unwrap();
167 let location_gen = DefaultLocationGenerator::with_data_location(
168 temp_dir.path().to_str().unwrap().to_string(),
169 );
170 let file_name_gen =
171 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
172
173 let schema = Schema::builder()
174 .with_schema_id(3)
175 .with_fields(vec![
176 NestedField::required(3, "foo", Type::Primitive(PrimitiveType::Int)).into(),
177 NestedField::required(4, "bar", Type::Primitive(PrimitiveType::String)).into(),
178 ])
179 .build()?;
180
181 let pw = ParquetWriterBuilder::new(WriterProperties::builder().build(), Arc::new(schema));
182
183 let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
184 pw,
185 file_io.clone(),
186 location_gen,
187 file_name_gen,
188 );
189
190 let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
191 .build(None)
192 .await
193 .unwrap();
194
195 let arrow_schema = arrow_schema::Schema::new(vec![
196 Field::new("foo", DataType::Int32, false).with_metadata(HashMap::from([(
197 PARQUET_FIELD_ID_META_KEY.to_string(),
198 3.to_string(),
199 )])),
200 Field::new("bar", DataType::Utf8, false).with_metadata(HashMap::from([(
201 PARQUET_FIELD_ID_META_KEY.to_string(),
202 4.to_string(),
203 )])),
204 ]);
205 let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
206 Arc::new(Int32Array::from(vec![1, 2, 3])),
207 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
208 ])?;
209 data_file_writer.write(batch).await?;
210
211 let data_files = data_file_writer.close().await.unwrap();
212 assert_eq!(data_files.len(), 1);
213
214 let data_file = &data_files[0];
215 assert_eq!(data_file.file_format, DataFileFormat::Parquet);
216 assert_eq!(data_file.content, DataContentType::Data);
217 assert_eq!(data_file.partition, Struct::empty());
218
219 let input_file = file_io.new_input(data_file.file_path.clone())?;
220 let input_content = input_file.read().await?;
221
222 let parquet_reader =
223 ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())
224 .expect("Failed to load Parquet metadata");
225
226 let field_ids: Vec<i32> = parquet_reader
227 .parquet_schema()
228 .columns()
229 .iter()
230 .map(|col| col.self_type().get_basic_info().id())
231 .collect();
232
233 assert_eq!(field_ids, vec![3, 4]);
234 Ok(())
235 }
236
237 #[tokio::test]
238 async fn test_parquet_writer_with_partition() -> Result<()> {
239 let temp_dir = TempDir::new().unwrap();
240 let file_io = FileIOBuilder::new_fs_io().build().unwrap();
241 let location_gen = DefaultLocationGenerator::with_data_location(
242 temp_dir.path().to_str().unwrap().to_string(),
243 );
244 let file_name_gen = DefaultFileNameGenerator::new(
245 "test_partitioned".to_string(),
246 None,
247 DataFileFormat::Parquet,
248 );
249
250 let schema = Schema::builder()
251 .with_schema_id(5)
252 .with_fields(vec![
253 NestedField::required(5, "id", Type::Primitive(PrimitiveType::Int)).into(),
254 NestedField::required(6, "name", Type::Primitive(PrimitiveType::String)).into(),
255 ])
256 .build()?;
257 let schema_ref = Arc::new(schema);
258
259 let partition_value = Struct::from_iter([Some(Literal::int(1))]);
260 let partition_key = PartitionKey::new(
261 PartitionSpec::builder(schema_ref.clone()).build()?,
262 schema_ref.clone(),
263 partition_value.clone(),
264 );
265
266 let parquet_writer_builder =
267 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema_ref.clone());
268
269 let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
270 parquet_writer_builder,
271 file_io.clone(),
272 location_gen,
273 file_name_gen,
274 );
275
276 let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
277 .build(Some(partition_key))
278 .await?;
279
280 let arrow_schema = arrow_schema::Schema::new(vec![
281 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
282 PARQUET_FIELD_ID_META_KEY.to_string(),
283 5.to_string(),
284 )])),
285 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
286 PARQUET_FIELD_ID_META_KEY.to_string(),
287 6.to_string(),
288 )])),
289 ]);
290 let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
291 Arc::new(Int32Array::from(vec![1, 2, 3])),
292 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
293 ])?;
294 data_file_writer.write(batch).await?;
295
296 let data_files = data_file_writer.close().await.unwrap();
297 assert_eq!(data_files.len(), 1);
298
299 let data_file = &data_files[0];
300 assert_eq!(data_file.file_format, DataFileFormat::Parquet);
301 assert_eq!(data_file.content, DataContentType::Data);
302 assert_eq!(data_file.partition, partition_value);
303
304 let input_file = file_io.new_input(data_file.file_path.clone())?;
305 let input_content = input_file.read().await?;
306
307 let parquet_reader =
308 ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())?;
309
310 let field_ids: Vec<i32> = parquet_reader
311 .parquet_schema()
312 .columns()
313 .iter()
314 .map(|col| col.self_type().get_basic_info().id())
315 .collect();
316 assert_eq!(field_ids, vec![5, 6]);
317
318 let field_names: Vec<&str> = parquet_reader
319 .parquet_schema()
320 .columns()
321 .iter()
322 .map(|col| col.name())
323 .collect();
324 assert_eq!(field_names, vec!["id", "name"]);
325
326 Ok(())
327 }
328}