1use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::{DataFusionError, Result as DFResult};
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{
30 PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator, schema_to_arrow_schema,
31 strip_metadata_from_schema,
32};
33use iceberg::spec::PartitionSpec;
34use iceberg::table::Table;
35
36use crate::to_datafusion_error;
37
38pub fn project_with_partition(
52 input: Arc<dyn ExecutionPlan>,
53 table: &Table,
54) -> DFResult<Arc<dyn ExecutionPlan>> {
55 let metadata = table.metadata();
56 let partition_spec = metadata.default_partition_spec();
57 let table_schema = metadata.current_schema();
58
59 if partition_spec.is_unpartitioned() {
60 return Ok(input);
61 }
62
63 let input_schema = input.schema();
64
65 let expected_arrow_schema =
68 schema_to_arrow_schema(table_schema.as_ref()).map_err(to_datafusion_error)?;
69 let input_schema_cleaned =
70 strip_metadata_from_schema(&input_schema).map_err(to_datafusion_error)?;
71 let expected_schema_cleaned =
72 strip_metadata_from_schema(&expected_arrow_schema).map_err(to_datafusion_error)?;
73
74 if input_schema_cleaned != expected_schema_cleaned {
75 return Err(DataFusionError::Plan(format!(
76 "Input schema does not match Iceberg table schema.\n\
77 Expected schema: {expected_schema_cleaned}\n\
78 Input schema: {input_schema_cleaned}"
79 )));
80 }
81
82 let calculator =
83 PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
84 .map_err(to_datafusion_error)?;
85
86 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
87 Vec::with_capacity(input_schema.fields().len() + 1);
88
89 for (index, field) in input_schema.fields().iter().enumerate() {
90 let column_expr = Arc::new(Column::new(field.name(), index));
91 projection_exprs.push((column_expr, field.name().clone()));
92 }
93
94 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
95 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
96
97 let projection = ProjectionExec::try_new(projection_exprs, input)?;
98 Ok(Arc::new(projection))
99}
100
101#[derive(Debug, Clone)]
103struct PartitionExpr {
104 calculator: Arc<PartitionValueCalculator>,
105 partition_spec: Arc<PartitionSpec>,
106}
107
108impl PartitionExpr {
109 fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
110 Self {
111 calculator: Arc::new(calculator),
112 partition_spec,
113 }
114 }
115}
116
117impl PartialEq for PartitionExpr {
120 fn eq(&self, other: &Self) -> bool {
121 Arc::ptr_eq(&self.calculator, &other.calculator)
122 && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
123 }
124}
125
126impl Eq for PartitionExpr {}
127
128impl PhysicalExpr for PartitionExpr {
129 fn as_any(&self) -> &dyn std::any::Any {
130 self
131 }
132
133 fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
134 Ok(self.calculator.partition_arrow_type().clone())
135 }
136
137 fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
138 Ok(false)
139 }
140
141 fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
142 let array = self
143 .calculator
144 .calculate(batch)
145 .map_err(to_datafusion_error)?;
146 Ok(ColumnarValue::Array(array))
147 }
148
149 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
150 vec![]
151 }
152
153 fn with_new_children(
154 self: Arc<Self>,
155 _children: Vec<Arc<dyn PhysicalExpr>>,
156 ) -> DFResult<Arc<dyn PhysicalExpr>> {
157 Ok(self)
158 }
159
160 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 let field_names: Vec<String> = self
162 .partition_spec
163 .fields()
164 .iter()
165 .map(|pf| format!("{}({})", pf.transform, pf.name))
166 .collect();
167 write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
168 }
169}
170
171impl std::fmt::Display for PartitionExpr {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 let field_names: Vec<&str> = self
174 .partition_spec
175 .fields()
176 .iter()
177 .map(|pf| pf.name.as_str())
178 .collect();
179 write!(f, "iceberg_partition_values({})", field_names.join(", "))
180 }
181}
182
183impl std::hash::Hash for PartitionExpr {
184 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185 Arc::as_ptr(&self.calculator).hash(state);
187 Arc::as_ptr(&self.partition_spec).hash(state);
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
194 use datafusion::arrow::datatypes::{DataType, Field, Fields};
195 use datafusion::physical_plan::empty::EmptyExec;
196 use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
197 use iceberg::test_utils::test_runtime;
198
199 use super::*;
200
201 #[test]
202 fn test_partition_calculator_basic() {
203 let table_schema = Schema::builder()
204 .with_schema_id(0)
205 .with_fields(vec![
206 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
207 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
208 ])
209 .build()
210 .unwrap();
211
212 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
213 .add_partition_field("id", "id_partition", Transform::Identity)
214 .unwrap()
215 .build()
216 .unwrap();
217
218 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
219
220 assert_eq!(calculator.partition_type().fields().len(), 1);
222 assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
223 }
224
225 #[test]
226 fn test_partition_expr_with_projection() {
227 let table_schema = Schema::builder()
228 .with_schema_id(0)
229 .with_fields(vec![
230 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
231 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
232 ])
233 .build()
234 .unwrap();
235
236 let partition_spec = Arc::new(
237 iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
238 .add_partition_field("id", "id_partition", Transform::Identity)
239 .unwrap()
240 .build()
241 .unwrap(),
242 );
243
244 let arrow_schema = Arc::new(ArrowSchema::new(vec![
245 Field::new("id", DataType::Int32, false),
246 Field::new("name", DataType::Utf8, false),
247 ]));
248
249 let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
250
251 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
252
253 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
254 Vec::with_capacity(arrow_schema.fields().len() + 1);
255 for (i, field) in arrow_schema.fields().iter().enumerate() {
256 let column_expr = Arc::new(Column::new(field.name(), i));
257 projection_exprs.push((column_expr, field.name().clone()));
258 }
259
260 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
261 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
262
263 let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
264 let result = Arc::new(projection);
265
266 assert_eq!(result.schema().fields().len(), 3);
267 assert_eq!(result.schema().field(0).name(), "id");
268 assert_eq!(result.schema().field(1).name(), "name");
269 assert_eq!(result.schema().field(2).name(), "_partition");
270 }
271
272 #[test]
273 fn test_partition_expr_evaluate() {
274 let table_schema = Schema::builder()
275 .with_schema_id(0)
276 .with_fields(vec![
277 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
278 NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
279 ])
280 .build()
281 .unwrap();
282
283 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
284 .add_partition_field("id", "id_partition", Transform::Identity)
285 .unwrap()
286 .build()
287 .unwrap();
288
289 let arrow_schema = Arc::new(ArrowSchema::new(vec![
290 Field::new("id", DataType::Int32, false),
291 Field::new("data", DataType::Utf8, false),
292 ]));
293
294 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
295 Arc::new(Int32Array::from(vec![10, 20, 30])),
296 Arc::new(datafusion::arrow::array::StringArray::from(vec![
297 "a", "b", "c",
298 ])),
299 ])
300 .unwrap();
301
302 let partition_spec = Arc::new(partition_spec);
303 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
304 let partition_type = calculator.partition_arrow_type().clone();
305 let expr = PartitionExpr::new(calculator, partition_spec);
306
307 assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
308 assert!(!expr.nullable(&arrow_schema).unwrap());
309
310 let result = expr.evaluate(&batch).unwrap();
311 match result {
312 ColumnarValue::Array(array) => {
313 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
314 let id_partition = struct_array
315 .column_by_name("id_partition")
316 .unwrap()
317 .as_any()
318 .downcast_ref::<Int32Array>()
319 .unwrap();
320 assert_eq!(id_partition.value(0), 10);
321 assert_eq!(id_partition.value(1), 20);
322 assert_eq!(id_partition.value(2), 30);
323 }
324 _ => panic!("Expected array result"),
325 }
326 }
327
328 #[test]
329 fn test_nested_partition() {
330 let address_struct = StructType::new(vec![
331 NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
332 NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
333 ]);
334
335 let table_schema = Schema::builder()
336 .with_schema_id(0)
337 .with_fields(vec![
338 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
339 NestedField::required(2, "address", Type::Struct(address_struct)).into(),
340 ])
341 .build()
342 .unwrap();
343
344 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
345 .add_partition_field("address.city", "city_partition", Transform::Identity)
346 .unwrap()
347 .build()
348 .unwrap();
349
350 let struct_fields = Fields::from(vec![
351 Field::new("street", DataType::Utf8, false),
352 Field::new("city", DataType::Utf8, false),
353 ]);
354
355 let arrow_schema = Arc::new(ArrowSchema::new(vec![
356 Field::new("id", DataType::Int32, false),
357 Field::new("address", DataType::Struct(struct_fields), false),
358 ]));
359
360 let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
361 "123 Main St",
362 "456 Oak Ave",
363 ]));
364 let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
365 "New York",
366 "Los Angeles",
367 ]));
368
369 let struct_array = StructArray::from(vec![
370 (
371 Arc::new(Field::new("street", DataType::Utf8, false)),
372 street_array as ArrayRef,
373 ),
374 (
375 Arc::new(Field::new("city", DataType::Utf8, false)),
376 city_array as ArrayRef,
377 ),
378 ]);
379
380 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
381 Arc::new(Int32Array::from(vec![1, 2])),
382 Arc::new(struct_array),
383 ])
384 .unwrap();
385
386 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
387 let array = calculator.calculate(&batch).unwrap();
388
389 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
390 let city_partition = struct_array
391 .column_by_name("city_partition")
392 .unwrap()
393 .as_any()
394 .downcast_ref::<datafusion::arrow::array::StringArray>()
395 .unwrap();
396
397 assert_eq!(city_partition.value(0), "New York");
398 assert_eq!(city_partition.value(1), "Los Angeles");
399 }
400
401 #[test]
402 fn test_schema_validation_matching_schemas() {
403 use iceberg::TableIdent;
404 use iceberg::io::FileIO;
405 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
406
407 let table_schema = Arc::new(
408 Schema::builder()
409 .with_fields(vec![
410 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
411 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
412 ])
413 .build()
414 .unwrap(),
415 );
416
417 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
418 .add_partition_field("id", "id_partition", Transform::Identity)
419 .unwrap()
420 .build()
421 .unwrap();
422
423 let sort_order = iceberg::spec::SortOrder::builder()
424 .build(&table_schema)
425 .unwrap();
426
427 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
428 (*table_schema).clone(),
429 partition_spec,
430 sort_order,
431 "/test/table".to_string(),
432 FormatVersion::V2,
433 std::collections::HashMap::new(),
434 )
435 .unwrap();
436
437 let table_metadata = table_metadata_builder.build().unwrap();
438
439 let arrow_schema = Arc::new(ArrowSchema::new(vec![
441 Field::new("id", DataType::Int32, false),
442 Field::new("name", DataType::Utf8, false),
443 ]));
444
445 let input = Arc::new(EmptyExec::new(arrow_schema));
446
447 let table = iceberg::table::Table::builder()
448 .metadata(table_metadata.metadata)
449 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
450 .file_io(FileIO::new_with_fs())
451 .metadata_location("/test/metadata.json")
452 .runtime(test_runtime())
453 .build()
454 .unwrap();
455
456 let result = project_with_partition(input, &table);
457 assert!(result.is_ok(), "Schema validation should pass");
458 }
459
460 #[test]
461 fn test_schema_validation_mismatched_schemas() {
462 use iceberg::TableIdent;
463 use iceberg::io::FileIO;
464 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
465
466 let table_schema = Arc::new(
467 Schema::builder()
468 .with_fields(vec![
469 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
470 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
471 ])
472 .build()
473 .unwrap(),
474 );
475
476 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
477 .add_partition_field("id", "id_partition", Transform::Identity)
478 .unwrap()
479 .build()
480 .unwrap();
481
482 let sort_order = iceberg::spec::SortOrder::builder()
483 .build(&table_schema)
484 .unwrap();
485
486 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
487 (*table_schema).clone(),
488 partition_spec,
489 sort_order,
490 "/test/table".to_string(),
491 FormatVersion::V2,
492 std::collections::HashMap::new(),
493 )
494 .unwrap();
495
496 let table_metadata = table_metadata_builder.build().unwrap();
497
498 let arrow_schema = Arc::new(ArrowSchema::new(vec![
500 Field::new("id", DataType::Int32, false),
501 Field::new("different_name", DataType::Utf8, false), ]));
503
504 let input = Arc::new(EmptyExec::new(arrow_schema));
505
506 let table = iceberg::table::Table::builder()
507 .metadata(table_metadata.metadata)
508 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
509 .file_io(FileIO::new_with_fs())
510 .metadata_location("/test/metadata.json")
511 .runtime(test_runtime())
512 .build()
513 .unwrap();
514
515 let result = project_with_partition(input, &table);
516 assert!(
517 result.is_err(),
518 "Schema validation should fail for mismatched schemas"
519 );
520 assert!(
521 result
522 .unwrap_err()
523 .to_string()
524 .contains("Input schema does not match Iceberg table schema")
525 );
526 }
527
528 #[test]
529 fn test_schema_validation_with_metadata_differences() {
530 use std::collections::HashMap;
531
532 use iceberg::TableIdent;
533 use iceberg::io::FileIO;
534 use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
535
536 let table_schema = Arc::new(
537 Schema::builder()
538 .with_fields(vec![
539 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
540 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
541 ])
542 .build()
543 .unwrap(),
544 );
545
546 let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
547 .add_partition_field("id", "id_partition", Transform::Identity)
548 .unwrap()
549 .build()
550 .unwrap();
551
552 let sort_order = iceberg::spec::SortOrder::builder()
553 .build(&table_schema)
554 .unwrap();
555
556 let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
557 (*table_schema).clone(),
558 partition_spec,
559 sort_order,
560 "/test/table".to_string(),
561 FormatVersion::V2,
562 std::collections::HashMap::new(),
563 )
564 .unwrap();
565
566 let table_metadata = table_metadata_builder.build().unwrap();
567
568 let mut metadata = HashMap::new();
570 metadata.insert("extra".to_string(), "metadata".to_string());
571
572 let arrow_schema = Arc::new(ArrowSchema::new(vec![
573 Field::new("id", DataType::Int32, false).with_metadata(metadata.clone()),
574 Field::new("name", DataType::Utf8, false).with_metadata(metadata),
575 ]));
576
577 let input = Arc::new(EmptyExec::new(arrow_schema));
578
579 let table = iceberg::table::Table::builder()
580 .metadata(table_metadata.metadata)
581 .identifier(TableIdent::from_strs(["test", "table"]).unwrap())
582 .file_io(FileIO::new_with_fs())
583 .metadata_location("/test/metadata.json")
584 .runtime(test_runtime())
585 .build()
586 .unwrap();
587
588 let result = project_with_partition(input, &table);
589 assert!(
590 result.is_ok(),
591 "Schema validation should pass even with metadata differences"
592 );
593 }
594}