1use super::*;
19
20struct PruneColumn {
21 selected: HashSet<i32>,
22 select_full_types: bool,
23}
24
25pub fn prune_columns(
27 schema: &Schema,
28 selected: impl IntoIterator<Item = i32>,
29 select_full_types: bool,
30) -> Result<Type> {
31 let mut visitor = PruneColumn::new(HashSet::from_iter(selected), select_full_types);
32 let result = visit_schema(schema, &mut visitor);
33
34 match result {
35 Ok(s) => {
36 if let Some(struct_type) = s {
37 Ok(struct_type)
38 } else {
39 Ok(Type::Struct(StructType::default()))
40 }
41 }
42 Err(e) => Err(e),
43 }
44}
45
46impl PruneColumn {
47 fn new(selected: HashSet<i32>, select_full_types: bool) -> Self {
48 Self {
49 selected,
50 select_full_types,
51 }
52 }
53
54 fn project_selected_struct(projected_field: Option<Type>) -> Result<StructType> {
55 match projected_field {
56 Some(Type::Struct(s)) => Ok(s),
58 Some(_) => Err(Error::new(
59 ErrorKind::Unexpected,
60 "Projected field with struct type must be struct".to_string(),
61 )),
62 None => Ok(StructType::default()),
64 }
65 }
66 fn project_list(list: &ListType, element_result: Type) -> Result<ListType> {
67 if *list.element_field.field_type == element_result {
68 return Ok(list.clone());
69 }
70 Ok(ListType {
71 element_field: Arc::new(NestedField {
72 id: list.element_field.id,
73 name: list.element_field.name.clone(),
74 required: list.element_field.required,
75 field_type: Box::new(element_result),
76 doc: list.element_field.doc.clone(),
77 initial_default: list.element_field.initial_default.clone(),
78 write_default: list.element_field.write_default.clone(),
79 }),
80 })
81 }
82 fn project_map(map: &MapType, value_result: Type) -> Result<MapType> {
83 if *map.value_field.field_type == value_result {
84 return Ok(map.clone());
85 }
86 Ok(MapType {
87 key_field: map.key_field.clone(),
88 value_field: Arc::new(NestedField {
89 id: map.value_field.id,
90 name: map.value_field.name.clone(),
91 required: map.value_field.required,
92 field_type: Box::new(value_result),
93 doc: map.value_field.doc.clone(),
94 initial_default: map.value_field.initial_default.clone(),
95 write_default: map.value_field.write_default.clone(),
96 }),
97 })
98 }
99}
100
101impl SchemaVisitor for PruneColumn {
102 type T = Option<Type>;
103
104 fn schema(&mut self, _schema: &Schema, value: Option<Type>) -> Result<Option<Type>> {
105 Ok(Some(value.unwrap()))
106 }
107
108 fn field(&mut self, field: &NestedFieldRef, value: Option<Type>) -> Result<Option<Type>> {
109 if self.selected.contains(&field.id) {
110 if self.select_full_types {
111 Ok(Some(*field.field_type.clone()))
112 } else if field.field_type.is_struct() {
113 return Ok(Some(Type::Struct(PruneColumn::project_selected_struct(
114 value,
115 )?)));
116 } else if !field.field_type.is_nested() {
117 return Ok(Some(*field.field_type.clone()));
118 } else {
119 return Err(Error::new(
120 ErrorKind::DataInvalid,
121 "Can't project list or map field directly when not selecting full type."
122 .to_string(),
123 )
124 .with_context("field_id", field.id.to_string())
125 .with_context("field_type", field.field_type.to_string()));
126 }
127 } else {
128 Ok(value)
129 }
130 }
131
132 fn r#struct(
133 &mut self,
134 r#struct: &StructType,
135 results: Vec<Option<Type>>,
136 ) -> Result<Option<Type>> {
137 let fields = r#struct.fields();
138 let mut selected_field = Vec::with_capacity(fields.len());
139 let mut same_type = true;
140
141 for (field, projected_type) in zip_eq(fields.iter(), results.iter()) {
142 if let Some(projected_type) = projected_type {
143 if *field.field_type == *projected_type {
144 selected_field.push(field.clone());
145 } else {
146 same_type = false;
147 let new_field = NestedField {
148 id: field.id,
149 name: field.name.clone(),
150 required: field.required,
151 field_type: Box::new(projected_type.clone()),
152 doc: field.doc.clone(),
153 initial_default: field.initial_default.clone(),
154 write_default: field.write_default.clone(),
155 };
156 selected_field.push(Arc::new(new_field));
157 }
158 }
159 }
160
161 if !selected_field.is_empty() {
162 if selected_field.len() == fields.len() && same_type {
163 return Ok(Some(Type::Struct(r#struct.clone())));
164 } else {
165 return Ok(Some(Type::Struct(StructType::new(selected_field))));
166 }
167 }
168 Ok(None)
169 }
170
171 fn list(&mut self, list: &ListType, value: Option<Type>) -> Result<Option<Type>> {
172 if self.selected.contains(&list.element_field.id) {
173 if self.select_full_types {
174 Ok(Some(Type::List(list.clone())))
175 } else if list.element_field.field_type.is_struct() {
176 let projected_struct = PruneColumn::project_selected_struct(value).unwrap();
177 return Ok(Some(Type::List(PruneColumn::project_list(
178 list,
179 Type::Struct(projected_struct),
180 )?)));
181 } else if list.element_field.field_type.is_primitive() {
182 return Ok(Some(Type::List(list.clone())));
183 } else {
184 return Err(Error::new(
185 ErrorKind::DataInvalid,
186 format!(
187 "Cannot explicitly project List or Map types, List element {} of type {} was selected",
188 list.element_field.id, list.element_field.field_type
189 ),
190 ));
191 }
192 } else if let Some(result) = value {
193 Ok(Some(Type::List(PruneColumn::project_list(list, result)?)))
194 } else {
195 Ok(None)
196 }
197 }
198
199 fn map(
200 &mut self,
201 map: &MapType,
202 _key_value: Option<Type>,
203 value: Option<Type>,
204 ) -> Result<Option<Type>> {
205 if self.selected.contains(&map.value_field.id) {
206 if self.select_full_types {
207 Ok(Some(Type::Map(map.clone())))
208 } else if map.value_field.field_type.is_struct() {
209 let projected_struct =
210 PruneColumn::project_selected_struct(Some(value.unwrap())).unwrap();
211 return Ok(Some(Type::Map(PruneColumn::project_map(
212 map,
213 Type::Struct(projected_struct),
214 )?)));
215 } else if map.value_field.field_type.is_primitive() {
216 return Ok(Some(Type::Map(map.clone())));
217 } else {
218 return Err(Error::new(
219 ErrorKind::DataInvalid,
220 format!(
221 "Cannot explicitly project List or Map types, Map value {} of type {} was selected",
222 map.value_field.id, map.value_field.field_type
223 ),
224 ));
225 }
226 } else if let Some(value_result) = value {
227 return Ok(Some(Type::Map(PruneColumn::project_map(
228 map,
229 value_result,
230 )?)));
231 } else if self.selected.contains(&map.key_field.id) {
232 Ok(Some(Type::Map(map.clone())))
233 } else {
234 Ok(None)
235 }
236 }
237
238 fn primitive(&mut self, _p: &PrimitiveType) -> Result<Option<Type>> {
239 Ok(None)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use Type::Primitive;
246
247 use super::*;
248 use crate::spec::schema::tests::table_schema_nested;
249
250 #[test]
251 fn test_schema_prune_columns_string() {
252 let expected_type = Type::from(
253 Schema::builder()
254 .with_fields(vec![
255 NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
256 ])
257 .build()
258 .unwrap()
259 .as_struct()
260 .clone(),
261 );
262 let schema = table_schema_nested();
263 let selected: HashSet<i32> = HashSet::from([1]);
264 let result = prune_columns(&schema, selected, false);
265 assert!(result.is_ok());
266 assert_eq!(result.unwrap(), expected_type);
267 }
268
269 #[test]
270 fn test_schema_prune_columns_string_full() {
271 let expected_type = Type::from(
272 Schema::builder()
273 .with_fields(vec![
274 NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
275 ])
276 .build()
277 .unwrap()
278 .as_struct()
279 .clone(),
280 );
281 let schema = table_schema_nested();
282 let selected: HashSet<i32> = HashSet::from([1]);
283 let result = prune_columns(&schema, selected, true);
284 assert!(result.is_ok());
285 assert_eq!(result.unwrap(), expected_type);
286 }
287
288 #[test]
289 fn test_schema_prune_columns_list() {
290 let expected_type = Type::from(
291 Schema::builder()
292 .with_fields(vec![
293 NestedField::required(
294 4,
295 "qux",
296 Type::List(ListType {
297 element_field: NestedField::list_element(
298 5,
299 Type::Primitive(PrimitiveType::String),
300 true,
301 )
302 .into(),
303 }),
304 )
305 .into(),
306 ])
307 .build()
308 .unwrap()
309 .as_struct()
310 .clone(),
311 );
312 let schema = table_schema_nested();
313 let selected: HashSet<i32> = HashSet::from([5]);
314 let result = prune_columns(&schema, selected, false);
315 assert!(result.is_ok());
316 assert_eq!(result.unwrap(), expected_type);
317 }
318
319 #[test]
320 fn test_prune_columns_list_itself() {
321 let schema = table_schema_nested();
322 let selected: HashSet<i32> = HashSet::from([4]);
323 let result = prune_columns(&schema, selected, false);
324 assert!(result.is_err());
325 }
326
327 #[test]
328 fn test_schema_prune_columns_list_full() {
329 let expected_type = Type::from(
330 Schema::builder()
331 .with_fields(vec![
332 NestedField::required(
333 4,
334 "qux",
335 Type::List(ListType {
336 element_field: NestedField::list_element(
337 5,
338 Type::Primitive(PrimitiveType::String),
339 true,
340 )
341 .into(),
342 }),
343 )
344 .into(),
345 ])
346 .build()
347 .unwrap()
348 .as_struct()
349 .clone(),
350 );
351 let schema = table_schema_nested();
352 let selected: HashSet<i32> = HashSet::from([5]);
353 let result = prune_columns(&schema, selected, true);
354 assert!(result.is_ok());
355 assert_eq!(result.unwrap(), expected_type);
356 }
357
358 #[test]
359 fn test_prune_columns_map() {
360 let expected_type = Type::from(
361 Schema::builder()
362 .with_fields(vec![
363 NestedField::required(
364 6,
365 "quux",
366 Type::Map(MapType {
367 key_field: NestedField::map_key_element(
368 7,
369 Type::Primitive(PrimitiveType::String),
370 )
371 .into(),
372 value_field: NestedField::map_value_element(
373 8,
374 Type::Map(MapType {
375 key_field: NestedField::map_key_element(
376 9,
377 Type::Primitive(PrimitiveType::String),
378 )
379 .into(),
380 value_field: NestedField::map_value_element(
381 10,
382 Type::Primitive(PrimitiveType::Int),
383 true,
384 )
385 .into(),
386 }),
387 true,
388 )
389 .into(),
390 }),
391 )
392 .into(),
393 ])
394 .build()
395 .unwrap()
396 .as_struct()
397 .clone(),
398 );
399 let schema = table_schema_nested();
400 let selected: HashSet<i32> = HashSet::from([9]);
401 let result = prune_columns(&schema, selected, false);
402 assert!(result.is_ok());
403 assert_eq!(result.unwrap(), expected_type);
404 }
405
406 #[test]
407 fn test_prune_columns_map_itself() {
408 let schema = table_schema_nested();
409 let selected: HashSet<i32> = HashSet::from([6]);
410 let result = prune_columns(&schema, selected, false);
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_prune_columns_map_full() {
416 let expected_type = Type::from(
417 Schema::builder()
418 .with_fields(vec![
419 NestedField::required(
420 6,
421 "quux",
422 Type::Map(MapType {
423 key_field: NestedField::map_key_element(
424 7,
425 Type::Primitive(PrimitiveType::String),
426 )
427 .into(),
428 value_field: NestedField::map_value_element(
429 8,
430 Type::Map(MapType {
431 key_field: NestedField::map_key_element(
432 9,
433 Type::Primitive(PrimitiveType::String),
434 )
435 .into(),
436 value_field: NestedField::map_value_element(
437 10,
438 Type::Primitive(PrimitiveType::Int),
439 true,
440 )
441 .into(),
442 }),
443 true,
444 )
445 .into(),
446 }),
447 )
448 .into(),
449 ])
450 .build()
451 .unwrap()
452 .as_struct()
453 .clone(),
454 );
455 let schema = table_schema_nested();
456 let selected: HashSet<i32> = HashSet::from([9]);
457 let result = prune_columns(&schema, selected, true);
458 assert!(result.is_ok());
459 assert_eq!(result.unwrap(), expected_type);
460 }
461
462 #[test]
463 fn test_prune_columns_map_key() {
464 let expected_type = Type::from(
465 Schema::builder()
466 .with_fields(vec![
467 NestedField::required(
468 6,
469 "quux",
470 Type::Map(MapType {
471 key_field: NestedField::map_key_element(
472 7,
473 Type::Primitive(PrimitiveType::String),
474 )
475 .into(),
476 value_field: NestedField::map_value_element(
477 8,
478 Type::Map(MapType {
479 key_field: NestedField::map_key_element(
480 9,
481 Type::Primitive(PrimitiveType::String),
482 )
483 .into(),
484 value_field: NestedField::map_value_element(
485 10,
486 Type::Primitive(PrimitiveType::Int),
487 true,
488 )
489 .into(),
490 }),
491 true,
492 )
493 .into(),
494 }),
495 )
496 .into(),
497 ])
498 .build()
499 .unwrap()
500 .as_struct()
501 .clone(),
502 );
503 let schema = table_schema_nested();
504 let selected: HashSet<i32> = HashSet::from([10]);
505 let result = prune_columns(&schema, selected, false);
506 assert!(result.is_ok());
507 assert_eq!(result.unwrap(), expected_type);
508 }
509
510 #[test]
511 fn test_prune_columns_struct() {
512 let expected_type = Type::from(
513 Schema::builder()
514 .with_fields(vec![
515 NestedField::optional(
516 15,
517 "person",
518 Type::Struct(StructType::new(vec![
519 NestedField::optional(
520 16,
521 "name",
522 Type::Primitive(PrimitiveType::String),
523 )
524 .into(),
525 ])),
526 )
527 .into(),
528 ])
529 .build()
530 .unwrap()
531 .as_struct()
532 .clone(),
533 );
534 let schema = table_schema_nested();
535 let selected: HashSet<i32> = HashSet::from([16]);
536 let result = prune_columns(&schema, selected, false);
537 assert!(result.is_ok());
538 assert_eq!(result.unwrap(), expected_type);
539 }
540
541 #[test]
542 fn test_prune_columns_struct_full() {
543 let expected_type = Type::from(
544 Schema::builder()
545 .with_fields(vec![
546 NestedField::optional(
547 15,
548 "person",
549 Type::Struct(StructType::new(vec![
550 NestedField::optional(
551 16,
552 "name",
553 Type::Primitive(PrimitiveType::String),
554 )
555 .into(),
556 ])),
557 )
558 .into(),
559 ])
560 .build()
561 .unwrap()
562 .as_struct()
563 .clone(),
564 );
565 let schema = table_schema_nested();
566 let selected: HashSet<i32> = HashSet::from([16]);
567 let result = prune_columns(&schema, selected, true);
568 assert!(result.is_ok());
569 assert_eq!(result.unwrap(), expected_type);
570 }
571
572 #[test]
573 fn test_prune_columns_empty_struct() {
574 let schema_with_empty_struct_field = Schema::builder()
575 .with_fields(vec![
576 NestedField::optional(15, "person", Type::Struct(StructType::new(vec![]))).into(),
577 ])
578 .build()
579 .unwrap();
580 let expected_type = Type::from(
581 Schema::builder()
582 .with_fields(vec![
583 NestedField::optional(15, "person", Type::Struct(StructType::new(vec![])))
584 .into(),
585 ])
586 .build()
587 .unwrap()
588 .as_struct()
589 .clone(),
590 );
591 let selected: HashSet<i32> = HashSet::from([15]);
592 let result = prune_columns(&schema_with_empty_struct_field, selected, false);
593 assert!(result.is_ok());
594 assert_eq!(result.unwrap(), expected_type);
595 }
596
597 #[test]
598 fn test_prune_columns_empty_struct_full() {
599 let schema_with_empty_struct_field = Schema::builder()
600 .with_fields(vec![
601 NestedField::optional(15, "person", Type::Struct(StructType::new(vec![]))).into(),
602 ])
603 .build()
604 .unwrap();
605 let expected_type = Type::from(
606 Schema::builder()
607 .with_fields(vec![
608 NestedField::optional(15, "person", Type::Struct(StructType::new(vec![])))
609 .into(),
610 ])
611 .build()
612 .unwrap()
613 .as_struct()
614 .clone(),
615 );
616 let selected: HashSet<i32> = HashSet::from([15]);
617 let result = prune_columns(&schema_with_empty_struct_field, selected, true);
618 assert!(result.is_ok());
619 assert_eq!(result.unwrap(), expected_type);
620 }
621
622 #[test]
623 fn test_prune_columns_struct_in_map() {
624 let schema_with_struct_in_map_field = Schema::builder()
625 .with_schema_id(1)
626 .with_fields(vec![
627 NestedField::required(
628 6,
629 "id_to_person",
630 Type::Map(MapType {
631 key_field: NestedField::map_key_element(
632 7,
633 Type::Primitive(PrimitiveType::Int),
634 )
635 .into(),
636 value_field: NestedField::map_value_element(
637 8,
638 Type::Struct(StructType::new(vec![
639 NestedField::optional(10, "name", Primitive(PrimitiveType::String))
640 .into(),
641 NestedField::required(11, "age", Primitive(PrimitiveType::Int))
642 .into(),
643 ])),
644 true,
645 )
646 .into(),
647 }),
648 )
649 .into(),
650 ])
651 .build()
652 .unwrap();
653 let expected_type = Type::from(
654 Schema::builder()
655 .with_fields(vec![
656 NestedField::required(
657 6,
658 "id_to_person",
659 Type::Map(MapType {
660 key_field: NestedField::map_key_element(
661 7,
662 Type::Primitive(PrimitiveType::Int),
663 )
664 .into(),
665 value_field: NestedField::map_value_element(
666 8,
667 Type::Struct(StructType::new(vec![
668 NestedField::required(11, "age", Primitive(PrimitiveType::Int))
669 .into(),
670 ])),
671 true,
672 )
673 .into(),
674 }),
675 )
676 .into(),
677 ])
678 .build()
679 .unwrap()
680 .as_struct()
681 .clone(),
682 );
683 let selected: HashSet<i32> = HashSet::from([11]);
684 let result = prune_columns(&schema_with_struct_in_map_field, selected, false);
685 assert!(result.is_ok());
686 assert_eq!(result.unwrap(), expected_type);
687 }
688 #[test]
689 fn test_prune_columns_struct_in_map_full() {
690 let schema = Schema::builder()
691 .with_schema_id(1)
692 .with_fields(vec![
693 NestedField::required(
694 6,
695 "id_to_person",
696 Type::Map(MapType {
697 key_field: NestedField::map_key_element(
698 7,
699 Type::Primitive(PrimitiveType::Int),
700 )
701 .into(),
702 value_field: NestedField::map_value_element(
703 8,
704 Type::Struct(StructType::new(vec![
705 NestedField::optional(10, "name", Primitive(PrimitiveType::String))
706 .into(),
707 NestedField::required(11, "age", Primitive(PrimitiveType::Int))
708 .into(),
709 ])),
710 true,
711 )
712 .into(),
713 }),
714 )
715 .into(),
716 ])
717 .build()
718 .unwrap();
719 let expected_type = Type::from(
720 Schema::builder()
721 .with_fields(vec![
722 NestedField::required(
723 6,
724 "id_to_person",
725 Type::Map(MapType {
726 key_field: NestedField::map_key_element(
727 7,
728 Type::Primitive(PrimitiveType::Int),
729 )
730 .into(),
731 value_field: NestedField::map_value_element(
732 8,
733 Type::Struct(StructType::new(vec![
734 NestedField::required(11, "age", Primitive(PrimitiveType::Int))
735 .into(),
736 ])),
737 true,
738 )
739 .into(),
740 }),
741 )
742 .into(),
743 ])
744 .build()
745 .unwrap()
746 .as_struct()
747 .clone(),
748 );
749 let selected: HashSet<i32> = HashSet::from([11]);
750 let result = prune_columns(&schema, selected, true);
751 assert!(result.is_ok());
752 assert_eq!(result.unwrap(), expected_type);
753 }
754
755 #[test]
756 fn test_prune_columns_select_original_schema() {
757 let schema = table_schema_nested();
758 let selected: HashSet<i32> = (0..schema.highest_field_id() + 1).collect();
759 let result = prune_columns(&schema, selected, true);
760 assert!(result.is_ok());
761 assert_eq!(result.unwrap(), Type::Struct(schema.as_struct().clone()));
762 }
763}