mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-13 23:32:57 +00:00
feat(rust): allow to create execution plan on queries (#1350)
This commit is contained in:
10
Cargo.toml
10
Cargo.toml
@@ -1,5 +1,11 @@
|
||||
[workspace]
|
||||
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"]
|
||||
members = [
|
||||
"rust/ffi/node",
|
||||
"rust/lancedb",
|
||||
"nodejs",
|
||||
"python",
|
||||
"java/core/lancedb-jni"
|
||||
]
|
||||
# Python package needs to be built by maturin.
|
||||
exclude = ["python"]
|
||||
resolver = "2"
|
||||
@@ -18,6 +24,7 @@ lance = { "version" = "=0.11.1", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.11.1" }
|
||||
lance-linalg = { "version" = "=0.11.1" }
|
||||
lance-testing = { "version" = "=0.11.1" }
|
||||
lance-datafusion = { "version" = "=0.11.1" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "51.0", optional = false }
|
||||
arrow-array = "51.0"
|
||||
@@ -29,6 +36,7 @@ arrow-arith = "51.0"
|
||||
arrow-cast = "51.0"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion-physical-plan = "37.1"
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
|
||||
@@ -19,11 +19,13 @@ arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
chrono = { workspace = true }
|
||||
datafusion-physical-plan.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lazy_static.workspace = true
|
||||
lance = { workspace = true }
|
||||
lance-datafusion.workspace = true
|
||||
lance-index = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
@@ -43,7 +45,7 @@ serde_with = { version = "3.8.1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true}
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
|
||||
@@ -17,7 +17,10 @@ use std::sync::Arc;
|
||||
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use half::f16;
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
|
||||
use crate::arrow::SendableRecordBatchStream;
|
||||
use crate::error::{Error, Result};
|
||||
@@ -425,6 +428,15 @@ impl Default for QueryExecutionOptions {
|
||||
/// There are various kinds of queries but they all return results
|
||||
/// in the same way.
|
||||
pub trait ExecutableQuery {
|
||||
/// Return the Datafusion [ExecutionPlan].
|
||||
///
|
||||
/// The caller can further optimize the plan or execute it.
|
||||
///
|
||||
fn create_plan(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<Arc<dyn ExecutionPlan>>> + Send;
|
||||
|
||||
/// Execute the query with default options and return results
|
||||
///
|
||||
/// See [`ExecutableQuery::execute_with_options`] for more details.
|
||||
@@ -545,6 +557,13 @@ impl HasQuery for Query {
|
||||
}
|
||||
|
||||
impl ExecutableQuery for Query {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.parent
|
||||
.clone()
|
||||
.create_plan(&self.clone().into_vector(), options)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
@@ -718,12 +737,19 @@ impl VectorQuery {
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.base.parent.clone().create_plan(self, options).await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
self.base.parent.clone().vector_query(self, options).await?,
|
||||
DatasetRecordBatchStream::new(execute_plan(
|
||||
self.create_plan(options).await?,
|
||||
Default::default(),
|
||||
)?),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -972,6 +998,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
|
||||
if plan.name() == name {
|
||||
return true;
|
||||
}
|
||||
plan.children()
|
||||
.iter()
|
||||
.any(|child| assert_plan_exists(child, name))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_execute_plan() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let plan = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.create_plan(QueryExecutionOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_plan_exists(&plan, "KNNFlatSearch");
|
||||
assert_plan_exists(&plan, "ProjectionExec");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn query_base_methods_on_vector_query() {
|
||||
// Make sure VectorQuery can be used as a QueryBase
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
|
||||
|
||||
use crate::{
|
||||
@@ -71,6 +74,13 @@ impl TableInternal for RemoteTable {
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
async fn create_plan(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn plain_query(
|
||||
&self,
|
||||
_query: &Query,
|
||||
@@ -78,13 +88,6 @@ impl TableInternal for RemoteTable {
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
todo!()
|
||||
}
|
||||
async fn vector_query(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
todo!()
|
||||
}
|
||||
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use arrow::datatypes::Float32Type;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
|
||||
@@ -35,6 +36,7 @@ use lance::dataset::{
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::vector::hnsw::builder::HnswBuildParams;
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_index::vector::pq::PQBuildParams;
|
||||
@@ -366,16 +368,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>>;
|
||||
async fn plain_query(
|
||||
&self,
|
||||
query: &Query,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn vector_query(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
@@ -1479,79 +1481,11 @@ impl NativeTable {
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
let column = if let Some(col) = query.column.as_ref() {
|
||||
col.clone()
|
||||
} else {
|
||||
// Infer a vector column with the same dimension of the query vector.
|
||||
let arrow_schema = Schema::from(ds_ref.schema());
|
||||
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
|
||||
};
|
||||
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
dim,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
|
||||
}
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.prefilter);
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
|
||||
match &query.base.select {
|
||||
Select::Columns(select) => {
|
||||
scanner.project(select.as_slice())?;
|
||||
}
|
||||
Select::Dynamic(select_with_transform) => {
|
||||
scanner.project_with_transform(select_with_transform.as_slice())?;
|
||||
}
|
||||
Select::All => { /* Do nothing */ }
|
||||
}
|
||||
|
||||
if let Some(filter) = &query.base.filter {
|
||||
scanner.filter(filter)?;
|
||||
}
|
||||
|
||||
if let Some(refine_factor) = query.refine_factor {
|
||||
scanner.refine(refine_factor);
|
||||
}
|
||||
|
||||
if let Some(distance_type) = query.distance_type {
|
||||
scanner.distance_metric(distance_type.into());
|
||||
}
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
let plan = self.create_plan(query, options).await?;
|
||||
Ok(DatasetRecordBatchStream::new(execute_plan(
|
||||
plan,
|
||||
Default::default(),
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1703,6 +1637,86 @@ impl TableInternal for NativeTable {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
let column = if let Some(col) = query.column.as_ref() {
|
||||
col.clone()
|
||||
} else {
|
||||
// Infer a vector column with the same dimension of the query vector.
|
||||
let arrow_schema = Schema::from(ds_ref.schema());
|
||||
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
|
||||
};
|
||||
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
dim,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
|
||||
}
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.prefilter);
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
|
||||
match &query.base.select {
|
||||
Select::Columns(select) => {
|
||||
scanner.project(select.as_slice())?;
|
||||
}
|
||||
Select::Dynamic(select_with_transform) => {
|
||||
scanner.project_with_transform(select_with_transform.as_slice())?;
|
||||
}
|
||||
Select::All => { /* Do nothing */ }
|
||||
}
|
||||
|
||||
if let Some(filter) = &query.base.filter {
|
||||
scanner.filter(filter)?;
|
||||
}
|
||||
|
||||
if let Some(refine_factor) = query.refine_factor {
|
||||
scanner.refine(refine_factor);
|
||||
}
|
||||
|
||||
if let Some(distance_type) = query.distance_type {
|
||||
scanner.distance_metric(distance_type.into());
|
||||
}
|
||||
Ok(scanner.create_plan().await?)
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
&self,
|
||||
query: &Query,
|
||||
@@ -1712,14 +1726,6 @@ impl TableInternal for NativeTable {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn vector_query(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
self.generic_query(query, options).await
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
|
||||
Reference in New Issue
Block a user