fix: sync hybrid search ignores the distance range params (#2356)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for distance range filtering in hybrid vector queries,
allowing users to specify lower and upper bounds for search results.

- **Tests**
- Introduced new tests to validate distance range filtering and
reranking in both synchronous and asynchronous hybrid query scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-04-25 13:01:22 +08:00
committed by GitHub
parent 8c0622fa2c
commit 9b902272f1
3 changed files with 67 additions and 1 deletions

View File

@@ -1592,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._distance_type = None
self._phrase_query = None
self._lower_bound = None
self._upper_bound = None
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
@@ -1671,6 +1673,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.ef(self._ef)
if self._bypass_vector_index:
self._vector_query.bypass_vector_index()
if self._lower_bound or self._upper_bound:
self._vector_query.distance_range(
lower_bound=self._lower_bound, upper_bound=self._upper_bound
)
if self._reranker is None:
self._reranker = RRFReranker()

View File

@@ -4,13 +4,32 @@
import lancedb
from lancedb.query import LanceHybridQueryBuilder
from lancedb.rerankers.rrf import RRFReranker
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.index import FTS
from lancedb.table import AsyncTable
from lancedb.table import AsyncTable, Table
@pytest.fixture
def sync_table(tmpdir_factory) -> Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": pa.array(["a", "b", "cat", "dog"]),
"vector": pa.array(
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", with_position=False, use_tantivy=False)
return table
@pytest_asyncio.fixture
@@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
assert texts.count("a") == 1
def test_hybrid_query_distance_range(sync_table: Table):
reranker = RRFReranker(return_score="all")
result = (
sync_table.search(query_type="hybrid")
.vector([0.0, 0.4])
.text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
print(result)
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_hybrid_query_distance_range_async(table: AsyncTable):
reranker = RRFReranker(return_score="all")
result = await (
table.query()
.nearest_to([0.0, 0.4])
.nearest_to_text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_explain_plan(table: AsyncTable):
plan = await (

View File

@@ -652,6 +652,11 @@ impl HybridQuery {
self.inner_vec.bypass_vector_index();
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner_vec.distance_range(lower_bound, upper_bound);
}
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
Ok(VectorQuery {
inner: self.inner_vec.inner.clone(),