feat(rust): allow to create execution plan on queries (#1350)

This commit is contained in:
Lei Xu
2024-05-31 17:33:58 -07:00
committed by GitHub
parent 7c133ec416
commit 56b4fd2bd9
5 changed files with 165 additions and 96 deletions

View File

@@ -1,5 +1,11 @@
[workspace]
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"]
members = [
"rust/ffi/node",
"rust/lancedb",
"nodejs",
"python",
"java/core/lancedb-jni"
]
# Python package needs to be built by maturin.
exclude = ["python"]
resolver = "2"
@@ -18,6 +24,7 @@ lance = { "version" = "=0.11.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.11.1" }
lance-linalg = { "version" = "=0.11.1" }
lance-testing = { "version" = "=0.11.1" }
lance-datafusion = { "version" = "=0.11.1" }
# Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"
@@ -29,6 +36,7 @@ arrow-arith = "51.0"
arrow-cast = "51.0"
async-trait = "0"
chrono = "0.4.35"
datafusion-physical-plan = "37.1"
half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits",
] }

View File

@@ -19,11 +19,13 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
chrono = { workspace = true }
datafusion-physical-plan.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-datafusion.workspace = true
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
@@ -43,7 +45,7 @@ serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true}
polars = { version = ">=0.37,<0.40.0", optional = true }
[dev-dependencies]
tempfile = "3.5.0"

View File

@@ -17,7 +17,10 @@ use std::sync::Arc;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan;
use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance_datafusion::exec::execute_plan;
use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result};
@@ -425,6 +428,15 @@ impl Default for QueryExecutionOptions {
/// There are various kinds of queries but they all return results
/// in the same way.
pub trait ExecutableQuery {
/// Return the Datafusion [ExecutionPlan].
///
/// The caller can further optimize the plan or execute it.
///
fn create_plan(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<Arc<dyn ExecutionPlan>>> + Send;
/// Execute the query with default options and return results
///
/// See [`ExecutableQuery::execute_with_options`] for more details.
@@ -545,6 +557,13 @@ impl HasQuery for Query {
}
impl ExecutableQuery for Query {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.parent
.clone()
.create_plan(&self.clone().into_vector(), options)
.await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
@@ -718,12 +737,19 @@ impl VectorQuery {
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.base.parent.clone().create_plan(self, options).await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
Ok(SendableRecordBatchStream::from(
self.base.parent.clone().vector_query(self, options).await?,
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
))
}
}
@@ -972,6 +998,30 @@ mod tests {
}
}
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
if plan.name() == name {
return true;
}
plan.children()
.iter()
.any(|child| assert_plan_exists(child, name))
}
#[tokio::test]
async fn test_create_execute_plan() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let plan = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
.unwrap()
.create_plan(QueryExecutionOptions::default())
.await
.unwrap();
assert_plan_exists(&plan, "KNNFlatSearch");
assert_plan_exists(&plan, "ProjectionExec");
}
#[tokio::test]
async fn query_base_methods_on_vector_query() {
// Make sure VectorQuery can be used as a QueryBase

View File

@@ -1,6 +1,9 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
use crate::{
@@ -71,6 +74,13 @@ impl TableInternal for RemoteTable {
) -> Result<()> {
todo!()
}
async fn create_plan(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn plain_query(
&self,
_query: &Query,
@@ -78,13 +88,6 @@ impl TableInternal for RemoteTable {
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
}

View File

@@ -23,6 +23,7 @@ use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
@@ -35,6 +36,7 @@ use lance::dataset::{
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_datafusion::exec::execute_plan;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
@@ -366,16 +368,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>>;
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn add(
&self,
add: AddDataBuilder<NoData>,
@@ -1479,79 +1481,11 @@ impl NativeTable {
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.try_into_stream().await?)
let plan = self.create_plan(query, options).await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
}
}
@@ -1703,6 +1637,86 @@ impl TableInternal for NativeTable {
Ok(())
}
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.create_plan().await?)
}
async fn plain_query(
&self,
query: &Query,
@@ -1712,14 +1726,6 @@ impl TableInternal for NativeTable {
.await
}
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,