diff --git a/Cargo.toml b/Cargo.toml index 27a40ae3..7f336d10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,11 @@ [workspace] -members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"] +members = [ + "rust/ffi/node", + "rust/lancedb", + "nodejs", + "python", + "java/core/lancedb-jni" +] # Python package needs to be built by maturin. exclude = ["python"] resolver = "2" @@ -18,6 +24,7 @@ lance = { "version" = "=0.11.1", "features" = ["dynamodb"] } lance-index = { "version" = "=0.11.1" } lance-linalg = { "version" = "=0.11.1" } lance-testing = { "version" = "=0.11.1" } +lance-datafusion = { "version" = "=0.11.1" } # Note that this one does not include pyarrow arrow = { version = "51.0", optional = false } arrow-array = "51.0" @@ -29,6 +36,7 @@ arrow-arith = "51.0" arrow-cast = "51.0" async-trait = "0" chrono = "0.4.35" +datafusion-physical-plan = "37.1" half = { "version" = "=2.4.1", default-features = false, features = [ "num-traits", ] } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index ea606727..a36baed5 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -19,11 +19,13 @@ arrow-ord = { workspace = true } arrow-cast = { workspace = true } arrow-ipc.workspace = true chrono = { workspace = true } +datafusion-physical-plan.workspace = true object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } lazy_static.workspace = true lance = { workspace = true } +lance-datafusion.workspace = true lance-index = { workspace = true } lance-linalg = { workspace = true } lance-testing = { workspace = true } @@ -43,7 +45,7 @@ serde_with = { version = "3.8.1" } # For remote feature reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } -polars = { version = ">=0.37,<0.40.0", optional = true} +polars = { version = ">=0.37,<0.40.0", optional = true } [dev-dependencies] tempfile = "3.5.0" diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index f64de7ba..cc915556 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -17,7 +17,10 @@ use std::sync::Arc; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_schema::DataType; +use datafusion_physical_plan::ExecutionPlan; use half::f16; +use lance::dataset::scanner::DatasetRecordBatchStream; +use lance_datafusion::exec::execute_plan; use crate::arrow::SendableRecordBatchStream; use crate::error::{Error, Result}; @@ -425,6 +428,15 @@ impl Default for QueryExecutionOptions { /// There are various kinds of queries but they all return results /// in the same way. pub trait ExecutableQuery { + /// Return the Datafusion [ExecutionPlan]. + /// + /// The caller can further optimize the plan or execute it. + /// + fn create_plan( + &self, + options: QueryExecutionOptions, + ) -> impl Future>> + Send; + /// Execute the query with default options and return results /// /// See [`ExecutableQuery::execute_with_options`] for more details. @@ -545,6 +557,13 @@ impl HasQuery for Query { } impl ExecutableQuery for Query { + async fn create_plan(&self, options: QueryExecutionOptions) -> Result> { + self.parent + .clone() + .create_plan(&self.clone().into_vector(), options) + .await + } + async fn execute_with_options( &self, options: QueryExecutionOptions, @@ -718,12 +737,19 @@ impl VectorQuery { } impl ExecutableQuery for VectorQuery { + async fn create_plan(&self, options: QueryExecutionOptions) -> Result> { + self.base.parent.clone().create_plan(self, options).await + } + async fn execute_with_options( &self, options: QueryExecutionOptions, ) -> Result { Ok(SendableRecordBatchStream::from( - self.base.parent.clone().vector_query(self, options).await?, + DatasetRecordBatchStream::new(execute_plan( + self.create_plan(options).await?, + Default::default(), + )?), )) } } @@ -972,6 +998,30 @@ mod tests { } } + fn assert_plan_exists(plan: &Arc, name: &str) -> bool { + if plan.name() == name { + return true; + } + plan.children() + .iter() + .any(|child| assert_plan_exists(child, name)) + } + + #[tokio::test] + async fn test_create_execute_plan() { + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + let plan = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3, 0.4]) + .unwrap() + .create_plan(QueryExecutionOptions::default()) + .await + .unwrap(); + assert_plan_exists(&plan, "KNNFlatSearch"); + assert_plan_exists(&plan, "ProjectionExec"); + } + #[tokio::test] async fn query_base_methods_on_vector_query() { // Make sure VectorQuery can be used as a QueryBase diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 84b2c247..1b7c6d20 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + use arrow_array::RecordBatchReader; use arrow_schema::SchemaRef; use async_trait::async_trait; +use datafusion_physical_plan::ExecutionPlan; use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform}; use crate::{ @@ -71,6 +74,13 @@ impl TableInternal for RemoteTable { ) -> Result<()> { todo!() } + async fn create_plan( + &self, + _query: &VectorQuery, + _options: QueryExecutionOptions, + ) -> Result> { + unimplemented!() + } async fn plain_query( &self, _query: &Query, @@ -78,13 +88,6 @@ impl TableInternal for RemoteTable { ) -> Result { todo!() } - async fn vector_query( - &self, - _query: &VectorQuery, - _options: QueryExecutionOptions, - ) -> Result { - todo!() - } async fn update(&self, _update: UpdateBuilder) -> Result<()> { todo!() } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 7a181a22..3bd0fc49 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -23,6 +23,7 @@ use arrow::datatypes::Float32Type; use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; +use datafusion_physical_plan::ExecutionPlan; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; @@ -35,6 +36,7 @@ use lance::dataset::{ }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; +use lance_datafusion::exec::execute_plan; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -366,16 +368,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn schema(&self) -> Result; /// Count the number of rows in this table. async fn count_rows(&self, filter: Option) -> Result; + async fn create_plan( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result>; async fn plain_query( &self, query: &Query, options: QueryExecutionOptions, ) -> Result; - async fn vector_query( - &self, - query: &VectorQuery, - options: QueryExecutionOptions, - ) -> Result; async fn add( &self, add: AddDataBuilder, @@ -1479,79 +1481,11 @@ impl NativeTable { query: &VectorQuery, options: QueryExecutionOptions, ) -> Result { - let ds_ref = self.dataset.get().await?; - let mut scanner: Scanner = ds_ref.scan(); - - if let Some(query_vector) = query.query_vector.as_ref() { - // If there is a vector query, default to limit=10 if unspecified - let column = if let Some(col) = query.column.as_ref() { - col.clone() - } else { - // Infer a vector column with the same dimension of the query vector. - let arrow_schema = Schema::from(ds_ref.schema()); - default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? - }; - let field = ds_ref.schema().field(&column).ok_or(Error::Schema { - message: format!("Column {} not found in dataset schema", column), - })?; - if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { - if !f.data_type().is_floating() { - return Err(Error::InvalidInput { - message: format!( - "The data type of the vector column '{}' is not a floating point type", - column - ), - }); - } - if dim != query_vector.len() as i32 { - return Err(Error::InvalidInput { - message: format!( - "The dimension of the query vector does not match with the dimension of the vector column '{}': \ - query dim={}, expected vector dim={}", - column, - query_vector.len(), - dim, - ), - }); - } - } - let query_vector = query_vector.as_primitive::(); - scanner.nearest( - &column, - query_vector, - query.base.limit.unwrap_or(DEFAULT_TOP_K), - )?; - } else { - // If there is no vector query, it's ok to not have a limit - scanner.limit(query.base.limit.map(|limit| limit as i64), None)?; - } - scanner.nprobs(query.nprobes); - scanner.use_index(query.use_index); - scanner.prefilter(query.prefilter); - scanner.batch_size(options.max_batch_length as usize); - - match &query.base.select { - Select::Columns(select) => { - scanner.project(select.as_slice())?; - } - Select::Dynamic(select_with_transform) => { - scanner.project_with_transform(select_with_transform.as_slice())?; - } - Select::All => { /* Do nothing */ } - } - - if let Some(filter) = &query.base.filter { - scanner.filter(filter)?; - } - - if let Some(refine_factor) = query.refine_factor { - scanner.refine(refine_factor); - } - - if let Some(distance_type) = query.distance_type { - scanner.distance_metric(distance_type.into()); - } - Ok(scanner.try_into_stream().await?) + let plan = self.create_plan(query, options).await?; + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + Default::default(), + )?)) } } @@ -1703,6 +1637,86 @@ impl TableInternal for NativeTable { Ok(()) } + async fn create_plan( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result> { + let ds_ref = self.dataset.get().await?; + let mut scanner: Scanner = ds_ref.scan(); + + if let Some(query_vector) = query.query_vector.as_ref() { + // If there is a vector query, default to limit=10 if unspecified + let column = if let Some(col) = query.column.as_ref() { + col.clone() + } else { + // Infer a vector column with the same dimension of the query vector. + let arrow_schema = Schema::from(ds_ref.schema()); + default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? + }; + let field = ds_ref.schema().field(&column).ok_or(Error::Schema { + message: format!("Column {} not found in dataset schema", column), + })?; + if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { + if !f.data_type().is_floating() { + return Err(Error::InvalidInput { + message: format!( + "The data type of the vector column '{}' is not a floating point type", + column + ), + }); + } + if dim != query_vector.len() as i32 { + return Err(Error::InvalidInput { + message: format!( + "The dimension of the query vector does not match with the dimension of the vector column '{}': \ + query dim={}, expected vector dim={}", + column, + query_vector.len(), + dim, + ), + }); + } + } + let query_vector = query_vector.as_primitive::(); + scanner.nearest( + &column, + query_vector, + query.base.limit.unwrap_or(DEFAULT_TOP_K), + )?; + } else { + // If there is no vector query, it's ok to not have a limit + scanner.limit(query.base.limit.map(|limit| limit as i64), None)?; + } + scanner.nprobs(query.nprobes); + scanner.use_index(query.use_index); + scanner.prefilter(query.prefilter); + scanner.batch_size(options.max_batch_length as usize); + + match &query.base.select { + Select::Columns(select) => { + scanner.project(select.as_slice())?; + } + Select::Dynamic(select_with_transform) => { + scanner.project_with_transform(select_with_transform.as_slice())?; + } + Select::All => { /* Do nothing */ } + } + + if let Some(filter) = &query.base.filter { + scanner.filter(filter)?; + } + + if let Some(refine_factor) = query.refine_factor { + scanner.refine(refine_factor); + } + + if let Some(distance_type) = query.distance_type { + scanner.distance_metric(distance_type.into()); + } + Ok(scanner.create_plan().await?) + } + async fn plain_query( &self, query: &Query, @@ -1712,14 +1726,6 @@ impl TableInternal for NativeTable { .await } - async fn vector_query( - &self, - query: &VectorQuery, - options: QueryExecutionOptions, - ) -> Result { - self.generic_query(query, options).await - } - async fn merge_insert( &self, params: MergeInsertBuilder,