From bc49c4db82efb8912c8d5625e84ff0e65cd8f114 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 7 Mar 2025 05:53:36 -0800 Subject: [PATCH] feat: respect datafusion's batch size when running as a table provider (#2187) Datafusion makes the batch size available as part of the `SessionState`. We should use that to set the `max_batch_length` property in the `QueryExecutionOptions`. --- Cargo.lock | 2 +- rust/lancedb/src/table/datafusion.rs | 50 +++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f350116..ca66afe3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4044,7 +4044,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.21.0-beta.0" +version = "0.21.0-beta.1" dependencies = [ "arrow", "env_logger", diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index c91d945b..5a24d9ac 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -151,7 +151,7 @@ impl TableProvider for BaseTableAdapter { async fn scan( &self, - _state: &dyn Session, + state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], limit: Option, @@ -177,9 +177,15 @@ impl TableProvider for BaseTableAdapter { // Need to override the default of 10 query.limit = None; } + + let options = QueryExecutionOptions { + max_batch_length: state.config().batch_size() as u32, + ..Default::default() + }; + let plan = self .table - .create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default()) + .create_plan(&AnyQuery::Query(query), options) .map_err(|err| DataFusionError::External(err.into())) .await?; Ok(Arc::new(MetadataEraserExec::new(plan))) @@ -208,11 +214,14 @@ pub mod tests { RecordBatchReader, StringArray, UInt32Array, }; use arrow_schema::{DataType, Field, Schema}; - use datafusion::{datasource::provider_as_source, prelude::SessionContext}; + use datafusion::{ + datasource::provider_as_source, + prelude::{SessionConfig, SessionContext}, + }; use datafusion_catalog::TableProvider; use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder}; - use futures::TryStreamExt; + use futures::{StreamExt, TryStreamExt}; use tempfile::tempdir; use crate::{ @@ -332,7 +341,14 @@ pub mod tests { } async fn plan_to_stream(plan: LogicalPlan) -> SendableRecordBatchStream { - SessionContext::new() + Self::plan_to_stream_with_config(plan, SessionConfig::default()).await + } + + async fn plan_to_stream_with_config( + plan: LogicalPlan, + config: SessionConfig, + ) -> SendableRecordBatchStream { + SessionContext::new_with_config(config) .execute_logical_plan(plan) .await .unwrap() @@ -382,6 +398,30 @@ pub mod tests { } } + #[tokio::test] + async fn test_batch_size() { + let fixture = TestFixture::new().await; + + let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter2), None) + .unwrap() + .build() + .unwrap(); + + let config = SessionConfig::default().with_batch_size(100); + + let stream = TestFixture::plan_to_stream_with_config(plan.clone(), config).await; + + let batch_count = stream.count().await; + assert_eq!(batch_count, 10); + + let config = SessionConfig::default().with_batch_size(250); + + let stream = TestFixture::plan_to_stream_with_config(plan, config).await; + + let batch_count = stream.count().await; + assert_eq!(batch_count, 4); + } + #[tokio::test] async fn test_metadata_erased() { let fixture = TestFixture::new().await;