1use std::sync::Arc;
19
20use arrow_array::{ArrayRef, RecordBatch, StructArray, make_array};
21use arrow_buffer::NullBuffer;
22use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
23use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
24
25use crate::arrow::schema::schema_to_arrow_schema;
26use crate::error::Result;
27use crate::spec::Schema as IcebergSchema;
28use crate::{Error, ErrorKind};
29
30#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct RecordBatchProjector {
33 field_indices: Vec<Vec<usize>>,
37 projected_schema: SchemaRef,
39}
40
41impl RecordBatchProjector {
42 pub(crate) fn new<F1, F2>(
49 original_schema: SchemaRef,
50 field_ids: &[i32],
51 field_id_fetch_func: F1,
52 searchable_field_func: F2,
53 ) -> Result<Self>
54 where
55 F1: Fn(&Field) -> Result<Option<i64>>,
56 F2: Fn(&Field) -> bool,
57 {
58 let mut field_indices = Vec::with_capacity(field_ids.len());
59 let mut fields = Vec::with_capacity(field_ids.len());
60 for &id in field_ids {
61 let mut field_index = vec![];
62 let field = Self::fetch_field_index(
63 original_schema.fields(),
64 &mut field_index,
65 id as i64,
66 &field_id_fetch_func,
67 &searchable_field_func,
68 )?
69 .ok_or_else(|| {
70 Error::new(ErrorKind::Unexpected, "Field not found")
71 .with_context("field_id", id.to_string())
72 })?;
73 fields.push(field.clone());
74 field_indices.push(field_index);
75 }
76 let delete_arrow_schema = Arc::new(Schema::new(fields));
77 Ok(Self {
78 field_indices,
79 projected_schema: delete_arrow_schema,
80 })
81 }
82
83 pub fn from_iceberg_schema(
92 iceberg_schema: Arc<IcebergSchema>,
93 target_field_ids: &[i32],
94 ) -> Result<Self> {
95 let arrow_schema_with_ids = Arc::new(schema_to_arrow_schema(&iceberg_schema)?);
96
97 let field_id_fetch_func = |field: &Field| -> Result<Option<i64>> {
98 if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
99 let field_id = value.parse::<i32>().map_err(|e| {
100 Error::new(
101 ErrorKind::DataInvalid,
102 "Failed to parse field id".to_string(),
103 )
104 .with_context("value", value)
105 .with_source(e)
106 })?;
107 Ok(Some(field_id as i64))
108 } else {
109 Ok(None)
110 }
111 };
112
113 let searchable_field_func = |_field: &Field| -> bool { true };
114
115 Self::new(
116 arrow_schema_with_ids,
117 target_field_ids,
118 field_id_fetch_func,
119 searchable_field_func,
120 )
121 }
122
123 fn fetch_field_index<F1, F2>(
124 fields: &Fields,
125 index_vec: &mut Vec<usize>,
126 target_field_id: i64,
127 field_id_fetch_func: &F1,
128 searchable_field_func: &F2,
129 ) -> Result<Option<FieldRef>>
130 where
131 F1: Fn(&Field) -> Result<Option<i64>>,
132 F2: Fn(&Field) -> bool,
133 {
134 for (pos, field) in fields.iter().enumerate() {
135 let id = field_id_fetch_func(field)?;
136 if let Some(id) = id
137 && target_field_id == id
138 {
139 index_vec.push(pos);
140 return Ok(Some(field.clone()));
141 }
142 if let DataType::Struct(inner) = field.data_type()
143 && searchable_field_func(field)
144 && let Some(res) = Self::fetch_field_index(
145 inner,
146 index_vec,
147 target_field_id,
148 field_id_fetch_func,
149 searchable_field_func,
150 )?
151 {
152 index_vec.push(pos);
153 return Ok(Some(res));
154 }
155 }
156 Ok(None)
157 }
158
159 pub(crate) fn projected_schema_ref(&self) -> &SchemaRef {
161 &self.projected_schema
162 }
163
164 pub(crate) fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
166 RecordBatch::try_new(
167 self.projected_schema.clone(),
168 self.project_column(batch.columns())?,
169 )
170 .map_err(|err| Error::new(ErrorKind::DataInvalid, format!("{err}")))
171 }
172
173 pub fn project_column(&self, batch: &[ArrayRef]) -> Result<Vec<ArrayRef>> {
175 self.field_indices
176 .iter()
177 .map(|index_vec| Self::get_column_by_field_index(batch, index_vec))
178 .collect::<Result<Vec<_>>>()
179 }
180
181 fn get_column_by_field_index(batch: &[ArrayRef], field_index: &[usize]) -> Result<ArrayRef> {
182 let mut rev_iterator = field_index.iter().rev();
183 let mut array = batch[*rev_iterator.next().unwrap()].clone();
184 let mut null_buffer = array.logical_nulls();
185 for idx in rev_iterator {
186 array = array
187 .as_any()
188 .downcast_ref::<StructArray>()
189 .ok_or(Error::new(
190 ErrorKind::Unexpected,
191 "Cannot convert Array to StructArray",
192 ))?
193 .column(*idx)
194 .clone();
195 null_buffer = NullBuffer::union(null_buffer.as_ref(), array.logical_nulls().as_ref());
196 }
197 Ok(make_array(
198 array.to_data().into_builder().nulls(null_buffer).build()?,
199 ))
200 }
201}
202
203#[cfg(test)]
204mod test {
205 use std::sync::Arc;
206
207 use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray};
208 use arrow_schema::{DataType, Field, Fields, Schema};
209
210 use crate::arrow::record_batch_projector::RecordBatchProjector;
211 use crate::spec::{NestedField, PrimitiveType, Schema as IcebergSchema, Type};
212 use crate::{Error, ErrorKind};
213
214 #[test]
215 fn test_record_batch_projector_nested_level() {
216 let inner_fields = vec![
217 Field::new("inner_field1", DataType::Int32, false),
218 Field::new("inner_field2", DataType::Utf8, false),
219 ];
220 let fields = vec![
221 Field::new("field1", DataType::Int32, false),
222 Field::new(
223 "field2",
224 DataType::Struct(Fields::from(inner_fields.clone())),
225 false,
226 ),
227 ];
228 let schema = Arc::new(Schema::new(fields));
229
230 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
231 "field1" => Ok(Some(1)),
232 "field2" => Ok(Some(2)),
233 "inner_field1" => Ok(Some(3)),
234 "inner_field2" => Ok(Some(4)),
235 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
236 };
237 let projector =
238 RecordBatchProjector::new(schema.clone(), &[1, 3], field_id_fetch_func, |_| true)
239 .unwrap();
240
241 assert_eq!(projector.field_indices.len(), 2);
242 assert_eq!(projector.field_indices[0], vec![0]);
243 assert_eq!(projector.field_indices[1], vec![0, 1]);
244
245 let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
246 let inner_int_array = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
247 let inner_string_array = Arc::new(StringArray::from(vec!["x", "y", "z"])) as ArrayRef;
248 let struct_array = Arc::new(StructArray::from(vec![
249 (
250 Arc::new(inner_fields[0].clone()),
251 inner_int_array as ArrayRef,
252 ),
253 (
254 Arc::new(inner_fields[1].clone()),
255 inner_string_array as ArrayRef,
256 ),
257 ])) as ArrayRef;
258 let batch = RecordBatch::try_new(schema, vec![int_array, struct_array]).unwrap();
259
260 let projected_batch = projector.project_batch(batch).unwrap();
261 assert_eq!(projected_batch.num_columns(), 2);
262 let projected_int_array = projected_batch
263 .column(0)
264 .as_any()
265 .downcast_ref::<Int32Array>()
266 .unwrap();
267 let projected_inner_int_array = projected_batch
268 .column(1)
269 .as_any()
270 .downcast_ref::<Int32Array>()
271 .unwrap();
272
273 assert_eq!(projected_int_array.values(), &[1, 2, 3]);
274 assert_eq!(projected_inner_int_array.values(), &[4, 5, 6]);
275 }
276
277 #[test]
278 fn test_field_not_found() {
279 let inner_fields = vec![
280 Field::new("inner_field1", DataType::Int32, false),
281 Field::new("inner_field2", DataType::Utf8, false),
282 ];
283
284 let fields = vec![
285 Field::new("field1", DataType::Int32, false),
286 Field::new(
287 "field2",
288 DataType::Struct(Fields::from(inner_fields.clone())),
289 false,
290 ),
291 ];
292 let schema = Arc::new(Schema::new(fields));
293
294 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
295 "field1" => Ok(Some(1)),
296 "field2" => Ok(Some(2)),
297 "inner_field1" => Ok(Some(3)),
298 "inner_field2" => Ok(Some(4)),
299 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
300 };
301 let projector =
302 RecordBatchProjector::new(schema.clone(), &[1, 5], field_id_fetch_func, |_| true);
303
304 assert!(projector.is_err());
305 }
306
307 #[test]
308 fn test_field_not_reachable() {
309 let inner_fields = vec![
310 Field::new("inner_field1", DataType::Int32, false),
311 Field::new("inner_field2", DataType::Utf8, false),
312 ];
313
314 let fields = vec![
315 Field::new("field1", DataType::Int32, false),
316 Field::new(
317 "field2",
318 DataType::Struct(Fields::from(inner_fields.clone())),
319 false,
320 ),
321 ];
322 let schema = Arc::new(Schema::new(fields));
323
324 let field_id_fetch_func = |field: &Field| match field.name().as_str() {
325 "field1" => Ok(Some(1)),
326 "field2" => Ok(Some(2)),
327 "inner_field1" => Ok(Some(3)),
328 "inner_field2" => Ok(Some(4)),
329 _ => Err(Error::new(ErrorKind::Unexpected, "Field id not found")),
330 };
331 let projector =
332 RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| false);
333 assert!(projector.is_err());
334
335 let projector =
336 RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
337 assert!(projector.is_ok());
338 }
339
340 #[test]
341 fn test_from_iceberg_schema() {
342 let iceberg_schema = IcebergSchema::builder()
343 .with_schema_id(0)
344 .with_fields(vec![
345 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
346 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
347 NestedField::optional(3, "age", Type::Primitive(PrimitiveType::Int)).into(),
348 ])
349 .build()
350 .unwrap();
351
352 let projector =
353 RecordBatchProjector::from_iceberg_schema(Arc::new(iceberg_schema), &[1, 3]).unwrap();
354
355 assert_eq!(projector.field_indices.len(), 2);
356 assert_eq!(projector.projected_schema_ref().fields().len(), 2);
357 assert_eq!(projector.projected_schema_ref().field(0).name(), "id");
358 assert_eq!(projector.projected_schema_ref().field(1).name(), "age");
359 }
360}