mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
test: query with dist range and new rows (#2126)
we found a bug that flat KNN plan node's stats is not in right order as fields in schema, it would cause an error if querying with distance range and new unindexed rows. we've fixed this in lance so add this test for verifying it works Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -232,6 +232,71 @@ async def test_distance_range_async(table_async: AsyncTable):
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_with_new_rows_async():
|
||||
conn = await lancedb.connect_async(
|
||||
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||
np.random.rand(256, 2)
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await conn.create_table("test", data)
|
||||
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
|
||||
|
||||
q = [0, 0]
|
||||
rs = await table.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
# append more rows so that execution plan would be mixed with ANN & Flat KNN
|
||||
new_data = pa.table(
|
||||
{
|
||||
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(4, 2)),
|
||||
}
|
||||
)
|
||||
await table.add(new_data)
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist >= max_dist
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist < max_dist
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist >= min_dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user