iceberg/arrow/
partition_value_calculator.rs1use std::sync::Arc;
24
25use arrow_array::{ArrayRef, RecordBatch, StructArray};
26use arrow_schema::DataType;
27
28use super::record_batch_projector::RecordBatchProjector;
29use super::type_to_arrow_type;
30use crate::spec::{PartitionSpec, Schema, StructType, Type};
31use crate::transform::{BoxedTransformFunction, create_transform_function};
32use crate::{Error, ErrorKind, Result};
33
34#[derive(Debug)]
39pub struct PartitionValueCalculator {
40 projector: RecordBatchProjector,
41 transform_functions: Vec<BoxedTransformFunction>,
42 partition_type: StructType,
43 partition_arrow_type: DataType,
44}
45
46impl PartitionValueCalculator {
47 pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> Result<Self> {
65 if partition_spec.is_unpartitioned() {
66 return Err(Error::new(
67 ErrorKind::DataInvalid,
68 "Cannot create partition calculator for unpartitioned table",
69 ));
70 }
71
72 let transform_functions: Vec<BoxedTransformFunction> = partition_spec
74 .fields()
75 .iter()
76 .map(|pf| create_transform_function(&pf.transform))
77 .collect::<Result<Vec<_>>>()?;
78
79 let source_field_ids: Vec<i32> = partition_spec
81 .fields()
82 .iter()
83 .map(|pf| pf.source_id)
84 .collect();
85
86 let projector = RecordBatchProjector::from_iceberg_schema(
88 Arc::new(table_schema.clone()),
89 &source_field_ids,
90 )?;
91
92 let partition_type = partition_spec.partition_type(table_schema)?;
94 let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
95
96 Ok(Self {
97 projector,
98 transform_functions,
99 partition_type,
100 partition_arrow_type,
101 })
102 }
103
104 pub fn partition_type(&self) -> &StructType {
106 &self.partition_type
107 }
108
109 pub fn partition_arrow_type(&self) -> &DataType {
111 &self.partition_arrow_type
112 }
113
114 pub fn calculate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
136 let source_columns = self.projector.project_column(batch.columns())?;
138
139 let expected_struct_fields = match &self.partition_arrow_type {
141 DataType::Struct(fields) => fields.clone(),
142 _ => {
143 return Err(Error::new(
144 ErrorKind::DataInvalid,
145 "Expected partition type must be a struct",
146 ));
147 }
148 };
149
150 let mut partition_values = Vec::with_capacity(self.transform_functions.len());
152 for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) {
153 let partition_value = transform_fn.transform(source_column.clone())?;
154 partition_values.push(partition_value);
155 }
156
157 let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None)
159 .map_err(|e| {
160 Error::new(
161 ErrorKind::DataInvalid,
162 format!("Failed to create partition struct array: {e}"),
163 )
164 })?;
165
166 Ok(Arc::new(struct_array))
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use std::sync::Arc;
173
174 use arrow_array::{Int32Array, RecordBatch, StringArray};
175 use arrow_schema::{Field, Schema as ArrowSchema};
176
177 use super::*;
178 use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, Transform};
179
180 #[test]
181 fn test_partition_calculator_identity_transform() {
182 let table_schema = Schema::builder()
183 .with_schema_id(0)
184 .with_fields(vec![
185 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
186 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
187 ])
188 .build()
189 .unwrap();
190
191 let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
192 .add_partition_field("id", "id_partition", Transform::Identity)
193 .unwrap()
194 .build()
195 .unwrap();
196
197 let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
198
199 assert_eq!(calculator.partition_type().fields().len(), 1);
201 assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
202
203 let arrow_schema = Arc::new(ArrowSchema::new(vec![
205 Field::new("id", DataType::Int32, false),
206 Field::new("name", DataType::Utf8, false),
207 ]));
208
209 let batch = RecordBatch::try_new(arrow_schema, vec![
210 Arc::new(Int32Array::from(vec![10, 20, 30])),
211 Arc::new(StringArray::from(vec!["a", "b", "c"])),
212 ])
213 .unwrap();
214
215 let result = calculator.calculate(&batch).unwrap();
217 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
218
219 let id_partition = struct_array
220 .column_by_name("id_partition")
221 .unwrap()
222 .as_any()
223 .downcast_ref::<Int32Array>()
224 .unwrap();
225
226 assert_eq!(id_partition.value(0), 10);
227 assert_eq!(id_partition.value(1), 20);
228 assert_eq!(id_partition.value(2), 30);
229 }
230
231 #[test]
232 fn test_partition_calculator_unpartitioned_error() {
233 let table_schema = Schema::builder()
234 .with_schema_id(0)
235 .with_fields(vec![
236 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
237 ])
238 .build()
239 .unwrap();
240
241 let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
242 .build()
243 .unwrap();
244
245 let result = PartitionValueCalculator::try_new(&partition_spec, &table_schema);
246 assert!(result.is_err());
247 assert!(
248 result
249 .unwrap_err()
250 .to_string()
251 .contains("unpartitioned table")
252 );
253 }
254}