iceberg/spec/schema/
prune_columns.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::*;
19
20struct PruneColumn {
21    selected: HashSet<i32>,
22    select_full_types: bool,
23}
24
25/// Visit a schema and returns only the fields selected by id set
26pub fn prune_columns(
27    schema: &Schema,
28    selected: impl IntoIterator<Item = i32>,
29    select_full_types: bool,
30) -> Result<Type> {
31    let mut visitor = PruneColumn::new(HashSet::from_iter(selected), select_full_types);
32    let result = visit_schema(schema, &mut visitor);
33
34    match result {
35        Ok(s) => {
36            if let Some(struct_type) = s {
37                Ok(struct_type)
38            } else {
39                Ok(Type::Struct(StructType::default()))
40            }
41        }
42        Err(e) => Err(e),
43    }
44}
45
46impl PruneColumn {
47    fn new(selected: HashSet<i32>, select_full_types: bool) -> Self {
48        Self {
49            selected,
50            select_full_types,
51        }
52    }
53
54    fn project_selected_struct(projected_field: Option<Type>) -> Result<StructType> {
55        match projected_field {
56            // If the field is a StructType, return it as such
57            Some(Type::Struct(s)) => Ok(s),
58            Some(_) => Err(Error::new(
59                ErrorKind::Unexpected,
60                "Projected field with struct type must be struct".to_string(),
61            )),
62            // If projected_field is None or not a StructType, return an empty StructType
63            None => Ok(StructType::default()),
64        }
65    }
66    fn project_list(list: &ListType, element_result: Type) -> Result<ListType> {
67        if *list.element_field.field_type == element_result {
68            return Ok(list.clone());
69        }
70        Ok(ListType {
71            element_field: Arc::new(NestedField {
72                id: list.element_field.id,
73                name: list.element_field.name.clone(),
74                required: list.element_field.required,
75                field_type: Box::new(element_result),
76                doc: list.element_field.doc.clone(),
77                initial_default: list.element_field.initial_default.clone(),
78                write_default: list.element_field.write_default.clone(),
79            }),
80        })
81    }
82    fn project_map(map: &MapType, value_result: Type) -> Result<MapType> {
83        if *map.value_field.field_type == value_result {
84            return Ok(map.clone());
85        }
86        Ok(MapType {
87            key_field: map.key_field.clone(),
88            value_field: Arc::new(NestedField {
89                id: map.value_field.id,
90                name: map.value_field.name.clone(),
91                required: map.value_field.required,
92                field_type: Box::new(value_result),
93                doc: map.value_field.doc.clone(),
94                initial_default: map.value_field.initial_default.clone(),
95                write_default: map.value_field.write_default.clone(),
96            }),
97        })
98    }
99}
100
101impl SchemaVisitor for PruneColumn {
102    type T = Option<Type>;
103
104    fn schema(&mut self, _schema: &Schema, value: Option<Type>) -> Result<Option<Type>> {
105        Ok(Some(value.unwrap()))
106    }
107
108    fn field(&mut self, field: &NestedFieldRef, value: Option<Type>) -> Result<Option<Type>> {
109        if self.selected.contains(&field.id) {
110            if self.select_full_types {
111                Ok(Some(*field.field_type.clone()))
112            } else if field.field_type.is_struct() {
113                return Ok(Some(Type::Struct(PruneColumn::project_selected_struct(
114                    value,
115                )?)));
116            } else if !field.field_type.is_nested() {
117                return Ok(Some(*field.field_type.clone()));
118            } else {
119                return Err(Error::new(
120                    ErrorKind::DataInvalid,
121                    "Can't project list or map field directly when not selecting full type."
122                        .to_string(),
123                )
124                .with_context("field_id", field.id.to_string())
125                .with_context("field_type", field.field_type.to_string()));
126            }
127        } else {
128            Ok(value)
129        }
130    }
131
132    fn r#struct(
133        &mut self,
134        r#struct: &StructType,
135        results: Vec<Option<Type>>,
136    ) -> Result<Option<Type>> {
137        let fields = r#struct.fields();
138        let mut selected_field = Vec::with_capacity(fields.len());
139        let mut same_type = true;
140
141        for (field, projected_type) in zip_eq(fields.iter(), results.iter()) {
142            if let Some(projected_type) = projected_type {
143                if *field.field_type == *projected_type {
144                    selected_field.push(field.clone());
145                } else {
146                    same_type = false;
147                    let new_field = NestedField {
148                        id: field.id,
149                        name: field.name.clone(),
150                        required: field.required,
151                        field_type: Box::new(projected_type.clone()),
152                        doc: field.doc.clone(),
153                        initial_default: field.initial_default.clone(),
154                        write_default: field.write_default.clone(),
155                    };
156                    selected_field.push(Arc::new(new_field));
157                }
158            }
159        }
160
161        if !selected_field.is_empty() {
162            if selected_field.len() == fields.len() && same_type {
163                return Ok(Some(Type::Struct(r#struct.clone())));
164            } else {
165                return Ok(Some(Type::Struct(StructType::new(selected_field))));
166            }
167        }
168        Ok(None)
169    }
170
171    fn list(&mut self, list: &ListType, value: Option<Type>) -> Result<Option<Type>> {
172        if self.selected.contains(&list.element_field.id) {
173            if self.select_full_types {
174                Ok(Some(Type::List(list.clone())))
175            } else if list.element_field.field_type.is_struct() {
176                let projected_struct = PruneColumn::project_selected_struct(value).unwrap();
177                return Ok(Some(Type::List(PruneColumn::project_list(
178                    list,
179                    Type::Struct(projected_struct),
180                )?)));
181            } else if list.element_field.field_type.is_primitive() {
182                return Ok(Some(Type::List(list.clone())));
183            } else {
184                return Err(Error::new(
185                    ErrorKind::DataInvalid,
186                    format!(
187                        "Cannot explicitly project List or Map types, List element {} of type {} was selected",
188                        list.element_field.id, list.element_field.field_type
189                    ),
190                ));
191            }
192        } else if let Some(result) = value {
193            Ok(Some(Type::List(PruneColumn::project_list(list, result)?)))
194        } else {
195            Ok(None)
196        }
197    }
198
199    fn map(
200        &mut self,
201        map: &MapType,
202        _key_value: Option<Type>,
203        value: Option<Type>,
204    ) -> Result<Option<Type>> {
205        if self.selected.contains(&map.value_field.id) {
206            if self.select_full_types {
207                Ok(Some(Type::Map(map.clone())))
208            } else if map.value_field.field_type.is_struct() {
209                let projected_struct =
210                    PruneColumn::project_selected_struct(Some(value.unwrap())).unwrap();
211                return Ok(Some(Type::Map(PruneColumn::project_map(
212                    map,
213                    Type::Struct(projected_struct),
214                )?)));
215            } else if map.value_field.field_type.is_primitive() {
216                return Ok(Some(Type::Map(map.clone())));
217            } else {
218                return Err(Error::new(
219                    ErrorKind::DataInvalid,
220                    format!(
221                        "Cannot explicitly project List or Map types, Map value {} of type {} was selected",
222                        map.value_field.id, map.value_field.field_type
223                    ),
224                ));
225            }
226        } else if let Some(value_result) = value {
227            return Ok(Some(Type::Map(PruneColumn::project_map(
228                map,
229                value_result,
230            )?)));
231        } else if self.selected.contains(&map.key_field.id) {
232            Ok(Some(Type::Map(map.clone())))
233        } else {
234            Ok(None)
235        }
236    }
237
238    fn primitive(&mut self, _p: &PrimitiveType) -> Result<Option<Type>> {
239        Ok(None)
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use Type::Primitive;
246
247    use super::*;
248    use crate::spec::schema::tests::table_schema_nested;
249
250    #[test]
251    fn test_schema_prune_columns_string() {
252        let expected_type = Type::from(
253            Schema::builder()
254                .with_fields(vec![
255                    NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
256                ])
257                .build()
258                .unwrap()
259                .as_struct()
260                .clone(),
261        );
262        let schema = table_schema_nested();
263        let selected: HashSet<i32> = HashSet::from([1]);
264        let result = prune_columns(&schema, selected, false);
265        assert!(result.is_ok());
266        assert_eq!(result.unwrap(), expected_type);
267    }
268
269    #[test]
270    fn test_schema_prune_columns_string_full() {
271        let expected_type = Type::from(
272            Schema::builder()
273                .with_fields(vec![
274                    NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
275                ])
276                .build()
277                .unwrap()
278                .as_struct()
279                .clone(),
280        );
281        let schema = table_schema_nested();
282        let selected: HashSet<i32> = HashSet::from([1]);
283        let result = prune_columns(&schema, selected, true);
284        assert!(result.is_ok());
285        assert_eq!(result.unwrap(), expected_type);
286    }
287
288    #[test]
289    fn test_schema_prune_columns_list() {
290        let expected_type = Type::from(
291            Schema::builder()
292                .with_fields(vec![
293                    NestedField::required(
294                        4,
295                        "qux",
296                        Type::List(ListType {
297                            element_field: NestedField::list_element(
298                                5,
299                                Type::Primitive(PrimitiveType::String),
300                                true,
301                            )
302                            .into(),
303                        }),
304                    )
305                    .into(),
306                ])
307                .build()
308                .unwrap()
309                .as_struct()
310                .clone(),
311        );
312        let schema = table_schema_nested();
313        let selected: HashSet<i32> = HashSet::from([5]);
314        let result = prune_columns(&schema, selected, false);
315        assert!(result.is_ok());
316        assert_eq!(result.unwrap(), expected_type);
317    }
318
319    #[test]
320    fn test_prune_columns_list_itself() {
321        let schema = table_schema_nested();
322        let selected: HashSet<i32> = HashSet::from([4]);
323        let result = prune_columns(&schema, selected, false);
324        assert!(result.is_err());
325    }
326
327    #[test]
328    fn test_schema_prune_columns_list_full() {
329        let expected_type = Type::from(
330            Schema::builder()
331                .with_fields(vec![
332                    NestedField::required(
333                        4,
334                        "qux",
335                        Type::List(ListType {
336                            element_field: NestedField::list_element(
337                                5,
338                                Type::Primitive(PrimitiveType::String),
339                                true,
340                            )
341                            .into(),
342                        }),
343                    )
344                    .into(),
345                ])
346                .build()
347                .unwrap()
348                .as_struct()
349                .clone(),
350        );
351        let schema = table_schema_nested();
352        let selected: HashSet<i32> = HashSet::from([5]);
353        let result = prune_columns(&schema, selected, true);
354        assert!(result.is_ok());
355        assert_eq!(result.unwrap(), expected_type);
356    }
357
358    #[test]
359    fn test_prune_columns_map() {
360        let expected_type = Type::from(
361            Schema::builder()
362                .with_fields(vec![
363                    NestedField::required(
364                        6,
365                        "quux",
366                        Type::Map(MapType {
367                            key_field: NestedField::map_key_element(
368                                7,
369                                Type::Primitive(PrimitiveType::String),
370                            )
371                            .into(),
372                            value_field: NestedField::map_value_element(
373                                8,
374                                Type::Map(MapType {
375                                    key_field: NestedField::map_key_element(
376                                        9,
377                                        Type::Primitive(PrimitiveType::String),
378                                    )
379                                    .into(),
380                                    value_field: NestedField::map_value_element(
381                                        10,
382                                        Type::Primitive(PrimitiveType::Int),
383                                        true,
384                                    )
385                                    .into(),
386                                }),
387                                true,
388                            )
389                            .into(),
390                        }),
391                    )
392                    .into(),
393                ])
394                .build()
395                .unwrap()
396                .as_struct()
397                .clone(),
398        );
399        let schema = table_schema_nested();
400        let selected: HashSet<i32> = HashSet::from([9]);
401        let result = prune_columns(&schema, selected, false);
402        assert!(result.is_ok());
403        assert_eq!(result.unwrap(), expected_type);
404    }
405
406    #[test]
407    fn test_prune_columns_map_itself() {
408        let schema = table_schema_nested();
409        let selected: HashSet<i32> = HashSet::from([6]);
410        let result = prune_columns(&schema, selected, false);
411        assert!(result.is_err());
412    }
413
414    #[test]
415    fn test_prune_columns_map_full() {
416        let expected_type = Type::from(
417            Schema::builder()
418                .with_fields(vec![
419                    NestedField::required(
420                        6,
421                        "quux",
422                        Type::Map(MapType {
423                            key_field: NestedField::map_key_element(
424                                7,
425                                Type::Primitive(PrimitiveType::String),
426                            )
427                            .into(),
428                            value_field: NestedField::map_value_element(
429                                8,
430                                Type::Map(MapType {
431                                    key_field: NestedField::map_key_element(
432                                        9,
433                                        Type::Primitive(PrimitiveType::String),
434                                    )
435                                    .into(),
436                                    value_field: NestedField::map_value_element(
437                                        10,
438                                        Type::Primitive(PrimitiveType::Int),
439                                        true,
440                                    )
441                                    .into(),
442                                }),
443                                true,
444                            )
445                            .into(),
446                        }),
447                    )
448                    .into(),
449                ])
450                .build()
451                .unwrap()
452                .as_struct()
453                .clone(),
454        );
455        let schema = table_schema_nested();
456        let selected: HashSet<i32> = HashSet::from([9]);
457        let result = prune_columns(&schema, selected, true);
458        assert!(result.is_ok());
459        assert_eq!(result.unwrap(), expected_type);
460    }
461
462    #[test]
463    fn test_prune_columns_map_key() {
464        let expected_type = Type::from(
465            Schema::builder()
466                .with_fields(vec![
467                    NestedField::required(
468                        6,
469                        "quux",
470                        Type::Map(MapType {
471                            key_field: NestedField::map_key_element(
472                                7,
473                                Type::Primitive(PrimitiveType::String),
474                            )
475                            .into(),
476                            value_field: NestedField::map_value_element(
477                                8,
478                                Type::Map(MapType {
479                                    key_field: NestedField::map_key_element(
480                                        9,
481                                        Type::Primitive(PrimitiveType::String),
482                                    )
483                                    .into(),
484                                    value_field: NestedField::map_value_element(
485                                        10,
486                                        Type::Primitive(PrimitiveType::Int),
487                                        true,
488                                    )
489                                    .into(),
490                                }),
491                                true,
492                            )
493                            .into(),
494                        }),
495                    )
496                    .into(),
497                ])
498                .build()
499                .unwrap()
500                .as_struct()
501                .clone(),
502        );
503        let schema = table_schema_nested();
504        let selected: HashSet<i32> = HashSet::from([10]);
505        let result = prune_columns(&schema, selected, false);
506        assert!(result.is_ok());
507        assert_eq!(result.unwrap(), expected_type);
508    }
509
510    #[test]
511    fn test_prune_columns_struct() {
512        let expected_type = Type::from(
513            Schema::builder()
514                .with_fields(vec![
515                    NestedField::optional(
516                        15,
517                        "person",
518                        Type::Struct(StructType::new(vec![
519                            NestedField::optional(
520                                16,
521                                "name",
522                                Type::Primitive(PrimitiveType::String),
523                            )
524                            .into(),
525                        ])),
526                    )
527                    .into(),
528                ])
529                .build()
530                .unwrap()
531                .as_struct()
532                .clone(),
533        );
534        let schema = table_schema_nested();
535        let selected: HashSet<i32> = HashSet::from([16]);
536        let result = prune_columns(&schema, selected, false);
537        assert!(result.is_ok());
538        assert_eq!(result.unwrap(), expected_type);
539    }
540
541    #[test]
542    fn test_prune_columns_struct_full() {
543        let expected_type = Type::from(
544            Schema::builder()
545                .with_fields(vec![
546                    NestedField::optional(
547                        15,
548                        "person",
549                        Type::Struct(StructType::new(vec![
550                            NestedField::optional(
551                                16,
552                                "name",
553                                Type::Primitive(PrimitiveType::String),
554                            )
555                            .into(),
556                        ])),
557                    )
558                    .into(),
559                ])
560                .build()
561                .unwrap()
562                .as_struct()
563                .clone(),
564        );
565        let schema = table_schema_nested();
566        let selected: HashSet<i32> = HashSet::from([16]);
567        let result = prune_columns(&schema, selected, true);
568        assert!(result.is_ok());
569        assert_eq!(result.unwrap(), expected_type);
570    }
571
572    #[test]
573    fn test_prune_columns_empty_struct() {
574        let schema_with_empty_struct_field = Schema::builder()
575            .with_fields(vec![
576                NestedField::optional(15, "person", Type::Struct(StructType::new(vec![]))).into(),
577            ])
578            .build()
579            .unwrap();
580        let expected_type = Type::from(
581            Schema::builder()
582                .with_fields(vec![
583                    NestedField::optional(15, "person", Type::Struct(StructType::new(vec![])))
584                        .into(),
585                ])
586                .build()
587                .unwrap()
588                .as_struct()
589                .clone(),
590        );
591        let selected: HashSet<i32> = HashSet::from([15]);
592        let result = prune_columns(&schema_with_empty_struct_field, selected, false);
593        assert!(result.is_ok());
594        assert_eq!(result.unwrap(), expected_type);
595    }
596
597    #[test]
598    fn test_prune_columns_empty_struct_full() {
599        let schema_with_empty_struct_field = Schema::builder()
600            .with_fields(vec![
601                NestedField::optional(15, "person", Type::Struct(StructType::new(vec![]))).into(),
602            ])
603            .build()
604            .unwrap();
605        let expected_type = Type::from(
606            Schema::builder()
607                .with_fields(vec![
608                    NestedField::optional(15, "person", Type::Struct(StructType::new(vec![])))
609                        .into(),
610                ])
611                .build()
612                .unwrap()
613                .as_struct()
614                .clone(),
615        );
616        let selected: HashSet<i32> = HashSet::from([15]);
617        let result = prune_columns(&schema_with_empty_struct_field, selected, true);
618        assert!(result.is_ok());
619        assert_eq!(result.unwrap(), expected_type);
620    }
621
622    #[test]
623    fn test_prune_columns_struct_in_map() {
624        let schema_with_struct_in_map_field = Schema::builder()
625            .with_schema_id(1)
626            .with_fields(vec![
627                NestedField::required(
628                    6,
629                    "id_to_person",
630                    Type::Map(MapType {
631                        key_field: NestedField::map_key_element(
632                            7,
633                            Type::Primitive(PrimitiveType::Int),
634                        )
635                        .into(),
636                        value_field: NestedField::map_value_element(
637                            8,
638                            Type::Struct(StructType::new(vec![
639                                NestedField::optional(10, "name", Primitive(PrimitiveType::String))
640                                    .into(),
641                                NestedField::required(11, "age", Primitive(PrimitiveType::Int))
642                                    .into(),
643                            ])),
644                            true,
645                        )
646                        .into(),
647                    }),
648                )
649                .into(),
650            ])
651            .build()
652            .unwrap();
653        let expected_type = Type::from(
654            Schema::builder()
655                .with_fields(vec![
656                    NestedField::required(
657                        6,
658                        "id_to_person",
659                        Type::Map(MapType {
660                            key_field: NestedField::map_key_element(
661                                7,
662                                Type::Primitive(PrimitiveType::Int),
663                            )
664                            .into(),
665                            value_field: NestedField::map_value_element(
666                                8,
667                                Type::Struct(StructType::new(vec![
668                                    NestedField::required(11, "age", Primitive(PrimitiveType::Int))
669                                        .into(),
670                                ])),
671                                true,
672                            )
673                            .into(),
674                        }),
675                    )
676                    .into(),
677                ])
678                .build()
679                .unwrap()
680                .as_struct()
681                .clone(),
682        );
683        let selected: HashSet<i32> = HashSet::from([11]);
684        let result = prune_columns(&schema_with_struct_in_map_field, selected, false);
685        assert!(result.is_ok());
686        assert_eq!(result.unwrap(), expected_type);
687    }
688    #[test]
689    fn test_prune_columns_struct_in_map_full() {
690        let schema = Schema::builder()
691            .with_schema_id(1)
692            .with_fields(vec![
693                NestedField::required(
694                    6,
695                    "id_to_person",
696                    Type::Map(MapType {
697                        key_field: NestedField::map_key_element(
698                            7,
699                            Type::Primitive(PrimitiveType::Int),
700                        )
701                        .into(),
702                        value_field: NestedField::map_value_element(
703                            8,
704                            Type::Struct(StructType::new(vec![
705                                NestedField::optional(10, "name", Primitive(PrimitiveType::String))
706                                    .into(),
707                                NestedField::required(11, "age", Primitive(PrimitiveType::Int))
708                                    .into(),
709                            ])),
710                            true,
711                        )
712                        .into(),
713                    }),
714                )
715                .into(),
716            ])
717            .build()
718            .unwrap();
719        let expected_type = Type::from(
720            Schema::builder()
721                .with_fields(vec![
722                    NestedField::required(
723                        6,
724                        "id_to_person",
725                        Type::Map(MapType {
726                            key_field: NestedField::map_key_element(
727                                7,
728                                Type::Primitive(PrimitiveType::Int),
729                            )
730                            .into(),
731                            value_field: NestedField::map_value_element(
732                                8,
733                                Type::Struct(StructType::new(vec![
734                                    NestedField::required(11, "age", Primitive(PrimitiveType::Int))
735                                        .into(),
736                                ])),
737                                true,
738                            )
739                            .into(),
740                        }),
741                    )
742                    .into(),
743                ])
744                .build()
745                .unwrap()
746                .as_struct()
747                .clone(),
748        );
749        let selected: HashSet<i32> = HashSet::from([11]);
750        let result = prune_columns(&schema, selected, true);
751        assert!(result.is_ok());
752        assert_eq!(result.unwrap(), expected_type);
753    }
754
755    #[test]
756    fn test_prune_columns_select_original_schema() {
757        let schema = table_schema_nested();
758        let selected: HashSet<i32> = (0..schema.highest_field_id() + 1).collect();
759        let result = prune_columns(&schema, selected, true);
760        assert!(result.is_ok());
761        assert_eq!(result.unwrap(), Type::Struct(schema.as_struct().clone()));
762    }
763}