pub mod table_provider_factory;
use std::any::Any;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use iceberg::arrow::schema_to_arrow_schema;
use iceberg::table::Table;
use iceberg::{Catalog, Error, ErrorKind, NamespaceIdent, Result, TableIdent};
use crate::physical_plan::scan::IcebergTableScan;
#[derive(Debug, Clone)]
pub struct IcebergTableProvider {
table: Table,
snapshot_id: Option<i64>,
schema: ArrowSchemaRef,
}
impl IcebergTableProvider {
pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self {
IcebergTableProvider {
table,
snapshot_id: None,
schema,
}
}
pub(crate) async fn try_new(
client: Arc<dyn Catalog>,
namespace: NamespaceIdent,
name: impl Into<String>,
) -> Result<Self> {
let ident = TableIdent::new(namespace, name.into());
let table = client.load_table(&ident).await?;
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
Ok(IcebergTableProvider {
table,
snapshot_id: None,
schema,
})
}
pub async fn try_new_from_table(table: Table) -> Result<Self> {
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
Ok(IcebergTableProvider {
table,
snapshot_id: None,
schema,
})
}
pub async fn try_new_from_table_snapshot(table: Table, snapshot_id: i64) -> Result<Self> {
let snapshot = table
.metadata()
.snapshot_by_id(snapshot_id)
.ok_or_else(|| {
Error::new(
ErrorKind::Unexpected,
format!(
"snapshot id {snapshot_id} not found in table {}",
table.identifier().name()
),
)
})?;
let schema = snapshot.schema(table.metadata())?;
let schema = Arc::new(schema_to_arrow_schema(&schema)?);
Ok(IcebergTableProvider {
table,
snapshot_id: Some(snapshot_id),
schema,
})
}
}
#[async_trait]
impl TableProvider for IcebergTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> ArrowSchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.snapshot_id,
self.schema.clone(),
projection,
filters,
)))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> std::result::Result<Vec<TableProviderFilterPushDown>, datafusion::error::DataFusionError>
{
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
}
}
#[cfg(test)]
mod tests {
use datafusion::common::Column;
use datafusion::prelude::SessionContext;
use iceberg::io::FileIO;
use iceberg::table::{StaticTable, Table};
use iceberg::TableIdent;
use super::*;
async fn get_test_table_from_metadata_file() -> Table {
let metadata_file_name = "TableMetadataV2Valid.json";
let metadata_file_path = format!(
"{}/tests/test_data/{}",
env!("CARGO_MANIFEST_DIR"),
metadata_file_name
);
let file_io = FileIO::from_path(&metadata_file_path)
.unwrap()
.build()
.unwrap();
let static_identifier = TableIdent::from_strs(["static_ns", "static_table"]).unwrap();
let static_table =
StaticTable::from_metadata_file(&metadata_file_path, static_identifier, file_io)
.await
.unwrap();
static_table.into_table()
}
#[tokio::test]
async fn test_try_new_from_table() {
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT * FROM mytable").await.unwrap();
let df_schema = df.schema();
let df_columns = df_schema.fields();
assert_eq!(df_columns.len(), 3);
let x_column = df_columns.first().unwrap();
let column_data = format!(
"{:?}:{:?}",
x_column.name(),
x_column.data_type().to_string()
);
assert_eq!(column_data, "\"x\":\"Int64\"");
let has_column = df_schema.has_column(&Column::from_name("z"));
assert!(has_column);
}
#[tokio::test]
async fn test_try_new_from_table_snapshot() {
let table = get_test_table_from_metadata_file().await;
let snapshot_id = table.metadata().snapshots().next().unwrap().snapshot_id();
let table_provider =
IcebergTableProvider::try_new_from_table_snapshot(table.clone(), snapshot_id)
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT * FROM mytable").await.unwrap();
let df_schema = df.schema();
let df_columns = df_schema.fields();
assert_eq!(df_columns.len(), 3);
let x_column = df_columns.first().unwrap();
let column_data = format!(
"{:?}:{:?}",
x_column.name(),
x_column.data_type().to_string()
);
assert_eq!(column_data, "\"x\":\"Int64\"");
let has_column = df_schema.has_column(&Column::from_name("z"));
assert!(has_column);
}
#[tokio::test]
async fn test_physical_input_schema_consistent_with_logical_input_schema() {
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT count(*) FROM mytable").await.unwrap();
let physical_plan = df.create_physical_plan().await;
assert!(physical_plan.is_ok())
}
}