iceberg/arrow/
record_batch_projector.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow_array::{ArrayRef, RecordBatch, StructArray, make_array};
21use arrow_buffer::NullBuffer;
22use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
23use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
24
25use crate::arrow::schema::schema_to_arrow_schema;
26use crate::error::Result;
27use crate::spec::Schema as IcebergSchema;
28use crate::{Error, ErrorKind};
29
30/// Help to project specific field from `RecordBatch`` according to the fields id.
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct RecordBatchProjector {
33    // A vector of vectors, where each inner vector represents the index path to access a specific field in a nested structure.
34    // E.g. [[0], [1, 2]] means the first field is accessed directly from the first column,
35    // while the second field is accessed from the second column and then from its third subcolumn (second column must be a struct column).
36    field_indices: Vec<Vec<usize>>,
37    // The schema reference after projection. This schema is derived from the original schema based on the given field IDs.
38    projected_schema: SchemaRef,
39}
40
41impl RecordBatchProjector {
42    /// Init ArrowFieldProjector
43    ///
44    /// This function will iterate through the field and fetch the field from the original schema according to the field ids.
45    /// The function to fetch the field id from the field is provided by `field_id_fetch_func`, return None if the field need to be skipped.
46    /// This function will iterate through the nested fields if the field is a struct, `searchable_field_func` can be used to control whether
47    /// iterate into the nested fields.
48    pub(crate) fn new<F1, F2>(
49        original_schema: SchemaRef,
50        field_ids: &[i32],
51        field_id_fetch_func: F1,
52        searchable_field_func: F2,
53    ) -> Result<Self>
54    where
55        F1: Fn(&Field) -> Result<Option<i64>>,
56        F2: Fn(&Field) -> bool,
57    {
58        let mut field_indices = Vec::with_capacity(field_ids.len());
59        let mut fields = Vec::with_capacity(field_ids.len());
60        for &id in field_ids {
61            let mut field_index = vec![];
62            let field = Self::fetch_field_index(
63                original_schema.fields(),
64                &mut field_index,
65                id as i64,
66                &field_id_fetch_func,
67                &searchable_field_func,
68            )?
69            .ok_or_else(|| {
70                Error::new(ErrorKind::Unexpected, "Field not found")
71                    .with_context("field_id", id.to_string())
72            })?;
73            fields.push(field.clone());
74            field_indices.push(field_index);
75        }
76        let delete_arrow_schema = Arc::new(Schema::new(fields));
77        Ok(Self {
78            field_indices,
79            projected_schema: delete_arrow_schema,
80        })
81    }
82
83    /// Create RecordBatchProjector using Iceberg schema.
84    ///
85    /// This constructor converts the Iceberg schema to Arrow schema with field ID metadata,
86    /// then uses the standard field ID lookup for projection.
87    ///
88    /// # Arguments
89    /// * `iceberg_schema` - The Iceberg schema for field ID mapping  
90    /// * `target_field_ids` - The field IDs to project
91    pub fn from_iceberg_schema(
92        iceberg_schema: Arc<IcebergSchema>,
93        target_field_ids: &[i32],
94    ) -> Result<Self> {
95        let arrow_schema_with_ids = Arc::new(schema_to_arrow_schema(&iceberg_schema)?);
96
97        let field_id_fetch_func = |field: &Field| -> Result<Option<i64>> {
98            if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
99                let field_id = value.parse::<i32>().map_err(|e| {
100                    Error::new(
101                        ErrorKind::DataInvalid,
102                        "Failed to parse field id".to_string(),
103                    )
104                    .with_context("value", value)
105                    .with_source(e)
106                })?;
107                Ok(Some(field_id as i64))
108            } else {
109                Ok(None)
110            }
111        };
112
113        let searchable_field_func = |_field: &Field| -> bool { true };
114
115        Self::new(
116            arrow_schema_with_ids,
117            target_field_ids,
118            field_id_fetch_func,
119            searchable_field_func,
120        )
121    }
122
123    fn fetch_field_index<F1, F2>(
124        fields: &Fields,
125        index_vec: &mut Vec<usize>,
126        target_field_id: i64,
127        field_id_fetch_func: &F1,
128        searchable_field_func: &F2,
129    ) -> Result<Option<FieldRef>>
130    where
131        F1: Fn(&Field) -> Result<Option<i64>>,
132        F2: Fn(&Field) -> bool,
133    {
134        for (pos, field) in fields.iter().enumerate() {
135            let id = field_id_fetch_func(field)?;
136            if let Some(id) = id
137                && target_field_id == id
138            {
139                index_vec.push(pos);
140                return Ok(Some(field.clone()));
141            }
142            if let DataType::Struct(inner) = field.data_type()
143                && searchable_field_func(field)
144                && let Some(res) = Self::fetch_field_index(
145                    inner,
146                    index_vec,
147                    target_field_id,
148                    field_id_fetch_func,
149                    searchable_field_func,
150                )?
151            {
152                index_vec.push(pos);
153                return Ok(Some(res));
154            }
155        }
156        Ok(None)
157    }
158
159    /// Return the reference of projected schema
160    pub(crate) fn projected_schema_ref(&self) -> &SchemaRef {
161        &self.projected_schema
162    }
163
164    /// Do projection with record batch
165    pub(crate) fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
166        RecordBatch::try_new(
167            self.projected_schema.clone(),
168            self.project_column(batch.columns())?,
169        )
170        .map_err(|err| Error::new(ErrorKind::DataInvalid, format!("{err}")))
171    }
172
173    /// Do projection with columns
174    pub fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
175        self.field_indices
176            .iter()
177            .map(|index_vec| Self::get_column_by_field_index(batch, index_vec))
178            .collect::<Result<Vec<_>>>()
179    }
180
181    fn get_column_by_field_index(batch: &[ArrayRef], field_index: &[usize]) -> Result<ArrayRef> {
182        let mut rev_iterator = field_index.iter().rev();
183        let mut array = batch[*rev_iterator.next().unwrap()].clone();
184        let mut null_buffer = array.logical_nulls();
185        for idx in rev_iterator {
186            array = array
187                .as_any()
188                .downcast_ref::<StructArray>()
189                .ok_or(Error::new(
190                    ErrorKind::Unexpected,
191                    "Cannot convert Array to StructArray",
192                ))?
193                .column(*idx)
194                .clone();
195            null_buffer = NullBuffer::union(null_buffer.as_ref(), array.logical_nulls().as_ref());
196        }
197        Ok(make_array(
198            array.to_data().into_builder().nulls(null_buffer).build()?,
199        ))
200    }
201}
202
203#[cfg(test)]
204mod test {
205    use std::sync::Arc;
206
207    use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray};
208    use arrow_schema::{DataType, Field, Fields, Schema};
209
210    use crate::arrow::record_batch_projector::RecordBatchProjector;
211    use crate::spec::{NestedField, PrimitiveType, Schema as IcebergSchema, Type};
212    use crate::{Error, ErrorKind};
213
214    #[test]
215    fn test_record_batch_projector_nested_level() {
216        let inner_fields = vec![
217            Field::new("inner_field1", DataType::Int32, false),
218            Field::new("inner_field2", DataType::Utf8, false),
219        ];
220        let fields = vec![
221            Field::new("field1", DataType::Int32, false),
222            Field::new(
223                "field2",
224                DataType::Struct(Fields::from(inner_fields.clone())),
225                false,
226            ),
227        ];
228        let schema = Arc::new(Schema::new(fields));
229
230        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
231            "field1" => Ok(Some(1)),
232            "field2" => Ok(Some(2)),
233            "inner_field1" => Ok(Some(3)),
234            "inner_field2" => Ok(Some(4)),
235            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
236        };
237        let projector =
238            RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true)
239                .unwrap();
240
241        assert_eq!(projector.field_indices.len(), 2);
242        assert_eq!(projector.field_indices[0], vec![0]);
243        assert_eq!(projector.field_indices[1], vec![0, 1]);
244
245        let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
246        let inner_int_array = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
247        let inner_string_array = Arc::new(StringArray::from(vec!["x", "y", "z"])) as ArrayRef;
248        let struct_array = Arc::new(StructArray::from(vec![
249            (
250                Arc::new(inner_fields[0].clone()),
251                inner_int_array as ArrayRef,
252            ),
253            (
254                Arc::new(inner_fields[1].clone()),
255                inner_string_array as ArrayRef,
256            ),
257        ])) as ArrayRef;
258        let batch = RecordBatch::try_new(schema, vec![int_array, struct_array]).unwrap();
259
260        let projected_batch = projector.project_batch(batch).unwrap();
261        assert_eq!(projected_batch.num_columns(), 2);
262        let projected_int_array = projected_batch
263            .column(0)
264            .as_any()
265            .downcast_ref::<Int32Array>()
266            .unwrap();
267        let projected_inner_int_array = projected_batch
268            .column(1)
269            .as_any()
270            .downcast_ref::<Int32Array>()
271            .unwrap();
272
273        assert_eq!(projected_int_array.values(), &[1, 2, 3]);
274        assert_eq!(projected_inner_int_array.values(), &[4, 5, 6]);
275    }
276
277    #[test]
278    fn test_field_not_found() {
279        let inner_fields = vec![
280            Field::new("inner_field1", DataType::Int32, false),
281            Field::new("inner_field2", DataType::Utf8, false),
282        ];
283
284        let fields = vec![
285            Field::new("field1", DataType::Int32, false),
286            Field::new(
287                "field2",
288                DataType::Struct(Fields::from(inner_fields.clone())),
289                false,
290            ),
291        ];
292        let schema = Arc::new(Schema::new(fields));
293
294        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
295            "field1" => Ok(Some(1)),
296            "field2" => Ok(Some(2)),
297            "inner_field1" => Ok(Some(3)),
298            "inner_field2" => Ok(Some(4)),
299            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
300        };
301        let projector =
302            RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true);
303
304        assert!(projector.is_err());
305    }
306
307    #[test]
308    fn test_field_not_reachable() {
309        let inner_fields = vec![
310            Field::new("inner_field1", DataType::Int32, false),
311            Field::new("inner_field2", DataType::Utf8, false),
312        ];
313
314        let fields = vec![
315            Field::new("field1", DataType::Int32, false),
316            Field::new(
317                "field2",
318                DataType::Struct(Fields::from(inner_fields.clone())),
319                false,
320            ),
321        ];
322        let schema = Arc::new(Schema::new(fields));
323
324        let field_id_fetch_func = |field: &Field| match field.name().as_str() {
325            "field1" => Ok(Some(1)),
326            "field2" => Ok(Some(2)),
327            "inner_field1" => Ok(Some(3)),
328            "inner_field2" => Ok(Some(4)),
329            _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
330        };
331        let projector =
332            RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false);
333        assert!(projector.is_err());
334
335        let projector =
336            RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
337        assert!(projector.is_ok());
338    }
339
340    #[test]
341    fn test_from_iceberg_schema() {
342        let iceberg_schema = IcebergSchema::builder()
343            .with_schema_id(0)
344            .with_fields(vec![
345                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
346                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
347                NestedField::optional(3, "age", Type::Primitive(PrimitiveType::Int)).into(),
348            ])
349            .build()
350            .unwrap();
351
352        let projector =
353            RecordBatchProjector::from_iceberg_schema(Arc::new(iceberg_schema), &[1, 3]).unwrap();
354
355        assert_eq!(projector.field_indices.len(), 2);
356        assert_eq!(projector.projected_schema_ref().fields().len(), 2);
357        assert_eq!(projector.projected_schema_ref().field(0).name(), "id");
358        assert_eq!(projector.projected_schema_ref().field(1).name(), "age");
359    }
360}