iceberg/arrow/
nan_val_cnt_visitor.rs1use std::collections::HashMap;
21use std::collections::hash_map::Entry;
22use std::sync::Arc;
23
24use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, StructArray};
25use arrow_schema::DataType;
26
27use crate::Result;
28use crate::arrow::{ArrowArrayAccessor, FieldMatchMode};
29use crate::spec::{
30 ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef, SchemaWithPartnerVisitor,
31 StructType, visit_struct_with_partner,
32};
33
34macro_rules! cast_and_update_cnt_map {
35 ($t:ty, $col:ident, $self:ident, $field_id:ident) => {
36 let nan_val_cnt = $col
37 .as_any()
38 .downcast_ref::<$t>()
39 .unwrap()
40 .iter()
41 .filter(|value| value.map_or(false, |v| v.is_nan()))
42 .count() as u64;
43
44 match $self.nan_value_counts.entry($field_id) {
45 Entry::Occupied(mut ele) => {
46 let total_nan_val_cnt = ele.get() + nan_val_cnt;
47 ele.insert(total_nan_val_cnt);
48 }
49 Entry::Vacant(v) => {
50 v.insert(nan_val_cnt);
51 }
52 };
53 };
54}
55
56macro_rules! count_float_nans {
57 ($col:ident, $self:ident, $field_id:ident) => {
58 match $col.data_type() {
59 DataType::Float32 => {
60 cast_and_update_cnt_map!(Float32Array, $col, $self, $field_id);
61 }
62 DataType::Float64 => {
63 cast_and_update_cnt_map!(Float64Array, $col, $self, $field_id);
64 }
65 _ => {}
66 }
67 };
68}
69
70pub struct NanValueCountVisitor {
72 pub nan_value_counts: HashMap<i32, u64>,
74 match_mode: FieldMatchMode,
75}
76
77impl SchemaWithPartnerVisitor<ArrayRef> for NanValueCountVisitor {
78 type T = ();
79
80 fn schema(
81 &mut self,
82 _schema: &Schema,
83 _partner: &ArrayRef,
84 _value: Self::T,
85 ) -> Result<Self::T> {
86 Ok(())
87 }
88
89 fn field(
90 &mut self,
91 _field: &NestedFieldRef,
92 _partner: &ArrayRef,
93 _value: Self::T,
94 ) -> Result<Self::T> {
95 Ok(())
96 }
97
98 fn r#struct(
99 &mut self,
100 _struct: &StructType,
101 _partner: &ArrayRef,
102 _results: Vec<Self::T>,
103 ) -> Result<Self::T> {
104 Ok(())
105 }
106
107 fn list(&mut self, _list: &ListType, _list_arr: &ArrayRef, _value: Self::T) -> Result<Self::T> {
108 Ok(())
109 }
110
111 fn map(
112 &mut self,
113 _map: &MapType,
114 _partner: &ArrayRef,
115 _key_value: Self::T,
116 _value: Self::T,
117 ) -> Result<Self::T> {
118 Ok(())
119 }
120
121 fn primitive(&mut self, _p: &PrimitiveType, _col: &ArrayRef) -> Result<Self::T> {
122 Ok(())
123 }
124
125 fn after_struct_field(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
126 let field_id = field.id;
127 count_float_nans!(partner, self, field_id);
128 Ok(())
129 }
130
131 fn after_list_element(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
132 let field_id = field.id;
133 count_float_nans!(partner, self, field_id);
134 Ok(())
135 }
136
137 fn after_map_key(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
138 let field_id = field.id;
139 count_float_nans!(partner, self, field_id);
140 Ok(())
141 }
142
143 fn after_map_value(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
144 let field_id = field.id;
145 count_float_nans!(partner, self, field_id);
146 Ok(())
147 }
148}
149
150impl NanValueCountVisitor {
151 pub fn new() -> Self {
153 Self::new_with_match_mode(FieldMatchMode::Id)
154 }
155
156 pub fn new_with_match_mode(match_mode: FieldMatchMode) -> Self {
158 Self {
159 nan_value_counts: HashMap::new(),
160 match_mode,
161 }
162 }
163
164 pub fn compute(&mut self, schema: SchemaRef, batch: RecordBatch) -> Result<()> {
166 let arrow_arr_partner_accessor = ArrowArrayAccessor::new_with_match_mode(self.match_mode);
167
168 let struct_arr = Arc::new(StructArray::from(batch)) as ArrayRef;
169 visit_struct_with_partner(
170 schema.as_struct(),
171 &struct_arr,
172 self,
173 &arrow_arr_partner_accessor,
174 )?;
175
176 Ok(())
177 }
178}
179
180impl Default for NanValueCountVisitor {
181 fn default() -> Self {
182 Self::new()
183 }
184}