diff --git a/Cargo.lock b/Cargo.lock index f96d3494..72f9074e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4135,6 +4135,7 @@ dependencies = [ "candle-transformers", "chrono", "crunchy", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", diff --git a/Cargo.toml b/Cargo.toml index 62a2d486..7b403a1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ arrow-arith = "53.2" arrow-cast = "53.2" async-trait = "0" chrono = "0.4.35" +datafusion = { version = "44.0", default-features = false } datafusion-catalog = "44.0" datafusion-common = { version = "44.0", default-features = false } datafusion-execution = "44.0" diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 28dafa0b..fe2aa096 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -85,6 +85,7 @@ aws-sdk-s3 = { version = "1.38.0" } aws-sdk-kms = { version = "1.37" } aws-config = { version = "1.0" } aws-smithy-runtime = { version = "1.3" } +datafusion.workspace = true http-body = "1" # Matching reqwest diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index 4564786b..a27a8d92 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -120,7 +120,14 @@ pub struct BaseTableAdapter { impl BaseTableAdapter { pub async fn try_new(table: Arc) -> Result { - let schema = table.schema().await?; + let schema = Arc::new( + table + .schema() + .await? + .as_ref() + .clone() + .with_metadata(HashMap::default()), + ); Ok(Self { table, schema }) } } @@ -185,3 +192,72 @@ impl TableProvider for BaseTableAdapter { None } } + +#[cfg(test)] +pub mod tests { + use std::{collections::HashMap, sync::Arc}; + + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::{datasource::provider_as_source, prelude::SessionContext}; + use datafusion_catalog::TableProvider; + use datafusion_expr::LogicalPlanBuilder; + use futures::TryStreamExt; + use tempfile::tempdir; + + use crate::{connect, table::datafusion::BaseTableAdapter}; + + fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); + let schema = Arc::new( + Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata), + ); + RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], + )], + schema, + ) + } + + #[tokio::test] + async fn test_metadata_erased() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + let tbl = db + .create_table("foo", make_test_batches()) + .execute() + .await + .unwrap(); + + let provider = Arc::new( + BaseTableAdapter::try_new(tbl.base_table().clone()) + .await + .unwrap(), + ); + + assert!(provider.schema().metadata().is_empty()); + + let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None) + .unwrap() + .build() + .unwrap(); + + let mut stream = SessionContext::new() + .execute_logical_plan(plan) + .await + .unwrap() + .execute_stream() + .await + .unwrap(); + + while let Some(batch) = stream.try_next().await.unwrap() { + assert!(batch.schema().metadata().is_empty()); + } + } +}