iceberg_datafusion/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use dashmap::DashMap;
23use datafusion::catalog::SchemaProvider;
24use datafusion::datasource::TableProvider;
25use datafusion::error::{DataFusionError, Result as DFResult};
26use datafusion::execution::TaskContext;
27use datafusion::prelude::SessionContext;
28use futures::StreamExt;
29use futures::future::try_join_all;
30use iceberg::arrow::arrow_schema_to_schema_auto_assign_ids;
31use iceberg::inspect::MetadataTableType;
32use iceberg::{Catalog, Error, ErrorKind, NamespaceIdent, Result, TableCreation, TableIdent};
33
34use crate::table::IcebergTableProvider;
35use crate::to_datafusion_error;
36
37/// Represents a [`SchemaProvider`] for the Iceberg [`Catalog`], managing
38/// access to table providers within a specific namespace.
39#[derive(Debug)]
40pub(crate) struct IcebergSchemaProvider {
41    /// Reference to the Iceberg catalog
42    catalog: Arc<dyn Catalog>,
43    /// The namespace this schema represents
44    namespace: NamespaceIdent,
45    /// A concurrent map where keys are table names
46    /// and values are dynamic references to objects implementing the
47    /// [`TableProvider`] trait.
48    /// Wrapped in Arc to allow sharing across async boundaries in register_table.
49    tables: Arc<DashMap<String, Arc<IcebergTableProvider>>>,
50}
51
52impl IcebergSchemaProvider {
53    /// Asynchronously tries to construct a new [`IcebergSchemaProvider`]
54    /// using the given client to fetch and initialize table providers for
55    /// the provided namespace in the Iceberg [`Catalog`].
56    ///
57    /// This method retrieves a list of table names
58    /// attempts to create a table provider for each table name, and
59    /// collects these providers into a `HashMap`.
60    pub(crate) async fn try_new(
61        client: Arc<dyn Catalog>,
62        namespace: NamespaceIdent,
63    ) -> Result<Self> {
64        // TODO:
65        // Tables and providers should be cached based on table_name
66        // if we have a cache miss; we update our internal cache & check again
67        // As of right now; tables might become stale.
68        let table_names: Vec<_> = client
69            .list_tables(&namespace)
70            .await?
71            .iter()
72            .map(|tbl| tbl.name().to_string())
73            .collect();
74
75        let providers = try_join_all(
76            table_names
77                .iter()
78                .map(|name| IcebergTableProvider::try_new(client.clone(), namespace.clone(), name))
79                .collect::<Vec<_>>(),
80        )
81        .await?;
82
83        let tables = Arc::new(DashMap::new());
84        for (name, provider) in table_names.into_iter().zip(providers.into_iter()) {
85            tables.insert(name, Arc::new(provider));
86        }
87
88        Ok(IcebergSchemaProvider {
89            catalog: client,
90            namespace,
91            tables,
92        })
93    }
94}
95
96#[async_trait]
97impl SchemaProvider for IcebergSchemaProvider {
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101
102    fn table_names(&self) -> Vec<String> {
103        self.tables
104            .iter()
105            .flat_map(|entry| {
106                let table_name = entry.key().clone();
107                [table_name.clone()]
108                    .into_iter()
109                    .chain(
110                        MetadataTableType::all_types().map(move |metadata_table_name| {
111                            format!("{}${}", table_name, metadata_table_name.as_str())
112                        }),
113                    )
114            })
115            .collect()
116    }
117
118    fn table_exist(&self, name: &str) -> bool {
119        if let Some((table_name, metadata_table_name)) = name.split_once('$') {
120            self.tables.contains_key(table_name)
121                && MetadataTableType::try_from(metadata_table_name).is_ok()
122        } else {
123            self.tables.contains_key(name)
124        }
125    }
126
127    async fn table(&self, name: &str) -> DFResult<Option<Arc<dyn TableProvider>>> {
128        if let Some((table_name, metadata_table_name)) = name.split_once('$') {
129            let metadata_table_type =
130                MetadataTableType::try_from(metadata_table_name).map_err(DataFusionError::Plan)?;
131            if let Some(table) = self.tables.get(table_name) {
132                let metadata_table = table
133                    .metadata_table(metadata_table_type)
134                    .await
135                    .map_err(to_datafusion_error)?;
136                return Ok(Some(Arc::new(metadata_table)));
137            } else {
138                return Ok(None);
139            }
140        }
141
142        Ok(self
143            .tables
144            .get(name)
145            .map(|entry| entry.value().clone() as Arc<dyn TableProvider>))
146    }
147
148    fn register_table(
149        &self,
150        name: String,
151        table: Arc<dyn TableProvider>,
152    ) -> DFResult<Option<Arc<dyn TableProvider>>> {
153        // Check if table already exists
154        if self.table_exist(name.as_str()) {
155            return Err(DataFusionError::Execution(format!(
156                "Table {name} already exists"
157            )));
158        }
159
160        // Convert DataFusion schema to Iceberg schema
161        // DataFusion schemas don't have field IDs, so we use the function that assigns them automatically
162        let df_schema = table.schema();
163        let iceberg_schema = arrow_schema_to_schema_auto_assign_ids(df_schema.as_ref())
164            .map_err(to_datafusion_error)?;
165
166        // Create the table in the Iceberg catalog
167        let table_creation = TableCreation::builder()
168            .name(name.clone())
169            .schema(iceberg_schema)
170            .build();
171
172        let catalog = self.catalog.clone();
173        let namespace = self.namespace.clone();
174        let tables = self.tables.clone();
175        let name_clone = name.clone();
176
177        // Use tokio's spawn_blocking to handle the async work on a blocking thread pool
178        let result = tokio::task::spawn_blocking(move || {
179            // Create a new runtime handle to execute the async work
180            let rt = tokio::runtime::Handle::current();
181            rt.block_on(async move {
182                // Verify the input table is empty - CREATE TABLE only accepts schema definition
183                ensure_table_is_empty(&table)
184                    .await
185                    .map_err(to_datafusion_error)?;
186
187                catalog
188                    .create_table(&namespace, table_creation)
189                    .await
190                    .map_err(to_datafusion_error)?;
191
192                // Create a new table provider using the catalog reference
193                let table_provider = IcebergTableProvider::try_new(
194                    catalog.clone(),
195                    namespace.clone(),
196                    name_clone.clone(),
197                )
198                .await
199                .map_err(to_datafusion_error)?;
200
201                // Store the new table provider
202                tables.insert(name_clone, Arc::new(table_provider));
203
204                Ok(None)
205            })
206        });
207
208        // Block on the spawned task to get the result
209        // This is safe because spawn_blocking moves the blocking to a dedicated thread pool
210        futures::executor::block_on(result).map_err(|e| {
211            DataFusionError::Execution(format!("Failed to create Iceberg table: {e}"))
212        })?
213    }
214
215    fn deregister_table(&self, name: &str) -> DFResult<Option<Arc<dyn TableProvider>>> {
216        // Check if table exists
217        if !self.table_exist(name) {
218            return Ok(None);
219        }
220
221        let catalog = self.catalog.clone();
222        let namespace = self.namespace.clone();
223        let tables = self.tables.clone();
224        let table_name = name.to_string();
225
226        // Use tokio's spawn_blocking to handle the async work on a blocking thread pool
227        let result = tokio::task::spawn_blocking(move || {
228            let rt = tokio::runtime::Handle::current();
229            rt.block_on(async move {
230                let table_ident = TableIdent::new(namespace, table_name.clone());
231
232                // Drop the table from the Iceberg catalog
233                catalog
234                    .drop_table(&table_ident)
235                    .await
236                    .map_err(to_datafusion_error)?;
237
238                // Remove from local cache and return the removed provider
239                let removed = tables
240                    .remove(&table_name)
241                    .map(|(_, table)| table as Arc<dyn TableProvider>);
242
243                Ok(removed)
244            })
245        });
246
247        futures::executor::block_on(result)
248            .map_err(|e| DataFusionError::Execution(format!("Failed to drop Iceberg table: {e}")))?
249    }
250}
251
252/// Verifies that a table provider contains no data by scanning with LIMIT 1.
253/// Returns an error if the table has any rows.
254async fn ensure_table_is_empty(table: &Arc<dyn TableProvider>) -> Result<()> {
255    let session_ctx = SessionContext::new();
256    let exec_plan = table
257        .scan(&session_ctx.state(), None, &[], Some(1))
258        .await
259        .map_err(|e| Error::new(ErrorKind::Unexpected, format!("Failed to scan table: {e}")))?;
260
261    let task_ctx = Arc::new(TaskContext::default());
262    let stream = exec_plan.execute(0, task_ctx).map_err(|e| {
263        Error::new(
264            ErrorKind::Unexpected,
265            format!("Failed to execute scan: {e}"),
266        )
267    })?;
268
269    let batches: Vec<_> = stream.collect().await;
270    let has_data = batches
271        .into_iter()
272        .filter_map(|r| r.ok())
273        .any(|batch| batch.num_rows() > 0);
274
275    if has_data {
276        return Err(Error::new(
277            ErrorKind::Unexpected,
278            "register_table does not support tables with data.",
279        ));
280    }
281
282    Ok(())
283}
284
285#[cfg(test)]
286mod tests {
287    use std::collections::HashMap;
288    use std::sync::Arc;
289
290    use datafusion::arrow::array::{Int32Array, StringArray};
291    use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
292    use datafusion::arrow::record_batch::RecordBatch;
293    use datafusion::datasource::MemTable;
294    use iceberg::memory::{MEMORY_CATALOG_WAREHOUSE, MemoryCatalogBuilder};
295    use iceberg::{Catalog, CatalogBuilder, NamespaceIdent};
296    use tempfile::TempDir;
297
298    use super::*;
299
300    async fn create_test_schema_provider() -> (IcebergSchemaProvider, TempDir) {
301        let temp_dir = TempDir::new().unwrap();
302        let warehouse_path = temp_dir.path().to_str().unwrap().to_string();
303
304        let catalog = MemoryCatalogBuilder::default()
305            .load(
306                "memory",
307                HashMap::from([(MEMORY_CATALOG_WAREHOUSE.to_string(), warehouse_path.clone())]),
308            )
309            .await
310            .unwrap();
311
312        let namespace = NamespaceIdent::new("test_ns".to_string());
313        catalog
314            .create_namespace(&namespace, HashMap::new())
315            .await
316            .unwrap();
317
318        let provider = IcebergSchemaProvider::try_new(Arc::new(catalog), namespace)
319            .await
320            .unwrap();
321
322        (provider, temp_dir)
323    }
324
325    #[tokio::test]
326    async fn test_register_table_with_data_fails() {
327        let (schema_provider, _temp_dir) = create_test_schema_provider().await;
328
329        // Create a MemTable with data
330        let arrow_schema = Arc::new(ArrowSchema::new(vec![
331            Field::new("id", DataType::Int32, false),
332            Field::new("name", DataType::Utf8, true),
333        ]));
334
335        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
336            Arc::new(Int32Array::from(vec![1, 2, 3])),
337            Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
338        ])
339        .unwrap();
340
341        let mem_table = MemTable::try_new(arrow_schema, vec![vec![batch]]).unwrap();
342
343        // Attempt to register the table with data - should fail
344        let result = schema_provider.register_table("test_table".to_string(), Arc::new(mem_table));
345
346        assert!(result.is_err());
347        let err = result.unwrap_err();
348        assert!(
349            err.to_string()
350                .contains("register_table does not support tables with data."),
351            "Expected error about tables with data, got: {err}",
352        );
353    }
354
355    #[tokio::test]
356    async fn test_register_empty_table_succeeds() {
357        let (schema_provider, _temp_dir) = create_test_schema_provider().await;
358
359        // Create an empty MemTable (schema only, no data rows)
360        let arrow_schema = Arc::new(ArrowSchema::new(vec![
361            Field::new("id", DataType::Int32, false),
362            Field::new("name", DataType::Utf8, true),
363        ]));
364
365        // Create an empty batch (0 rows) - MemTable requires at least one partition
366        let empty_batch = RecordBatch::new_empty(arrow_schema.clone());
367        let mem_table = MemTable::try_new(arrow_schema, vec![vec![empty_batch]]).unwrap();
368
369        // Attempt to register the empty table - should succeed
370        let result = schema_provider.register_table("empty_table".to_string(), Arc::new(mem_table));
371
372        assert!(result.is_ok(), "Expected success, got: {result:?}");
373
374        // Verify the table was registered
375        assert!(schema_provider.table_exist("empty_table"));
376    }
377
378    #[tokio::test]
379    async fn test_register_duplicate_table_fails() {
380        let (schema_provider, _temp_dir) = create_test_schema_provider().await;
381
382        // Create empty MemTables
383        let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
384            "id",
385            DataType::Int32,
386            false,
387        )]));
388
389        let empty_batch1 = RecordBatch::new_empty(arrow_schema.clone());
390        let empty_batch2 = RecordBatch::new_empty(arrow_schema.clone());
391        let mem_table1 = MemTable::try_new(arrow_schema.clone(), vec![vec![empty_batch1]]).unwrap();
392        let mem_table2 = MemTable::try_new(arrow_schema, vec![vec![empty_batch2]]).unwrap();
393
394        // Register first table - should succeed
395        let result1 = schema_provider.register_table("dup_table".to_string(), Arc::new(mem_table1));
396        assert!(result1.is_ok());
397
398        // Register second table with same name - should fail
399        let result2 = schema_provider.register_table("dup_table".to_string(), Arc::new(mem_table2));
400        assert!(result2.is_err());
401        let err = result2.unwrap_err();
402        assert!(
403            err.to_string().contains("already exists"),
404            "Expected error about table already existing, got: {err}",
405        );
406    }
407
408    #[tokio::test]
409    async fn test_deregister_table_succeeds() {
410        let (schema_provider, _temp_dir) = create_test_schema_provider().await;
411
412        // Create and register an empty table
413        let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
414            "id",
415            DataType::Int32,
416            false,
417        )]));
418
419        let empty_batch = RecordBatch::new_empty(arrow_schema.clone());
420        let mem_table = MemTable::try_new(arrow_schema, vec![vec![empty_batch]]).unwrap();
421
422        // Register the table
423        let result = schema_provider.register_table("drop_me".to_string(), Arc::new(mem_table));
424        assert!(result.is_ok());
425        assert!(schema_provider.table_exist("drop_me"));
426
427        // Deregister the table
428        let result = schema_provider.deregister_table("drop_me");
429        assert!(result.is_ok());
430        assert!(result.unwrap().is_some());
431
432        // Verify the table no longer exists
433        assert!(!schema_provider.table_exist("drop_me"));
434    }
435
436    #[tokio::test]
437    async fn test_deregister_nonexistent_table_returns_none() {
438        let (schema_provider, _temp_dir) = create_test_schema_provider().await;
439
440        // Attempt to deregister a table that doesn't exist
441        let result = schema_provider.deregister_table("nonexistent");
442        assert!(result.is_ok());
443        assert!(result.unwrap().is_none());
444    }
445}