From 9b902272f1dad6b6c758e0fe26f6ed7e29eb53cf Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 25 Apr 2025 13:01:22 +0800 Subject: [PATCH] fix: sync hybrid search ignores the distance range params (#2356) ## Summary by CodeRabbit - **New Features** - Added support for distance range filtering in hybrid vector queries, allowing users to specify lower and upper bounds for search results. - **Tests** - Introduced new tests to validate distance range filtering and reranking in both synchronous and asynchronous hybrid query scenarios. --------- Signed-off-by: BubbleCal --- python/python/lancedb/query.py | 6 +++ python/python/tests/test_hybrid_query.py | 57 +++++++++++++++++++++++- python/src/query.rs | 5 +++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 235e0b16..233af240 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1592,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._refine_factor = None self._distance_type = None self._phrase_query = None + self._lower_bound = None + self._upper_bound = None def _validate_query(self, query, vector=None, text=None): if query is not None and (vector is not None or text is not None): @@ -1671,6 +1673,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.ef(self._ef) if self._bypass_vector_index: self._vector_query.bypass_vector_index() + if self._lower_bound or self._upper_bound: + self._vector_query.distance_range( + lower_bound=self._lower_bound, upper_bound=self._upper_bound + ) if self._reranker is None: self._reranker = RRFReranker() diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index e0d8633d..33245f7f 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -4,13 +4,32 @@ import lancedb from lancedb.query import LanceHybridQueryBuilder +from lancedb.rerankers.rrf import RRFReranker import pyarrow as pa import pyarrow.compute as pc import pytest import pytest_asyncio from lancedb.index import FTS -from lancedb.table import AsyncTable +from lancedb.table import AsyncTable, Table + + +@pytest.fixture +def sync_table(tmpdir_factory) -> Table: + tmp_path = str(tmpdir_factory.mktemp("data")) + db = lancedb.connect(tmp_path) + data = pa.table( + { + "text": pa.array(["a", "b", "cat", "dog"]), + "vector": pa.array( + [[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]], + type=pa.list_(pa.float32(), list_size=2), + ), + } + ) + table = db.create_table("test", data) + table.create_fts_index("text", with_position=False, use_tantivy=False) + return table @pytest_asyncio.fixture @@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable): assert texts.count("a") == 1 +def test_hybrid_query_distance_range(sync_table: Table): + reranker = RRFReranker(return_score="all") + result = ( + sync_table.search(query_type="hybrid") + .vector([0.0, 0.4]) + .text("cat and dog") + .distance_range(lower_bound=0.2, upper_bound=0.5) + .rerank(reranker) + .limit(2) + .to_arrow() + ) + assert len(result) == 2 + print(result) + for dist in result["_distance"]: + if dist.is_valid: + assert 0.2 <= dist.as_py() <= 0.5 + + +@pytest.mark.asyncio +async def test_hybrid_query_distance_range_async(table: AsyncTable): + reranker = RRFReranker(return_score="all") + result = await ( + table.query() + .nearest_to([0.0, 0.4]) + .nearest_to_text("cat and dog") + .distance_range(lower_bound=0.2, upper_bound=0.5) + .rerank(reranker) + .limit(2) + .to_arrow() + ) + assert len(result) == 2 + for dist in result["_distance"]: + if dist.is_valid: + assert 0.2 <= dist.as_py() <= 0.5 + + @pytest.mark.asyncio async def test_explain_plan(table: AsyncTable): plan = await ( diff --git a/python/src/query.rs b/python/src/query.rs index e3f3bb6d..f2cb2d62 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -652,6 +652,11 @@ impl HybridQuery { self.inner_vec.bypass_vector_index(); } + #[pyo3(signature = (lower_bound=None, upper_bound=None))] + pub fn distance_range(&mut self, lower_bound: Option, upper_bound: Option) { + self.inner_vec.distance_range(lower_bound, upper_bound); + } + pub fn to_vector_query(&mut self) -> PyResult { Ok(VectorQuery { inner: self.inner_vec.inner.clone(),