iceberg/arrow/
record_batch_projector.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow_array::{ArrayRef, RecordBatch, StructArray, make_array};
21use arrow_buffer::NullBuffer;
22use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
23use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
24
25use crate::arrow::schema::schema_to_arrow_schema;
26use crate::error::Result;
27use crate::spec::Schema as IcebergSchema;
28use crate::{Error, ErrorKind};
29
30/// Help to project specific field from `RecordBatch`` according to the fields id.
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct RecordBatchProjector {
33    // A vector of vectors, where each inner vector represents the index path to access a specific field in a nested structure.
34    // E.g. [[0], [1, 2]] means the first field is accessed directly from the first column,
35    // while the second field is accessed from the second column and then from its third subcolumn (second column must be a struct column).
36    field_indices: Vec<Vec<usize>>,
37    // The schema reference after projection. This schema is derived from the original schema based on the given field IDs.
38    projected_schema: SchemaRef,
39}
40
41impl RecordBatchProjector {
42    /// Init ArrowFieldProjector
43    ///
44    /// This function will iterate through the field and fetch the field from the original schema according to the field ids.
45    /// The function to fetch the field id from the field is provided by `field_id_fetch_func`, return None if the field need to be skipped.
46    /// This function will iterate through the nested fields if the field is a struct, `searchable_field_func` can be used to control whether
47    /// iterate into the nested fields.
48    pub(crate) fn new<F1, F2>(
49        original_schema: SchemaRef,
50        field_ids: &[i32],
51        field_id_fetch_func: F1,
52        searchable_field_func: F2,
53    ) -> Result<Self>
54    where
55        F1: Fn(&Field) -> Result<Option<i64>>,
56        F2: Fn(&Field) -> bool,
57    {
58        let mut field_indices = Vec::with_capacity(field_ids.len());
59        let mut fields = Vec::with_capacity(field_ids.len());
60        for &id in field_ids {
61            let mut field_index = vec![];
62            let field = Self::fetch_field_index(
63                original_schema.fields(),
64                &mut field_index,
65                id as i64,
66                &field_id_fetch_func,
67                &searchable_field_func,
68            )?
69            .ok_or_else(|| {
70                Error::new(ErrorKind::Unexpected, "Field not found")
71                    .with_context("field_id", id.to_string())
72            })?;
73            fields.push(field.clone());
74            field_indices.push(field_index);
75        }
76        let delete_arrow_schema = Arc::new(Schema::new(fields));
77        Ok(Self {
78            field_indices,
79            projected_schema: delete_arrow_schema,
80        })
81    }
82
83    /// Create RecordBatchProjector using Iceberg schema.
84    ///
85    /// This constructor converts the Iceberg schema to Arrow schema with field ID metadata,
86    /// then uses the standard field ID lookup for projection.
87    ///
88    /// # Arguments
89    /// * `iceberg_schema` - The Iceberg schema for field ID mapping  
90    /// * `target_field_ids` - The field IDs to project
91    pub fn from_iceberg_schema(
92        iceberg_schema: Arc<IcebergSchema>,
93        target_field_ids: &[i32],
94    ) -> Result<Self> {
95        let arrow_schema_with_ids = Arc::new(schema_to_arrow_schema(&iceberg_schema)?);
96
97        let field_id_fetch_func = |field: &Field| -> Result<Option<i64>> {
98            if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
99                let field_id = value.parse::<i32>().map_err(|e| {
100                    Error::new(
101                        ErrorKind::DataInvalid,
102                        "Failed to parse field id".to_string(),
103                    )
104                    .with_context("value", value)
105                    .with_source(e)
106                })?;
107                Ok(Some(field_id as i64))
108            } else {
109                Ok(None)
110            }
111        };
112
113        let searchable_field_func = |_field: &Field| -> bool { true };
114
115        Self::new(
116            arrow_schema_with_ids,
117            target_field_ids,
118            field_id_fetch_func,
119            searchable_field_func,
120        )
121    }
122
123    fn fetch_field_index<F1, F2>(
124        fields: &Fields,
125        index_vec: &mut Vec<usize>,
126        target_field_id: i64,
127        field_id_fetch_func: &F1,
128        searchable_field_func: &F2,
129    ) -> Result<Option<FieldRef>>
130    where
131        F1: Fn(&Field) -> Result<Option<i64>>,
132        F2: Fn(&Field) -> bool,
133    {
134        for (pos, field) in fields.iter().enumerate() {
135            let id = field_id_fetch_func(field)?;
136            if let Some(id) = id {
137                if target_field_id == id {
138                    index_vec.push(pos);
139                    return Ok(Some(field.clone()));
140                }
141            }
142            if let DataType::Struct(inner) = field.data_type() {
143                if searchable_field_func(field) {
144                    if let Some(res) = Self::fetch_field_index(
145                        inner,
146                        index_vec,
147                        target_field_id,
148                        field_id_fetch_func,
149                        searchable_field_func,
150                    )? {
151                        index_vec.push(pos);
152                        return Ok(Some(res));
153                    }
154                }
155            }
156        }
157        Ok(None)
158    }
159
160    /// Return the reference of projected schema
161    pub(crate) fn projected_schema_ref(&self) -> &SchemaRef {
162        &self.projected_schema
163    }
164
165    /// Do projection with record batch
166    pub(crate) fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
167        RecordBatch::try_new(
168            self.projected_schema.clone(),
169            self.project_column(batch.columns())?,
170        )
171        .map_err(|err| Error::new(ErrorKind::DataInvalid, format!("{err}")))
172    }
173
174    /// Do projection with columns
175    pub fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
176        self.field_indices
177            .iter()
178            .map(|index_vec| Self::get_column_by_field_index(batch, index_vec))
179            .collect::<Result<Vec<_>>>()
180    }
181
182    fn get_column_by_field_index(batch: &[ArrayRef], field_index: &[usize]) -> Result<ArrayRef> {
183        let mut rev_iterator = field_index.iter().rev();
184        let mut array = batch[*rev_iterator.next().unwrap()].clone();
185        let mut null_buffer = array.logical_nulls();
186        for idx in rev_iterator {
187            array = array
188                .as_any()
189                .downcast_ref::<StructArray>()
190                .ok_or(Error::new(
191                    ErrorKind::Unexpected,
192                    "Cannot convert Array to StructArray",
193                ))?
194                .column(*idx)
195                .clone();
196            null_buffer = NullBuffer::union(null_buffer.as_ref(), array.logical_nulls().as_ref());
197        }
198        Ok(make_array(
199            array.to_data().into_builder().nulls(null_buffer).build()?,
200        ))
201    }
202}
203
204#[cfg(test)]
205mod test {
206    use std::sync::Arc;
207
208    use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray};
209    use arrow_schema::{DataType, Field, Fields, Schema};
210
211    use crate::arrow::record_batch_projector::RecordBatchProjector;
212    use crate::spec::{NestedField, PrimitiveType, Schema as IcebergSchema, Type};
213    use crate::{Error, ErrorKind};
214
215    #[test]
216    fn test_record_batch_projector_nested_level() {
217        let inner_fields = vec![
218            Field::new("inner_field1", DataType::Int32, false),
219            Field::new("inner_field2", DataType::Utf8, false),
220        ];
221        let fields = vec![
222            Field::new("field1", DataType::Int32, false),
223            Field::new(
224                "field2",
225                DataType::Struct(Fields::from(inner_fields.clone())),
226                false,
227            ),
228        ];
229        let schema = Arc::new(Schema::new(fields));
230
231        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
232            "field1" => Ok(Some(1)),
233            "field2" => Ok(Some(2)),
234            "inner_field1" => Ok(Some(3)),
235            "inner_field2" => Ok(Some(4)),
236            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
237        };
238        let projector =
239            RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true)
240                .unwrap();
241
242        assert_eq!(projector.field_indices.len(), 2);
243        assert_eq!(projector.field_indices[0], vec![0]);
244        assert_eq!(projector.field_indices[1], vec![0, 1]);
245
246        let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
247        let inner_int_array = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
248        let inner_string_array = Arc::new(StringArray::from(vec!["x", "y", "z"])) as ArrayRef;
249        let struct_array = Arc::new(StructArray::from(vec![
250            (
251                Arc::new(inner_fields[0].clone()),
252                inner_int_array as ArrayRef,
253            ),
254            (
255                Arc::new(inner_fields[1].clone()),
256                inner_string_array as ArrayRef,
257            ),
258        ])) as ArrayRef;
259        let batch = RecordBatch::try_new(schema, vec![int_array, struct_array]).unwrap();
260
261        let projected_batch = projector.project_batch(batch).unwrap();
262        assert_eq!(projected_batch.num_columns(), 2);
263        let projected_int_array = projected_batch
264            .column(0)
265            .as_any()
266            .downcast_ref::<Int32Array>()
267            .unwrap();
268        let projected_inner_int_array = projected_batch
269            .column(1)
270            .as_any()
271            .downcast_ref::<Int32Array>()
272            .unwrap();
273
274        assert_eq!(projected_int_array.values(), &[1, 2, 3]);
275        assert_eq!(projected_inner_int_array.values(), &[4, 5, 6]);
276    }
277
278    #[test]
279    fn test_field_not_found() {
280        let inner_fields = vec![
281            Field::new("inner_field1", DataType::Int32, false),
282            Field::new("inner_field2", DataType::Utf8, false),
283        ];
284
285        let fields = vec![
286            Field::new("field1", DataType::Int32, false),
287            Field::new(
288                "field2",
289                DataType::Struct(Fields::from(inner_fields.clone())),
290                false,
291            ),
292        ];
293        let schema = Arc::new(Schema::new(fields));
294
295        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
296            "field1" => Ok(Some(1)),
297            "field2" => Ok(Some(2)),
298            "inner_field1" => Ok(Some(3)),
299            "inner_field2" => Ok(Some(4)),
300            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
301        };
302        let projector =
303            RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true);
304
305        assert!(projector.is_err());
306    }
307
308    #[test]
309    fn test_field_not_reachable() {
310        let inner_fields = vec![
311            Field::new("inner_field1", DataType::Int32, false),
312            Field::new("inner_field2", DataType::Utf8, false),
313        ];
314
315        let fields = vec![
316            Field::new("field1", DataType::Int32, false),
317            Field::new(
318                "field2",
319                DataType::Struct(Fields::from(inner_fields.clone())),
320                false,
321            ),
322        ];
323        let schema = Arc::new(Schema::new(fields));
324
325        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
326            "field1" => Ok(Some(1)),
327            "field2" => Ok(Some(2)),
328            "inner_field1" => Ok(Some(3)),
329            "inner_field2" => Ok(Some(4)),
330            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
331        };
332        let projector =
333            RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false);
334        assert!(projector.is_err());
335
336        let projector =
337            RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
338        assert!(projector.is_ok());
339    }
340
341    #[test]
342    fn test_from_iceberg_schema() {
343        let iceberg_schema = IcebergSchema::builder()
344            .with_schema_id(0)
345            .with_fields(vec![
346                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
347                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
348                NestedField::optional(3, "age", Type::Primitive(PrimitiveType::Int)).into(),
349            ])
350            .build()
351            .unwrap();
352
353        let projector =
354            RecordBatchProjector::from_iceberg_schema(Arc::new(iceberg_schema), &[1, 3]).unwrap();
355
356        assert_eq!(projector.field_indices.len(), 2);
357        assert_eq!(projector.projected_schema_ref().fields().len(), 2);
358        assert_eq!(projector.projected_schema_ref().field(0).name(), "id");
359        assert_eq!(projector.projected_schema_ref().field(1).name(), "age");
360    }
361}