use std::collections::{HashMap, HashSet};
use std::ops::Range;
use std::str::FromStr;
use std::sync::Arc;
use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef};
use arrow_string::like::starts_with;
use bytes::Bytes;
use fnv::FnvHashSet;
use futures::channel::mpsc::{channel, Sender};
use futures::future::BoxFuture;
use futures::{try_join, FutureExt, SinkExt, StreamExt, TryFutureExt, TryStreamExt};
use parquet::arrow::arrow_reader::{ArrowPredicateFn, ArrowReaderOptions, RowFilter, RowSelection};
use parquet::arrow::async_reader::AsyncFileReader;
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY};
use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader};
use parquet::schema::types::{SchemaDescriptor, Type as ParquetType};
use crate::arrow::record_batch_transformer::RecordBatchTransformer;
use crate::arrow::{arrow_schema_to_schema, get_arrow_datum};
use crate::error::Result;
use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor};
use crate::expr::visitors::page_index_evaluator::PageIndexEvaluator;
use crate::expr::visitors::row_group_metrics_evaluator::RowGroupMetricsEvaluator;
use crate::expr::{BoundPredicate, BoundReference};
use crate::io::{FileIO, FileMetadata, FileRead};
use crate::runtime::spawn;
use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream};
use crate::spec::{Datum, Schema};
use crate::utils::available_parallelism;
use crate::{Error, ErrorKind};
pub struct ArrowReaderBuilder {
batch_size: Option<usize>,
file_io: FileIO,
concurrency_limit_data_files: usize,
row_group_filtering_enabled: bool,
row_selection_enabled: bool,
}
impl ArrowReaderBuilder {
pub(crate) fn new(file_io: FileIO) -> Self {
let num_cpus = available_parallelism().get();
ArrowReaderBuilder {
batch_size: None,
file_io,
concurrency_limit_data_files: num_cpus,
row_group_filtering_enabled: true,
row_selection_enabled: false,
}
}
pub fn with_data_file_concurrency_limit(mut self, val: usize) -> Self {
self.concurrency_limit_data_files = val;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn with_row_group_filtering_enabled(mut self, row_group_filtering_enabled: bool) -> Self {
self.row_group_filtering_enabled = row_group_filtering_enabled;
self
}
pub fn with_row_selection_enabled(mut self, row_selection_enabled: bool) -> Self {
self.row_selection_enabled = row_selection_enabled;
self
}
pub fn build(self) -> ArrowReader {
ArrowReader {
batch_size: self.batch_size,
file_io: self.file_io,
concurrency_limit_data_files: self.concurrency_limit_data_files,
row_group_filtering_enabled: self.row_group_filtering_enabled,
row_selection_enabled: self.row_selection_enabled,
}
}
}
#[derive(Clone)]
pub struct ArrowReader {
batch_size: Option<usize>,
file_io: FileIO,
concurrency_limit_data_files: usize,
row_group_filtering_enabled: bool,
row_selection_enabled: bool,
}
impl ArrowReader {
pub fn read(self, tasks: FileScanTaskStream) -> Result<ArrowRecordBatchStream> {
let file_io = self.file_io.clone();
let batch_size = self.batch_size;
let concurrency_limit_data_files = self.concurrency_limit_data_files;
let row_group_filtering_enabled = self.row_group_filtering_enabled;
let row_selection_enabled = self.row_selection_enabled;
let (tx, rx) = channel(concurrency_limit_data_files);
let mut channel_for_error = tx.clone();
spawn(async move {
let result = tasks
.map(|task| Ok((task, file_io.clone(), tx.clone())))
.try_for_each_concurrent(
concurrency_limit_data_files,
|(file_scan_task, file_io, tx)| async move {
match file_scan_task {
Ok(task) => {
let file_path = task.data_file_path.to_string();
spawn(async move {
Self::process_file_scan_task(
task,
batch_size,
file_io,
tx,
row_group_filtering_enabled,
row_selection_enabled,
)
.await
})
.await
.map_err(|e| e.with_context("file_path", file_path))
}
Err(err) => Err(err),
}
},
)
.await;
if let Err(error) = result {
let _ = channel_for_error.send(Err(error)).await;
}
});
return Ok(rx.boxed());
}
async fn process_file_scan_task(
task: FileScanTask,
batch_size: Option<usize>,
file_io: FileIO,
mut tx: Sender<Result<RecordBatch>>,
row_group_filtering_enabled: bool,
row_selection_enabled: bool,
) -> Result<()> {
let parquet_file = file_io.new_input(&task.data_file_path)?;
let (parquet_metadata, parquet_reader) =
try_join!(parquet_file.metadata(), parquet_file.reader())?;
let parquet_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader);
let should_load_page_index = row_selection_enabled && task.predicate.is_some();
let mut record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options(
parquet_file_reader,
ArrowReaderOptions::new().with_page_index(should_load_page_index),
)
.await?;
let projection_mask = Self::get_arrow_projection_mask(
&task.project_field_ids,
&task.schema,
record_batch_stream_builder.parquet_schema(),
record_batch_stream_builder.schema(),
)?;
record_batch_stream_builder = record_batch_stream_builder.with_projection(projection_mask);
let mut record_batch_transformer =
RecordBatchTransformer::build(task.schema_ref(), task.project_field_ids());
if let Some(batch_size) = batch_size {
record_batch_stream_builder = record_batch_stream_builder.with_batch_size(batch_size);
}
if let Some(predicate) = &task.predicate {
let (iceberg_field_ids, field_id_map) = Self::build_field_id_set_and_map(
record_batch_stream_builder.parquet_schema(),
predicate,
)?;
let row_filter = Self::get_row_filter(
predicate,
record_batch_stream_builder.parquet_schema(),
&iceberg_field_ids,
&field_id_map,
)?;
record_batch_stream_builder = record_batch_stream_builder.with_row_filter(row_filter);
let mut selected_row_groups = None;
if row_group_filtering_enabled {
let result = Self::get_selected_row_group_indices(
predicate,
record_batch_stream_builder.metadata(),
&field_id_map,
&task.schema,
)?;
selected_row_groups = Some(result);
}
if row_selection_enabled {
let row_selection = Self::get_row_selection(
predicate,
record_batch_stream_builder.metadata(),
&selected_row_groups,
&field_id_map,
&task.schema,
)?;
record_batch_stream_builder =
record_batch_stream_builder.with_row_selection(row_selection);
}
if let Some(selected_row_groups) = selected_row_groups {
record_batch_stream_builder =
record_batch_stream_builder.with_row_groups(selected_row_groups);
}
}
let mut record_batch_stream = record_batch_stream_builder.build()?;
while let Some(batch) = record_batch_stream.try_next().await? {
tx.send(record_batch_transformer.process_record_batch(batch))
.await?
}
Ok(())
}
fn build_field_id_set_and_map(
parquet_schema: &SchemaDescriptor,
predicate: &BoundPredicate,
) -> Result<(HashSet<i32>, HashMap<i32, usize>)> {
let mut collector = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
visit(&mut collector, predicate)?;
let iceberg_field_ids = collector.field_ids();
let field_id_map = build_field_id_map(parquet_schema)?;
Ok((iceberg_field_ids, field_id_map))
}
fn get_arrow_projection_mask(
field_ids: &[i32],
iceberg_schema_of_task: &Schema,
parquet_schema: &SchemaDescriptor,
arrow_schema: &ArrowSchemaRef,
) -> Result<ProjectionMask> {
if field_ids.is_empty() {
Ok(ProjectionMask::all())
} else {
let mut column_map = HashMap::new();
let fields = arrow_schema.fields();
let iceberg_schema = arrow_schema_to_schema(arrow_schema)?;
fields.filter_leaves(|idx, field| {
let field_id = field.metadata().get(PARQUET_FIELD_ID_META_KEY);
if field_id.is_none() {
return false;
}
let field_id = i32::from_str(field_id.unwrap());
if field_id.is_err() {
return false;
}
let field_id = field_id.unwrap();
if !field_ids.contains(&field_id) {
return false;
}
let iceberg_field = iceberg_schema_of_task.field_by_id(field_id);
let parquet_iceberg_field = iceberg_schema.field_by_id(field_id);
if iceberg_field.is_none() || parquet_iceberg_field.is_none() {
return false;
}
if iceberg_field.unwrap().field_type != parquet_iceberg_field.unwrap().field_type {
return false;
}
column_map.insert(field_id, idx);
true
});
if column_map.len() != field_ids.len() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Parquet schema {} and Iceberg schema {} do not match.",
iceberg_schema, iceberg_schema_of_task
),
));
}
let mut indices = vec![];
for field_id in field_ids {
if let Some(col_idx) = column_map.get(field_id) {
indices.push(*col_idx);
} else {
return Err(Error::new(
ErrorKind::DataInvalid,
format!("Field {} is not found in Parquet schema.", field_id),
));
}
}
Ok(ProjectionMask::leaves(parquet_schema, indices))
}
}
fn get_row_filter(
predicates: &BoundPredicate,
parquet_schema: &SchemaDescriptor,
iceberg_field_ids: &HashSet<i32>,
field_id_map: &HashMap<i32, usize>,
) -> Result<RowFilter> {
let mut column_indices = iceberg_field_ids
.iter()
.filter_map(|field_id| field_id_map.get(field_id).cloned())
.collect::<Vec<_>>();
column_indices.sort();
let mut converter = PredicateConverter {
parquet_schema,
column_map: field_id_map,
column_indices: &column_indices,
};
let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone());
let predicate_func = visit(&mut converter, predicates)?;
let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func);
Ok(RowFilter::new(vec![Box::new(arrow_predicate)]))
}
fn get_selected_row_group_indices(
predicate: &BoundPredicate,
parquet_metadata: &Arc<ParquetMetaData>,
field_id_map: &HashMap<i32, usize>,
snapshot_schema: &Schema,
) -> Result<Vec<usize>> {
let row_groups_metadata = parquet_metadata.row_groups();
let mut results = Vec::with_capacity(row_groups_metadata.len());
for (idx, row_group_metadata) in row_groups_metadata.iter().enumerate() {
if RowGroupMetricsEvaluator::eval(
predicate,
row_group_metadata,
field_id_map,
snapshot_schema,
)? {
results.push(idx);
}
}
Ok(results)
}
fn get_row_selection(
predicate: &BoundPredicate,
parquet_metadata: &Arc<ParquetMetaData>,
selected_row_groups: &Option<Vec<usize>>,
field_id_map: &HashMap<i32, usize>,
snapshot_schema: &Schema,
) -> Result<RowSelection> {
let Some(column_index) = parquet_metadata.column_index() else {
return Err(Error::new(
ErrorKind::Unexpected,
"Parquet file metadata does not contain a column index",
));
};
let Some(offset_index) = parquet_metadata.offset_index() else {
return Err(Error::new(
ErrorKind::Unexpected,
"Parquet file metadata does not contain an offset index",
));
};
let mut selected_row_groups_idx = 0;
let page_index = column_index
.iter()
.enumerate()
.zip(offset_index)
.zip(parquet_metadata.row_groups());
let mut results = Vec::new();
for (((idx, column_index), offset_index), row_group_metadata) in page_index {
if let Some(selected_row_groups) = selected_row_groups {
if idx == selected_row_groups[selected_row_groups_idx] {
selected_row_groups_idx += 1;
} else {
continue;
}
}
let selections_for_page = PageIndexEvaluator::eval(
predicate,
column_index,
offset_index,
row_group_metadata,
field_id_map,
snapshot_schema,
)?;
results.push(selections_for_page);
if let Some(selected_row_groups) = selected_row_groups {
if selected_row_groups_idx == selected_row_groups.len() {
break;
}
}
}
Ok(results.into_iter().flatten().collect::<Vec<_>>().into())
}
}
fn build_field_id_map(parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> {
let mut column_map = HashMap::new();
for (idx, field) in parquet_schema.columns().iter().enumerate() {
let field_type = field.self_type();
match field_type {
ParquetType::PrimitiveType { basic_info, .. } => {
if !basic_info.has_id() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Leave column idx: {}, name: {}, type {:?} in schema doesn't have field id",
idx,
basic_info.name(),
field_type
),
));
}
column_map.insert(basic_info.id(), idx);
}
ParquetType::GroupType { .. } => {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Leave column in schema should be primitive type but got {:?}",
field_type
),
));
}
};
}
Ok(column_map)
}
struct CollectFieldIdVisitor {
field_ids: HashSet<i32>,
}
impl CollectFieldIdVisitor {
fn field_ids(self) -> HashSet<i32> {
self.field_ids
}
}
impl BoundPredicateVisitor for CollectFieldIdVisitor {
type T = ();
fn always_true(&mut self) -> Result<()> {
Ok(())
}
fn always_false(&mut self) -> Result<()> {
Ok(())
}
fn and(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
Ok(())
}
fn or(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
Ok(())
}
fn not(&mut self, _inner: ()) -> Result<()> {
Ok(())
}
fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn not_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn is_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn not_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn less_than(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn less_than_or_eq(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn greater_than(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn greater_than_or_eq(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn eq(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn not_eq(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn starts_with(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn not_starts_with(
&mut self,
reference: &BoundReference,
_literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn r#in(
&mut self,
reference: &BoundReference,
_literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
fn not_in(
&mut self,
reference: &BoundReference,
_literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<()> {
self.field_ids.insert(reference.field().id);
Ok(())
}
}
struct PredicateConverter<'a> {
pub parquet_schema: &'a SchemaDescriptor,
pub column_map: &'a HashMap<i32, usize>,
pub column_indices: &'a Vec<usize>,
}
impl PredicateConverter<'_> {
fn bound_reference(&mut self, reference: &BoundReference) -> Result<Option<usize>> {
if let Some(column_idx) = self.column_map.get(&reference.field().id) {
if self.parquet_schema.get_column_root_idx(*column_idx) != *column_idx {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Leave column `{}` in predicates isn't a root column in Parquet schema.",
reference.field().name
),
));
}
let index = self
.column_indices
.iter()
.position(|&idx| idx == *column_idx).ok_or(Error::new(ErrorKind::DataInvalid, format!(
"Leave column `{}` in predicates cannot be found in the required column indices.",
reference.field().name
)))?;
Ok(Some(index))
} else {
Ok(None)
}
}
fn build_always_true(&self) -> Result<Box<PredicateResult>> {
Ok(Box::new(|batch| {
Ok(BooleanArray::from(vec![true; batch.num_rows()]))
}))
}
fn build_always_false(&self) -> Result<Box<PredicateResult>> {
Ok(Box::new(|batch| {
Ok(BooleanArray::from(vec![false; batch.num_rows()]))
}))
}
}
fn project_column(
batch: &RecordBatch,
column_idx: usize,
) -> std::result::Result<ArrayRef, ArrowError> {
let column = batch.column(column_idx);
match column.data_type() {
DataType::Struct(_) => Err(ArrowError::SchemaError(
"Does not support struct column yet.".to_string(),
)),
_ => Ok(column.clone()),
}
}
type PredicateResult =
dyn FnMut(RecordBatch) -> std::result::Result<BooleanArray, ArrowError> + Send + 'static;
impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
type T = Box<PredicateResult>;
fn always_true(&mut self) -> Result<Box<PredicateResult>> {
self.build_always_true()
}
fn always_false(&mut self) -> Result<Box<PredicateResult>> {
self.build_always_false()
}
fn and(
&mut self,
mut lhs: Box<PredicateResult>,
mut rhs: Box<PredicateResult>,
) -> Result<Box<PredicateResult>> {
Ok(Box::new(move |batch| {
let left = lhs(batch.clone())?;
let right = rhs(batch)?;
and(&left, &right)
}))
}
fn or(
&mut self,
mut lhs: Box<PredicateResult>,
mut rhs: Box<PredicateResult>,
) -> Result<Box<PredicateResult>> {
Ok(Box::new(move |batch| {
let left = lhs(batch.clone())?;
let right = rhs(batch)?;
or(&left, &right)
}))
}
fn not(&mut self, mut inner: Box<PredicateResult>) -> Result<Box<PredicateResult>> {
Ok(Box::new(move |batch| {
let pred_ret = inner(batch)?;
not(&pred_ret)
}))
}
fn is_null(
&mut self,
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
is_null(&column)
}))
} else {
self.build_always_true()
}
}
fn not_null(
&mut self,
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
is_not_null(&column)
}))
} else {
self.build_always_false()
}
}
fn is_nan(
&mut self,
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if self.bound_reference(reference)?.is_some() {
self.build_always_true()
} else {
self.build_always_false()
}
}
fn not_nan(
&mut self,
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if self.bound_reference(reference)?.is_some() {
self.build_always_false()
} else {
self.build_always_true()
}
}
fn less_than(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
lt(&left, literal.as_ref())
}))
} else {
self.build_always_true()
}
}
fn less_than_or_eq(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
lt_eq(&left, literal.as_ref())
}))
} else {
self.build_always_true()
}
}
fn greater_than(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
gt(&left, literal.as_ref())
}))
} else {
self.build_always_false()
}
}
fn greater_than_or_eq(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
gt_eq(&left, literal.as_ref())
}))
} else {
self.build_always_false()
}
}
fn eq(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
eq(&left, literal.as_ref())
}))
} else {
self.build_always_false()
}
}
fn not_eq(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
neq(&left, literal.as_ref())
}))
} else {
self.build_always_false()
}
}
fn starts_with(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
starts_with(&left, literal.as_ref())
}))
} else {
self.build_always_false()
}
}
fn not_starts_with(
&mut self,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
not(&starts_with(&left, literal.as_ref())?)
}))
} else {
self.build_always_true()
}
}
fn r#in(
&mut self,
reference: &BoundReference,
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
for literal in &literals {
acc = or(&acc, &eq(&left, literal.as_ref())?)?
}
Ok(acc)
}))
} else {
self.build_always_false()
}
}
fn not_in(
&mut self,
reference: &BoundReference,
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();
Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
for literal in &literals {
acc = and(&acc, &neq(&left, literal.as_ref())?)?
}
Ok(acc)
}))
} else {
self.build_always_true()
}
}
}
struct ArrowFileReader<R: FileRead> {
meta: FileMetadata,
r: R,
}
impl<R: FileRead> ArrowFileReader<R> {
fn new(meta: FileMetadata, r: R) -> Self {
Self { meta, r }
}
}
impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
fn get_bytes(&mut self, range: Range<usize>) -> BoxFuture<'_, parquet::errors::Result<Bytes>> {
Box::pin(
self.r
.read(range.start as _..range.end as _)
.map_err(|err| parquet::errors::ParquetError::External(Box::new(err))),
)
}
fn get_metadata(&mut self) -> BoxFuture<'_, parquet::errors::Result<Arc<ParquetMetaData>>> {
async move {
let reader = ParquetMetaDataReader::new();
let size = self.meta.size as usize;
let meta = reader.load_and_finish(self, size).await?;
Ok(Arc::new(meta))
}
.boxed()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::Arc;
use crate::arrow::reader::CollectFieldIdVisitor;
use crate::expr::visitors::bound_predicate_visitor::visit;
use crate::expr::{Bind, Reference};
use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
fn table_schema_simple() -> SchemaRef {
Arc::new(
Schema::builder()
.with_schema_id(1)
.with_identifier_field_ids(vec![2])
.with_fields(vec![
NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
NestedField::optional(4, "qux", Type::Primitive(PrimitiveType::Float)).into(),
])
.build()
.unwrap(),
)
}
#[test]
fn test_collect_field_id() {
let schema = table_schema_simple();
let expr = Reference::new("qux").is_null();
let bound_expr = expr.bind(schema, true).unwrap();
let mut visitor = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
visit(&mut visitor, &bound_expr).unwrap();
let mut expected = HashSet::default();
expected.insert(4_i32);
assert_eq!(visitor.field_ids, expected);
}
#[test]
fn test_collect_field_id_with_and() {
let schema = table_schema_simple();
let expr = Reference::new("qux")
.is_null()
.and(Reference::new("baz").is_null());
let bound_expr = expr.bind(schema, true).unwrap();
let mut visitor = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
visit(&mut visitor, &bound_expr).unwrap();
let mut expected = HashSet::default();
expected.insert(4_i32);
expected.insert(3);
assert_eq!(visitor.field_ids, expected);
}
#[test]
fn test_collect_field_id_with_or() {
let schema = table_schema_simple();
let expr = Reference::new("qux")
.is_null()
.or(Reference::new("baz").is_null());
let bound_expr = expr.bind(schema, true).unwrap();
let mut visitor = CollectFieldIdVisitor {
field_ids: HashSet::default(),
};
visit(&mut visitor, &bound_expr).unwrap();
let mut expected = HashSet::default();
expected.insert(4_i32);
expected.insert(3);
assert_eq!(visitor.field_ids, expected);
}
}