iceberg_datafusion/physical_plan/
project.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partition value projection for Iceberg tables.
19
20use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::{DataFusionError, Result as DFResult};
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{
30    PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator, schema_to_arrow_schema,
31    strip_metadata_from_schema,
32};
33use iceberg::spec::PartitionSpec;
34use iceberg::table::Table;
35
36use crate::to_datafusion_error;
37
38/// Extends an ExecutionPlan with partition value calculations for Iceberg tables.
39///
40/// This function takes an input ExecutionPlan and extends it with an additional column
41/// containing calculated partition values based on the table's partition specification.
42/// For unpartitioned tables, returns the original plan unchanged.
43///
44/// # Arguments
45/// * `input` - The input ExecutionPlan to extend
46/// * `table` - The Iceberg table with partition specification
47///
48/// # Returns
49/// * `Ok(Arc<dyn ExecutionPlan>)` - Extended plan with partition values column
50/// * `Err` - If partition spec is not found or transformation fails
51pub fn project_with_partition(
52    input: Arc<dyn ExecutionPlan>,
53    table: &Table,
54) -> DFResult<Arc<dyn ExecutionPlan>> {
55    let metadata = table.metadata();
56    let partition_spec = metadata.default_partition_spec();
57    let table_schema = metadata.current_schema();
58
59    if partition_spec.is_unpartitioned() {
60        return Ok(input);
61    }
62
63    let input_schema = input.schema();
64
65    // Validate that input_schema matches the Iceberg table schema
66    // Strip metadata from both schemas before comparison to ignore metadata differences
67    let expected_arrow_schema =
68        schema_to_arrow_schema(table_schema.as_ref()).map_err(to_datafusion_error)?;
69    let input_schema_cleaned =
70        strip_metadata_from_schema(&input_schema).map_err(to_datafusion_error)?;
71    let expected_schema_cleaned =
72        strip_metadata_from_schema(&expected_arrow_schema).map_err(to_datafusion_error)?;
73
74    if input_schema_cleaned != expected_schema_cleaned {
75        return Err(DataFusionError::Plan(format!(
76            "Input schema does not match Iceberg table schema.\n\
77             Expected schema: {expected_schema_cleaned}\n\
78             Input schema: {input_schema_cleaned}"
79        )));
80    }
81
82    let calculator =
83        PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
84            .map_err(to_datafusion_error)?;
85
86    let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
87        Vec::with_capacity(input_schema.fields().len() + 1);
88
89    for (index, field) in input_schema.fields().iter().enumerate() {
90        let column_expr = Arc::new(Column::new(field.name(), index));
91        projection_exprs.push((column_expr, field.name().clone()));
92    }
93
94    let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
95    projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
96
97    let projection = ProjectionExec::try_new(projection_exprs, input)?;
98    Ok(Arc::new(projection))
99}
100
101/// PhysicalExpr implementation for partition value calculation
102#[derive(Debug, Clone)]
103struct PartitionExpr {
104    calculator: Arc<PartitionValueCalculator>,
105    partition_spec: Arc<PartitionSpec>,
106}
107
108impl PartitionExpr {
109    fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
110        Self {
111            calculator: Arc::new(calculator),
112            partition_spec,
113        }
114    }
115}
116
117// Manual PartialEq/Eq implementations for pointer-based equality
118// (two PartitionExpr are equal if they share the same calculator and partition_spec instances)
119impl PartialEq for PartitionExpr {
120    fn eq(&self, other: &Self) -> bool {
121        Arc::ptr_eq(&self.calculator, &other.calculator)
122            && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
123    }
124}
125
126impl Eq for PartitionExpr {}
127
128impl PhysicalExpr for PartitionExpr {
129    fn as_any(&self) -> &dyn std::any::Any {
130        self
131    }
132
133    fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
134        Ok(self.calculator.partition_arrow_type().clone())
135    }
136
137    fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
138        Ok(false)
139    }
140
141    fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
142        let array = self
143            .calculator
144            .calculate(batch)
145            .map_err(to_datafusion_error)?;
146        Ok(ColumnarValue::Array(array))
147    }
148
149    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
150        vec![]
151    }
152
153    fn with_new_children(
154        self: Arc<Self>,
155        _children: Vec<Arc<dyn PhysicalExpr>>,
156    ) -> DFResult<Arc<dyn PhysicalExpr>> {
157        Ok(self)
158    }
159
160    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        let field_names: Vec<String> = self
162            .partition_spec
163            .fields()
164            .iter()
165            .map(|pf| format!("{}({})", pf.transform, pf.name))
166            .collect();
167        write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
168    }
169}
170
171impl std::fmt::Display for PartitionExpr {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        let field_names: Vec<&str> = self
174            .partition_spec
175            .fields()
176            .iter()
177            .map(|pf| pf.name.as_str())
178            .collect();
179        write!(f, "iceberg_partition_values({})", field_names.join(", "))
180    }
181}
182
183impl std::hash::Hash for PartitionExpr {
184    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185        // Two PartitionExpr are equal if they share the same calculator and partition_spec Arcs
186        Arc::as_ptr(&self.calculator).hash(state);
187        Arc::as_ptr(&self.partition_spec).hash(state);
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
194    use datafusion::arrow::datatypes::{DataType, Field, Fields};
195    use datafusion::physical_plan::empty::EmptyExec;
196    use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
197    use iceberg::test_utils::test_runtime;
198
199    use super::*;
200
201    #[test]
202    fn test_partition_calculator_basic() {
203        let table_schema = Schema::builder()
204            .with_schema_id(0)
205            .with_fields(vec![
206                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
207                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
208            ])
209            .build()
210            .unwrap();
211
212        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
213            .add_partition_field("id", "id_partition", Transform::Identity)
214            .unwrap()
215            .build()
216            .unwrap();
217
218        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
219
220        // Verify partition type
221        assert_eq!(calculator.partition_type().fields().len(), 1);
222        assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
223    }
224
225    #[test]
226    fn test_partition_expr_with_projection() {
227        let table_schema = Schema::builder()
228            .with_schema_id(0)
229            .with_fields(vec![
230                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
231                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
232            ])
233            .build()
234            .unwrap();
235
236        let partition_spec = Arc::new(
237            iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
238                .add_partition_field("id", "id_partition", Transform::Identity)
239                .unwrap()
240                .build()
241                .unwrap(),
242        );
243
244        let arrow_schema = Arc::new(ArrowSchema::new(vec![
245            Field::new("id", DataType::Int32, false),
246            Field::new("name", DataType::Utf8, false),
247        ]));
248
249        let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
250
251        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
252
253        let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
254            Vec::with_capacity(arrow_schema.fields().len() + 1);
255        for (i, field) in arrow_schema.fields().iter().enumerate() {
256            let column_expr = Arc::new(Column::new(field.name(), i));
257            projection_exprs.push((column_expr, field.name().clone()));
258        }
259
260        let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
261        projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
262
263        let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
264        let result = Arc::new(projection);
265
266        assert_eq!(result.schema().fields().len(), 3);
267        assert_eq!(result.schema().field(0).name(), "id");
268        assert_eq!(result.schema().field(1).name(), "name");
269        assert_eq!(result.schema().field(2).name(), "_partition");
270    }
271
272    #[test]
273    fn test_partition_expr_evaluate() {
274        let table_schema = Schema::builder()
275            .with_schema_id(0)
276            .with_fields(vec![
277                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
278                NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
279            ])
280            .build()
281            .unwrap();
282
283        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
284            .add_partition_field("id", "id_partition", Transform::Identity)
285            .unwrap()
286            .build()
287            .unwrap();
288
289        let arrow_schema = Arc::new(ArrowSchema::new(vec![
290            Field::new("id", DataType::Int32, false),
291            Field::new("data", DataType::Utf8, false),
292        ]));
293
294        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
295            Arc::new(Int32Array::from(vec![10, 20, 30])),
296            Arc::new(datafusion::arrow::array::StringArray::from(vec![
297                "a", "b", "c",
298            ])),
299        ])
300        .unwrap();
301
302        let partition_spec = Arc::new(partition_spec);
303        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
304        let partition_type = calculator.partition_arrow_type().clone();
305        let expr = PartitionExpr::new(calculator, partition_spec);
306
307        assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
308        assert!(!expr.nullable(&arrow_schema).unwrap());
309
310        let result = expr.evaluate(&batch).unwrap();
311        match result {
312            ColumnarValue::Array(array) => {
313                let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
314                let id_partition = struct_array
315                    .column_by_name("id_partition")
316                    .unwrap()
317                    .as_any()
318                    .downcast_ref::<Int32Array>()
319                    .unwrap();
320                assert_eq!(id_partition.value(0), 10);
321                assert_eq!(id_partition.value(1), 20);
322                assert_eq!(id_partition.value(2), 30);
323            }
324            _ => panic!("Expected array result"),
325        }
326    }
327
328    #[test]
329    fn test_nested_partition() {
330        let address_struct = StructType::new(vec![
331            NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
332            NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
333        ]);
334
335        let table_schema = Schema::builder()
336            .with_schema_id(0)
337            .with_fields(vec![
338                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
339                NestedField::required(2, "address", Type::Struct(address_struct)).into(),
340            ])
341            .build()
342            .unwrap();
343
344        let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
345            .add_partition_field("address.city", "city_partition", Transform::Identity)
346            .unwrap()
347            .build()
348            .unwrap();
349
350        let struct_fields = Fields::from(vec![
351            Field::new("street", DataType::Utf8, false),
352            Field::new("city", DataType::Utf8, false),
353        ]);
354
355        let arrow_schema = Arc::new(ArrowSchema::new(vec![
356            Field::new("id", DataType::Int32, false),
357            Field::new("address", DataType::Struct(struct_fields), false),
358        ]));
359
360        let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
361            "123 Main St",
362            "456 Oak Ave",
363        ]));
364        let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
365            "New York",
366            "Los Angeles",
367        ]));
368
369        let struct_array = StructArray::from(vec![
370            (
371                Arc::new(Field::new("street", DataType::Utf8, false)),
372                street_array as ArrayRef,
373            ),
374            (
375                Arc::new(Field::new("city", DataType::Utf8, false)),
376                city_array as ArrayRef,
377            ),
378        ]);
379
380        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
381            Arc::new(Int32Array::from(vec![1, 2])),
382            Arc::new(struct_array),
383        ])
384        .unwrap();
385
386        let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
387        let array = calculator.calculate(&batch).unwrap();
388
389        let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
390        let city_partition = struct_array
391            .column_by_name("city_partition")
392            .unwrap()
393            .as_any()
394            .downcast_ref::<datafusion::arrow::array::StringArray>()
395            .unwrap();
396
397        assert_eq!(city_partition.value(0), "New York");
398        assert_eq!(city_partition.value(1), "Los Angeles");
399    }
400
401    #[test]
402    fn test_schema_validation_matching_schemas() {
403        use iceberg::TableIdent;
404        use iceberg::io::FileIO;
405        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
406
407        let table_schema = Arc::new(
408            Schema::builder()
409                .with_fields(vec![
410                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
411                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
412                ])
413                .build()
414                .unwrap(),
415        );
416
417        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
418            .add_partition_field("id", "id_partition", Transform::Identity)
419            .unwrap()
420            .build()
421            .unwrap();
422
423        let sort_order = iceberg::spec::SortOrder::builder()
424            .build(&table_schema)
425            .unwrap();
426
427        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
428            (*table_schema).clone(),
429            partition_spec,
430            sort_order,
431            "/test/table".to_string(),
432            FormatVersion::V2,
433            std::collections::HashMap::new(),
434        )
435        .unwrap();
436
437        let table_metadata = table_metadata_builder.build().unwrap();
438
439        // Create Arrow schema matching the table schema
440        let arrow_schema = Arc::new(ArrowSchema::new(vec![
441            Field::new("id", DataType::Int32, false),
442            Field::new("name", DataType::Utf8, false),
443        ]));
444
445        let input = Arc::new(EmptyExec::new(arrow_schema));
446
447        let table = iceberg::table::Table::builder()
448            .metadata(table_metadata.metadata)
449            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
450            .file_io(FileIO::new_with_fs())
451            .metadata_location("/test/metadata.json")
452            .runtime(test_runtime())
453            .build()
454            .unwrap();
455
456        let result = project_with_partition(input, &table);
457        assert!(result.is_ok(), "Schema validation should pass");
458    }
459
460    #[test]
461    fn test_schema_validation_mismatched_schemas() {
462        use iceberg::TableIdent;
463        use iceberg::io::FileIO;
464        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
465
466        let table_schema = Arc::new(
467            Schema::builder()
468                .with_fields(vec![
469                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
470                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
471                ])
472                .build()
473                .unwrap(),
474        );
475
476        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
477            .add_partition_field("id", "id_partition", Transform::Identity)
478            .unwrap()
479            .build()
480            .unwrap();
481
482        let sort_order = iceberg::spec::SortOrder::builder()
483            .build(&table_schema)
484            .unwrap();
485
486        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
487            (*table_schema).clone(),
488            partition_spec,
489            sort_order,
490            "/test/table".to_string(),
491            FormatVersion::V2,
492            std::collections::HashMap::new(),
493        )
494        .unwrap();
495
496        let table_metadata = table_metadata_builder.build().unwrap();
497
498        // Create Arrow schema with different field name (mismatched)
499        let arrow_schema = Arc::new(ArrowSchema::new(vec![
500            Field::new("id", DataType::Int32, false),
501            Field::new("different_name", DataType::Utf8, false), // Wrong field name
502        ]));
503
504        let input = Arc::new(EmptyExec::new(arrow_schema));
505
506        let table = iceberg::table::Table::builder()
507            .metadata(table_metadata.metadata)
508            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
509            .file_io(FileIO::new_with_fs())
510            .metadata_location("/test/metadata.json")
511            .runtime(test_runtime())
512            .build()
513            .unwrap();
514
515        let result = project_with_partition(input, &table);
516        assert!(
517            result.is_err(),
518            "Schema validation should fail for mismatched schemas"
519        );
520        assert!(
521            result
522                .unwrap_err()
523                .to_string()
524                .contains("Input schema does not match Iceberg table schema")
525        );
526    }
527
528    #[test]
529    fn test_schema_validation_with_metadata_differences() {
530        use std::collections::HashMap;
531
532        use iceberg::TableIdent;
533        use iceberg::io::FileIO;
534        use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
535
536        let table_schema = Arc::new(
537            Schema::builder()
538                .with_fields(vec![
539                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
540                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
541                ])
542                .build()
543                .unwrap(),
544        );
545
546        let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
547            .add_partition_field("id", "id_partition", Transform::Identity)
548            .unwrap()
549            .build()
550            .unwrap();
551
552        let sort_order = iceberg::spec::SortOrder::builder()
553            .build(&table_schema)
554            .unwrap();
555
556        let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
557            (*table_schema).clone(),
558            partition_spec,
559            sort_order,
560            "/test/table".to_string(),
561            FormatVersion::V2,
562            std::collections::HashMap::new(),
563        )
564        .unwrap();
565
566        let table_metadata = table_metadata_builder.build().unwrap();
567
568        // Create Arrow schema with metadata (should be ignored in comparison)
569        let mut metadata = HashMap::new();
570        metadata.insert("extra".to_string(), "metadata".to_string());
571
572        let arrow_schema = Arc::new(ArrowSchema::new(vec![
573            Field::new("id", DataType::Int32, false).with_metadata(metadata.clone()),
574            Field::new("name", DataType::Utf8, false).with_metadata(metadata),
575        ]));
576
577        let input = Arc::new(EmptyExec::new(arrow_schema));
578
579        let table = iceberg::table::Table::builder()
580            .metadata(table_metadata.metadata)
581            .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
582            .file_io(FileIO::new_with_fs())
583            .metadata_location("/test/metadata.json")
584            .runtime(test_runtime())
585            .build()
586            .unwrap();
587
588        let result = project_with_partition(input, &table);
589        assert!(
590            result.is_ok(),
591            "Schema validation should pass even with metadata differences"
592        );
593    }
594}