feat: refactor the query API and add query support to the python async API (#1113)

In addition, there are also a number of changes in nodejs to the
docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
Weston Pace
2024-03-18 12:36:49 -07:00
parent 2db257ca29
commit 4180b44472
38 changed files with 2609 additions and 754 deletions

View File

@@ -3,6 +3,7 @@ use std::ops::Deref;
use futures::{TryFutureExt, TryStreamExt};
use lance_linalg::distance::MetricType;
use lancedb::query::{ExecutableQuery, QueryBase, Select};
use neon::context::FunctionContext;
use neon::handle::Handle;
use neon::prelude::*;
@@ -56,53 +57,72 @@ impl JsQuery {
let channel = cx.channel();
let table = js_table.table.clone();
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
let mut builder = table.query();
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
builder = builder.nearest_to(&query);
if let Some(metric_type) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap())
{
builder = builder.metric_type(metric_type);
}
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
builder = builder.nprobes(nprobes);
};
if let Some(filter) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx))
{
builder = builder.filter(filter);
builder = builder.only_if(filter);
}
if let Some(select) = select {
builder = builder.select(select.as_slice());
builder = builder.select(Select::columns(select.as_slice()));
}
if let Some(limit) = limit {
builder = builder.limit(limit as usize);
};
builder = builder.prefilter(prefilter);
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
let mut vector_builder = builder.nearest_to(query).unwrap();
if let Some(metric_type) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap())
{
vector_builder = vector_builder.distance_type(metric_type);
}
rt.spawn(async move {
let record_batch_stream = builder.execute_stream();
let results = record_batch_stream
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
vector_builder = vector_builder.nprobes(nprobes);
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
if !prefilter {
vector_builder = vector_builder.postfilter();
}
rt.spawn(async move {
let results = vector_builder
.execute()
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
});
});
});
} else {
rt.spawn(async move {
let results = builder
.execute()
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(lancedb::error::Error::from)
})
.await;
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_throw(&mut cx)?;
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
convert::new_js_buffer(buffer, &mut cx, is_electron)
});
});
};
Ok(promise)
}
}

View File

@@ -21,6 +21,7 @@ use futures::TryStreamExt;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use lancedb::{connect, Result, Table as LanceDbTable};
#[tokio::main]
@@ -150,9 +151,10 @@ async fn create_index(table: &LanceDbTable) -> Result<()> {
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
// --8<-- [start:search]
table
.search(&[1.0; 128])
.query()
.limit(2)
.execute_stream()
.nearest_to(&[1.0; 128])?
.execute()
.await?
.try_collect::<Vec<_>>()
.await

View File

@@ -342,7 +342,11 @@ mod test {
use object_store::local::LocalFileSystem;
use tempfile;
use crate::{connect, table::WriteOptions};
use crate::{
connect,
query::{ExecutableQuery, QueryBase},
table::WriteOptions,
};
#[tokio::test]
async fn test_e2e() {
@@ -381,9 +385,11 @@ mod test {
assert_eq!(t.count_rows(None).await.unwrap(), 100);
let q = t
.search(&[0.1, 0.1, 0.1, 0.1])
.query()
.limit(10)
.execute_stream()
.nearest_to(&[0.1, 0.1, 0.1, 0.1])
.unwrap()
.execute()
.await
.unwrap();

View File

@@ -150,6 +150,7 @@
//! # use arrow_schema::{DataType, Schema, Field};
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use lancedb::query::{ExecutableQuery, QueryBase};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
@@ -170,8 +171,10 @@
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
//! # let table = db.open_table("my_table").execute().await.unwrap();
//! let results = table
//! .search(&[1.0; 128])
//! .execute_stream()
//! .query()
//! .nearest_to(&[1.0; 128])
//! .unwrap()
//! .execute()
//! .await
//! .unwrap()
//! .try_collect::<Vec<_>>()

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewCol
use crate::{
error::Result,
index::{IndexBuilder, IndexConfig},
query::Query,
query::{Query, QueryExecutionOptions, VectorQuery},
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableInternal, UpdateBuilder,
@@ -66,7 +66,18 @@ impl TableInternal for RemoteTable {
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
todo!()
}
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
async fn plain_query(
&self,
_query: &Query,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {

View File

@@ -17,6 +17,8 @@
use std::path::Path;
use std::sync::Arc;
use arrow::array::AsArray;
use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
@@ -47,7 +49,9 @@ use crate::index::{
vector::{suggested_num_partitions, suggested_num_sub_vectors},
Index, IndexBuilder,
};
use crate::query::{Query, Select, DEFAULT_TOP_K};
use crate::query::{
Query, QueryExecutionOptions, Select, ToQueryVector, VectorQuery, DEFAULT_TOP_K,
};
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
use self::dataset::DatasetConsistencyWrapper;
@@ -230,7 +234,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn add(&self, add: AddDataBuilder) -> Result<()>;
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream>;
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn delete(&self, predicate: &str) -> Result<()>;
async fn update(&self, update: UpdateBuilder) -> Result<()>;
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
@@ -528,21 +541,30 @@ impl Table {
)
}
/// Search the table with a given query vector.
/// Create a [`Query`] Builder.
///
/// This is a convenience method for preparing an ANN query.
pub fn search(&self, query: &[f32]) -> Query {
self.query().nearest_to(query)
}
/// Create a generic [`Query`] Builder.
/// Queries allow you to search your existing data. By default the query will
/// return all the data in the table in no particular order. The builder
/// returned by this method can be used to control the query using filtering,
/// vector similarity, sorting, and more.
///
/// When appropriate, various indices and statistics based pruning will be used to
/// accelerate the query.
/// Note: By default, all columns are returned. For best performance, you should
/// only fetch the columns you need. See [`Query::select_with_projection`] for
/// more details.
///
/// When appropriate, various indices and statistics will be used to accelerate
/// the query.
///
/// # Examples
///
/// ## Run a vector search (ANN) query.
/// ## Vector search
///
/// This example will find the 10 rows whose value in the "vector" column are
/// closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
/// on the "vector" column then this will perform an ANN search.
///
/// The [`Query::refine_factor`] and [`Query::nprobes`] methods are used to
/// control the recall / latency tradeoff of the search.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
@@ -551,19 +573,25 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// use crate::lancedb::query::ExecutableQuery;
/// let stream = tbl
/// .query()
/// .nearest_to(&[1.0, 2.0, 3.0])
/// .unwrap()
/// .refine_factor(5)
/// .nprobes(10)
/// .execute_stream()
/// .execute()
/// .await
/// .unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
///
/// ## Run a SQL-style filter
/// ## SQL-style filter
///
/// This query will return up to 1000 rows whose value in the `id` column
/// is greater than 5. LanceDb supports a broad set of filtering functions.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
/// # use futures::TryStreamExt;
@@ -571,18 +599,23 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// use crate::lancedb::query::{ExecutableQuery, QueryBase};
/// let stream = tbl
/// .query()
/// .filter("id > 5")
/// .only_if("id > 5")
/// .limit(1000)
/// .execute_stream()
/// .execute()
/// .await
/// .unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
///
/// ## Run a full scan query.
/// ## Full scan
///
/// This query will return everything in the table in no particular
/// order.
///
/// ```no_run
/// # use arrow_array::RecordBatch;
/// # use futures::TryStreamExt;
@@ -590,7 +623,8 @@ impl Table {
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
/// use crate::lancedb::Table;
/// let stream = tbl.query().execute_stream().await.unwrap();
/// use crate::lancedb::query::ExecutableQuery;
/// let stream = tbl.query().execute().await.unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
@@ -598,6 +632,15 @@ impl Table {
Query::new(self.inner.clone())
}
/// Search the table with a given query vector.
///
/// This is a convenience method for preparing a vector query and
/// is the same thing as calling `nearest_to` on the builder returned
/// by `query`. See [`Query::nearest_to`] for more details.
pub fn vector_search(&self, query: impl ToQueryVector) -> Result<VectorQuery> {
self.query().nearest_to(query)
}
/// Optimize the on-disk data and indices for better performance.
///
/// <section class="warning">Experimental API</section>
@@ -1107,6 +1150,75 @@ impl NativeTable {
.await?;
Ok(())
}
async fn generic_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
{
return Err(Error::Schema {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query_vector.len(),
),
});
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type);
}
Ok(scanner.try_into_stream().await?)
}
}
#[async_trait::async_trait]
@@ -1232,63 +1344,21 @@ impl TableInternal for NativeTable {
Ok(())
}
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
async fn plain_query(
&self,
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(&query.clone().into_vector(), options)
.await
}
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
{
return Err(Error::Schema {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query_vector.len(),
),
});
}
scanner.nearest(&column, query_vector, query.limit.unwrap_or(DEFAULT_TOP_K))?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
match &query.select {
Select::Simple(select) => {
scanner.project(select.as_slice())?;
}
Select::Projection(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(metric_type) = query.metric_type {
scanner.distance_metric(metric_type);
}
Ok(scanner.try_into_stream().await?)
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert(
@@ -1450,6 +1520,7 @@ mod tests {
use crate::connect;
use crate::connection::ConnectBuilder;
use crate::index::scalar::BTreeIndexBuilder;
use crate::query::{ExecutableQuery, QueryBase};
use super::*;
@@ -1689,8 +1760,8 @@ mod tests {
let mut batches = table
.query()
.select(&["id", "name"])
.execute_stream()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
@@ -1841,7 +1912,7 @@ mod tests {
let mut batches = table
.query()
.select(&[
.select(Select::columns(&[
"string",
"large_string",
"int32",
@@ -1855,8 +1926,8 @@ mod tests {
"timestamp_ms",
"vec_f32",
"vec_f64",
])
.execute_stream()
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()