From 098e397cf0c06e760087b9eb0bcb359da122cdec Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 13 Dec 2023 22:59:01 -0800 Subject: [PATCH] feat: LocalTable for vectordb now supports filters without vector search (#693) Note this currently the filter/where is only implemented for LocalTable so that it requires an explicit cast to "enable" (see new unit test). The alternative is to add it to the Table interface, but since it's not available on RemoteTable this may cause some user experience issues. --- node/src/index.ts | 10 +++++ node/src/query.ts | 18 +++++---- node/src/test/test.ts | 23 +++++++++++- rust/ffi/node/src/query.rs | 23 ++++++++---- rust/vectordb/src/io/object_store.rs | 2 +- rust/vectordb/src/query.rs | 55 ++++++++++++++++++++-------- rust/vectordb/src/table.rs | 10 +++-- 7 files changed, 104 insertions(+), 37 deletions(-) diff --git a/node/src/index.ts b/node/src/index.ts index 195b079c..def416d4 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -488,6 +488,16 @@ export class LocalTable implements Table { return new Query(query, this._tbl, this._embeddings) } + /** + * Creates a filter query to find all rows matching the specified criteria + * @param value The filter criteria (like SQL where clause syntax) + */ + filter (value: string): Query { + return new Query(undefined, this._tbl, this._embeddings).filter(value) + } + + where = this.filter + /** * Insert records into this Table. * diff --git a/node/src/query.ts b/node/src/query.ts index c1156c38..932adddb 100644 --- a/node/src/query.ts +++ b/node/src/query.ts @@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js') * A builder for nearest neighbor queries for LanceDB. */ export class Query { - private readonly _query: T + private readonly _query?: T private readonly _tbl?: any private _queryVector?: number[] - private _limit: number + private _limit?: number private _refineFactor?: number private _nprobes: number private _select?: string[] @@ -35,10 +35,10 @@ export class Query { private _prefilter: boolean protected readonly _embeddings?: EmbeddingFunction - constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction) { + constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction) { this._tbl = tbl this._query = query - this._limit = 10 + this._limit = undefined this._nprobes = 20 this._refineFactor = undefined this._select = undefined @@ -113,10 +113,12 @@ export class Query { * Execute the query and return the results as an Array of Objects */ async execute> (): Promise { - if (this._embeddings !== undefined) { - this._queryVector = (await this._embeddings.embed([this._query]))[0] - } else { - this._queryVector = this._query as number[] + if (this._query !== undefined) { + if (this._embeddings !== undefined) { + this._queryVector = (await this._embeddings.embed([this._query]))[0] + } else { + this._queryVector = this._query as number[] + } } const isElectron = this.isElectron() diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 3e89f90d..89a44c6c 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -78,12 +78,31 @@ describe('LanceDB client', function () { }) it('limits # of results', async function () { - const uri = await createTestDB() + const uri = await createTestDB(2, 100) const con = await lancedb.connect(uri) const table = await con.openTable('vectors') - const results = await table.search([0.1, 0.3]).limit(1).execute() + let results = await table.search([0.1, 0.3]).limit(1).execute() assert.equal(results.length, 1) assert.equal(results[0].id, 1) + + // there is a default limit if unspecified + results = await table.search([0.1, 0.3]).execute() + assert.equal(results.length, 10) + }) + + it('uses a filter / where clause without vector search', async function () { + // eslint-disable-next-line @typescript-eslint/explicit-function-return-type + const assertResults = (results: Array>) => { + assert.equal(results.length, 50) + } + + const uri = await createTestDB(2, 100) + const con = await lancedb.connect(uri) + const table = (await con.openTable('vectors')) as LocalTable + let results = await table.filter('id % 2 = 0').execute() + assertResults(results) + results = await table.where('id % 2 = 0').execute() + assertResults(results) }) it('uses a filter / where clause', async function () { diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index cd2f18a7..1c8e2327 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -23,8 +23,14 @@ impl JsQuery { let query_obj = cx.argument::(0)?; let limit = query_obj - .get::(&mut cx, "_limit")? - .value(&mut cx); + .get_opt::(&mut cx, "_limit")? + .map(|value| { + let limit = value.value(&mut cx) as u64; + if limit <= 0 { + panic!("Limit must be a positive integer"); + } + limit + }); let select = query_obj .get_opt::(&mut cx, "_select")? .map(|arr| { @@ -61,20 +67,23 @@ impl JsQuery { let (deferred, promise) = cx.promise(); let channel = cx.channel(); - let query_vector = query_obj.get::(&mut cx, "_queryVector")?; - let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); + let query_vector = query_obj.get_opt::(&mut cx, "_queryVector")?; let table = js_table.table.clone(); + let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)); rt.spawn(async move { - let builder = table - .search(Float32Array::from(query)) - .limit(limit as usize) + let mut builder = table + .search(query.map(|q| Float32Array::from(q))) .refine_factor(refine_factor) .nprobes(nprobes) .filter(filter) .metric_type(metric_type) .select(select) .prefilter(prefilter); + if let Some(limit) = limit { + builder = builder.limit(limit as usize); + }; + let record_batch_stream = builder.execute(); let results = record_batch_stream .and_then(|stream| { diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index 9643091a..01c85d87 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -359,7 +359,7 @@ mod test { assert_eq!(t.count_rows().await.unwrap(), 100); let q = t - .search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])) + .search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))) .limit(10) .execute() .await diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index e6580936..f28a96c2 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -24,8 +24,8 @@ use crate::error::Result; /// A builder for nearest neighbor queries for LanceDB. pub struct Query { pub dataset: Arc, - pub query_vector: Float32Array, - pub limit: usize, + pub query_vector: Option, + pub limit: Option, pub filter: Option, pub select: Option>, pub nprobes: usize, @@ -46,11 +46,11 @@ impl Query { /// # Returns /// /// * A [Query] object. - pub(crate) fn new(dataset: Arc, vector: Float32Array) -> Self { + pub(crate) fn new(dataset: Arc, vector: Option) -> Self { Query { dataset, query_vector: vector, - limit: 10, + limit: None, nprobes: 20, refine_factor: None, metric_type: None, @@ -69,11 +69,13 @@ impl Query { pub async fn execute(&self) -> Result { let mut scanner: Scanner = self.dataset.scan(); - scanner.nearest( - crate::table::VECTOR_COLUMN_NAME, - &self.query_vector, - self.limit, - )?; + if let Some(query) = self.query_vector.as_ref() { + // If there is a vector query, default to limit=10 if unspecified + scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?; + } else { + // If there is no vector query, it's ok to not have a limit + scanner.limit(self.limit.map(|limit| limit as i64), None)?; + } scanner.nprobs(self.nprobes); scanner.use_index(self.use_index); scanner.prefilter(self.prefilter); @@ -91,7 +93,7 @@ impl Query { /// /// * `limit` - The maximum number of results to return. pub fn limit(mut self, limit: usize) -> Query { - self.limit = limit; + self.limit = Some(limit); self } @@ -101,7 +103,7 @@ impl Query { /// /// * `vector` - The vector that will be used for search. pub fn query_vector(mut self, query_vector: Float32Array) -> Query { - self.query_vector = query_vector; + self.query_vector = Some(query_vector); self } @@ -174,7 +176,7 @@ mod tests { use std::sync::Arc; use super::*; - use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; + use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::StreamExt; use lance::dataset::Dataset; @@ -187,7 +189,7 @@ mod tests { let batches = make_test_batches(); let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); - let vector = Float32Array::from_iter_values([0.1, 0.2]); + let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); let query = Query::new(Arc::new(ds), vector.clone()); assert_eq!(query.query_vector, vector); @@ -201,8 +203,8 @@ mod tests { .metric_type(Some(MetricType::Cosine)) .refine_factor(Some(999)); - assert_eq!(query.query_vector, new_vector); - assert_eq!(query.limit, 100); + assert_eq!(query.query_vector.unwrap(), new_vector); + assert_eq!(query.limit.unwrap(), 100); assert_eq!(query.nprobes, 1000); assert_eq!(query.use_index, true); assert_eq!(query.metric_type, Some(MetricType::Cosine)); @@ -214,7 +216,7 @@ mod tests { let batches = make_non_empty_batches(); let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); - let vector = Float32Array::from_iter_values([0.1; 4]); + let vector = Some(Float32Array::from_iter_values([0.1; 4])); let query = Query::new(ds.clone(), vector.clone()); let result = query @@ -244,6 +246,27 @@ mod tests { } } + #[tokio::test] + async fn test_execute_no_vector() { + // test that it's ok to not specify a query vector (just filter / limit) + let batches = make_non_empty_batches(); + let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); + + let query = Query::new(ds.clone(), None); + let result = query + .filter(Some("id % 2 == 0".to_string())) + .execute() + .await; + let mut stream = result.expect("should have result"); + // should only have one batch + while let Some(batch) = stream.next().await { + let b = batch.expect("should be Ok"); + // cast arr into Int32Array + let arr: &Int32Array = b["id"].as_primitive(); + assert!(arr.iter().all(|x| x.unwrap() % 2 == 0)); + } + } + fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { let vec = Box::new(RandomVector::new().named("vector".to_string())); let id = Box::new(IncrementingInt32::new().named("id".to_string())); diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index be2aa5f3..9e216559 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -308,10 +308,14 @@ impl Table { /// # Returns /// /// * A [Query] object. - pub fn search(&self, query_vector: Float32Array) -> Query { + pub fn search(&self, query_vector: Option) -> Query { Query::new(self.dataset.clone(), query_vector) } + pub fn filter(&self, expr: String) -> Query { + Query::new(self.dataset.clone(), None).filter(Some(expr)) + } + /// Returns the number of rows in this Table pub async fn count_rows(&self) -> Result { Ok(self.dataset.count_rows().await?) @@ -844,8 +848,8 @@ mod tests { let table = Table::open(uri).await.unwrap(); let vector = Float32Array::from_iter_values([0.1, 0.2]); - let query = table.search(vector.clone()); - assert_eq!(vector, query.query_vector); + let query = table.search(Some(vector.clone())); + assert_eq!(vector, query.query_vector.unwrap()); } #[derive(Default, Debug)]