diff --git a/node/src/index.ts b/node/src/index.ts index 195b079c..def416d4 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -488,6 +488,16 @@ export class LocalTable implements Table { return new Query(query, this._tbl, this._embeddings) } + /** + * Creates a filter query to find all rows matching the specified criteria + * @param value The filter criteria (like SQL where clause syntax) + */ + filter (value: string): Query { + return new Query(undefined, this._tbl, this._embeddings).filter(value) + } + + where = this.filter + /** * Insert records into this Table. * diff --git a/node/src/query.ts b/node/src/query.ts index c1156c38..932adddb 100644 --- a/node/src/query.ts +++ b/node/src/query.ts @@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js') * A builder for nearest neighbor queries for LanceDB. */ export class Query { - private readonly _query: T + private readonly _query?: T private readonly _tbl?: any private _queryVector?: number[] - private _limit: number + private _limit?: number private _refineFactor?: number private _nprobes: number private _select?: string[] @@ -35,10 +35,10 @@ export class Query { private _prefilter: boolean protected readonly _embeddings?: EmbeddingFunction - constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction) { + constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction) { this._tbl = tbl this._query = query - this._limit = 10 + this._limit = undefined this._nprobes = 20 this._refineFactor = undefined this._select = undefined @@ -113,10 +113,12 @@ export class Query { * Execute the query and return the results as an Array of Objects */ async execute> (): Promise { - if (this._embeddings !== undefined) { - this._queryVector = (await this._embeddings.embed([this._query]))[0] - } else { - this._queryVector = this._query as number[] + if (this._query !== undefined) { + if (this._embeddings !== undefined) { + this._queryVector = (await this._embeddings.embed([this._query]))[0] + } else { + this._queryVector = this._query as number[] + } } const isElectron = this.isElectron() diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 3e89f90d..89a44c6c 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -78,12 +78,31 @@ describe('LanceDB client', function () { }) it('limits # of results', async function () { - const uri = await createTestDB() + const uri = await createTestDB(2, 100) const con = await lancedb.connect(uri) const table = await con.openTable('vectors') - const results = await table.search([0.1, 0.3]).limit(1).execute() + let results = await table.search([0.1, 0.3]).limit(1).execute() assert.equal(results.length, 1) assert.equal(results[0].id, 1) + + // there is a default limit if unspecified + results = await table.search([0.1, 0.3]).execute() + assert.equal(results.length, 10) + }) + + it('uses a filter / where clause without vector search', async function () { + // eslint-disable-next-line @typescript-eslint/explicit-function-return-type + const assertResults = (results: Array>) => { + assert.equal(results.length, 50) + } + + const uri = await createTestDB(2, 100) + const con = await lancedb.connect(uri) + const table = (await con.openTable('vectors')) as LocalTable + let results = await table.filter('id % 2 = 0').execute() + assertResults(results) + results = await table.where('id % 2 = 0').execute() + assertResults(results) }) it('uses a filter / where clause', async function () { diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index cd2f18a7..1c8e2327 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -23,8 +23,14 @@ impl JsQuery { let query_obj = cx.argument::(0)?; let limit = query_obj - .get::(&mut cx, "_limit")? - .value(&mut cx); + .get_opt::(&mut cx, "_limit")? + .map(|value| { + let limit = value.value(&mut cx) as u64; + if limit <= 0 { + panic!("Limit must be a positive integer"); + } + limit + }); let select = query_obj .get_opt::(&mut cx, "_select")? .map(|arr| { @@ -61,20 +67,23 @@ impl JsQuery { let (deferred, promise) = cx.promise(); let channel = cx.channel(); - let query_vector = query_obj.get::(&mut cx, "_queryVector")?; - let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); + let query_vector = query_obj.get_opt::(&mut cx, "_queryVector")?; let table = js_table.table.clone(); + let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)); rt.spawn(async move { - let builder = table - .search(Float32Array::from(query)) - .limit(limit as usize) + let mut builder = table + .search(query.map(|q| Float32Array::from(q))) .refine_factor(refine_factor) .nprobes(nprobes) .filter(filter) .metric_type(metric_type) .select(select) .prefilter(prefilter); + if let Some(limit) = limit { + builder = builder.limit(limit as usize); + }; + let record_batch_stream = builder.execute(); let results = record_batch_stream .and_then(|stream| { diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index 9643091a..01c85d87 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -359,7 +359,7 @@ mod test { assert_eq!(t.count_rows().await.unwrap(), 100); let q = t - .search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])) + .search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))) .limit(10) .execute() .await diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index e6580936..f28a96c2 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -24,8 +24,8 @@ use crate::error::Result; /// A builder for nearest neighbor queries for LanceDB. pub struct Query { pub dataset: Arc, - pub query_vector: Float32Array, - pub limit: usize, + pub query_vector: Option, + pub limit: Option, pub filter: Option, pub select: Option>, pub nprobes: usize, @@ -46,11 +46,11 @@ impl Query { /// # Returns /// /// * A [Query] object. - pub(crate) fn new(dataset: Arc, vector: Float32Array) -> Self { + pub(crate) fn new(dataset: Arc, vector: Option) -> Self { Query { dataset, query_vector: vector, - limit: 10, + limit: None, nprobes: 20, refine_factor: None, metric_type: None, @@ -69,11 +69,13 @@ impl Query { pub async fn execute(&self) -> Result { let mut scanner: Scanner = self.dataset.scan(); - scanner.nearest( - crate::table::VECTOR_COLUMN_NAME, - &self.query_vector, - self.limit, - )?; + if let Some(query) = self.query_vector.as_ref() { + // If there is a vector query, default to limit=10 if unspecified + scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?; + } else { + // If there is no vector query, it's ok to not have a limit + scanner.limit(self.limit.map(|limit| limit as i64), None)?; + } scanner.nprobs(self.nprobes); scanner.use_index(self.use_index); scanner.prefilter(self.prefilter); @@ -91,7 +93,7 @@ impl Query { /// /// * `limit` - The maximum number of results to return. pub fn limit(mut self, limit: usize) -> Query { - self.limit = limit; + self.limit = Some(limit); self } @@ -101,7 +103,7 @@ impl Query { /// /// * `vector` - The vector that will be used for search. pub fn query_vector(mut self, query_vector: Float32Array) -> Query { - self.query_vector = query_vector; + self.query_vector = Some(query_vector); self } @@ -174,7 +176,7 @@ mod tests { use std::sync::Arc; use super::*; - use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; + use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::StreamExt; use lance::dataset::Dataset; @@ -187,7 +189,7 @@ mod tests { let batches = make_test_batches(); let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); - let vector = Float32Array::from_iter_values([0.1, 0.2]); + let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); let query = Query::new(Arc::new(ds), vector.clone()); assert_eq!(query.query_vector, vector); @@ -201,8 +203,8 @@ mod tests { .metric_type(Some(MetricType::Cosine)) .refine_factor(Some(999)); - assert_eq!(query.query_vector, new_vector); - assert_eq!(query.limit, 100); + assert_eq!(query.query_vector.unwrap(), new_vector); + assert_eq!(query.limit.unwrap(), 100); assert_eq!(query.nprobes, 1000); assert_eq!(query.use_index, true); assert_eq!(query.metric_type, Some(MetricType::Cosine)); @@ -214,7 +216,7 @@ mod tests { let batches = make_non_empty_batches(); let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); - let vector = Float32Array::from_iter_values([0.1; 4]); + let vector = Some(Float32Array::from_iter_values([0.1; 4])); let query = Query::new(ds.clone(), vector.clone()); let result = query @@ -244,6 +246,27 @@ mod tests { } } + #[tokio::test] + async fn test_execute_no_vector() { + // test that it's ok to not specify a query vector (just filter / limit) + let batches = make_non_empty_batches(); + let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); + + let query = Query::new(ds.clone(), None); + let result = query + .filter(Some("id % 2 == 0".to_string())) + .execute() + .await; + let mut stream = result.expect("should have result"); + // should only have one batch + while let Some(batch) = stream.next().await { + let b = batch.expect("should be Ok"); + // cast arr into Int32Array + let arr: &Int32Array = b["id"].as_primitive(); + assert!(arr.iter().all(|x| x.unwrap() % 2 == 0)); + } + } + fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { let vec = Box::new(RandomVector::new().named("vector".to_string())); let id = Box::new(IncrementingInt32::new().named("id".to_string())); diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index be2aa5f3..9e216559 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -308,10 +308,14 @@ impl Table { /// # Returns /// /// * A [Query] object. - pub fn search(&self, query_vector: Float32Array) -> Query { + pub fn search(&self, query_vector: Option) -> Query { Query::new(self.dataset.clone(), query_vector) } + pub fn filter(&self, expr: String) -> Query { + Query::new(self.dataset.clone(), None).filter(Some(expr)) + } + /// Returns the number of rows in this Table pub async fn count_rows(&self) -> Result { Ok(self.dataset.count_rows().await?) @@ -844,8 +848,8 @@ mod tests { let table = Table::open(uri).await.unwrap(); let vector = Float32Array::from_iter_values([0.1, 0.2]); - let query = table.search(vector.clone()); - assert_eq!(vector, query.query_vector); + let query = table.search(Some(vector.clone())); + assert_eq!(vector, query.query_vector.unwrap()); } #[derive(Default, Debug)]