mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
fix: sync hybrid search ignores the distance range params (#2356)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for distance range filtering in hybrid vector queries, allowing users to specify lower and upper bounds for search results. - **Tests** - Introduced new tests to validate distance range filtering and reranking in both synchronous and asynchronous hybrid query scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -1592,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = None
|
||||
self._distance_type = None
|
||||
self._phrase_query = None
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
@@ -1671,6 +1673,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.ef(self._ef)
|
||||
if self._bypass_vector_index:
|
||||
self._vector_query.bypass_vector_index()
|
||||
if self._lower_bound or self._upper_bound:
|
||||
self._vector_query.distance_range(
|
||||
lower_bound=self._lower_bound, upper_bound=self._upper_bound
|
||||
)
|
||||
|
||||
if self._reranker is None:
|
||||
self._reranker = RRFReranker()
|
||||
|
||||
@@ -4,13 +4,32 @@
|
||||
import lancedb
|
||||
|
||||
from lancedb.query import LanceHybridQueryBuilder
|
||||
from lancedb.rerankers.rrf import RRFReranker
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from lancedb.index import FTS
|
||||
from lancedb.table import AsyncTable
|
||||
from lancedb.table import AsyncTable, Table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_table(tmpdir_factory) -> Table:
|
||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": pa.array(["a", "b", "cat", "dog"]),
|
||||
"vector": pa.array(
|
||||
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
|
||||
type=pa.list_(pa.float32(), list_size=2),
|
||||
),
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
table.create_fts_index("text", with_position=False, use_tantivy=False)
|
||||
return table
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
|
||||
assert texts.count("a") == 1
|
||||
|
||||
|
||||
def test_hybrid_query_distance_range(sync_table: Table):
|
||||
reranker = RRFReranker(return_score="all")
|
||||
result = (
|
||||
sync_table.search(query_type="hybrid")
|
||||
.vector([0.0, 0.4])
|
||||
.text("cat and dog")
|
||||
.distance_range(lower_bound=0.2, upper_bound=0.5)
|
||||
.rerank(reranker)
|
||||
.limit(2)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 2
|
||||
print(result)
|
||||
for dist in result["_distance"]:
|
||||
if dist.is_valid:
|
||||
assert 0.2 <= dist.as_py() <= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_query_distance_range_async(table: AsyncTable):
|
||||
reranker = RRFReranker(return_score="all")
|
||||
result = await (
|
||||
table.query()
|
||||
.nearest_to([0.0, 0.4])
|
||||
.nearest_to_text("cat and dog")
|
||||
.distance_range(lower_bound=0.2, upper_bound=0.5)
|
||||
.rerank(reranker)
|
||||
.limit(2)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 2
|
||||
for dist in result["_distance"]:
|
||||
if dist.is_valid:
|
||||
assert 0.2 <= dist.as_py() <= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_plan(table: AsyncTable):
|
||||
plan = await (
|
||||
|
||||
@@ -652,6 +652,11 @@ impl HybridQuery {
|
||||
self.inner_vec.bypass_vector_index();
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner_vec.distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
|
||||
Ok(VectorQuery {
|
||||
inner: self.inner_vec.inner.clone(),
|
||||
|
||||
Reference in New Issue
Block a user