From a608621476cf5b26710039aa0864e5d1f06bf3c3 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 17 Feb 2025 12:57:45 +0800 Subject: [PATCH] test: query with dist range and new rows (#2126) we found a bug that flat KNN plan node's stats is not in right order as fields in schema, it would cause an error if querying with distance range and new unindexed rows. we've fixed this in lance so add this test for verifying it works Signed-off-by: BubbleCal --- python/python/tests/test_query.py | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 5f37b515..1cdd23d1 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -232,6 +232,71 @@ async def test_distance_range_async(table_async: AsyncTable): assert res["_distance"].to_pylist() == [min_dist, max_dist] +@pytest.mark.asyncio +async def test_distance_range_with_new_rows_async(): + conn = await lancedb.connect_async( + "memory://", read_consistency_interval=timedelta(seconds=0) + ) + data = pa.table( + { + "vector": pa.FixedShapeTensorArray.from_numpy_ndarray( + np.random.rand(256, 2) + ), + } + ) + table = await conn.create_table("test", data) + table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)) + + q = [0, 0] + rs = await table.query().nearest_to(q).to_arrow() + dists = rs["_distance"].to_pylist() + min_dist = dists[0] + max_dist = dists[-1] + + # append more rows so that execution plan would be mixed with ANN & Flat KNN + new_data = pa.table( + { + "vector": pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(4, 2)), + } + ) + await table.add(new_data) + + res = ( + await table.query() + .nearest_to(q) + .distance_range(upper_bound=min_dist) + .to_arrow() + ) + assert len(res) == 0 + + res = ( + await table.query() + .nearest_to(q) + .distance_range(lower_bound=max_dist) + .to_arrow() + ) + for dist in res["_distance"].to_pylist(): + assert dist >= max_dist + + res = ( + await table.query() + .nearest_to(q) + .distance_range(upper_bound=max_dist) + .to_arrow() + ) + for dist in res["_distance"].to_pylist(): + assert dist < max_dist + + res = ( + await table.query() + .nearest_to(q) + .distance_range(lower_bound=min_dist) + .to_arrow() + ) + for dist in res["_distance"].to_pylist(): + assert dist >= min_dist + + @pytest.mark.parametrize( "multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True )