iceberg_datafusion/physical_plan/
project.rs1use std::sync::Arc;
21
22use datafusion::arrow::array::RecordBatch;
23use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
24use datafusion::common::Result as DFResult;
25use datafusion::physical_expr::PhysicalExpr;
26use datafusion::physical_expr::expressions::Column;
27use datafusion::physical_plan::projection::ProjectionExec;
28use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
29use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator};
30use iceberg::spec::PartitionSpec;
31use iceberg::table::Table;
32
33use crate::to_datafusion_error;
34
35pub fn project_with_partition(
49 input: Arc<dyn ExecutionPlan>,
50 table: &Table,
51) -> DFResult<Arc<dyn ExecutionPlan>> {
52 let metadata = table.metadata();
53 let partition_spec = metadata.default_partition_spec();
54 let table_schema = metadata.current_schema();
55
56 if partition_spec.is_unpartitioned() {
57 return Ok(input);
58 }
59
60 let input_schema = input.schema();
61 let calculator =
64 PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
65 .map_err(to_datafusion_error)?;
66
67 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
68 Vec::with_capacity(input_schema.fields().len() + 1);
69
70 for (index, field) in input_schema.fields().iter().enumerate() {
71 let column_expr = Arc::new(Column::new(field.name(), index));
72 projection_exprs.push((column_expr, field.name().clone()));
73 }
74
75 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
76 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
77
78 let projection = ProjectionExec::try_new(projection_exprs, input)?;
79 Ok(Arc::new(projection))
80}
81
82#[derive(Debug, Clone)]
84struct PartitionExpr {
85 calculator: Arc<PartitionValueCalculator>,
86 partition_spec: Arc<PartitionSpec>,
87}
88
89impl PartitionExpr {
90 fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
91 Self {
92 calculator: Arc::new(calculator),
93 partition_spec,
94 }
95 }
96}
97
98impl PartialEq for PartitionExpr {
101 fn eq(&self, other: &Self) -> bool {
102 Arc::ptr_eq(&self.calculator, &other.calculator)
103 && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
104 }
105}
106
107impl Eq for PartitionExpr {}
108
109impl PhysicalExpr for PartitionExpr {
110 fn as_any(&self) -> &dyn std::any::Any {
111 self
112 }
113
114 fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
115 Ok(self.calculator.partition_arrow_type().clone())
116 }
117
118 fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
119 Ok(false)
120 }
121
122 fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
123 let array = self
124 .calculator
125 .calculate(batch)
126 .map_err(to_datafusion_error)?;
127 Ok(ColumnarValue::Array(array))
128 }
129
130 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
131 vec![]
132 }
133
134 fn with_new_children(
135 self: Arc<Self>,
136 _children: Vec<Arc<dyn PhysicalExpr>>,
137 ) -> DFResult<Arc<dyn PhysicalExpr>> {
138 Ok(self)
139 }
140
141 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 let field_names: Vec<String> = self
143 .partition_spec
144 .fields()
145 .iter()
146 .map(|pf| format!("{}({})", pf.transform, pf.name))
147 .collect();
148 write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
149 }
150}
151
152impl std::fmt::Display for PartitionExpr {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 let field_names: Vec<&str> = self
155 .partition_spec
156 .fields()
157 .iter()
158 .map(|pf| pf.name.as_str())
159 .collect();
160 write!(f, "iceberg_partition_values({})", field_names.join(", "))
161 }
162}
163
164impl std::hash::Hash for PartitionExpr {
165 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
166 Arc::as_ptr(&self.calculator).hash(state);
168 Arc::as_ptr(&self.partition_spec).hash(state);
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
175 use datafusion::arrow::datatypes::{Field, Fields};
176 use datafusion::physical_plan::empty::EmptyExec;
177 use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
178
179 use super::*;
180
181 #[test]
182 fn test_partition_calculator_basic() {
183 let table_schema = Schema::builder()
184 .with_schema_id(0)
185 .with_fields(vec![
186 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
187 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
188 ])
189 .build()
190 .unwrap();
191
192 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
193 .add_partition_field("id", "id_partition", Transform::Identity)
194 .unwrap()
195 .build()
196 .unwrap();
197
198 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
199
200 assert_eq!(calculator.partition_type().fields().len(), 1);
202 assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
203 }
204
205 #[test]
206 fn test_partition_expr_with_projection() {
207 let table_schema = Schema::builder()
208 .with_schema_id(0)
209 .with_fields(vec![
210 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
211 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
212 ])
213 .build()
214 .unwrap();
215
216 let partition_spec = Arc::new(
217 iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
218 .add_partition_field("id", "id_partition", Transform::Identity)
219 .unwrap()
220 .build()
221 .unwrap(),
222 );
223
224 let arrow_schema = Arc::new(ArrowSchema::new(vec![
225 Field::new("id", DataType::Int32, false),
226 Field::new("name", DataType::Utf8, false),
227 ]));
228
229 let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
230
231 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
232
233 let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
234 Vec::with_capacity(arrow_schema.fields().len() + 1);
235 for (i, field) in arrow_schema.fields().iter().enumerate() {
236 let column_expr = Arc::new(Column::new(field.name(), i));
237 projection_exprs.push((column_expr, field.name().clone()));
238 }
239
240 let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
241 projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
242
243 let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
244 let result = Arc::new(projection);
245
246 assert_eq!(result.schema().fields().len(), 3);
247 assert_eq!(result.schema().field(0).name(), "id");
248 assert_eq!(result.schema().field(1).name(), "name");
249 assert_eq!(result.schema().field(2).name(), "_partition");
250 }
251
252 #[test]
253 fn test_partition_expr_evaluate() {
254 let table_schema = Schema::builder()
255 .with_schema_id(0)
256 .with_fields(vec![
257 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
258 NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
259 ])
260 .build()
261 .unwrap();
262
263 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
264 .add_partition_field("id", "id_partition", Transform::Identity)
265 .unwrap()
266 .build()
267 .unwrap();
268
269 let arrow_schema = Arc::new(ArrowSchema::new(vec![
270 Field::new("id", DataType::Int32, false),
271 Field::new("data", DataType::Utf8, false),
272 ]));
273
274 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
275 Arc::new(Int32Array::from(vec![10, 20, 30])),
276 Arc::new(datafusion::arrow::array::StringArray::from(vec![
277 "a", "b", "c",
278 ])),
279 ])
280 .unwrap();
281
282 let partition_spec = Arc::new(partition_spec);
283 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
284 let partition_type = calculator.partition_arrow_type().clone();
285 let expr = PartitionExpr::new(calculator, partition_spec);
286
287 assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
288 assert!(!expr.nullable(&arrow_schema).unwrap());
289
290 let result = expr.evaluate(&batch).unwrap();
291 match result {
292 ColumnarValue::Array(array) => {
293 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
294 let id_partition = struct_array
295 .column_by_name("id_partition")
296 .unwrap()
297 .as_any()
298 .downcast_ref::<Int32Array>()
299 .unwrap();
300 assert_eq!(id_partition.value(0), 10);
301 assert_eq!(id_partition.value(1), 20);
302 assert_eq!(id_partition.value(2), 30);
303 }
304 _ => panic!("Expected array result"),
305 }
306 }
307
308 #[test]
309 fn test_nested_partition() {
310 let address_struct = StructType::new(vec![
311 NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
312 NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
313 ]);
314
315 let table_schema = Schema::builder()
316 .with_schema_id(0)
317 .with_fields(vec![
318 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
319 NestedField::required(2, "address", Type::Struct(address_struct)).into(),
320 ])
321 .build()
322 .unwrap();
323
324 let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
325 .add_partition_field("address.city", "city_partition", Transform::Identity)
326 .unwrap()
327 .build()
328 .unwrap();
329
330 let struct_fields = Fields::from(vec![
331 Field::new("street", DataType::Utf8, false),
332 Field::new("city", DataType::Utf8, false),
333 ]);
334
335 let arrow_schema = Arc::new(ArrowSchema::new(vec![
336 Field::new("id", DataType::Int32, false),
337 Field::new("address", DataType::Struct(struct_fields), false),
338 ]));
339
340 let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
341 "123 Main St",
342 "456 Oak Ave",
343 ]));
344 let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
345 "New York",
346 "Los Angeles",
347 ]));
348
349 let struct_array = StructArray::from(vec![
350 (
351 Arc::new(Field::new("street", DataType::Utf8, false)),
352 street_array as ArrayRef,
353 ),
354 (
355 Arc::new(Field::new("city", DataType::Utf8, false)),
356 city_array as ArrayRef,
357 ),
358 ]);
359
360 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
361 Arc::new(Int32Array::from(vec![1, 2])),
362 Arc::new(struct_array),
363 ])
364 .unwrap();
365
366 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
367 let array = calculator.calculate(&batch).unwrap();
368
369 let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
370 let city_partition = struct_array
371 .column_by_name("city_partition")
372 .unwrap()
373 .as_any()
374 .downcast_ref::<datafusion::arrow::array::StringArray>()
375 .unwrap();
376
377 assert_eq!(city_partition.value(0), "New York");
378 assert_eq!(city_partition.value(1), "Los Angeles");
379 }
380}