1use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray};
22use arrow_select::filter::filter_record_batch;
23
24use super::arrow_struct_to_literal;
25use super::partition_value_calculator::PartitionValueCalculator;
26use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, StructType};
27use crate::{Error, ErrorKind, Result};
28
29pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition";
31
32pub struct RecordBatchPartitionSplitter {
42 schema: SchemaRef,
43 partition_spec: PartitionSpecRef,
44 calculator: Option<PartitionValueCalculator>,
45 partition_type: StructType,
46}
47
48impl RecordBatchPartitionSplitter {
49 pub fn try_new(
63 iceberg_schema: SchemaRef,
64 partition_spec: PartitionSpecRef,
65 calculator: Option<PartitionValueCalculator>,
66 ) -> Result<Self> {
67 let partition_type = partition_spec.partition_type(&iceberg_schema)?;
68
69 Ok(Self {
70 schema: iceberg_schema,
71 partition_spec,
72 calculator,
73 partition_type,
74 })
75 }
76
77 pub fn try_new_with_computed_values(
91 iceberg_schema: SchemaRef,
92 partition_spec: PartitionSpecRef,
93 ) -> Result<Self> {
94 let calculator = PartitionValueCalculator::try_new(&partition_spec, &iceberg_schema)?;
95 Self::try_new(iceberg_schema, partition_spec, Some(calculator))
96 }
97
98 pub fn try_new_with_precomputed_values(
112 iceberg_schema: SchemaRef,
113 partition_spec: PartitionSpecRef,
114 ) -> Result<Self> {
115 Self::try_new(iceberg_schema, partition_spec, None)
116 }
117
118 pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(PartitionKey, RecordBatch)>> {
120 let partition_structs = if let Some(calculator) = &self.calculator {
121 let partition_array = calculator.calculate(batch)?;
123 let struct_array = arrow_struct_to_literal(&partition_array, &self.partition_type)?;
124
125 struct_array
126 .into_iter()
127 .map(|s| {
128 if let Some(Literal::Struct(s)) = s {
129 Ok(s)
130 } else {
131 Err(Error::new(
132 ErrorKind::DataInvalid,
133 "Partition value is not a struct literal or is null",
134 ))
135 }
136 })
137 .collect::<Result<Vec<_>>>()?
138 } else {
139 let partition_column = batch
141 .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN)
142 .ok_or_else(|| {
143 Error::new(
144 ErrorKind::DataInvalid,
145 format!(
146 "Partition column '{PROJECTED_PARTITION_VALUE_COLUMN}' not found in batch"
147 ),
148 )
149 })?;
150
151 let partition_struct_array = partition_column
152 .as_any()
153 .downcast_ref::<StructArray>()
154 .ok_or_else(|| {
155 Error::new(
156 ErrorKind::DataInvalid,
157 "Partition column is not a StructArray",
158 )
159 })?;
160
161 let arrow_struct_array = Arc::new(partition_struct_array.clone()) as ArrayRef;
162 let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?;
163
164 struct_array
165 .into_iter()
166 .map(|s| {
167 if let Some(Literal::Struct(s)) = s {
168 Ok(s)
169 } else {
170 Err(Error::new(
171 ErrorKind::DataInvalid,
172 "Partition value is not a struct literal or is null",
173 ))
174 }
175 })
176 .collect::<Result<Vec<_>>>()?
177 };
178
179 let mut group_ids = HashMap::new();
181 partition_structs
182 .iter()
183 .enumerate()
184 .for_each(|(row_id, row)| {
185 group_ids.entry(row.clone()).or_insert(vec![]).push(row_id);
186 });
187
188 let mut partition_batches = Vec::with_capacity(group_ids.len());
190 for (row, row_ids) in group_ids.into_iter() {
191 let filter_array: BooleanArray = {
193 let mut filter = vec![false; batch.num_rows()];
194 row_ids.into_iter().for_each(|row_id| {
195 filter[row_id] = true;
196 });
197 filter.into()
198 };
199
200 let partition_key = PartitionKey::new(
202 self.partition_spec.as_ref().clone(),
203 self.schema.clone(),
204 row,
205 );
206
207 partition_batches.push((partition_key, filter_record_batch(batch, &filter_array)?));
209 }
210
211 Ok(partition_batches)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::sync::Arc;
218
219 use arrow_array::{Int32Array, RecordBatch, StringArray};
220 use arrow_schema::DataType;
221 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
222
223 use super::*;
224 use crate::arrow::schema_to_arrow_schema;
225 use crate::spec::{
226 NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct, Transform, Type,
227 UnboundPartitionField,
228 };
229
230 #[test]
231 fn test_record_batch_partition_split() {
232 let schema = Arc::new(
233 Schema::builder()
234 .with_fields(vec![
235 NestedField::required(
236 1,
237 "id",
238 Type::Primitive(crate::spec::PrimitiveType::Int),
239 )
240 .into(),
241 NestedField::required(
242 2,
243 "name",
244 Type::Primitive(crate::spec::PrimitiveType::String),
245 )
246 .into(),
247 ])
248 .build()
249 .unwrap(),
250 );
251 let partition_spec = Arc::new(
252 PartitionSpecBuilder::new(schema.clone())
253 .with_spec_id(1)
254 .add_unbound_field(UnboundPartitionField {
255 source_id: 1,
256 field_id: None,
257 name: "id_bucket".to_string(),
258 transform: Transform::Identity,
259 })
260 .unwrap()
261 .build()
262 .unwrap(),
263 );
264 let partition_splitter = RecordBatchPartitionSplitter::try_new_with_computed_values(
265 schema.clone(),
266 partition_spec,
267 )
268 .expect("Failed to create splitter");
269
270 let arrow_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
271 let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
272 let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
273 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
274 Arc::new(id_array),
275 Arc::new(data_array),
276 ])
277 .expect("Failed to create RecordBatch");
278
279 let mut partitioned_batches = partition_splitter
280 .split(&batch)
281 .expect("Failed to split RecordBatch");
282 partitioned_batches.sort_by_key(|(partition_key, _)| {
283 if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
284 .as_ref()
285 .unwrap()
286 .as_primitive_literal()
287 .unwrap()
288 {
289 i
290 } else {
291 panic!("The partition value is not a int");
292 }
293 });
294 assert_eq!(partitioned_batches.len(), 3);
295 {
296 let expected_id_array = Int32Array::from(vec![1, 1, 1]);
298 let expected_data_array = StringArray::from(vec!["a", "c", "g"]);
299 let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
300 Arc::new(expected_id_array),
301 Arc::new(expected_data_array),
302 ])
303 .expect("Failed to create expected RecordBatch");
304 assert_eq!(partitioned_batches[0].1, expected_batch);
305 }
306 {
307 let expected_id_array = Int32Array::from(vec![2, 2]);
309 let expected_data_array = StringArray::from(vec!["b", "e"]);
310 let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
311 Arc::new(expected_id_array),
312 Arc::new(expected_data_array),
313 ])
314 .expect("Failed to create expected RecordBatch");
315 assert_eq!(partitioned_batches[1].1, expected_batch);
316 }
317 {
318 let expected_id_array = Int32Array::from(vec![3, 3]);
320 let expected_data_array = StringArray::from(vec!["d", "f"]);
321 let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
322 Arc::new(expected_id_array),
323 Arc::new(expected_data_array),
324 ])
325 .expect("Failed to create expected RecordBatch");
326 assert_eq!(partitioned_batches[2].1, expected_batch);
327 }
328
329 let partition_values = partitioned_batches
330 .iter()
331 .map(|(partition_key, _)| partition_key.data().clone())
332 .collect::<Vec<_>>();
333 assert_eq!(partition_values, vec![
335 Struct::from_iter(vec![Some(Literal::int(1))]),
336 Struct::from_iter(vec![Some(Literal::int(2))]),
337 Struct::from_iter(vec![Some(Literal::int(3))]),
338 ]);
339 }
340
341 #[test]
342 fn test_record_batch_partition_split_with_partition_column() {
343 use arrow_array::StructArray;
344 use arrow_schema::{Field, Schema as ArrowSchema};
345
346 let schema = Arc::new(
347 Schema::builder()
348 .with_fields(vec![
349 NestedField::required(
350 1,
351 "id",
352 Type::Primitive(crate::spec::PrimitiveType::Int),
353 )
354 .into(),
355 NestedField::required(
356 2,
357 "name",
358 Type::Primitive(crate::spec::PrimitiveType::String),
359 )
360 .into(),
361 ])
362 .build()
363 .unwrap(),
364 );
365 let partition_spec = Arc::new(
366 PartitionSpecBuilder::new(schema.clone())
367 .with_spec_id(1)
368 .add_unbound_field(UnboundPartitionField {
369 source_id: 1,
370 field_id: None,
371 name: "id_bucket".to_string(),
372 transform: Transform::Identity,
373 })
374 .unwrap()
375 .build()
376 .unwrap(),
377 );
378
379 let partition_field = Field::new("id_bucket", DataType::Int32, false).with_metadata(
382 HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1000".to_string())]),
383 );
384 let partition_struct_field = Field::new(
385 PROJECTED_PARTITION_VALUE_COLUMN,
386 DataType::Struct(vec![partition_field.clone()].into()),
387 false,
388 );
389
390 let input_schema = Arc::new(ArrowSchema::new(vec![
391 Field::new("id", DataType::Int32, false),
392 Field::new("name", DataType::Utf8, false),
393 partition_struct_field,
394 ]));
395
396 let partition_splitter = RecordBatchPartitionSplitter::try_new_with_precomputed_values(
398 schema.clone(),
399 partition_spec,
400 )
401 .expect("Failed to create splitter");
402
403 let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
405 let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
406
407 let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
409 let partition_struct = StructArray::from(vec![(
410 Arc::new(partition_field),
411 Arc::new(partition_values) as ArrayRef,
412 )]);
413
414 let batch = RecordBatch::try_new(input_schema.clone(), vec![
415 Arc::new(id_array),
416 Arc::new(data_array),
417 Arc::new(partition_struct),
418 ])
419 .expect("Failed to create RecordBatch");
420
421 let mut partitioned_batches = partition_splitter
423 .split(&batch)
424 .expect("Failed to split RecordBatch");
425
426 partitioned_batches.sort_by_key(|(partition_key, _)| {
427 if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
428 .as_ref()
429 .unwrap()
430 .as_primitive_literal()
431 .unwrap()
432 {
433 i
434 } else {
435 panic!("The partition value is not a int");
436 }
437 });
438
439 assert_eq!(partitioned_batches.len(), 3);
440
441 let extract_values = |batch: &RecordBatch| -> (Vec<i32>, Vec<String>) {
443 let id_col = batch
444 .column(0)
445 .as_any()
446 .downcast_ref::<Int32Array>()
447 .unwrap();
448 let name_col = batch
449 .column(1)
450 .as_any()
451 .downcast_ref::<StringArray>()
452 .unwrap();
453 (
454 id_col.values().to_vec(),
455 name_col.iter().map(|s| s.unwrap().to_string()).collect(),
456 )
457 };
458
459 let (key, batch) = &partitioned_batches[0];
461 assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(1))]));
462 let (ids, names) = extract_values(batch);
463 assert_eq!(ids, vec![1, 1, 1]);
464 assert_eq!(names, vec!["a", "c", "g"]);
465
466 let (key, batch) = &partitioned_batches[1];
468 assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(2))]));
469 let (ids, names) = extract_values(batch);
470 assert_eq!(ids, vec![2, 2]);
471 assert_eq!(names, vec!["b", "e"]);
472
473 let (key, batch) = &partitioned_batches[2];
475 assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(3))]));
476 let (ids, names) = extract_values(batch);
477 assert_eq!(ids, vec![3, 3]);
478 assert_eq!(names, vec!["d", "f"]);
479 }
480}