mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 10:22:56 +00:00
feat: LocalTable for vectordb now supports filters without vector search (#693)
Note this currently the filter/where is only implemented for LocalTable so that it requires an explicit cast to "enable" (see new unit test). The alternative is to add it to the Table interface, but since it's not available on RemoteTable this may cause some user experience issues.
This commit is contained in:
@@ -488,6 +488,16 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return new Query(query, this._tbl, this._embeddings)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a filter query to find all rows matching the specified criteria
|
||||
* @param value The filter criteria (like SQL where clause syntax)
|
||||
*/
|
||||
filter (value: string): Query<T> {
|
||||
return new Query(undefined, this._tbl, this._embeddings).filter(value)
|
||||
}
|
||||
|
||||
where = this.filter
|
||||
|
||||
/**
|
||||
* Insert records into this Table.
|
||||
*
|
||||
|
||||
@@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js')
|
||||
* A builder for nearest neighbor queries for LanceDB.
|
||||
*/
|
||||
export class Query<T = number[]> {
|
||||
private readonly _query: T
|
||||
private readonly _query?: T
|
||||
private readonly _tbl?: any
|
||||
private _queryVector?: number[]
|
||||
private _limit: number
|
||||
private _limit?: number
|
||||
private _refineFactor?: number
|
||||
private _nprobes: number
|
||||
private _select?: string[]
|
||||
@@ -35,10 +35,10 @@ export class Query<T = number[]> {
|
||||
private _prefilter: boolean
|
||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||
|
||||
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
this._tbl = tbl
|
||||
this._query = query
|
||||
this._limit = 10
|
||||
this._limit = undefined
|
||||
this._nprobes = 20
|
||||
this._refineFactor = undefined
|
||||
this._select = undefined
|
||||
@@ -113,10 +113,12 @@ export class Query<T = number[]> {
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
||||
if (this._embeddings !== undefined) {
|
||||
this._queryVector = (await this._embeddings.embed([this._query]))[0]
|
||||
} else {
|
||||
this._queryVector = this._query as number[]
|
||||
if (this._query !== undefined) {
|
||||
if (this._embeddings !== undefined) {
|
||||
this._queryVector = (await this._embeddings.embed([this._query]))[0]
|
||||
} else {
|
||||
this._queryVector = this._query as number[]
|
||||
}
|
||||
}
|
||||
|
||||
const isElectron = this.isElectron()
|
||||
|
||||
@@ -78,12 +78,31 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
|
||||
it('limits # of results', async function () {
|
||||
const uri = await createTestDB()
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
let results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 1)
|
||||
|
||||
// there is a default limit if unspecified
|
||||
results = await table.search([0.1, 0.3]).execute()
|
||||
assert.equal(results.length, 10)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause without vector search', async function () {
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
|
||||
const assertResults = (results: Array<Record<string, unknown>>) => {
|
||||
assert.equal(results.length, 50)
|
||||
}
|
||||
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = (await con.openTable('vectors')) as LocalTable
|
||||
let results = await table.filter('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
results = await table.where('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause', async function () {
|
||||
|
||||
@@ -23,8 +23,14 @@ impl JsQuery {
|
||||
let query_obj = cx.argument::<JsObject>(0)?;
|
||||
|
||||
let limit = query_obj
|
||||
.get::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.value(&mut cx);
|
||||
.get_opt::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.map(|value| {
|
||||
let limit = value.value(&mut cx) as u64;
|
||||
if limit <= 0 {
|
||||
panic!("Limit must be a positive integer");
|
||||
}
|
||||
limit
|
||||
});
|
||||
let select = query_obj
|
||||
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
|
||||
.map(|arr| {
|
||||
@@ -61,20 +67,23 @@ impl JsQuery {
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let table = js_table.table.clone();
|
||||
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
|
||||
|
||||
rt.spawn(async move {
|
||||
let builder = table
|
||||
.search(Float32Array::from(query))
|
||||
.limit(limit as usize)
|
||||
let mut builder = table
|
||||
.search(query.map(|q| Float32Array::from(q)))
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.filter(filter)
|
||||
.metric_type(metric_type)
|
||||
.select(select)
|
||||
.prefilter(prefilter);
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
|
||||
let record_batch_stream = builder.execute();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| {
|
||||
|
||||
@@ -359,7 +359,7 @@ mod test {
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))
|
||||
.search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -24,8 +24,8 @@ use crate::error::Result;
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Float32Array,
|
||||
pub limit: usize,
|
||||
pub query_vector: Option<Float32Array>,
|
||||
pub limit: Option<usize>,
|
||||
pub filter: Option<String>,
|
||||
pub select: Option<Vec<String>>,
|
||||
pub nprobes: usize,
|
||||
@@ -46,11 +46,11 @@ impl Query {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: vector,
|
||||
limit: 10,
|
||||
limit: None,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
metric_type: None,
|
||||
@@ -69,11 +69,13 @@ impl Query {
|
||||
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
||||
let mut scanner: Scanner = self.dataset.scan();
|
||||
|
||||
scanner.nearest(
|
||||
crate::table::VECTOR_COLUMN_NAME,
|
||||
&self.query_vector,
|
||||
self.limit,
|
||||
)?;
|
||||
if let Some(query) = self.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
|
||||
}
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.use_index(self.use_index);
|
||||
scanner.prefilter(self.prefilter);
|
||||
@@ -91,7 +93,7 @@ impl Query {
|
||||
///
|
||||
/// * `limit` - The maximum number of results to return.
|
||||
pub fn limit(mut self, limit: usize) -> Query {
|
||||
self.limit = limit;
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -101,7 +103,7 @@ impl Query {
|
||||
///
|
||||
/// * `vector` - The vector that will be used for search.
|
||||
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
|
||||
self.query_vector = query_vector;
|
||||
self.query_vector = Some(query_vector);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -174,7 +176,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::StreamExt;
|
||||
use lance::dataset::Dataset;
|
||||
@@ -187,7 +189,7 @@ mod tests {
|
||||
let batches = make_test_batches();
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
assert_eq!(query.query_vector, vector);
|
||||
|
||||
@@ -201,8 +203,8 @@ mod tests {
|
||||
.metric_type(Some(MetricType::Cosine))
|
||||
.refine_factor(Some(999));
|
||||
|
||||
assert_eq!(query.query_vector, new_vector);
|
||||
assert_eq!(query.limit, 100);
|
||||
assert_eq!(query.query_vector.unwrap(), new_vector);
|
||||
assert_eq!(query.limit.unwrap(), 100);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert_eq!(query.use_index, true);
|
||||
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
||||
@@ -214,7 +216,7 @@ mod tests {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1; 4]);
|
||||
let vector = Some(Float32Array::from_iter_values([0.1; 4]));
|
||||
|
||||
let query = Query::new(ds.clone(), vector.clone());
|
||||
let result = query
|
||||
@@ -244,6 +246,27 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_no_vector() {
|
||||
// test that it's ok to not specify a query vector (just filter / limit)
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let query = Query::new(ds.clone(), None);
|
||||
let result = query
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.execute()
|
||||
.await;
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
let b = batch.expect("should be Ok");
|
||||
// cast arr into Int32Array
|
||||
let arr: &Int32Array = b["id"].as_primitive();
|
||||
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
||||
}
|
||||
}
|
||||
|
||||
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
||||
let vec = Box::new(RandomVector::new().named("vector".to_string()));
|
||||
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
|
||||
|
||||
@@ -308,10 +308,14 @@ impl Table {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub fn search(&self, query_vector: Float32Array) -> Query {
|
||||
pub fn search(&self, query_vector: Option<Float32Array>) -> Query {
|
||||
Query::new(self.dataset.clone(), query_vector)
|
||||
}
|
||||
|
||||
pub fn filter(&self, expr: String) -> Query {
|
||||
Query::new(self.dataset.clone(), None).filter(Some(expr))
|
||||
}
|
||||
|
||||
/// Returns the number of rows in this Table
|
||||
pub async fn count_rows(&self) -> Result<usize> {
|
||||
Ok(self.dataset.count_rows().await?)
|
||||
@@ -844,8 +848,8 @@ mod tests {
|
||||
let table = Table::open(uri).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.search(vector.clone());
|
||||
assert_eq!(vector, query.query_vector);
|
||||
let query = table.search(Some(vector.clone()));
|
||||
assert_eq!(vector, query.query_vector.unwrap());
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
|
||||
Reference in New Issue
Block a user