mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_distance_range(table: lancedb.table.Table):
|
||||
q = [0, 0]
|
||||
rs = table.search(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||
assert len(res) == 0
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_async(table_async: AsyncTable):
|
||||
q = [0, 0]
|
||||
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
Reference in New Issue
Block a user