mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
lower_bound: Optional[float] = None
|
||||
upper_bound: Optional[float] = None
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncVectorQuery
|
||||
The AsyncVectorQuery object.
|
||||
"""
|
||||
self._inner.distance_range(lower_bound, upper_bound)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
@@ -2786,6 +2786,7 @@ class AsyncTable:
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
.distance_range(query.lower_bound, query.upper_bound)
|
||||
)
|
||||
if query.refine_factor:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
|
||||
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_distance_range(table: lancedb.table.Table):
|
||||
q = [0, 0]
|
||||
rs = table.search(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||
assert len(res) == 0
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_async(table_async: AsyncTable):
|
||||
q = [0, 0]
|
||||
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
@@ -306,6 +306,8 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
|
||||
@@ -284,6 +284,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user