1use std::sync::Arc;
19
20use arrow_array::{ArrayRef, RecordBatch, StructArray, make_array};
21use arrow_buffer::NullBuffer;
22use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
23use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
24
25use crate::arrow::schema::schema_to_arrow_schema;
26use crate::error::Result;
27use crate::spec::Schema as IcebergSchema;
28use crate::{Error, ErrorKind};
29
30#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct RecordBatchProjector {
33 field_indices: Vec<Vec<usize>>,
37 projected_schema: SchemaRef,
39}
40
41impl RecordBatchProjector {
42 pub(crate) fn new<F1, F2>(
49 original_schema: SchemaRef,
50 field_ids: &[i32],
51 field_id_fetch_func: F1,
52 searchable_field_func: F2,
53 ) -> Result<Self>
54 where
55 F1: Fn(&Field) -> Result<Option<i64>>,
56 F2: Fn(&Field) -> bool,
57 {
58 let mut field_indices = Vec::with_capacity(field_ids.len());
59 let mut fields = Vec::with_capacity(field_ids.len());
60 for &id in field_ids {
61 let mut field_index = vec![];
62 let field = Self::fetch_field_index(
63 original_schema.fields(),
64 &mut field_index,
65 id as i64,
66 &field_id_fetch_func,
67 &searchable_field_func,
68 )?
69 .ok_or_else(|| {
70 Error::new(ErrorKind::Unexpected, "Field not found")
71 .with_context("field_id", id.to_string())
72 })?;
73 fields.push(field.clone());
74 field_indices.push(field_index);
75 }
76 let delete_arrow_schema = Arc::new(Schema::new(fields));
77 Ok(Self {
78 field_indices,
79 projected_schema: delete_arrow_schema,
80 })
81 }
82
83 pub fn from_iceberg_schema(
92 iceberg_schema: Arc<IcebergSchema>,
93 target_field_ids: &[i32],
94 ) -> Result<Self> {
95 let arrow_schema_with_ids = Arc::new(schema_to_arrow_schema(&iceberg_schema)?);
96
97 let field_id_fetch_func = |field: &Field| -> Result<Option<i64>> {
98 if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
99 let field_id = value.parse::<i32>().map_err(|e| {
100 Error::new(
101 ErrorKind::DataInvalid,
102 "Failed to parse field id".to_string(),
103 )
104 .with_context("value", value)
105 .with_source(e)
106 })?;
107 Ok(Some(field_id as i64))
108 } else {
109 Ok(None)
110 }
111 };
112
113 let searchable_field_func = |_field: &Field| -> bool { true };
114
115 Self::new(
116 arrow_schema_with_ids,
117 target_field_ids,
118 field_id_fetch_func,
119 searchable_field_func,
120 )
121 }
122
123 fn fetch_field_index<F1, F2>(
124 fields: &Fields,
125 index_vec: &mut Vec<usize>,
126 target_field_id: i64,
127 field_id_fetch_func: &F1,
128 searchable_field_func: &F2,
129 ) -> Result<Option<FieldRef>>
130 where
131 F1: Fn(&Field) -> Result<Option<i64>>,
132 F2: Fn(&Field) -> bool,
133 {
134 for (pos, field) in fields.iter().enumerate() {
135 let id = field_id_fetch_func(field)?;
136 if let Some(id) = id {
137 if target_field_id == id {
138 index_vec.push(pos);
139 return Ok(Some(field.clone()));
140 }
141 }
142 if let DataType::Struct(inner) = field.data_type() {
143 if searchable_field_func(field) {
144 if let Some(res) = Self::fetch_field_index(
145 inner,
146 index_vec,
147 target_field_id,
148 field_id_fetch_func,
149 searchable_field_func,
150 )? {
151 index_vec.push(pos);
152 return Ok(Some(res));
153 }
154 }
155 }
156 }
157 Ok(None)
158 }
159
160 pub(crate) fn projected_schema_ref(&self) -> &SchemaRef {
162 &self.projected_schema
163 }
164
165 pub(crate) fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
167 RecordBatch::try_new(
168 self.projected_schema.clone(),
169 self.project_column(batch.columns())?,
170 )
171 .map_err(|err| Error::new(ErrorKind::DataInvalid, format!("{err}")))
172 }
173
174 pub fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
176 self.field_indices
177 .iter()
178 .map(|index_vec| Self::get_column_by_field_index(batch, index_vec))
179 .collect::<Result<Vec<_>>>()
180 }
181
182 fn get_column_by_field_index(batch: &[ArrayRef], field_index: &[usize]) -> Result<ArrayRef> {
183 let mut rev_iterator = field_index.iter().rev();
184 let mut array = batch[*rev_iterator.next().unwrap()].clone();
185 let mut null_buffer = array.logical_nulls();
186 for idx in rev_iterator {
187 array = array
188 .as_any()
189 .downcast_ref::<StructArray>()
190 .ok_or(Error::new(
191 ErrorKind::Unexpected,
192 "Cannot convert Array to StructArray",
193 ))?
194 .column(*idx)
195 .clone();
196 null_buffer = NullBuffer::union(null_buffer.as_ref(), array.logical_nulls().as_ref());
197 }
198 Ok(make_array(
199 array.to_data().into_builder().nulls(null_buffer).build()?,
200 ))
201 }
202}
203
204#[cfg(test)]
205mod test {
206 use std::sync::Arc;
207
208 use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray};
209 use arrow_schema::{DataType, Field, Fields, Schema};
210
211 use crate::arrow::record_batch_projector::RecordBatchProjector;
212 use crate::spec::{NestedField, PrimitiveType, Schema as IcebergSchema, Type};
213 use crate::{Error, ErrorKind};
214
215 #[test]
216 fn test_record_batch_projector_nested_level() {
217 let inner_fields = vec![
218 Field::new("inner_field1", DataType::Int32, false),
219 Field::new("inner_field2", DataType::Utf8, false),
220 ];
221 let fields = vec![
222 Field::new("field1", DataType::Int32, false),
223 Field::new(
224 "field2",
225 DataType::Struct(Fields::from(inner_fields.clone())),
226 false,
227 ),
228 ];
229 let schema = Arc::new(Schema::new(fields));
230
231 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
232 "field1" => Ok(Some(1)),
233 "field2" => Ok(Some(2)),
234 "inner_field1" => Ok(Some(3)),
235 "inner_field2" => Ok(Some(4)),
236 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
237 };
238 let projector =
239 RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true)
240 .unwrap();
241
242 assert_eq!(projector.field_indices.len(), 2);
243 assert_eq!(projector.field_indices[0], vec![0]);
244 assert_eq!(projector.field_indices[1], vec![0, 1]);
245
246 let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
247 let inner_int_array = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
248 let inner_string_array = Arc::new(StringArray::from(vec!["x", "y", "z"])) as ArrayRef;
249 let struct_array = Arc::new(StructArray::from(vec![
250 (
251 Arc::new(inner_fields[0].clone()),
252 inner_int_array as ArrayRef,
253 ),
254 (
255 Arc::new(inner_fields[1].clone()),
256 inner_string_array as ArrayRef,
257 ),
258 ])) as ArrayRef;
259 let batch = RecordBatch::try_new(schema, vec![int_array, struct_array]).unwrap();
260
261 let projected_batch = projector.project_batch(batch).unwrap();
262 assert_eq!(projected_batch.num_columns(), 2);
263 let projected_int_array = projected_batch
264 .column(0)
265 .as_any()
266 .downcast_ref::<Int32Array>()
267 .unwrap();
268 let projected_inner_int_array = projected_batch
269 .column(1)
270 .as_any()
271 .downcast_ref::<Int32Array>()
272 .unwrap();
273
274 assert_eq!(projected_int_array.values(), &[1, 2, 3]);
275 assert_eq!(projected_inner_int_array.values(), &[4, 5, 6]);
276 }
277
278 #[test]
279 fn test_field_not_found() {
280 let inner_fields = vec![
281 Field::new("inner_field1", DataType::Int32, false),
282 Field::new("inner_field2", DataType::Utf8, false),
283 ];
284
285 let fields = vec![
286 Field::new("field1", DataType::Int32, false),
287 Field::new(
288 "field2",
289 DataType::Struct(Fields::from(inner_fields.clone())),
290 false,
291 ),
292 ];
293 let schema = Arc::new(Schema::new(fields));
294
295 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
296 "field1" => Ok(Some(1)),
297 "field2" => Ok(Some(2)),
298 "inner_field1" => Ok(Some(3)),
299 "inner_field2" => Ok(Some(4)),
300 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
301 };
302 let projector =
303 RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true);
304
305 assert!(projector.is_err());
306 }
307
308 #[test]
309 fn test_field_not_reachable() {
310 let inner_fields = vec![
311 Field::new("inner_field1", DataType::Int32, false),
312 Field::new("inner_field2", DataType::Utf8, false),
313 ];
314
315 let fields = vec![
316 Field::new("field1", DataType::Int32, false),
317 Field::new(
318 "field2",
319 DataType::Struct(Fields::from(inner_fields.clone())),
320 false,
321 ),
322 ];
323 let schema = Arc::new(Schema::new(fields));
324
325 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
326 "field1" => Ok(Some(1)),
327 "field2" => Ok(Some(2)),
328 "inner_field1" => Ok(Some(3)),
329 "inner_field2" => Ok(Some(4)),
330 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
331 };
332 let projector =
333 RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false);
334 assert!(projector.is_err());
335
336 let projector =
337 RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
338 assert!(projector.is_ok());
339 }
340
341 #[test]
342 fn test_from_iceberg_schema() {
343 let iceberg_schema = IcebergSchema::builder()
344 .with_schema_id(0)
345 .with_fields(vec![
346 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
347 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
348 NestedField::optional(3, "age", Type::Primitive(PrimitiveType::Int)).into(),
349 ])
350 .build()
351 .unwrap();
352
353 let projector =
354 RecordBatchProjector::from_iceberg_schema(Arc::new(iceberg_schema), &[1, 3]).unwrap();
355
356 assert_eq!(projector.field_indices.len(), 2);
357 assert_eq!(projector.projected_schema_ref().fields().len(), 2);
358 assert_eq!(projector.projected_schema_ref().field(0).name(), "id");
359 assert_eq!(projector.projected_schema_ref().field(1).name(), "age");
360 }
361}