iceberg_datafusion/physical_plan/
project.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partition value projection for Iceberg tables.
19
20use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::{DataFusionError, Result as DFResult};
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{
30    PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator, schema_to_arrow_schema,
31    strip_metadata_from_schema,
32};
33use iceberg::spec::PartitionSpec;
34use iceberg::table::Table;
35
36use crate::to_datafusion_error;
37
38/// Extends an ExecutionPlan with partition value calculations for Iceberg tables.
39///
40/// This function takes an input ExecutionPlan and extends it with an additional column
41/// containing calculated partition values based on the table's partition specification.
42/// For unpartitioned tables, returns the original plan unchanged.
43///
44/// # Arguments
45/// * `input` - The input ExecutionPlan to extend
46/// * `table` - The Iceberg table with partition specification
47///
48/// # Returns
49/// * `Ok(Arc<dyn ExecutionPlan>)` - Extended plan with partition values column
50/// * `Err` - If partition spec is not found or transformation fails
51pub fn project_with_partition(
52    input: Arc<dyn ExecutionPlan>,
53    table: &Table,
54) -> DFResult<Arc<dyn ExecutionPlan>> {
55    let metadata = table.metadata();
56    let partition_spec = metadata.default_partition_spec();
57    let table_schema = metadata.current_schema();
58
59    if partition_spec.is_unpartitioned() {
60        return Ok(input);
61    }
62
63    let input_schema = input.schema();
64
65    // Validate that input_schema matches the Iceberg table schema
66    // Strip metadata from both schemas before comparison to ignore metadata differences
67    let expected_arrow_schema =
68        schema_to_arrow_schema(table_schema.as_ref()).map_err(to_datafusion_error)?;
69    let input_schema_cleaned =
70        strip_metadata_from_schema(&input_schema).map_err(to_datafusion_error)?;
71    let expected_schema_cleaned =
72        strip_metadata_from_schema(&expected_arrow_schema).map_err(to_datafusion_error)?;
73
74    if input_schema_cleaned != expected_schema_cleaned {
75        return Err(DataFusionError::Plan(format!(
76            "Input schema does not match Iceberg table schema.\n\
77             Expected schema: {expected_schema_cleaned}\n\
78             Input schema: {input_schema_cleaned}"
79        )));
80    }
81
82    let calculator =
83        PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
84            .map_err(to_datafusion_error)?;
85
86    let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
87        Vec::with_capacity(input_schema.fields().len() + 1);
88
89    for (index, field) in input_schema.fields().iter().enumerate() {
90        let column_expr = Arc::new(Column::new(field.name(), index));
91        projection_exprs.push((column_expr, field.name().clone()));
92    }
93
94    let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
95    projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
96
97    let projection = ProjectionExec::try_new(projection_exprs, input)?;
98    Ok(Arc::new(projection))
99}
100
101/// PhysicalExpr implementation for partition value calculation
102#[derive(Debug, Clone)]
103struct PartitionExpr {
104    calculator: Arc<PartitionValueCalculator>,
105    partition_spec: Arc<PartitionSpec>,
106}
107
108impl PartitionExpr {
109    fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
110        Self {
111            calculator: Arc::new(calculator),
112            partition_spec,
113        }
114    }
115}
116
117// Manual PartialEq/Eq implementations for pointer-based equality
118// (two PartitionExpr are equal if they share the same calculator and partition_spec instances)
119impl PartialEq for PartitionExpr {
120    fn eq(&self, other: &Self) -> bool {
121        Arc::ptr_eq(&self.calculator, &other.calculator)
122            && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
123    }
124}
125
126impl Eq for PartitionExpr {}
127
128impl PhysicalExpr for PartitionExpr {
129    fn as_any(&self) -> &dyn std::any::Any {
130        self
131    }
132
133    fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
134        Ok(self.calculator.partition_arrow_type().clone())
135    }
136
137    fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
138        Ok(false)
139    }
140
141    fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
142        let array = self
143            .calculator
144            .calculate(batch)
145            .map_err(to_datafusion_error)?;
146        Ok(ColumnarValue::Array(array))
147    }
148
149    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
150        vec![]
151    }
152
153    fn with_new_children(
154        self: Arc<Self>,
155        _children: Vec<Arc<dyn PhysicalExpr>>,
156    ) -> DFResult<Arc<dyn PhysicalExpr>> {
157        Ok(self)
158    }
159
160    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        let field_names: Vec<String> = self
162            .partition_spec
163            .fields()
164            .iter()
165            .map(|pf| format!("{}({})", pf.transform, pf.name))
166            .collect();
167        write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
168    }
169}
170
171impl std::fmt::Display for PartitionExpr {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        let field_names: Vec<&str> = self
174            .partition_spec
175            .fields()
176            .iter()
177            .map(|pf| pf.name.as_str())
178            .collect();
179        write!(f, "iceberg_partition_values({})", field_names.join(", "))
180    }
181}
182
183impl std::hash::Hash for PartitionExpr {
184    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185        // Two PartitionExpr are equal if they share the same calculator and partition_spec Arcs
186        Arc::as_ptr(&self.calculator).hash(state);
187        Arc::as_ptr(&self.partition_spec).hash(state);
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
194    use datafusion::arrow::datatypes::{DataType, Field, Fields};
195    use datafusion::physical_plan::empty::EmptyExec;
196    use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
197
198    use super::*;
199
200    #[test]
201    fn test_partition_calculator_basic() {
202        let table_schema = Schema::builder()
203            .with_schema_id(0)
204            .with_fields(vec![
205                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
206                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
207            ])
208            .build()
209            .unwrap();
210
211        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
212            .add_partition_field("id", "id_partition", Transform::Identity)
213            .unwrap()
214            .build()
215            .unwrap();
216
217        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
218
219        // Verify partition type
220        assert_eq!(calculator.partition_type().fields().len(), 1);
221        assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
222    }
223
224    #[test]
225    fn test_partition_expr_with_projection() {
226        let table_schema = Schema::builder()
227            .with_schema_id(0)
228            .with_fields(vec![
229                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
230                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
231            ])
232            .build()
233            .unwrap();
234
235        let partition_spec = Arc::new(
236            iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
237                .add_partition_field("id", "id_partition", Transform::Identity)
238                .unwrap()
239                .build()
240                .unwrap(),
241        );
242
243        let arrow_schema = Arc::new(ArrowSchema::new(vec![
244            Field::new("id", DataType::Int32, false),
245            Field::new("name", DataType::Utf8, false),
246        ]));
247
248        let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
249
250        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
251
252        let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
253            Vec::with_capacity(arrow_schema.fields().len() + 1);
254        for (i, field) in arrow_schema.fields().iter().enumerate() {
255            let column_expr = Arc::new(Column::new(field.name(), i));
256            projection_exprs.push((column_expr, field.name().clone()));
257        }
258
259        let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
260        projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
261
262        let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
263        let result = Arc::new(projection);
264
265        assert_eq!(result.schema().fields().len(), 3);
266        assert_eq!(result.schema().field(0).name(), "id");
267        assert_eq!(result.schema().field(1).name(), "name");
268        assert_eq!(result.schema().field(2).name(), "_partition");
269    }
270
271    #[test]
272    fn test_partition_expr_evaluate() {
273        let table_schema = Schema::builder()
274            .with_schema_id(0)
275            .with_fields(vec![
276                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
277                NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
278            ])
279            .build()
280            .unwrap();
281
282        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
283            .add_partition_field("id", "id_partition", Transform::Identity)
284            .unwrap()
285            .build()
286            .unwrap();
287
288        let arrow_schema = Arc::new(ArrowSchema::new(vec![
289            Field::new("id", DataType::Int32, false),
290            Field::new("data", DataType::Utf8, false),
291        ]));
292
293        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
294            Arc::new(Int32Array::from(vec![10, 20, 30])),
295            Arc::new(datafusion::arrow::array::StringArray::from(vec![
296                "a", "b", "c",
297            ])),
298        ])
299        .unwrap();
300
301        let partition_spec = Arc::new(partition_spec);
302        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
303        let partition_type = calculator.partition_arrow_type().clone();
304        let expr = PartitionExpr::new(calculator, partition_spec);
305
306        assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
307        assert!(!expr.nullable(&arrow_schema).unwrap());
308
309        let result = expr.evaluate(&batch).unwrap();
310        match result {
311            ColumnarValue::Array(array) => {
312                let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
313                let id_partition = struct_array
314                    .column_by_name("id_partition")
315                    .unwrap()
316                    .as_any()
317                    .downcast_ref::<Int32Array>()
318                    .unwrap();
319                assert_eq!(id_partition.value(0), 10);
320                assert_eq!(id_partition.value(1), 20);
321                assert_eq!(id_partition.value(2), 30);
322            }
323            _ => panic!("Expected array result"),
324        }
325    }
326
327    #[test]
328    fn test_nested_partition() {
329        let address_struct = StructType::new(vec![
330            NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
331            NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
332        ]);
333
334        let table_schema = Schema::builder()
335            .with_schema_id(0)
336            .with_fields(vec![
337                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
338                NestedField::required(2, "address", Type::Struct(address_struct)).into(),
339            ])
340            .build()
341            .unwrap();
342
343        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
344            .add_partition_field("address.city", "city_partition", Transform::Identity)
345            .unwrap()
346            .build()
347            .unwrap();
348
349        let struct_fields = Fields::from(vec![
350            Field::new("street", DataType::Utf8, false),
351            Field::new("city", DataType::Utf8, false),
352        ]);
353
354        let arrow_schema = Arc::new(ArrowSchema::new(vec![
355            Field::new("id", DataType::Int32, false),
356            Field::new("address", DataType::Struct(struct_fields), false),
357        ]));
358
359        let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
360            "123 Main St",
361            "456 Oak Ave",
362        ]));
363        let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
364            "New York",
365            "Los Angeles",
366        ]));
367
368        let struct_array = StructArray::from(vec![
369            (
370                Arc::new(Field::new("street", DataType::Utf8, false)),
371                street_array as ArrayRef,
372            ),
373            (
374                Arc::new(Field::new("city", DataType::Utf8, false)),
375                city_array as ArrayRef,
376            ),
377        ]);
378
379        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
380            Arc::new(Int32Array::from(vec![1, 2])),
381            Arc::new(struct_array),
382        ])
383        .unwrap();
384
385        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
386        let array = calculator.calculate(&batch).unwrap();
387
388        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
389        let city_partition = struct_array
390            .column_by_name("city_partition")
391            .unwrap()
392            .as_any()
393            .downcast_ref::<datafusion::arrow::array::StringArray>()
394            .unwrap();
395
396        assert_eq!(city_partition.value(0), "New York");
397        assert_eq!(city_partition.value(1), "Los Angeles");
398    }
399
400    #[test]
401    fn test_schema_validation_matching_schemas() {
402        use iceberg::TableIdent;
403        use iceberg::io::FileIO;
404        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
405
406        let table_schema = Arc::new(
407            Schema::builder()
408                .with_fields(vec![
409                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
410                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
411                ])
412                .build()
413                .unwrap(),
414        );
415
416        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
417            .add_partition_field("id", "id_partition", Transform::Identity)
418            .unwrap()
419            .build()
420            .unwrap();
421
422        let sort_order = iceberg::spec::SortOrder::builder()
423            .build(&table_schema)
424            .unwrap();
425
426        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
427            (*table_schema).clone(),
428            partition_spec,
429            sort_order,
430            "/test/table".to_string(),
431            FormatVersion::V2,
432            std::collections::HashMap::new(),
433        )
434        .unwrap();
435
436        let table_metadata = table_metadata_builder.build().unwrap();
437
438        // Create Arrow schema matching the table schema
439        let arrow_schema = Arc::new(ArrowSchema::new(vec![
440            Field::new("id", DataType::Int32, false),
441            Field::new("name", DataType::Utf8, false),
442        ]));
443
444        let input = Arc::new(EmptyExec::new(arrow_schema));
445
446        let table = iceberg::table::Table::builder()
447            .metadata(table_metadata.metadata)
448            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
449            .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
450            .metadata_location("/test/metadata.json".to_string())
451            .build()
452            .unwrap();
453
454        let result = project_with_partition(input, &table);
455        assert!(result.is_ok(), "Schema validation should pass");
456    }
457
458    #[test]
459    fn test_schema_validation_mismatched_schemas() {
460        use iceberg::TableIdent;
461        use iceberg::io::FileIO;
462        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
463
464        let table_schema = Arc::new(
465            Schema::builder()
466                .with_fields(vec![
467                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
468                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
469                ])
470                .build()
471                .unwrap(),
472        );
473
474        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
475            .add_partition_field("id", "id_partition", Transform::Identity)
476            .unwrap()
477            .build()
478            .unwrap();
479
480        let sort_order = iceberg::spec::SortOrder::builder()
481            .build(&table_schema)
482            .unwrap();
483
484        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
485            (*table_schema).clone(),
486            partition_spec,
487            sort_order,
488            "/test/table".to_string(),
489            FormatVersion::V2,
490            std::collections::HashMap::new(),
491        )
492        .unwrap();
493
494        let table_metadata = table_metadata_builder.build().unwrap();
495
496        // Create Arrow schema with different field name (mismatched)
497        let arrow_schema = Arc::new(ArrowSchema::new(vec![
498            Field::new("id", DataType::Int32, false),
499            Field::new("different_name", DataType::Utf8, false), // Wrong field name
500        ]));
501
502        let input = Arc::new(EmptyExec::new(arrow_schema));
503
504        let table = iceberg::table::Table::builder()
505            .metadata(table_metadata.metadata)
506            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
507            .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
508            .metadata_location("/test/metadata.json".to_string())
509            .build()
510            .unwrap();
511
512        let result = project_with_partition(input, &table);
513        assert!(
514            result.is_err(),
515            "Schema validation should fail for mismatched schemas"
516        );
517        assert!(
518            result
519                .unwrap_err()
520                .to_string()
521                .contains("Input schema does not match Iceberg table schema")
522        );
523    }
524
525    #[test]
526    fn test_schema_validation_with_metadata_differences() {
527        use std::collections::HashMap;
528
529        use iceberg::TableIdent;
530        use iceberg::io::FileIO;
531        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
532
533        let table_schema = Arc::new(
534            Schema::builder()
535                .with_fields(vec![
536                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
537                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
538                ])
539                .build()
540                .unwrap(),
541        );
542
543        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
544            .add_partition_field("id", "id_partition", Transform::Identity)
545            .unwrap()
546            .build()
547            .unwrap();
548
549        let sort_order = iceberg::spec::SortOrder::builder()
550            .build(&table_schema)
551            .unwrap();
552
553        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
554            (*table_schema).clone(),
555            partition_spec,
556            sort_order,
557            "/test/table".to_string(),
558            FormatVersion::V2,
559            std::collections::HashMap::new(),
560        )
561        .unwrap();
562
563        let table_metadata = table_metadata_builder.build().unwrap();
564
565        // Create Arrow schema with metadata (should be ignored in comparison)
566        let mut metadata = HashMap::new();
567        metadata.insert("extra".to_string(), "metadata".to_string());
568
569        let arrow_schema = Arc::new(ArrowSchema::new(vec![
570            Field::new("id", DataType::Int32, false).with_metadata(metadata.clone()),
571            Field::new("name", DataType::Utf8, false).with_metadata(metadata),
572        ]));
573
574        let input = Arc::new(EmptyExec::new(arrow_schema));
575
576        let table = iceberg::table::Table::builder()
577            .metadata(table_metadata.metadata)
578            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
579            .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
580            .metadata_location("/test/metadata.json".to_string())
581            .build()
582            .unwrap();
583
584        let result = project_with_partition(input, &table);
585        assert!(
586            result.is_ok(),
587            "Schema validation should pass even with metadata differences"
588        );
589    }
590}