mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
fix: .nprobes method in python bindings, improve error messages (#2556)
`nprobes` with a value greater than 20 fails with the minimum error:
```
self = <lancedb.query.AsyncVectorQuery object at 0x10b749720>, minimum_nprobes = 30
def minimum_nprobes(self, minimum_nprobes: int) -> Self:
"""Set the minimum number of probes to use.
See `nprobes` for more details.
These partitions will be searched on every indexed vector query and will
increase recall at the expense of latency.
"""
> self._inner.minimum_nprobes(minimum_nprobes)
E ValueError: Invalid input, minimum_nprobes must be less than or equal to maximum_nprobes
python/lancedb/query.py:2744: ValueError
```
Putting the max set before the min seems reasonable but it causes this
reasonable case to fail:
```
def test_nprobes_min_max_works_sync(table):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(2).maximum_nprobes(4).to_list()
```
with
```
self = <lancedb.query.AsyncVectorQuery object at 0x1203f1c90>, maximum_nprobes = 4
def maximum_nprobes(self, maximum_nprobes: int) -> Self:
"""Set the maximum number of probes to use.
See `nprobes` for more details.
If this value is greater than `minimum_nprobes` then the excess partitions
will be searched only if we have not found enough results.
This can be useful when there is a narrow filter to allow these queries to
spend more time searching and avoid potential false negatives.
If this value is 0 then no limit will be applied and all partitions could be
searched if needed to satisfy the limit.
"""
> self._inner.maximum_nprobes(maximum_nprobes)
E ValueError: Invalid input, maximum_nprobes must be greater than or equal to minimum_nprobes
python/lancedb/query.py:2761: ValueError
```.
The case I care about is where min == max, but this solution handles it
even if they're not. If both min and max exist, we set both to the
minimum and then set the max. This isn't 100% the same as the minimum
setter checks for 0 on the min and `.nprobes` does not do any sanity
checking at all. But I figured this was the most reasonable and general
solution without touching more of this code.
As part of this I noticed the error messages were a bit ambiguous so I
made them symmetric and clarified them while I was here.
This commit is contained in:
@@ -582,7 +582,7 @@ describe("When creating an index", () => {
|
||||
"Invalid input, minimum_nprobes must be greater than 0",
|
||||
);
|
||||
expect(() => tbl.query().nearestTo(queryVec).maximumNprobes(5)).toThrow(
|
||||
"Invalid input, maximum_nprobes must be greater than minimum_nprobes",
|
||||
"Invalid input, maximum_nprobes must be greater than or equal to minimum_nprobes",
|
||||
);
|
||||
|
||||
await tbl.dropIndex("vec_idx");
|
||||
|
||||
@@ -3673,9 +3673,14 @@ class AsyncTable:
|
||||
)
|
||||
if query.distance_type is not None:
|
||||
async_query = async_query.distance_type(query.distance_type)
|
||||
if query.minimum_nprobes is not None:
|
||||
if query.minimum_nprobes is not None and query.maximum_nprobes is not None:
|
||||
# Set both to the minimum first to avoid min > max error.
|
||||
async_query = async_query.nprobes(
|
||||
query.minimum_nprobes
|
||||
).maximum_nprobes(query.maximum_nprobes)
|
||||
elif query.minimum_nprobes is not None:
|
||||
async_query = async_query.minimum_nprobes(query.minimum_nprobes)
|
||||
if query.maximum_nprobes is not None:
|
||||
elif query.maximum_nprobes is not None:
|
||||
async_query = async_query.maximum_nprobes(query.maximum_nprobes)
|
||||
if query.refine_factor is not None:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
|
||||
@@ -445,25 +445,45 @@ def test_invalid_nprobes_sync(table):
|
||||
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
|
||||
ValueError,
|
||||
match="maximum_nprobes must be greater than or equal to minimum_nprobes",
|
||||
):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
|
||||
ValueError,
|
||||
match="minimum_nprobes must be less than or equal to maximum_nprobes",
|
||||
):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list()
|
||||
|
||||
|
||||
def test_nprobes_works_sync(table):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).to_list()
|
||||
|
||||
|
||||
def test_nprobes_min_max_works_sync(table):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(2).maximum_nprobes(
|
||||
4
|
||||
).to_list()
|
||||
|
||||
|
||||
def test_multiple_nprobes_calls_works_sync(table):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).maximum_nprobes(
|
||||
20
|
||||
).minimum_nprobes(20).to_list()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_nprobes_async(table_async: AsyncTable):
|
||||
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
|
||||
await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
|
||||
ValueError,
|
||||
match="maximum_nprobes must be greater than or equal to minimum_nprobes",
|
||||
):
|
||||
await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
|
||||
ValueError,
|
||||
match="minimum_nprobes must be less than or equal to maximum_nprobes",
|
||||
):
|
||||
await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list()
|
||||
|
||||
|
||||
@@ -958,7 +958,8 @@ impl VectorQuery {
|
||||
if let Some(maximum_nprobes) = self.request.maximum_nprobes {
|
||||
if minimum_nprobes > maximum_nprobes {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "minimum_nprobes must be less or equal to maximum_nprobes".to_string(),
|
||||
message: "minimum_nprobes must be less than or equal to maximum_nprobes"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -989,7 +990,8 @@ impl VectorQuery {
|
||||
}
|
||||
if maximum_nprobes < self.request.minimum_nprobes {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "maximum_nprobes must be greater than minimum_nprobes".to_string(),
|
||||
message: "maximum_nprobes must be greater than or equal to minimum_nprobes"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user