diff --git a/Cargo.toml b/Cargo.toml index 81b1ab10..b3a22a0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,14 +23,14 @@ rust-version = "1.78.0" [workspace.dependencies] lance = { "version" = "=0.21.1", "features" = [ "dynamodb", -], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } -lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } +], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } +lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" } # Note that this one does not include pyarrow arrow = { version = "53.2", optional = false } arrow-array = "53.2" diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 1e8468bc..ac33d4f5 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -115,6 +115,9 @@ class Query(pydantic.BaseModel): # e.g. `{"nprobes": "10", "refine_factor": "10"}` nprobes: int = 10 + lower_bound: Optional[float] = None + upper_bound: Optional[float] = None + # Refine factor. refine_factor: Optional[int] = None @@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._query = query self._metric = "L2" self._nprobes = 20 + self._lower_bound = None + self._upper_bound = None self._refine_factor = None self._vector_column = vector_column self._prefilter = False @@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def distance_range( + self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None + ) -> LanceVectorQueryBuilder: + """Set the distance range to use. + + Only rows with distances within range [lower_bound, upper_bound) + will be returned. + + Parameters + ---------- + lower: Optional[float] + The lower bound of the distance range. + upper_bound: Optional[float] + The upper bound of the distance range. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._lower_bound = lower_bound + self._upper_bound = upper_bound + return self + def ef(self, ef: int) -> LanceVectorQueryBuilder: """Set the number of candidates to consider during search. @@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): metric=self._metric, columns=self._columns, nprobes=self._nprobes, + lower_bound=self._lower_bound, + upper_bound=self._upper_bound, refine_factor=self._refine_factor, vector_column=self._vector_column, with_row_id=self._with_row_id, @@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def distance_range( + self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None + ) -> LanceHybridQueryBuilder: + """ + Set the distance range to use. + + Only rows with distances within range [lower_bound, upper_bound) + will be returned. + + Parameters + ---------- + lower: Optional[float] + The lower bound of the distance range. + upper_bound: Optional[float] + The upper bound of the distance range. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._lower_bound = lower_bound + self._upper_bound = upper_bound + return self + def ef(self, ef: int) -> LanceHybridQueryBuilder: """ Set the number of candidates to consider during search. @@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.nprobes(nprobes) return self + def distance_range( + self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None + ) -> AsyncVectorQuery: + """Set the distance range to use. + + Only rows with distances within range [lower_bound, upper_bound) + will be returned. + + Parameters + ---------- + lower: Optional[float] + The lower bound of the distance range. + upper_bound: Optional[float] + The upper bound of the distance range. + + Returns + ------- + AsyncVectorQuery + The AsyncVectorQuery object. + """ + self._inner.distance_range(lower_bound, upper_bound) + return self + def ef(self, ef: int) -> AsyncVectorQuery: """ Set the number of candidates to consider during search diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 59f10107..e15280b0 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2786,6 +2786,7 @@ class AsyncTable: async_query.nearest_to(query.vector) .distance_type(query.metric) .nprobes(query.nprobes) + .distance_range(query.lower_bound, query.upper_bound) ) if query.refine_factor: async_query = async_query.refine_factor(query.refine_factor) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index d1f4bf3e..910f406a 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table): assert rs["_rowid"].to_pylist() == [0, 1] +def test_distance_range(table: lancedb.table.Table): + q = [0, 0] + rs = table.search(q).to_arrow() + dists = rs["_distance"].to_pylist() + min_dist = dists[0] + max_dist = dists[-1] + + res = table.search(q).distance_range(upper_bound=min_dist).to_arrow() + assert len(res) == 0 + + res = table.search(q).distance_range(lower_bound=max_dist).to_arrow() + assert len(res) == 1 + assert res["_distance"].to_pylist() == [max_dist] + + res = table.search(q).distance_range(upper_bound=max_dist).to_arrow() + assert len(res) == 1 + assert res["_distance"].to_pylist() == [min_dist] + + res = table.search(q).distance_range(lower_bound=min_dist).to_arrow() + assert len(res) == 2 + assert res["_distance"].to_pylist() == [min_dist, max_dist] + + +@pytest.mark.asyncio +async def test_distance_range_async(table_async: AsyncTable): + q = [0, 0] + rs = await table_async.query().nearest_to(q).to_arrow() + dists = rs["_distance"].to_pylist() + min_dist = dists[0] + max_dist = dists[-1] + + res = ( + await table_async.query() + .nearest_to(q) + .distance_range(upper_bound=min_dist) + .to_arrow() + ) + assert len(res) == 0 + + res = ( + await table_async.query() + .nearest_to(q) + .distance_range(lower_bound=max_dist) + .to_arrow() + ) + assert len(res) == 1 + assert res["_distance"].to_pylist() == [max_dist] + + res = ( + await table_async.query() + .nearest_to(q) + .distance_range(upper_bound=max_dist) + .to_arrow() + ) + assert len(res) == 1 + assert res["_distance"].to_pylist() == [min_dist] + + res = ( + await table_async.query() + .nearest_to(q) + .distance_range(lower_bound=min_dist) + .to_arrow() + ) + assert len(res) == 2 + assert res["_distance"].to_pylist() == [min_dist, max_dist] + + def test_vector_query_with_no_limit(table): with pytest.raises(ValueError): LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select( diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index fb31b539..0f94948c 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -306,6 +306,8 @@ def test_query_sync_minimal(): "k": 10, "prefilter": False, "refine_factor": None, + "lower_bound": None, + "upper_bound": None, "ef": None, "vector": [1.0, 2.0, 3.0], "nprobes": 20, @@ -348,6 +350,8 @@ def test_query_sync_maximal(): "refine_factor": 10, "vector": [1.0, 2.0, 3.0], "nprobes": 5, + "lower_bound": None, + "upper_bound": None, "ef": None, "filter": "id > 0", "columns": ["id", "name"], @@ -449,6 +453,8 @@ def test_query_sync_hybrid(): "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, + "lower_bound": None, + "upper_bound": None, "ef": None, "with_row_id": True, "version": None, diff --git a/python/src/query.rs b/python/src/query.rs index ff2057e9..42e6adb4 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -284,6 +284,11 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + #[pyo3(signature = (lower_bound=None, upper_bound=None))] + pub fn distance_range(&mut self, lower_bound: Option, upper_bound: Option) { + self.inner = self.inner.clone().distance_range(lower_bound, upper_bound); + } + pub fn ef(&mut self, ef: u32) { self.inner = self.inner.clone().ef(ef as usize); } diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 86d39392..afbbb543 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -755,6 +755,10 @@ pub struct VectorQuery { // IVF PQ - ANN search. pub(crate) query_vector: Vec>, pub(crate) nprobes: usize, + // The lower bound (inclusive) of the distance to search for. + pub(crate) lower_bound: Option, + // The upper bound (exclusive) of the distance to search for. + pub(crate) upper_bound: Option, // The number of candidates to return during the refine step for HNSW, // defaults to 1.5 * limit. pub(crate) ef: Option, @@ -771,6 +775,8 @@ impl VectorQuery { column: None, query_vector: Vec::new(), nprobes: 20, + lower_bound: None, + upper_bound: None, ef: None, refine_factor: None, distance_type: None, @@ -831,6 +837,14 @@ impl VectorQuery { self } + /// Set the distance range for vector search, + /// only rows with distances in the range [lower_bound, upper_bound) will be returned + pub fn distance_range(mut self, lower_bound: Option, upper_bound: Option) -> Self { + self.lower_bound = lower_bound; + self.upper_bound = upper_bound; + self + } + /// Set the number of candidates to return during the refine step for HNSW /// /// This argument is only used when the vector column has an HNSW index. @@ -1350,6 +1364,30 @@ mod tests { } } + #[tokio::test] + async fn test_distance_range() { + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + let results = table + .vector_search(&[0.1, 0.2, 0.3, 0.4]) + .unwrap() + .distance_range(Some(0.0), Some(1.0)) + .limit(10) + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + for batch in results { + let distances = batch["_distance"].as_primitive::(); + assert!(distances.iter().all(|d| { + let d = d.unwrap(); + (0.0..1.0).contains(&d) + })); + } + } + #[tokio::test] async fn test_multiple_query_vectors() { let tmp_dir = tempdir().unwrap(); diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 2e57352b..a9daa1cb 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -210,6 +210,8 @@ impl RemoteTable { body["prefilter"] = query.base.prefilter.into(); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["nprobes"] = query.nprobes.into(); + body["lower_bound"] = query.lower_bound.into(); + body["upper_bound"] = query.upper_bound.into(); body["ef"] = query.ef.into(); body["refine_factor"] = query.refine_factor.into(); if let Some(vector_column) = query.column.as_ref() { @@ -1304,6 +1306,8 @@ mod tests { "prefilter": true, "distance_type": "l2", "nprobes": 20, + "lower_bound": Option::::None, + "upper_bound": Option::::None, "k": 10, "ef": Option::::None, "refine_factor": null, @@ -1353,6 +1357,8 @@ mod tests { "bypass_vector_index": true, "columns": ["a", "b"], "nprobes": 12, + "lower_bound": Option::::None, + "upper_bound": Option::::None, "ef": Option::::None, "refine_factor": 2, "version": null, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 4385270e..79b1ce4d 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable { if let Some(ef) = query.ef { scanner.ef(ef); } + scanner.distance_range(query.lower_bound, query.upper_bound); scanner.use_index(query.use_index); scanner.prefilter(query.base.prefilter); match query.base.select {