From 59c25574f0e07aed2723b7c3d0e65df0f484ee50 Mon Sep 17 00:00:00 2001 From: Rob Meng Date: Fri, 1 Dec 2023 16:49:10 -0500 Subject: [PATCH] feat: enable prefilter in node js (#675) enable prefiltering in node js, both native and remote --- node/src/query.ts | 7 +++++++ node/src/remote/client.ts | 4 +++- node/src/remote/index.ts | 3 ++- node/src/test/test.ts | 14 ++++++++++++++ rust/ffi/node/src/query.rs | 5 ++++- 5 files changed, 30 insertions(+), 3 deletions(-) diff --git a/node/src/query.ts b/node/src/query.ts index 6827cbb7..c1156c38 100644 --- a/node/src/query.ts +++ b/node/src/query.ts @@ -32,6 +32,7 @@ export class Query { private _select?: string[] private _filter?: string private _metricType?: MetricType + private _prefilter: boolean protected readonly _embeddings?: EmbeddingFunction constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction) { @@ -44,6 +45,7 @@ export class Query { this._filter = undefined this._metricType = undefined this._embeddings = embeddings + this._prefilter = false } /*** @@ -102,6 +104,11 @@ export class Query { return this } + prefilter (value: boolean): Query { + this._prefilter = value + return this + } + /** * Execute the query and return the results as an Array of Objects */ diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index 68c52721..3d4d59a2 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -38,6 +38,7 @@ export class HttpLancedbClient { vector: number[], k: number, nprobes: number, + prefilter: boolean, refineFactor?: number, columns?: string[], filter?: string @@ -50,7 +51,8 @@ export class HttpLancedbClient { nprobes, refineFactor, columns, - filter + filter, + prefilter }, { headers: { diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 2870b1a8..1130e9ce 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -156,7 +156,8 @@ export class RemoteQuery extends Query { (this as any)._nprobes, (this as any)._refineFactor, (this as any)._select, - (this as any)._filter + (this as any)._filter, + (this as any)._prefilter ) return data.toArray().map((entry: Record) => { diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 9fe251ab..ce56c890 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -102,6 +102,20 @@ describe('LanceDB client', function () { assertResults(results) }) + it('should correctly process prefilter/postfilter', async function () { + const uri = await createTestDB(16, 300) + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 }) + // post filter should return less than the limit + let results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(false).execute() + assert.isTrue(results.length < 10) + + // pre filter should return exactly the limit + results = await table.search(new Array(16).fill(0.1)).limit(10).filter('id >= 10').prefilter(true).execute() + assert.isTrue(results.length === 10) + }) + it('select only a subset of columns', async function () { const uri = await createTestDB() const con = await lancedb.connect(uri) diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index 566ee6ee..64a9ba3b 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -48,6 +48,8 @@ impl JsQuery { .map(|s| s.value(&mut cx)) .map(|s| MetricType::try_from(s.as_str()).unwrap()); + let prefilter = query_obj.get::(&mut cx, "_prefilter")?.value(&mut cx); + let is_electron = cx .argument::(1) .or_throw(&mut cx)? @@ -69,7 +71,8 @@ impl JsQuery { .nprobes(nprobes) .filter(filter) .metric_type(metric_type) - .select(select); + .select(select) + .prefilter(prefilter); let record_batch_stream = builder.execute(); let results = record_batch_stream .and_then(|stream| {