From 3c0a64be8fbe3fd448327fbbee8d402ac25d47d3 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 8 Jan 2025 11:03:27 +0800 Subject: [PATCH] feat: support distance range in queries (#1999) this also updates the docs --------- Signed-off-by: BubbleCal --- docs/src/search.md | 33 ++++ nodejs/__test__/table.test.ts | 56 +++++++ nodejs/examples/search.test.ts | 14 ++ nodejs/lancedb/query.ts | 13 ++ nodejs/package-lock.json | 147 ++++++++++++++++++ nodejs/src/query.rs | 9 ++ python/python/lancedb/query.py | 6 +- .../python/tests/docs/test_distance_range.py | 62 ++++++++ 8 files changed, 337 insertions(+), 3 deletions(-) create mode 100644 python/python/tests/docs/test_distance_range.py diff --git a/docs/src/search.md b/docs/src/search.md index 39d9db19..ad5c8d9a 100644 --- a/docs/src/search.md +++ b/docs/src/search.md @@ -138,6 +138,39 @@ LanceDB supports binary vectors as a data type, and has the ability to search bi --8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector" ``` +## Search with distance range + +You can also search for vectors within a specific distance range from the query vector. This is useful when you want to find vectors that are not just the nearest neighbors, but also those that are within a certain distance. This can be done by using the `distance_range` method. + +=== "Python" + + === "sync API" + + ```python + --8<-- "python/python/tests/docs/test_distance_range.py:imports" + + --8<-- "python/python/tests/docs/test_distance_range.py:sync_distance_range" + ``` + + === "async API" + + ```python + --8<-- "python/python/tests/docs/test_distance_range.py:imports" + + --8<-- "python/python/tests/docs/test_distance_range.py:async_distance_range" + ``` + +=== "TypeScript" + + === "@lancedb/lancedb" + + ```ts + --8<-- "nodejs/examples/search.test.ts:import" + + --8<-- "nodejs/examples/search.test.ts:distance_range" + ``` + + ## Output search results LanceDB returns vector search results via different formats commonly used in python. diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 9fd2986a..2c8f2a78 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -475,6 +475,62 @@ describe("When creating an index", () => { expect(rst.numRows).toBe(1); }); + it("should search with distance range", async () => { + await tbl.createIndex("vec"); + + const rst = await tbl.query().limit(10).nearestTo(queryVec).toArrow(); + const distanceColumn = rst.getChild("_distance"); + let minDist = undefined; + let maxDist = undefined; + if (distanceColumn) { + minDist = distanceColumn.get(0); + maxDist = distanceColumn.get(9); + } + + const rst2 = await tbl + .query() + .limit(10) + .nearestTo(queryVec) + .distanceRange(minDist, maxDist) + .toArrow(); + const distanceColumn2 = rst2.getChild("_distance"); + expect(distanceColumn2).toBeDefined(); + if (distanceColumn2) { + for await (const d of distanceColumn2) { + expect(d).toBeGreaterThanOrEqual(minDist); + expect(d).toBeLessThan(maxDist); + } + } + + const rst3 = await tbl + .query() + .limit(10) + .nearestTo(queryVec) + .distanceRange(maxDist, undefined) + .toArrow(); + const distanceColumn3 = rst3.getChild("_distance"); + expect(distanceColumn3).toBeDefined(); + if (distanceColumn3) { + for await (const d of distanceColumn3) { + expect(d).toBeGreaterThanOrEqual(maxDist); + } + } + + const rst4 = await tbl + .query() + .limit(10) + .nearestTo(queryVec) + .distanceRange(undefined, minDist) + .toArrow(); + const distanceColumn4 = rst4.getChild("_distance"); + expect(distanceColumn4).toBeDefined(); + if (distanceColumn4) { + for await (const d of distanceColumn4) { + expect(d).toBeLessThan(minDist); + } + } + }); + it("should create and search IVF_HNSW indices", async () => { await tbl.createIndex("vec", { config: Index.hnswSq(), diff --git a/nodejs/examples/search.test.ts b/nodejs/examples/search.test.ts index ccca9b78..d188f7e3 100644 --- a/nodejs/examples/search.test.ts +++ b/nodejs/examples/search.test.ts @@ -38,5 +38,19 @@ test("full text search", async () => { .toArray(); // --8<-- [end:search2] expect(results2.length).toBe(10); + + // --8<-- [start:distance_range] + const results3 = await ( + tbl.search(Array(128).fill(1.2)) as lancedb.VectorQuery + ) + .distanceType("cosine") + .distanceRange(0.1, 0.2) + .limit(10) + .toArray(); + // --8<-- [end:distance_range] + for (const r of results3) { + expect(r.distance).toBeGreaterThanOrEqual(0.1); + expect(r.distance).toBeLessThan(0.2); + } }); }); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index aa4b560f..eb22d512 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -388,6 +388,19 @@ export class VectorQuery extends QueryBase { return this; } + /* + * Set the distance range to use + * + * Only rows with distances within range [lower_bound, upper_bound) + * will be returned. + * + * `undefined` means no lower or upper bound. + */ + distanceRange(lowerBound?: number, upperBound?: number): VectorQuery { + super.doCall((inner) => inner.distanceRange(lowerBound, upperBound)); + return this; + } + /** * Set the number of candidates to consider during the search * diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 4e906d7e..27472928 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -18,6 +18,7 @@ "win32" ], "dependencies": { + "@lancedb/lancedb": "^0.14.1", "reflect-metadata": "^0.2.2" }, "devDependencies": { @@ -4149,6 +4150,152 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@lancedb/lancedb": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb/-/lancedb-0.14.1.tgz", + "integrity": "sha512-DfJ887t52n/2s8G1JnzE7gAR4i7UnfP1OjDYnJ4yTk0aIcn76CbVOUegYfURYlYjL+QFdI1MrAzUdMgYgsGGcA==", + "cpu": [ + "x64", + "arm64" + ], + "license": "Apache 2.0", + "os": [ + "darwin", + "linux", + "win32" + ], + "dependencies": { + "reflect-metadata": "^0.2.2" + }, + "engines": { + "node": ">= 18" + }, + "optionalDependencies": { + "@lancedb/lancedb-darwin-arm64": "0.14.1", + "@lancedb/lancedb-darwin-x64": "0.14.1", + "@lancedb/lancedb-linux-arm64-gnu": "0.14.1", + "@lancedb/lancedb-linux-arm64-musl": "0.14.1", + "@lancedb/lancedb-linux-x64-gnu": "0.14.1", + "@lancedb/lancedb-linux-x64-musl": "0.14.1", + "@lancedb/lancedb-win32-arm64-msvc": "0.14.1", + "@lancedb/lancedb-win32-x64-msvc": "0.14.1" + }, + "peerDependencies": { + "apache-arrow": ">=15.0.0 <=18.1.0" + } + }, + "node_modules/@lancedb/lancedb-darwin-arm64": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-arm64/-/lancedb-darwin-arm64-0.14.1.tgz", + "integrity": "sha512-eSWV3GydXfyaptPXZ+S3BgXY1YI26oHQDekACaVevRW6/YQD7sS9UhhSZn1mYyDtLTfJu2kOK2XHA9UY8nyuTg==", + "cpu": [ + "arm64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-darwin-x64": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-x64/-/lancedb-darwin-x64-0.14.1.tgz", + "integrity": "sha512-ecf50ykF9WCWmpwAjs3Mk2mph7d+rMJ9EVJeX0UJ4KHDC874lnTDo6Tfd9iUcbExtNI1KZbu+CFnYsbQU+R0gw==", + "cpu": [ + "x64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-linux-arm64-gnu": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-arm64-gnu/-/lancedb-linux-arm64-gnu-0.14.1.tgz", + "integrity": "sha512-X7ub1fOm7jZ19KFW/u3nDyFvj5XzDPqEVrp9mmcOgSrst3NJEGGBz1JypkLnTWpg/7IpCBs1UO1G7R7LEsHYOA==", + "cpu": [ + "arm64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-linux-arm64-musl": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-arm64-musl/-/lancedb-linux-arm64-musl-0.14.1.tgz", + "integrity": "sha512-rkiWpsQCXwybwEjcdFXkAeGahiLcK/NQUjZc9WBY6CKk2Y9dICIafYzxZ6MDCY19jeJIgs3JS0mjleUWYr3JFw==", + "cpu": [ + "arm64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-linux-x64-gnu": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-x64-gnu/-/lancedb-linux-x64-gnu-0.14.1.tgz", + "integrity": "sha512-LGp4D58pQJ3+H3GncNxWHkvhIVOKpTzYUBtVfC8he1rwZ6+CiYDyK9Sim/j8o3UJlJ7cP0m3gNUzPfQchQF9WA==", + "cpu": [ + "x64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-linux-x64-musl": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-x64-musl/-/lancedb-linux-x64-musl-0.14.1.tgz", + "integrity": "sha512-V/TeoyKUESPL/8L1z4WLbMFe5ZEv4gtxc0AFK8ghiduFYN/Hckuj4oTo/Y0ysLiBx1At9FCa91hWDB301ibHBg==", + "cpu": [ + "x64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 18" + } + }, + "node_modules/@lancedb/lancedb-win32-x64-msvc": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@lancedb/lancedb-win32-x64-msvc/-/lancedb-win32-x64-msvc-0.14.1.tgz", + "integrity": "sha512-4M8D0j8/3WZv4CKo+Z44sISKPCKWN5MWA0dcEEGw4sEXHF2RJLrMIOOgEpT5NF7VW+X4t2JJxUA6j2T3cXaD8w==", + "cpu": [ + "x64" + ], + "license": "Apache 2.0", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 18" + } + }, "node_modules/@napi-rs/cli": { "version": "2.18.3", "resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.3.tgz", diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 321e4052..8b2cb1a8 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -171,6 +171,15 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + #[napi] + pub fn distance_range(&mut self, lower_bound: Option, upper_bound: Option) { + // napi doesn't support f32, so we have to convert to f32 + self.inner = self + .inner + .clone() + .distance_range(lower_bound.map(|v| v as f32), upper_bound.map(|v| v as f32)); + } + #[napi] pub fn ef(&mut self, ef: u32) { self.inner = self.inner.clone().ef(ef as usize); diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 02f57fdc..0409e658 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -647,7 +647,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): Parameters ---------- - lower: Optional[float] + lower_bound: Optional[float] The lower bound of the distance range. upper_bound: Optional[float] The upper bound of the distance range. @@ -1309,7 +1309,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): Parameters ---------- - lower: Optional[float] + lower_bound: Optional[float] The lower bound of the distance range. upper_bound: Optional[float] The upper bound of the distance range. @@ -1940,7 +1940,7 @@ class AsyncVectorQuery(AsyncQueryBase): Parameters ---------- - lower: Optional[float] + lower_bound: Optional[float] The lower bound of the distance range. upper_bound: Optional[float] The upper bound of the distance range. diff --git a/python/python/tests/docs/test_distance_range.py b/python/python/tests/docs/test_distance_range.py new file mode 100644 index 00000000..41624d74 --- /dev/null +++ b/python/python/tests/docs/test_distance_range.py @@ -0,0 +1,62 @@ +import shutil +import pytest + +# --8<-- [start:imports] +import lancedb +import numpy as np +# --8<-- [end:imports] + +shutil.rmtree("data/distance_range_demo", ignore_errors=True) + + +def test_binary_vector(): + # --8<-- [start:sync_distance_range] + db = lancedb.connect("data/distance_range_demo") + data = [ + { + "id": i, + "vector": np.random.random(256), + } + for i in range(1024) + ] + tbl = db.create_table("my_table", data=data) + query = np.random.random(256) + + # Search for the vectors within the range of [0.1, 0.5) + tbl.search(query).distance_range(0.1, 0.5).to_arrow() + + # Search for the vectors with the distance less than 0.5 + tbl.search(query).distance_range(upper_bound=0.5).to_arrow() + + # Search for the vectors with the distance greater or equal to 0.1 + tbl.search(query).distance_range(lower_bound=0.1).to_arrow() + + # --8<-- [end:sync_distance_range] + db.drop_table("my_table") + + +@pytest.mark.asyncio +async def test_binary_vector_async(): + # --8<-- [start:async_distance_range] + db = await lancedb.connect_async("data/distance_range_demo") + data = [ + { + "id": i, + "vector": np.random.random(256), + } + for i in range(1024) + ] + tbl = await db.create_table("my_table", data=data) + query = np.random.random(256) + + # Search for the vectors within the range of [0.1, 0.5) + await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow() + + # Search for the vectors with the distance less than 0.5 + await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow() + + # Search for the vectors with the distance greater or equal to 0.1 + await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow() + + # --8<-- [end:async_distance_range] + await db.drop_table("my_table")