1use std::any::Any;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use dashmap::DashMap;
23use datafusion::catalog::SchemaProvider;
24use datafusion::datasource::TableProvider;
25use datafusion::error::{DataFusionError, Result as DFResult};
26use datafusion::execution::TaskContext;
27use datafusion::prelude::SessionContext;
28use futures::StreamExt;
29use futures::future::try_join_all;
30use iceberg::arrow::arrow_schema_to_schema_auto_assign_ids;
31use iceberg::inspect::MetadataTableType;
32use iceberg::{Catalog, Error, ErrorKind, NamespaceIdent, Result, TableCreation, TableIdent};
33
34use crate::table::IcebergTableProvider;
35use crate::to_datafusion_error;
36
37#[derive(Debug)]
40pub(crate) struct IcebergSchemaProvider {
41 catalog: Arc<dyn Catalog>,
43 namespace: NamespaceIdent,
45 tables: Arc<DashMap<String, Arc<IcebergTableProvider>>>,
50}
51
52impl IcebergSchemaProvider {
53 pub(crate) async fn try_new(
61 client: Arc<dyn Catalog>,
62 namespace: NamespaceIdent,
63 ) -> Result<Self> {
64 let table_names: Vec<_> = client
69 .list_tables(&namespace)
70 .await?
71 .iter()
72 .map(|tbl| tbl.name().to_string())
73 .collect();
74
75 let providers = try_join_all(
76 table_names
77 .iter()
78 .map(|name| IcebergTableProvider::try_new(client.clone(), namespace.clone(), name))
79 .collect::<Vec<_>>(),
80 )
81 .await?;
82
83 let tables = Arc::new(DashMap::new());
84 for (name, provider) in table_names.into_iter().zip(providers.into_iter()) {
85 tables.insert(name, Arc::new(provider));
86 }
87
88 Ok(IcebergSchemaProvider {
89 catalog: client,
90 namespace,
91 tables,
92 })
93 }
94}
95
96#[async_trait]
97impl SchemaProvider for IcebergSchemaProvider {
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101
102 fn table_names(&self) -> Vec<String> {
103 self.tables
104 .iter()
105 .flat_map(|entry| {
106 let table_name = entry.key().clone();
107 [table_name.clone()]
108 .into_iter()
109 .chain(
110 MetadataTableType::all_types().map(move |metadata_table_name| {
111 format!("{}${}", table_name, metadata_table_name.as_str())
112 }),
113 )
114 })
115 .collect()
116 }
117
118 fn table_exist(&self, name: &str) -> bool {
119 if let Some((table_name, metadata_table_name)) = name.split_once('$') {
120 self.tables.contains_key(table_name)
121 && MetadataTableType::try_from(metadata_table_name).is_ok()
122 } else {
123 self.tables.contains_key(name)
124 }
125 }
126
127 async fn table(&self, name: &str) -> DFResult<Option<Arc<dyn TableProvider>>> {
128 if let Some((table_name, metadata_table_name)) = name.split_once('$') {
129 let metadata_table_type =
130 MetadataTableType::try_from(metadata_table_name).map_err(DataFusionError::Plan)?;
131 if let Some(table) = self.tables.get(table_name) {
132 let metadata_table = table
133 .metadata_table(metadata_table_type)
134 .await
135 .map_err(to_datafusion_error)?;
136 return Ok(Some(Arc::new(metadata_table)));
137 } else {
138 return Ok(None);
139 }
140 }
141
142 Ok(self
143 .tables
144 .get(name)
145 .map(|entry| entry.value().clone() as Arc<dyn TableProvider>))
146 }
147
148 fn register_table(
149 &self,
150 name: String,
151 table: Arc<dyn TableProvider>,
152 ) -> DFResult<Option<Arc<dyn TableProvider>>> {
153 if self.table_exist(name.as_str()) {
155 return Err(DataFusionError::Execution(format!(
156 "Table {name} already exists"
157 )));
158 }
159
160 let df_schema = table.schema();
163 let iceberg_schema = arrow_schema_to_schema_auto_assign_ids(df_schema.as_ref())
164 .map_err(to_datafusion_error)?;
165
166 let table_creation = TableCreation::builder()
168 .name(name.clone())
169 .schema(iceberg_schema)
170 .build();
171
172 let catalog = self.catalog.clone();
173 let namespace = self.namespace.clone();
174 let tables = self.tables.clone();
175 let name_clone = name.clone();
176
177 let result = tokio::task::spawn_blocking(move || {
179 let rt = tokio::runtime::Handle::current();
181 rt.block_on(async move {
182 ensure_table_is_empty(&table)
184 .await
185 .map_err(to_datafusion_error)?;
186
187 catalog
188 .create_table(&namespace, table_creation)
189 .await
190 .map_err(to_datafusion_error)?;
191
192 let table_provider = IcebergTableProvider::try_new(
194 catalog.clone(),
195 namespace.clone(),
196 name_clone.clone(),
197 )
198 .await
199 .map_err(to_datafusion_error)?;
200
201 tables.insert(name_clone, Arc::new(table_provider));
203
204 Ok(None)
205 })
206 });
207
208 futures::executor::block_on(result).map_err(|e| {
211 DataFusionError::Execution(format!("Failed to create Iceberg table: {e}"))
212 })?
213 }
214
215 fn deregister_table(&self, name: &str) -> DFResult<Option<Arc<dyn TableProvider>>> {
216 if !self.table_exist(name) {
218 return Ok(None);
219 }
220
221 let catalog = self.catalog.clone();
222 let namespace = self.namespace.clone();
223 let tables = self.tables.clone();
224 let table_name = name.to_string();
225
226 let result = tokio::task::spawn_blocking(move || {
228 let rt = tokio::runtime::Handle::current();
229 rt.block_on(async move {
230 let table_ident = TableIdent::new(namespace, table_name.clone());
231
232 catalog
234 .drop_table(&table_ident)
235 .await
236 .map_err(to_datafusion_error)?;
237
238 let removed = tables
240 .remove(&table_name)
241 .map(|(_, table)| table as Arc<dyn TableProvider>);
242
243 Ok(removed)
244 })
245 });
246
247 futures::executor::block_on(result)
248 .map_err(|e| DataFusionError::Execution(format!("Failed to drop Iceberg table: {e}")))?
249 }
250}
251
252async fn ensure_table_is_empty(table: &Arc<dyn TableProvider>) -> Result<()> {
255 let session_ctx = SessionContext::new();
256 let exec_plan = table
257 .scan(&session_ctx.state(), None, &[], Some(1))
258 .await
259 .map_err(|e| Error::new(ErrorKind::Unexpected, format!("Failed to scan table: {e}")))?;
260
261 let task_ctx = Arc::new(TaskContext::default());
262 let stream = exec_plan.execute(0, task_ctx).map_err(|e| {
263 Error::new(
264 ErrorKind::Unexpected,
265 format!("Failed to execute scan: {e}"),
266 )
267 })?;
268
269 let batches: Vec<_> = stream.collect().await;
270 let has_data = batches
271 .into_iter()
272 .filter_map(|r| r.ok())
273 .any(|batch| batch.num_rows() > 0);
274
275 if has_data {
276 return Err(Error::new(
277 ErrorKind::Unexpected,
278 "register_table does not support tables with data.",
279 ));
280 }
281
282 Ok(())
283}
284
285#[cfg(test)]
286mod tests {
287 use std::collections::HashMap;
288 use std::sync::Arc;
289
290 use datafusion::arrow::array::{Int32Array, StringArray};
291 use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
292 use datafusion::arrow::record_batch::RecordBatch;
293 use datafusion::datasource::MemTable;
294 use iceberg::memory::{MEMORY_CATALOG_WAREHOUSE, MemoryCatalogBuilder};
295 use iceberg::{Catalog, CatalogBuilder, NamespaceIdent};
296 use tempfile::TempDir;
297
298 use super::*;
299
300 async fn create_test_schema_provider() -> (IcebergSchemaProvider, TempDir) {
301 let temp_dir = TempDir::new().unwrap();
302 let warehouse_path = temp_dir.path().to_str().unwrap().to_string();
303
304 let catalog = MemoryCatalogBuilder::default()
305 .load(
306 "memory",
307 HashMap::from([(MEMORY_CATALOG_WAREHOUSE.to_string(), warehouse_path.clone())]),
308 )
309 .await
310 .unwrap();
311
312 let namespace = NamespaceIdent::new("test_ns".to_string());
313 catalog
314 .create_namespace(&namespace, HashMap::new())
315 .await
316 .unwrap();
317
318 let provider = IcebergSchemaProvider::try_new(Arc::new(catalog), namespace)
319 .await
320 .unwrap();
321
322 (provider, temp_dir)
323 }
324
325 #[tokio::test]
326 async fn test_register_table_with_data_fails() {
327 let (schema_provider, _temp_dir) = create_test_schema_provider().await;
328
329 let arrow_schema = Arc::new(ArrowSchema::new(vec![
331 Field::new("id", DataType::Int32, false),
332 Field::new("name", DataType::Utf8, true),
333 ]));
334
335 let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
336 Arc::new(Int32Array::from(vec![1, 2, 3])),
337 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
338 ])
339 .unwrap();
340
341 let mem_table = MemTable::try_new(arrow_schema, vec![vec![batch]]).unwrap();
342
343 let result = schema_provider.register_table("test_table".to_string(), Arc::new(mem_table));
345
346 assert!(result.is_err());
347 let err = result.unwrap_err();
348 assert!(
349 err.to_string()
350 .contains("register_table does not support tables with data."),
351 "Expected error about tables with data, got: {err}",
352 );
353 }
354
355 #[tokio::test]
356 async fn test_register_empty_table_succeeds() {
357 let (schema_provider, _temp_dir) = create_test_schema_provider().await;
358
359 let arrow_schema = Arc::new(ArrowSchema::new(vec![
361 Field::new("id", DataType::Int32, false),
362 Field::new("name", DataType::Utf8, true),
363 ]));
364
365 let empty_batch = RecordBatch::new_empty(arrow_schema.clone());
367 let mem_table = MemTable::try_new(arrow_schema, vec![vec![empty_batch]]).unwrap();
368
369 let result = schema_provider.register_table("empty_table".to_string(), Arc::new(mem_table));
371
372 assert!(result.is_ok(), "Expected success, got: {result:?}");
373
374 assert!(schema_provider.table_exist("empty_table"));
376 }
377
378 #[tokio::test]
379 async fn test_register_duplicate_table_fails() {
380 let (schema_provider, _temp_dir) = create_test_schema_provider().await;
381
382 let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
384 "id",
385 DataType::Int32,
386 false,
387 )]));
388
389 let empty_batch1 = RecordBatch::new_empty(arrow_schema.clone());
390 let empty_batch2 = RecordBatch::new_empty(arrow_schema.clone());
391 let mem_table1 = MemTable::try_new(arrow_schema.clone(), vec![vec![empty_batch1]]).unwrap();
392 let mem_table2 = MemTable::try_new(arrow_schema, vec![vec![empty_batch2]]).unwrap();
393
394 let result1 = schema_provider.register_table("dup_table".to_string(), Arc::new(mem_table1));
396 assert!(result1.is_ok());
397
398 let result2 = schema_provider.register_table("dup_table".to_string(), Arc::new(mem_table2));
400 assert!(result2.is_err());
401 let err = result2.unwrap_err();
402 assert!(
403 err.to_string().contains("already exists"),
404 "Expected error about table already existing, got: {err}",
405 );
406 }
407
408 #[tokio::test]
409 async fn test_deregister_table_succeeds() {
410 let (schema_provider, _temp_dir) = create_test_schema_provider().await;
411
412 let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
414 "id",
415 DataType::Int32,
416 false,
417 )]));
418
419 let empty_batch = RecordBatch::new_empty(arrow_schema.clone());
420 let mem_table = MemTable::try_new(arrow_schema, vec![vec![empty_batch]]).unwrap();
421
422 let result = schema_provider.register_table("drop_me".to_string(), Arc::new(mem_table));
424 assert!(result.is_ok());
425 assert!(schema_provider.table_exist("drop_me"));
426
427 let result = schema_provider.deregister_table("drop_me");
429 assert!(result.is_ok());
430 assert!(result.unwrap().is_some());
431
432 assert!(!schema_provider.table_exist("drop_me"));
434 }
435
436 #[tokio::test]
437 async fn test_deregister_nonexistent_table_returns_none() {
438 let (schema_provider, _temp_dir) = create_test_schema_provider().await;
439
440 let result = schema_provider.deregister_table("nonexistent");
442 assert!(result.is_ok());
443 assert!(result.unwrap().is_none());
444 }
445}