iceberg/arrow/
nan_val_cnt_visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! The module contains the visitor for calculating NaN values in give arrow record batch.
19
20use std::collections::HashMap;
21use std::collections::hash_map::Entry;
22use std::sync::Arc;
23
24use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, StructArray};
25use arrow_schema::DataType;
26
27use crate::Result;
28use crate::arrow::{ArrowArrayAccessor, FieldMatchMode};
29use crate::spec::{
30    ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef, SchemaWithPartnerVisitor,
31    StructType, visit_struct_with_partner,
32};
33
34macro_rules! cast_and_update_cnt_map {
35    ($t:ty, $col:ident, $self:ident, $field_id:ident) => {
36        let nan_val_cnt = $col
37            .as_any()
38            .downcast_ref::<$t>()
39            .unwrap()
40            .iter()
41            .filter(|value| value.map_or(false, |v| v.is_nan()))
42            .count() as u64;
43
44        match $self.nan_value_counts.entry($field_id) {
45            Entry::Occupied(mut ele) => {
46                let total_nan_val_cnt = ele.get() + nan_val_cnt;
47                ele.insert(total_nan_val_cnt);
48            }
49            Entry::Vacant(v) => {
50                v.insert(nan_val_cnt);
51            }
52        };
53    };
54}
55
56macro_rules! count_float_nans {
57    ($col:ident, $self:ident, $field_id:ident) => {
58        match $col.data_type() {
59            DataType::Float32 => {
60                cast_and_update_cnt_map!(Float32Array, $col, $self, $field_id);
61            }
62            DataType::Float64 => {
63                cast_and_update_cnt_map!(Float64Array, $col, $self, $field_id);
64            }
65            _ => {}
66        }
67    };
68}
69
70/// Visitor which counts and keeps track of NaN value counts in given record batch(s)
71pub struct NanValueCountVisitor {
72    /// Stores field ID to NaN value count mapping
73    pub nan_value_counts: HashMap<i32, u64>,
74    match_mode: FieldMatchMode,
75}
76
77impl SchemaWithPartnerVisitor<ArrayRef> for NanValueCountVisitor {
78    type T = ();
79
80    fn schema(
81        &mut self,
82        _schema: &Schema,
83        _partner: &ArrayRef,
84        _value: Self::T,
85    ) -> Result<Self::T> {
86        Ok(())
87    }
88
89    fn field(
90        &mut self,
91        _field: &NestedFieldRef,
92        _partner: &ArrayRef,
93        _value: Self::T,
94    ) -> Result<Self::T> {
95        Ok(())
96    }
97
98    fn r#struct(
99        &mut self,
100        _struct: &StructType,
101        _partner: &ArrayRef,
102        _results: Vec<Self::T>,
103    ) -> Result<Self::T> {
104        Ok(())
105    }
106
107    fn list(&mut self, _list: &ListType, _list_arr: &ArrayRef, _value: Self::T) -> Result<Self::T> {
108        Ok(())
109    }
110
111    fn map(
112        &mut self,
113        _map: &MapType,
114        _partner: &ArrayRef,
115        _key_value: Self::T,
116        _value: Self::T,
117    ) -> Result<Self::T> {
118        Ok(())
119    }
120
121    fn primitive(&mut self, _p: &PrimitiveType, _col: &ArrayRef) -> Result<Self::T> {
122        Ok(())
123    }
124
125    fn after_struct_field(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
126        let field_id = field.id;
127        count_float_nans!(partner, self, field_id);
128        Ok(())
129    }
130
131    fn after_list_element(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
132        let field_id = field.id;
133        count_float_nans!(partner, self, field_id);
134        Ok(())
135    }
136
137    fn after_map_key(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
138        let field_id = field.id;
139        count_float_nans!(partner, self, field_id);
140        Ok(())
141    }
142
143    fn after_map_value(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> {
144        let field_id = field.id;
145        count_float_nans!(partner, self, field_id);
146        Ok(())
147    }
148}
149
150impl NanValueCountVisitor {
151    /// Creates new instance of NanValueCountVisitor
152    pub fn new() -> Self {
153        Self::new_with_match_mode(FieldMatchMode::Id)
154    }
155
156    /// Creates new instance of NanValueCountVisitor with explicit match mode
157    pub fn new_with_match_mode(match_mode: FieldMatchMode) -> Self {
158        Self {
159            nan_value_counts: HashMap::new(),
160            match_mode,
161        }
162    }
163
164    /// Compute nan value counts in given schema and record batch
165    pub fn compute(&mut self, schema: SchemaRef, batch: RecordBatch) -> Result<()> {
166        let arrow_arr_partner_accessor = ArrowArrayAccessor::new_with_match_mode(self.match_mode);
167
168        let struct_arr = Arc::new(StructArray::from(batch)) as ArrayRef;
169        visit_struct_with_partner(
170            schema.as_struct(),
171            &struct_arr,
172            self,
173            &arrow_arr_partner_accessor,
174        )?;
175
176        Ok(())
177    }
178}
179
180impl Default for NanValueCountVisitor {
181    fn default() -> Self {
182        Self::new()
183    }
184}