iceberg/arrow/
record_batch_partition_splitter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray};
22use arrow_select::filter::filter_record_batch;
23
24use super::arrow_struct_to_literal;
25use super::partition_value_calculator::PartitionValueCalculator;
26use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, StructType};
27use crate::{Error, ErrorKind, Result};
28
29/// Column name for the projected partition values struct
30pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition";
31
32/// The splitter used to split the record batch into multiple record batches by the partition spec.
33/// 1. It will project and transform the input record batch based on the partition spec, get the partitioned record batch.
34/// 2. Split the input record batch into multiple record batches based on the partitioned record batch.
35///
36/// # Partition Value Modes
37///
38/// The splitter supports two modes for obtaining partition values:
39/// - **Computed mode** (`calculator` is `Some`): Computes partition values from source columns using transforms
40/// - **Pre-computed mode** (`calculator` is `None`): Expects a `_partition` column in the input batch
41pub struct RecordBatchPartitionSplitter {
42    schema: SchemaRef,
43    partition_spec: PartitionSpecRef,
44    calculator: Option<PartitionValueCalculator>,
45    partition_type: StructType,
46}
47
48impl RecordBatchPartitionSplitter {
49    /// Create a new RecordBatchPartitionSplitter.
50    ///
51    /// # Arguments
52    ///
53    /// * `iceberg_schema` - The Iceberg schema reference
54    /// * `partition_spec` - The partition specification reference
55    /// * `calculator` - Optional calculator for computing partition values from source columns.
56    ///   - `Some(calculator)`: Compute partition values from source columns using transforms
57    ///   - `None`: Expect a pre-computed `_partition` column in the input batch
58    ///
59    /// # Returns
60    ///
61    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails.
62    pub fn try_new(
63        iceberg_schema: SchemaRef,
64        partition_spec: PartitionSpecRef,
65        calculator: Option<PartitionValueCalculator>,
66    ) -> Result<Self> {
67        let partition_type = partition_spec.partition_type(&iceberg_schema)?;
68
69        Ok(Self {
70            schema: iceberg_schema,
71            partition_spec,
72            calculator,
73            partition_type,
74        })
75    }
76
77    /// Create a new RecordBatchPartitionSplitter with computed partition values.
78    ///
79    /// This is a convenience method that creates a calculator and initializes the splitter
80    /// to compute partition values from source columns.
81    ///
82    /// # Arguments
83    ///
84    /// * `iceberg_schema` - The Iceberg schema reference
85    /// * `partition_spec` - The partition specification reference
86    ///
87    /// # Returns
88    ///
89    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails.
90    pub fn try_new_with_computed_values(
91        iceberg_schema: SchemaRef,
92        partition_spec: PartitionSpecRef,
93    ) -> Result<Self> {
94        let calculator = PartitionValueCalculator::try_new(&partition_spec, &iceberg_schema)?;
95        Self::try_new(iceberg_schema, partition_spec, Some(calculator))
96    }
97
98    /// Create a new RecordBatchPartitionSplitter expecting pre-computed partition values.
99    ///
100    /// This is a convenience method that initializes the splitter to expect a `_partition`
101    /// column in the input batches.
102    ///
103    /// # Arguments
104    ///
105    /// * `iceberg_schema` - The Iceberg schema reference
106    /// * `partition_spec` - The partition specification reference
107    ///
108    /// # Returns
109    ///
110    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails.
111    pub fn try_new_with_precomputed_values(
112        iceberg_schema: SchemaRef,
113        partition_spec: PartitionSpecRef,
114    ) -> Result<Self> {
115        Self::try_new(iceberg_schema, partition_spec, None)
116    }
117
118    /// Split the record batch into multiple record batches based on the partition spec.
119    pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(PartitionKey, RecordBatch)>> {
120        let partition_structs = if let Some(calculator) = &self.calculator {
121            // Compute partition values from source columns using calculator
122            let partition_array = calculator.calculate(batch)?;
123            let struct_array = arrow_struct_to_literal(&partition_array, &self.partition_type)?;
124
125            struct_array
126                .into_iter()
127                .map(|s| {
128                    if let Some(Literal::Struct(s)) = s {
129                        Ok(s)
130                    } else {
131                        Err(Error::new(
132                            ErrorKind::DataInvalid,
133                            "Partition value is not a struct literal or is null",
134                        ))
135                    }
136                })
137                .collect::<Result<Vec<_>>>()?
138        } else {
139            // Extract partition values from pre-computed partition column
140            let partition_column = batch
141                .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN)
142                .ok_or_else(|| {
143                    Error::new(
144                        ErrorKind::DataInvalid,
145                        format!(
146                            "Partition column '{PROJECTED_PARTITION_VALUE_COLUMN}' not found in batch"
147                        ),
148                    )
149                })?;
150
151            let partition_struct_array = partition_column
152                .as_any()
153                .downcast_ref::<StructArray>()
154                .ok_or_else(|| {
155                    Error::new(
156                        ErrorKind::DataInvalid,
157                        "Partition column is not a StructArray",
158                    )
159                })?;
160
161            let arrow_struct_array = Arc::new(partition_struct_array.clone()) as ArrayRef;
162            let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?;
163
164            struct_array
165                .into_iter()
166                .map(|s| {
167                    if let Some(Literal::Struct(s)) = s {
168                        Ok(s)
169                    } else {
170                        Err(Error::new(
171                            ErrorKind::DataInvalid,
172                            "Partition value is not a struct literal or is null",
173                        ))
174                    }
175                })
176                .collect::<Result<Vec<_>>>()?
177        };
178
179        // Group the batch by row value.
180        let mut group_ids = HashMap::new();
181        partition_structs
182            .iter()
183            .enumerate()
184            .for_each(|(row_id, row)| {
185                group_ids.entry(row.clone()).or_insert(vec![]).push(row_id);
186            });
187
188        // Partition the batch with same partition partition_values
189        let mut partition_batches = Vec::with_capacity(group_ids.len());
190        for (row, row_ids) in group_ids.into_iter() {
191            // generate the bool filter array from column_ids
192            let filter_array: BooleanArray = {
193                let mut filter = vec![false; batch.num_rows()];
194                row_ids.into_iter().for_each(|row_id| {
195                    filter[row_id] = true;
196                });
197                filter.into()
198            };
199
200            // Create PartitionKey from the partition struct
201            let partition_key = PartitionKey::new(
202                self.partition_spec.as_ref().clone(),
203                self.schema.clone(),
204                row,
205            );
206
207            // filter the RecordBatch
208            partition_batches.push((partition_key, filter_record_batch(batch, &filter_array)?));
209        }
210
211        Ok(partition_batches)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::sync::Arc;
218
219    use arrow_array::{Int32Array, RecordBatch, StringArray};
220    use arrow_schema::DataType;
221    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
222
223    use super::*;
224    use crate::arrow::schema_to_arrow_schema;
225    use crate::spec::{
226        NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct, Transform, Type,
227        UnboundPartitionField,
228    };
229
230    #[test]
231    fn test_record_batch_partition_split() {
232        let schema = Arc::new(
233            Schema::builder()
234                .with_fields(vec![
235                    NestedField::required(
236                        1,
237                        "id",
238                        Type::Primitive(crate::spec::PrimitiveType::Int),
239                    )
240                    .into(),
241                    NestedField::required(
242                        2,
243                        "name",
244                        Type::Primitive(crate::spec::PrimitiveType::String),
245                    )
246                    .into(),
247                ])
248                .build()
249                .unwrap(),
250        );
251        let partition_spec = Arc::new(
252            PartitionSpecBuilder::new(schema.clone())
253                .with_spec_id(1)
254                .add_unbound_field(UnboundPartitionField {
255                    source_id: 1,
256                    field_id: None,
257                    name: "id_bucket".to_string(),
258                    transform: Transform::Identity,
259                })
260                .unwrap()
261                .build()
262                .unwrap(),
263        );
264        let partition_splitter = RecordBatchPartitionSplitter::try_new_with_computed_values(
265            schema.clone(),
266            partition_spec,
267        )
268        .expect("Failed to create splitter");
269
270        let arrow_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
271        let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
272        let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
273        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
274            Arc::new(id_array),
275            Arc::new(data_array),
276        ])
277        .expect("Failed to create RecordBatch");
278
279        let mut partitioned_batches = partition_splitter
280            .split(&batch)
281            .expect("Failed to split RecordBatch");
282        partitioned_batches.sort_by_key(|(partition_key, _)| {
283            if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
284                .as_ref()
285                .unwrap()
286                .as_primitive_literal()
287                .unwrap()
288            {
289                i
290            } else {
291                panic!("The partition value is not a int");
292            }
293        });
294        assert_eq!(partitioned_batches.len(), 3);
295        {
296            // check the first partition
297            let expected_id_array = Int32Array::from(vec![1, 1, 1]);
298            let expected_data_array = StringArray::from(vec!["a", "c", "g"]);
299            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
300                Arc::new(expected_id_array),
301                Arc::new(expected_data_array),
302            ])
303            .expect("Failed to create expected RecordBatch");
304            assert_eq!(partitioned_batches[0].1, expected_batch);
305        }
306        {
307            // check the second partition
308            let expected_id_array = Int32Array::from(vec![2, 2]);
309            let expected_data_array = StringArray::from(vec!["b", "e"]);
310            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
311                Arc::new(expected_id_array),
312                Arc::new(expected_data_array),
313            ])
314            .expect("Failed to create expected RecordBatch");
315            assert_eq!(partitioned_batches[1].1, expected_batch);
316        }
317        {
318            // check the third partition
319            let expected_id_array = Int32Array::from(vec![3, 3]);
320            let expected_data_array = StringArray::from(vec!["d", "f"]);
321            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), vec![
322                Arc::new(expected_id_array),
323                Arc::new(expected_data_array),
324            ])
325            .expect("Failed to create expected RecordBatch");
326            assert_eq!(partitioned_batches[2].1, expected_batch);
327        }
328
329        let partition_values = partitioned_batches
330            .iter()
331            .map(|(partition_key, _)| partition_key.data().clone())
332            .collect::<Vec<_>>();
333        // check partition value is struct(1), struct(2), struct(3)
334        assert_eq!(partition_values, vec![
335            Struct::from_iter(vec![Some(Literal::int(1))]),
336            Struct::from_iter(vec![Some(Literal::int(2))]),
337            Struct::from_iter(vec![Some(Literal::int(3))]),
338        ]);
339    }
340
341    #[test]
342    fn test_record_batch_partition_split_with_partition_column() {
343        use arrow_array::StructArray;
344        use arrow_schema::{Field, Schema as ArrowSchema};
345
346        let schema = Arc::new(
347            Schema::builder()
348                .with_fields(vec![
349                    NestedField::required(
350                        1,
351                        "id",
352                        Type::Primitive(crate::spec::PrimitiveType::Int),
353                    )
354                    .into(),
355                    NestedField::required(
356                        2,
357                        "name",
358                        Type::Primitive(crate::spec::PrimitiveType::String),
359                    )
360                    .into(),
361                ])
362                .build()
363                .unwrap(),
364        );
365        let partition_spec = Arc::new(
366            PartitionSpecBuilder::new(schema.clone())
367                .with_spec_id(1)
368                .add_unbound_field(UnboundPartitionField {
369                    source_id: 1,
370                    field_id: None,
371                    name: "id_bucket".to_string(),
372                    transform: Transform::Identity,
373                })
374                .unwrap()
375                .build()
376                .unwrap(),
377        );
378
379        // Create input schema with _partition column
380        // Note: partition field IDs start from 1000 by default
381        let partition_field = Field::new("id_bucket", DataType::Int32, false).with_metadata(
382            HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1000".to_string())]),
383        );
384        let partition_struct_field = Field::new(
385            PROJECTED_PARTITION_VALUE_COLUMN,
386            DataType::Struct(vec![partition_field.clone()].into()),
387            false,
388        );
389
390        let input_schema = Arc::new(ArrowSchema::new(vec![
391            Field::new("id", DataType::Int32, false),
392            Field::new("name", DataType::Utf8, false),
393            partition_struct_field,
394        ]));
395
396        // Create splitter expecting pre-computed partition column
397        let partition_splitter = RecordBatchPartitionSplitter::try_new_with_precomputed_values(
398            schema.clone(),
399            partition_spec,
400        )
401        .expect("Failed to create splitter");
402
403        // Create test data with pre-computed partition column
404        let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
405        let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
406
407        // Create partition column (same values as id for Identity transform)
408        let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
409        let partition_struct = StructArray::from(vec![(
410            Arc::new(partition_field),
411            Arc::new(partition_values) as ArrayRef,
412        )]);
413
414        let batch = RecordBatch::try_new(input_schema.clone(), vec![
415            Arc::new(id_array),
416            Arc::new(data_array),
417            Arc::new(partition_struct),
418        ])
419        .expect("Failed to create RecordBatch");
420
421        // Split using the pre-computed partition column
422        let mut partitioned_batches = partition_splitter
423            .split(&batch)
424            .expect("Failed to split RecordBatch");
425
426        partitioned_batches.sort_by_key(|(partition_key, _)| {
427            if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
428                .as_ref()
429                .unwrap()
430                .as_primitive_literal()
431                .unwrap()
432            {
433                i
434            } else {
435                panic!("The partition value is not a int");
436            }
437        });
438
439        assert_eq!(partitioned_batches.len(), 3);
440
441        // Helper to extract id and name values from a batch
442        let extract_values = |batch: &RecordBatch| -> (Vec<i32>, Vec<String>) {
443            let id_col = batch
444                .column(0)
445                .as_any()
446                .downcast_ref::<Int32Array>()
447                .unwrap();
448            let name_col = batch
449                .column(1)
450                .as_any()
451                .downcast_ref::<StringArray>()
452                .unwrap();
453            (
454                id_col.values().to_vec(),
455                name_col.iter().map(|s| s.unwrap().to_string()).collect(),
456            )
457        };
458
459        // Verify partition 1: id=1, names=["a", "c", "g"]
460        let (key, batch) = &partitioned_batches[0];
461        assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(1))]));
462        let (ids, names) = extract_values(batch);
463        assert_eq!(ids, vec![1, 1, 1]);
464        assert_eq!(names, vec!["a", "c", "g"]);
465
466        // Verify partition 2: id=2, names=["b", "e"]
467        let (key, batch) = &partitioned_batches[1];
468        assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(2))]));
469        let (ids, names) = extract_values(batch);
470        assert_eq!(ids, vec![2, 2]);
471        assert_eq!(names, vec!["b", "e"]);
472
473        // Verify partition 3: id=3, names=["d", "f"]
474        let (key, batch) = &partitioned_batches[2];
475        assert_eq!(key.data(), &Struct::from_iter(vec![Some(Literal::int(3))]));
476        let (ids, names) = extract_values(batch);
477        assert_eq!(ids, vec![3, 3]);
478        assert_eq!(names, vec!["d", "f"]);
479    }
480}