fix: ensure metadata erased from schema call in table provider (#2099)

This also adds a basic unit test for the table provider
This commit is contained in:
Weston Pace
2025-02-06 15:30:20 -08:00
committed by GitHub
parent 1a449fa49e
commit 4e5fbe6c99
4 changed files with 80 additions and 1 deletions

1
Cargo.lock generated
View File

@@ -4135,6 +4135,7 @@ dependencies = [
"candle-transformers",
"chrono",
"crunchy",
"datafusion",
"datafusion-catalog",
"datafusion-common",
"datafusion-execution",

View File

@@ -42,6 +42,7 @@ arrow-arith = "53.2"
arrow-cast = "53.2"
async-trait = "0"
chrono = "0.4.35"
datafusion = { version = "44.0", default-features = false }
datafusion-catalog = "44.0"
datafusion-common = { version = "44.0", default-features = false }
datafusion-execution = "44.0"

View File

@@ -85,6 +85,7 @@ aws-sdk-s3 = { version = "1.38.0" }
aws-sdk-kms = { version = "1.37" }
aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "1.3" }
datafusion.workspace = true
http-body = "1" # Matching reqwest

View File

@@ -120,7 +120,14 @@ pub struct BaseTableAdapter {
impl BaseTableAdapter {
pub async fn try_new(table: Arc<dyn BaseTable>) -> Result<Self> {
let schema = table.schema().await?;
let schema = Arc::new(
table
.schema()
.await?
.as_ref()
.clone()
.with_metadata(HashMap::default()),
);
Ok(Self { table, schema })
}
}
@@ -185,3 +192,72 @@ impl TableProvider for BaseTableAdapter {
None
}
}
#[cfg(test)]
pub mod tests {
use std::{collections::HashMap, sync::Arc};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema};
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
use datafusion_catalog::TableProvider;
use datafusion_expr::LogicalPlanBuilder;
use futures::TryStreamExt;
use tempfile::tempdir;
use crate::{connect, table::datafusion::BaseTableAdapter};
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
let schema = Arc::new(
Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata),
);
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
#[tokio::test]
async fn test_metadata_erased() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("foo", make_test_batches())
.execute()
.await
.unwrap();
let provider = Arc::new(
BaseTableAdapter::try_new(tbl.base_table().clone())
.await
.unwrap(),
);
assert!(provider.schema().metadata().is_empty());
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None)
.unwrap()
.build()
.unwrap();
let mut stream = SessionContext::new()
.execute_logical_plan(plan)
.await
.unwrap()
.execute_stream()
.await
.unwrap();
while let Some(batch) = stream.try_next().await.unwrap() {
assert!(batch.schema().metadata().is_empty());
}
}
}