fix: .nprobes method in python bindings, improve error messages (#2556)

`nprobes` with a value greater than 20 fails with the minimum error:

```
self = <lancedb.query.AsyncVectorQuery object at 0x10b749720>, minimum_nprobes = 30

    def minimum_nprobes(self, minimum_nprobes: int) -> Self:
        """Set the minimum number of probes to use.

        See `nprobes` for more details.

        These partitions will be searched on every indexed vector query and will
        increase recall at the expense of latency.
        """
>       self._inner.minimum_nprobes(minimum_nprobes)
E       ValueError: Invalid input, minimum_nprobes must be less than or equal to maximum_nprobes

python/lancedb/query.py:2744: ValueError
```

Putting the max set before the min seems reasonable but it causes this
reasonable case to fail:
```
def test_nprobes_min_max_works_sync(table):
    LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(2).maximum_nprobes(4).to_list()
```

with

```
self = <lancedb.query.AsyncVectorQuery object at 0x1203f1c90>, maximum_nprobes = 4

    def maximum_nprobes(self, maximum_nprobes: int) -> Self:
        """Set the maximum number of probes to use.

        See `nprobes` for more details.

        If this value is greater than `minimum_nprobes` then the excess partitions
        will be searched only if we have not found enough results.

        This can be useful when there is a narrow filter to allow these queries to
        spend more time searching and avoid potential false negatives.

        If this value is 0 then no limit will be applied and all partitions could be
        searched if needed to satisfy the limit.
        """
>       self._inner.maximum_nprobes(maximum_nprobes)
E       ValueError: Invalid input, maximum_nprobes must be greater than or equal to minimum_nprobes

python/lancedb/query.py:2761: ValueError
```.

The case I care about is where min == max, but this solution handles it
even if they're not. If both min and max exist, we set both to the
minimum and then set the max. This isn't 100% the same as the minimum
setter checks for 0 on the min and `.nprobes` does not do any sanity
checking at all. But I figured this was the most reasonable and general
solution without touching more of this code.

As part of this I noticed the error messages were a bit ambiguous so I
made them symmetric and clarified them while I was here.
This commit is contained in:
Mark McCaskey
2025-07-31 00:23:25 +08:00
committed by GitHub
parent 67ec1fe75c
commit fe76496a59
4 changed files with 36 additions and 9 deletions

View File

@@ -582,7 +582,7 @@ describe("When creating an index", () => {
"Invalid input, minimum_nprobes must be greater than 0",
);
expect(() => tbl.query().nearestTo(queryVec).maximumNprobes(5)).toThrow(
"Invalid input, maximum_nprobes must be greater than minimum_nprobes",
"Invalid input, maximum_nprobes must be greater than or equal to minimum_nprobes",
);
await tbl.dropIndex("vec_idx");

View File

@@ -3673,9 +3673,14 @@ class AsyncTable:
)
if query.distance_type is not None:
async_query = async_query.distance_type(query.distance_type)
if query.minimum_nprobes is not None:
if query.minimum_nprobes is not None and query.maximum_nprobes is not None:
# Set both to the minimum first to avoid min > max error.
async_query = async_query.nprobes(
query.minimum_nprobes
).maximum_nprobes(query.maximum_nprobes)
elif query.minimum_nprobes is not None:
async_query = async_query.minimum_nprobes(query.minimum_nprobes)
if query.maximum_nprobes is not None:
elif query.maximum_nprobes is not None:
async_query = async_query.maximum_nprobes(query.maximum_nprobes)
if query.refine_factor is not None:
async_query = async_query.refine_factor(query.refine_factor)

View File

@@ -445,25 +445,45 @@ def test_invalid_nprobes_sync(table):
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list()
with pytest.raises(
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
ValueError,
match="maximum_nprobes must be greater than or equal to minimum_nprobes",
):
LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list()
with pytest.raises(
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
ValueError,
match="minimum_nprobes must be less than or equal to maximum_nprobes",
):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list()
def test_nprobes_works_sync(table):
LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).to_list()
def test_nprobes_min_max_works_sync(table):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(2).maximum_nprobes(
4
).to_list()
def test_multiple_nprobes_calls_works_sync(table):
LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).maximum_nprobes(
20
).minimum_nprobes(20).to_list()
@pytest.mark.asyncio
async def test_invalid_nprobes_async(table_async: AsyncTable):
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list()
with pytest.raises(
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
ValueError,
match="maximum_nprobes must be greater than or equal to minimum_nprobes",
):
await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list()
with pytest.raises(
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
ValueError,
match="minimum_nprobes must be less than or equal to maximum_nprobes",
):
await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list()

View File

@@ -958,7 +958,8 @@ impl VectorQuery {
if let Some(maximum_nprobes) = self.request.maximum_nprobes {
if minimum_nprobes > maximum_nprobes {
return Err(Error::InvalidInput {
message: "minimum_nprobes must be less or equal to maximum_nprobes".to_string(),
message: "minimum_nprobes must be less than or equal to maximum_nprobes"
.to_string(),
});
}
}
@@ -989,7 +990,8 @@ impl VectorQuery {
}
if maximum_nprobes < self.request.minimum_nprobes {
return Err(Error::InvalidInput {
message: "maximum_nprobes must be greater than minimum_nprobes".to_string(),
message: "maximum_nprobes must be greater than or equal to minimum_nprobes"
.to_string(),
});
}
}