1use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::{DataFusionError, Result as DFResult};
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{
30 PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator, schema_to_arrow_schema,
31 strip_metadata_from_schema,
32};
33use iceberg::spec::PartitionSpec;
34use iceberg::table::Table;
35
36use crate::to_datafusion_error;
37
38pub fn project_with_partition(
52 input: Arc<dyn ExecutionPlan>,
53 table: &Table,
54) -> DFResult<Arc<dyn ExecutionPlan>> {
55 let metadata = table.metadata();
56 let partition_spec = metadata.default_partition_spec();
57 let table_schema = metadata.current_schema();
58
59 if partition_spec.is_unpartitioned() {
60 return Ok(input);
61 }
62
63 let input_schema = input.schema();
64
65 let expected_arrow_schema =
68 schema_to_arrow_schema(table_schema.as_ref()).map_err(to_datafusion_error)?;
69 let input_schema_cleaned =
70 strip_metadata_from_schema(&input_schema).map_err(to_datafusion_error)?;
71 let expected_schema_cleaned =
72 strip_metadata_from_schema(&expected_arrow_schema).map_err(to_datafusion_error)?;
73
74 if input_schema_cleaned != expected_schema_cleaned {
75 return Err(DataFusionError::Plan(format!(
76 "Input schema does not match Iceberg table schema.\n\
77 Expected schema: {expected_schema_cleaned}\n\
78 Input schema: {input_schema_cleaned}"
79 )));
80 }
81
82 let calculator =
83 PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
84 .map_err(to_datafusion_error)?;
85
86 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
87 Vec::with_capacity(input_schema.fields().len() + 1);
88
89 for (index, field) in input_schema.fields().iter().enumerate() {
90 let column_expr = Arc::new(Column::new(field.name(), index));
91 projection_exprs.push((column_expr, field.name().clone()));
92 }
93
94 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
95 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
96
97 let projection = ProjectionExec::try_new(projection_exprs, input)?;
98 Ok(Arc::new(projection))
99}
100
101#[derive(Debug, Clone)]
103struct PartitionExpr {
104 calculator: Arc<PartitionValueCalculator>,
105 partition_spec: Arc<PartitionSpec>,
106}
107
108impl PartitionExpr {
109 fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
110 Self {
111 calculator: Arc::new(calculator),
112 partition_spec,
113 }
114 }
115}
116
117impl PartialEq for PartitionExpr {
120 fn eq(&self, other: &Self) -> bool {
121 Arc::ptr_eq(&self.calculator, &other.calculator)
122 && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
123 }
124}
125
126impl Eq for PartitionExpr {}
127
128impl PhysicalExpr for PartitionExpr {
129 fn as_any(&self) -> &dyn std::any::Any {
130 self
131 }
132
133 fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
134 Ok(self.calculator.partition_arrow_type().clone())
135 }
136
137 fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
138 Ok(false)
139 }
140
141 fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
142 let array = self
143 .calculator
144 .calculate(batch)
145 .map_err(to_datafusion_error)?;
146 Ok(ColumnarValue::Array(array))
147 }
148
149 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
150 vec![]
151 }
152
153 fn with_new_children(
154 self: Arc<Self>,
155 _children: Vec<Arc<dyn PhysicalExpr>>,
156 ) -> DFResult<Arc<dyn PhysicalExpr>> {
157 Ok(self)
158 }
159
160 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 let field_names: Vec<String> = self
162 .partition_spec
163 .fields()
164 .iter()
165 .map(|pf| format!("{}({})", pf.transform, pf.name))
166 .collect();
167 write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
168 }
169}
170
171impl std::fmt::Display for PartitionExpr {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 let field_names: Vec<&str> = self
174 .partition_spec
175 .fields()
176 .iter()
177 .map(|pf| pf.name.as_str())
178 .collect();
179 write!(f, "iceberg_partition_values({})", field_names.join(", "))
180 }
181}
182
183impl std::hash::Hash for PartitionExpr {
184 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185 Arc::as_ptr(&self.calculator).hash(state);
187 Arc::as_ptr(&self.partition_spec).hash(state);
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
194 use datafusion::arrow::datatypes::{DataType, Field, Fields};
195 use datafusion::physical_plan::empty::EmptyExec;
196 use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
197
198 use super::*;
199
200 #[test]
201 fn test_partition_calculator_basic() {
202 let table_schema = Schema::builder()
203 .with_schema_id(0)
204 .with_fields(vec![
205 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
206 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
207 ])
208 .build()
209 .unwrap();
210
211 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
212 .add_partition_field("id", "id_partition", Transform::Identity)
213 .unwrap()
214 .build()
215 .unwrap();
216
217 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
218
219 assert_eq!(calculator.partition_type().fields().len(), 1);
221 assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
222 }
223
224 #[test]
225 fn test_partition_expr_with_projection() {
226 let table_schema = Schema::builder()
227 .with_schema_id(0)
228 .with_fields(vec![
229 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
230 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
231 ])
232 .build()
233 .unwrap();
234
235 let partition_spec = Arc::new(
236 iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
237 .add_partition_field("id", "id_partition", Transform::Identity)
238 .unwrap()
239 .build()
240 .unwrap(),
241 );
242
243 let arrow_schema = Arc::new(ArrowSchema::new(vec![
244 Field::new("id", DataType::Int32, false),
245 Field::new("name", DataType::Utf8, false),
246 ]));
247
248 let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
249
250 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
251
252 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
253 Vec::with_capacity(arrow_schema.fields().len() + 1);
254 for (i, field) in arrow_schema.fields().iter().enumerate() {
255 let column_expr = Arc::new(Column::new(field.name(), i));
256 projection_exprs.push((column_expr, field.name().clone()));
257 }
258
259 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
260 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
261
262 let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
263 let result = Arc::new(projection);
264
265 assert_eq!(result.schema().fields().len(), 3);
266 assert_eq!(result.schema().field(0).name(), "id");
267 assert_eq!(result.schema().field(1).name(), "name");
268 assert_eq!(result.schema().field(2).name(), "_partition");
269 }
270
271 #[test]
272 fn test_partition_expr_evaluate() {
273 let table_schema = Schema::builder()
274 .with_schema_id(0)
275 .with_fields(vec![
276 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
277 NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
278 ])
279 .build()
280 .unwrap();
281
282 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
283 .add_partition_field("id", "id_partition", Transform::Identity)
284 .unwrap()
285 .build()
286 .unwrap();
287
288 let arrow_schema = Arc::new(ArrowSchema::new(vec![
289 Field::new("id", DataType::Int32, false),
290 Field::new("data", DataType::Utf8, false),
291 ]));
292
293 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
294 Arc::new(Int32Array::from(vec![10, 20, 30])),
295 Arc::new(datafusion::arrow::array::StringArray::from(vec![
296 "a", "b", "c",
297 ])),
298 ])
299 .unwrap();
300
301 let partition_spec = Arc::new(partition_spec);
302 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
303 let partition_type = calculator.partition_arrow_type().clone();
304 let expr = PartitionExpr::new(calculator, partition_spec);
305
306 assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
307 assert!(!expr.nullable(&arrow_schema).unwrap());
308
309 let result = expr.evaluate(&batch).unwrap();
310 match result {
311 ColumnarValue::Array(array) => {
312 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
313 let id_partition = struct_array
314 .column_by_name("id_partition")
315 .unwrap()
316 .as_any()
317 .downcast_ref::<Int32Array>()
318 .unwrap();
319 assert_eq!(id_partition.value(0), 10);
320 assert_eq!(id_partition.value(1), 20);
321 assert_eq!(id_partition.value(2), 30);
322 }
323 _ => panic!("Expected array result"),
324 }
325 }
326
327 #[test]
328 fn test_nested_partition() {
329 let address_struct = StructType::new(vec![
330 NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
331 NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
332 ]);
333
334 let table_schema = Schema::builder()
335 .with_schema_id(0)
336 .with_fields(vec![
337 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
338 NestedField::required(2, "address", Type::Struct(address_struct)).into(),
339 ])
340 .build()
341 .unwrap();
342
343 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
344 .add_partition_field("address.city", "city_partition", Transform::Identity)
345 .unwrap()
346 .build()
347 .unwrap();
348
349 let struct_fields = Fields::from(vec![
350 Field::new("street", DataType::Utf8, false),
351 Field::new("city", DataType::Utf8, false),
352 ]);
353
354 let arrow_schema = Arc::new(ArrowSchema::new(vec![
355 Field::new("id", DataType::Int32, false),
356 Field::new("address", DataType::Struct(struct_fields), false),
357 ]));
358
359 let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
360 "123 Main St",
361 "456 Oak Ave",
362 ]));
363 let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
364 "New York",
365 "Los Angeles",
366 ]));
367
368 let struct_array = StructArray::from(vec![
369 (
370 Arc::new(Field::new("street", DataType::Utf8, false)),
371 street_array as ArrayRef,
372 ),
373 (
374 Arc::new(Field::new("city", DataType::Utf8, false)),
375 city_array as ArrayRef,
376 ),
377 ]);
378
379 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
380 Arc::new(Int32Array::from(vec![1, 2])),
381 Arc::new(struct_array),
382 ])
383 .unwrap();
384
385 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
386 let array = calculator.calculate(&batch).unwrap();
387
388 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
389 let city_partition = struct_array
390 .column_by_name("city_partition")
391 .unwrap()
392 .as_any()
393 .downcast_ref::<datafusion::arrow::array::StringArray>()
394 .unwrap();
395
396 assert_eq!(city_partition.value(0), "New York");
397 assert_eq!(city_partition.value(1), "Los Angeles");
398 }
399
400 #[test]
401 fn test_schema_validation_matching_schemas() {
402 use iceberg::TableIdent;
403 use iceberg::io::FileIO;
404 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
405
406 let table_schema = Arc::new(
407 Schema::builder()
408 .with_fields(vec![
409 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
410 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
411 ])
412 .build()
413 .unwrap(),
414 );
415
416 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
417 .add_partition_field("id", "id_partition", Transform::Identity)
418 .unwrap()
419 .build()
420 .unwrap();
421
422 let sort_order = iceberg::spec::SortOrder::builder()
423 .build(&table_schema)
424 .unwrap();
425
426 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
427 (*table_schema).clone(),
428 partition_spec,
429 sort_order,
430 "/test/table".to_string(),
431 FormatVersion::V2,
432 std::collections::HashMap::new(),
433 )
434 .unwrap();
435
436 let table_metadata = table_metadata_builder.build().unwrap();
437
438 let arrow_schema = Arc::new(ArrowSchema::new(vec![
440 Field::new("id", DataType::Int32, false),
441 Field::new("name", DataType::Utf8, false),
442 ]));
443
444 let input = Arc::new(EmptyExec::new(arrow_schema));
445
446 let table = iceberg::table::Table::builder()
447 .metadata(table_metadata.metadata)
448 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
449 .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
450 .metadata_location("/test/metadata.json".to_string())
451 .build()
452 .unwrap();
453
454 let result = project_with_partition(input, &table);
455 assert!(result.is_ok(), "Schema validation should pass");
456 }
457
458 #[test]
459 fn test_schema_validation_mismatched_schemas() {
460 use iceberg::TableIdent;
461 use iceberg::io::FileIO;
462 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
463
464 let table_schema = Arc::new(
465 Schema::builder()
466 .with_fields(vec![
467 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
468 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
469 ])
470 .build()
471 .unwrap(),
472 );
473
474 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
475 .add_partition_field("id", "id_partition", Transform::Identity)
476 .unwrap()
477 .build()
478 .unwrap();
479
480 let sort_order = iceberg::spec::SortOrder::builder()
481 .build(&table_schema)
482 .unwrap();
483
484 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
485 (*table_schema).clone(),
486 partition_spec,
487 sort_order,
488 "/test/table".to_string(),
489 FormatVersion::V2,
490 std::collections::HashMap::new(),
491 )
492 .unwrap();
493
494 let table_metadata = table_metadata_builder.build().unwrap();
495
496 let arrow_schema = Arc::new(ArrowSchema::new(vec![
498 Field::new("id", DataType::Int32, false),
499 Field::new("different_name", DataType::Utf8, false), ]));
501
502 let input = Arc::new(EmptyExec::new(arrow_schema));
503
504 let table = iceberg::table::Table::builder()
505 .metadata(table_metadata.metadata)
506 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
507 .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
508 .metadata_location("/test/metadata.json".to_string())
509 .build()
510 .unwrap();
511
512 let result = project_with_partition(input, &table);
513 assert!(
514 result.is_err(),
515 "Schema validation should fail for mismatched schemas"
516 );
517 assert!(
518 result
519 .unwrap_err()
520 .to_string()
521 .contains("Input schema does not match Iceberg table schema")
522 );
523 }
524
525 #[test]
526 fn test_schema_validation_with_metadata_differences() {
527 use std::collections::HashMap;
528
529 use iceberg::TableIdent;
530 use iceberg::io::FileIO;
531 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
532
533 let table_schema = Arc::new(
534 Schema::builder()
535 .with_fields(vec![
536 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
537 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
538 ])
539 .build()
540 .unwrap(),
541 );
542
543 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
544 .add_partition_field("id", "id_partition", Transform::Identity)
545 .unwrap()
546 .build()
547 .unwrap();
548
549 let sort_order = iceberg::spec::SortOrder::builder()
550 .build(&table_schema)
551 .unwrap();
552
553 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
554 (*table_schema).clone(),
555 partition_spec,
556 sort_order,
557 "/test/table".to_string(),
558 FormatVersion::V2,
559 std::collections::HashMap::new(),
560 )
561 .unwrap();
562
563 let table_metadata = table_metadata_builder.build().unwrap();
564
565 let mut metadata = HashMap::new();
567 metadata.insert("extra".to_string(), "metadata".to_string());
568
569 let arrow_schema = Arc::new(ArrowSchema::new(vec![
570 Field::new("id", DataType::Int32, false).with_metadata(metadata.clone()),
571 Field::new("name", DataType::Utf8, false).with_metadata(metadata),
572 ]));
573
574 let input = Arc::new(EmptyExec::new(arrow_schema));
575
576 let table = iceberg::table::Table::builder()
577 .metadata(table_metadata.metadata)
578 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
579 .file_io(FileIO::from_path("/tmp").unwrap().build().unwrap())
580 .metadata_location("/test/metadata.json".to_string())
581 .build()
582 .unwrap();
583
584 let result = project_with_partition(input, &table);
585 assert!(
586 result.is_ok(),
587 "Schema validation should pass even with metadata differences"
588 );
589 }
590}