iceberg_datafusion/physical_plan/
project.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partition value projection for Iceberg tables.
19
20use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::Result as DFResult;
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator};
30use iceberg::spec::PartitionSpec;
31use iceberg::table::Table;
32
33use crate::to_datafusion_error;
34
35/// Extends an ExecutionPlan with partition value calculations for Iceberg tables.
36///
37/// This function takes an input ExecutionPlan and extends it with an additional column
38/// containing calculated partition values based on the table's partition specification.
39/// For unpartitioned tables, returns the original plan unchanged.
40///
41/// # Arguments
42/// * `input` - The input ExecutionPlan to extend
43/// * `table` - The Iceberg table with partition specification
44///
45/// # Returns
46/// * `Ok(Arc<dyn ExecutionPlan>)` - Extended plan with partition values column
47/// * `Err` - If partition spec is not found or transformation fails
48pub fn project_with_partition(
49    input: Arc<dyn ExecutionPlan>,
50    table: &Table,
51) -> DFResult<Arc<dyn ExecutionPlan>> {
52    let metadata = table.metadata();
53    let partition_spec = metadata.default_partition_spec();
54    let table_schema = metadata.current_schema();
55
56    if partition_spec.is_unpartitioned() {
57        return Ok(input);
58    }
59
60    let input_schema = input.schema();
61    // TODO: Validate that input_schema matches the Iceberg table schema.
62    // See: https://github.com/apache/iceberg-rust/issues/1752
63    let calculator =
64        PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
65            .map_err(to_datafusion_error)?;
66
67    let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
68        Vec::with_capacity(input_schema.fields().len() + 1);
69
70    for (index, field) in input_schema.fields().iter().enumerate() {
71        let column_expr = Arc::new(Column::new(field.name(), index));
72        projection_exprs.push((column_expr, field.name().clone()));
73    }
74
75    let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
76    projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
77
78    let projection = ProjectionExec::try_new(projection_exprs, input)?;
79    Ok(Arc::new(projection))
80}
81
82/// PhysicalExpr implementation for partition value calculation
83#[derive(Debug, Clone)]
84struct PartitionExpr {
85    calculator: Arc<PartitionValueCalculator>,
86    partition_spec: Arc<PartitionSpec>,
87}
88
89impl PartitionExpr {
90    fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
91        Self {
92            calculator: Arc::new(calculator),
93            partition_spec,
94        }
95    }
96}
97
98// Manual PartialEq/Eq implementations for pointer-based equality
99// (two PartitionExpr are equal if they share the same calculator and partition_spec instances)
100impl PartialEq for PartitionExpr {
101    fn eq(&self, other: &Self) -> bool {
102        Arc::ptr_eq(&self.calculator, &other.calculator)
103            && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
104    }
105}
106
107impl Eq for PartitionExpr {}
108
109impl PhysicalExpr for PartitionExpr {
110    fn as_any(&self) -> &dyn std::any::Any {
111        self
112    }
113
114    fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
115        Ok(self.calculator.partition_arrow_type().clone())
116    }
117
118    fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
119        Ok(false)
120    }
121
122    fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
123        let array = self
124            .calculator
125            .calculate(batch)
126            .map_err(to_datafusion_error)?;
127        Ok(ColumnarValue::Array(array))
128    }
129
130    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
131        vec![]
132    }
133
134    fn with_new_children(
135        self: Arc<Self>,
136        _children: Vec<Arc<dyn PhysicalExpr>>,
137    ) -> DFResult<Arc<dyn PhysicalExpr>> {
138        Ok(self)
139    }
140
141    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        let field_names: Vec<String> = self
143            .partition_spec
144            .fields()
145            .iter()
146            .map(|pf| format!("{}({})", pf.transform, pf.name))
147            .collect();
148        write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
149    }
150}
151
152impl std::fmt::Display for PartitionExpr {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        let field_names: Vec<&str> = self
155            .partition_spec
156            .fields()
157            .iter()
158            .map(|pf| pf.name.as_str())
159            .collect();
160        write!(f, "iceberg_partition_values({})", field_names.join(", "))
161    }
162}
163
164impl std::hash::Hash for PartitionExpr {
165    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
166        // Two PartitionExpr are equal if they share the same calculator and partition_spec Arcs
167        Arc::as_ptr(&self.calculator).hash(state);
168        Arc::as_ptr(&self.partition_spec).hash(state);
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
175    use datafusion::arrow::datatypes::{Field, Fields};
176    use datafusion::physical_plan::empty::EmptyExec;
177    use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
178
179    use super::*;
180
181    #[test]
182    fn test_partition_calculator_basic() {
183        let table_schema = Schema::builder()
184            .with_schema_id(0)
185            .with_fields(vec![
186                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
187                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
188            ])
189            .build()
190            .unwrap();
191
192        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
193            .add_partition_field("id", "id_partition", Transform::Identity)
194            .unwrap()
195            .build()
196            .unwrap();
197
198        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
199
200        // Verify partition type
201        assert_eq!(calculator.partition_type().fields().len(), 1);
202        assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
203    }
204
205    #[test]
206    fn test_partition_expr_with_projection() {
207        let table_schema = Schema::builder()
208            .with_schema_id(0)
209            .with_fields(vec![
210                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
211                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
212            ])
213            .build()
214            .unwrap();
215
216        let partition_spec = Arc::new(
217            iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
218                .add_partition_field("id", "id_partition", Transform::Identity)
219                .unwrap()
220                .build()
221                .unwrap(),
222        );
223
224        let arrow_schema = Arc::new(ArrowSchema::new(vec![
225            Field::new("id", DataType::Int32, false),
226            Field::new("name", DataType::Utf8, false),
227        ]));
228
229        let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
230
231        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
232
233        let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
234            Vec::with_capacity(arrow_schema.fields().len() + 1);
235        for (i, field) in arrow_schema.fields().iter().enumerate() {
236            let column_expr = Arc::new(Column::new(field.name(), i));
237            projection_exprs.push((column_expr, field.name().clone()));
238        }
239
240        let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
241        projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
242
243        let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
244        let result = Arc::new(projection);
245
246        assert_eq!(result.schema().fields().len(), 3);
247        assert_eq!(result.schema().field(0).name(), "id");
248        assert_eq!(result.schema().field(1).name(), "name");
249        assert_eq!(result.schema().field(2).name(), "_partition");
250    }
251
252    #[test]
253    fn test_partition_expr_evaluate() {
254        let table_schema = Schema::builder()
255            .with_schema_id(0)
256            .with_fields(vec![
257                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
258                NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
259            ])
260            .build()
261            .unwrap();
262
263        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
264            .add_partition_field("id", "id_partition", Transform::Identity)
265            .unwrap()
266            .build()
267            .unwrap();
268
269        let arrow_schema = Arc::new(ArrowSchema::new(vec![
270            Field::new("id", DataType::Int32, false),
271            Field::new("data", DataType::Utf8, false),
272        ]));
273
274        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
275            Arc::new(Int32Array::from(vec![10, 20, 30])),
276            Arc::new(datafusion::arrow::array::StringArray::from(vec![
277                "a", "b", "c",
278            ])),
279        ])
280        .unwrap();
281
282        let partition_spec = Arc::new(partition_spec);
283        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
284        let partition_type = calculator.partition_arrow_type().clone();
285        let expr = PartitionExpr::new(calculator, partition_spec);
286
287        assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
288        assert!(!expr.nullable(&arrow_schema).unwrap());
289
290        let result = expr.evaluate(&batch).unwrap();
291        match result {
292            ColumnarValue::Array(array) => {
293                let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
294                let id_partition = struct_array
295                    .column_by_name("id_partition")
296                    .unwrap()
297                    .as_any()
298                    .downcast_ref::<Int32Array>()
299                    .unwrap();
300                assert_eq!(id_partition.value(0), 10);
301                assert_eq!(id_partition.value(1), 20);
302                assert_eq!(id_partition.value(2), 30);
303            }
304            _ => panic!("Expected array result"),
305        }
306    }
307
308    #[test]
309    fn test_nested_partition() {
310        let address_struct = StructType::new(vec![
311            NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
312            NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
313        ]);
314
315        let table_schema = Schema::builder()
316            .with_schema_id(0)
317            .with_fields(vec![
318                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
319                NestedField::required(2, "address", Type::Struct(address_struct)).into(),
320            ])
321            .build()
322            .unwrap();
323
324        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
325            .add_partition_field("address.city", "city_partition", Transform::Identity)
326            .unwrap()
327            .build()
328            .unwrap();
329
330        let struct_fields = Fields::from(vec![
331            Field::new("street", DataType::Utf8, false),
332            Field::new("city", DataType::Utf8, false),
333        ]);
334
335        let arrow_schema = Arc::new(ArrowSchema::new(vec![
336            Field::new("id", DataType::Int32, false),
337            Field::new("address", DataType::Struct(struct_fields), false),
338        ]));
339
340        let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
341            "123 Main St",
342            "456 Oak Ave",
343        ]));
344        let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
345            "New York",
346            "Los Angeles",
347        ]));
348
349        let struct_array = StructArray::from(vec![
350            (
351                Arc::new(Field::new("street", DataType::Utf8, false)),
352                street_array as ArrayRef,
353            ),
354            (
355                Arc::new(Field::new("city", DataType::Utf8, false)),
356                city_array as ArrayRef,
357            ),
358        ]);
359
360        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
361            Arc::new(Int32Array::from(vec![1, 2])),
362            Arc::new(struct_array),
363        ])
364        .unwrap();
365
366        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
367        let array = calculator.calculate(&batch).unwrap();
368
369        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
370        let city_partition = struct_array
371            .column_by_name("city_partition")
372            .unwrap()
373            .as_any()
374            .downcast_ref::<datafusion::arrow::array::StringArray>()
375            .unwrap();
376
377        assert_eq!(city_partition.value(0), "New York");
378        assert_eq!(city_partition.value(1), "Los Angeles");
379    }
380}