iceberg/arrow/
partition_value_calculator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partition value calculation for Iceberg tables.
19//!
20//! This module provides utilities for calculating partition values from record batches
21//! based on a partition specification.
22
23use std::sync::Arc;
24
25use arrow_array::{ArrayRef, RecordBatch, StructArray};
26use arrow_schema::DataType;
27
28use super::record_batch_projector::RecordBatchProjector;
29use super::type_to_arrow_type;
30use crate::spec::{PartitionSpec, Schema, StructType, Type};
31use crate::transform::{BoxedTransformFunction, create_transform_function};
32use crate::{Error, ErrorKind, Result};
33
34/// Calculator for partition values in Iceberg tables.
35///
36/// This struct handles the projection of source columns and application of
37/// partition transforms to compute partition values for a given record batch.
38#[derive(Debug)]
39pub struct PartitionValueCalculator {
40    projector: RecordBatchProjector,
41    transform_functions: Vec<BoxedTransformFunction>,
42    partition_type: StructType,
43    partition_arrow_type: DataType,
44}
45
46impl PartitionValueCalculator {
47    /// Create a new PartitionValueCalculator.
48    ///
49    /// # Arguments
50    ///
51    /// * `partition_spec` - The partition specification
52    /// * `table_schema` - The Iceberg table schema
53    ///
54    /// # Returns
55    ///
56    /// Returns a new `PartitionValueCalculator` instance or an error if initialization fails.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if:
61    /// - The partition spec is unpartitioned
62    /// - Transform function creation fails
63    /// - Projector initialization fails
64    pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> Result<Self> {
65        if partition_spec.is_unpartitioned() {
66            return Err(Error::new(
67                ErrorKind::DataInvalid,
68                "Cannot create partition calculator for unpartitioned table",
69            ));
70        }
71
72        // Create transform functions for each partition field
73        let transform_functions: Vec<BoxedTransformFunction> = partition_spec
74            .fields()
75            .iter()
76            .map(|pf| create_transform_function(&pf.transform))
77            .collect::<Result<Vec<_>>>()?;
78
79        // Extract source field IDs for projection
80        let source_field_ids: Vec<i32> = partition_spec
81            .fields()
82            .iter()
83            .map(|pf| pf.source_id)
84            .collect();
85
86        // Create projector for extracting source columns
87        let projector = RecordBatchProjector::from_iceberg_schema(
88            Arc::new(table_schema.clone()),
89            &source_field_ids,
90        )?;
91
92        // Get partition type information
93        let partition_type = partition_spec.partition_type(table_schema)?;
94        let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
95
96        Ok(Self {
97            projector,
98            transform_functions,
99            partition_type,
100            partition_arrow_type,
101        })
102    }
103
104    /// Get the partition type as an Iceberg StructType.
105    pub fn partition_type(&self) -> &StructType {
106        &self.partition_type
107    }
108
109    /// Get the partition type as an Arrow DataType.
110    pub fn partition_arrow_type(&self) -> &DataType {
111        &self.partition_arrow_type
112    }
113
114    /// Calculate partition values for a record batch.
115    ///
116    /// This method:
117    /// 1. Projects the source columns from the batch
118    /// 2. Applies partition transforms to each source column
119    /// 3. Constructs a StructArray containing the partition values
120    ///
121    /// # Arguments
122    ///
123    /// * `batch` - The record batch to calculate partition values for
124    ///
125    /// # Returns
126    ///
127    /// Returns an ArrayRef containing a StructArray of partition values, or an error if calculation fails.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if:
132    /// - Column projection fails
133    /// - Transform application fails
134    /// - StructArray construction fails
135    pub fn calculate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
136        // Project source columns from the batch
137        let source_columns = self.projector.project_column(batch.columns())?;
138
139        // Get expected struct fields for the result
140        let expected_struct_fields = match &self.partition_arrow_type {
141            DataType::Struct(fields) => fields.clone(),
142            _ => {
143                return Err(Error::new(
144                    ErrorKind::DataInvalid,
145                    "Expected partition type must be a struct",
146                ));
147            }
148        };
149
150        // Apply transforms to each source column
151        let mut partition_values = Vec::with_capacity(self.transform_functions.len());
152        for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) {
153            let partition_value = transform_fn.transform(source_column.clone())?;
154            partition_values.push(partition_value);
155        }
156
157        // Construct the StructArray
158        let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None)
159            .map_err(|e| {
160                Error::new(
161                    ErrorKind::DataInvalid,
162                    format!("Failed to create partition struct array: {e}"),
163                )
164            })?;
165
166        Ok(Arc::new(struct_array))
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use std::sync::Arc;
173
174    use arrow_array::{Int32Array, RecordBatch, StringArray};
175    use arrow_schema::{Field, Schema as ArrowSchema};
176
177    use super::*;
178    use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, Transform};
179
180    #[test]
181    fn test_partition_calculator_identity_transform() {
182        let table_schema = Schema::builder()
183            .with_schema_id(0)
184            .with_fields(vec![
185                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
186                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
187            ])
188            .build()
189            .unwrap();
190
191        let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
192            .add_partition_field("id", "id_partition", Transform::Identity)
193            .unwrap()
194            .build()
195            .unwrap();
196
197        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
198
199        // Verify partition type
200        assert_eq!(calculator.partition_type().fields().len(), 1);
201        assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
202
203        // Create test batch
204        let arrow_schema = Arc::new(ArrowSchema::new(vec![
205            Field::new("id", DataType::Int32, false),
206            Field::new("name", DataType::Utf8, false),
207        ]));
208
209        let batch = RecordBatch::try_new(arrow_schema, vec![
210            Arc::new(Int32Array::from(vec![10, 20, 30])),
211            Arc::new(StringArray::from(vec!["a", "b", "c"])),
212        ])
213        .unwrap();
214
215        // Calculate partition values
216        let result = calculator.calculate(&batch).unwrap();
217        let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
218
219        let id_partition = struct_array
220            .column_by_name("id_partition")
221            .unwrap()
222            .as_any()
223            .downcast_ref::<Int32Array>()
224            .unwrap();
225
226        assert_eq!(id_partition.value(0), 10);
227        assert_eq!(id_partition.value(1), 20);
228        assert_eq!(id_partition.value(2), 30);
229    }
230
231    #[test]
232    fn test_partition_calculator_unpartitioned_error() {
233        let table_schema = Schema::builder()
234            .with_schema_id(0)
235            .with_fields(vec![
236                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
237            ])
238            .build()
239            .unwrap();
240
241        let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
242            .build()
243            .unwrap();
244
245        let result = PartitionValueCalculator::try_new(&partition_spec, &table_schema);
246        assert!(result.is_err());
247        assert!(
248            result
249                .unwrap_err()
250                .to_string()
251                .contains("unpartitioned table")
252        );
253    }
254}