feat: add maximum and minimum nprobes properties (#2430)

This exposes the maximum_nprobes and minimum_nprobes feature that was
added in https://github.com/lancedb/lance/pull/3903

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for specifying minimum and maximum probe counts in
vector search queries, allowing finer control over search behavior.
- Users can now independently set minimum and maximum probes for vector
and hybrid queries via new methods and parameters in Python, Node.js,
and Rust APIs.

- **Bug Fixes**
- Improved parameter validation to ensure correct usage of minimum and
maximum probe values.

- **Tests**
- Expanded test coverage to validate correct handling, serialization,
and error cases for the new probe parameters.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Weston Pace
2025-06-13 15:18:29 -07:00
committed by GitHub
parent fec8d58f06
commit 59b57e30ed
12 changed files with 505 additions and 32 deletions

View File

@@ -143,6 +143,8 @@ class VectorQuery:
def postfilter(self): ...
def refine_factor(self, refine_factor: int): ...
def nprobes(self, nprobes: int): ...
def minimum_nprobes(self, minimum_nprobes: int): ...
def maximum_nprobes(self, maximum_nprobes: int): ...
def bypass_vector_index(self): ...
def nearest_to_text(self, query: dict) -> HybridQuery: ...
def to_query_request(self) -> PyQueryRequest: ...
@@ -158,6 +160,8 @@ class HybridQuery:
def distance_type(self, distance_type: str): ...
def refine_factor(self, refine_factor: int): ...
def nprobes(self, nprobes: int): ...
def minimum_nprobes(self, minimum_nprobes: int): ...
def maximum_nprobes(self, maximum_nprobes: int): ...
def bypass_vector_index(self): ...
def to_vector_query(self) -> VectorQuery: ...
def to_fts_query(self) -> FTSQuery: ...
@@ -178,7 +182,8 @@ class PyQueryRequest:
with_row_id: Optional[bool]
column: Optional[str]
query_vector: Optional[List[pa.Array]]
nprobes: Optional[int]
minimum_nprobes: Optional[int]
maximum_nprobes: Optional[int]
lower_bound: Optional[float]
upper_bound: Optional[float]
ef: Optional[int]

View File

@@ -437,8 +437,18 @@ class Query(pydantic.BaseModel):
# which columns to return in the results
columns: Optional[Union[List[str], Dict[str, str]]] = None
# number of IVF partitions to search
nprobes: Optional[int] = None
# minimum number of IVF partitions to search
#
# If None then a default value (20) will be used.
minimum_nprobes: Optional[int] = None
# maximum number of IVF partitions to search
#
# If None then a default value (20) will be used.
#
# If 0 then no limit will be applied and all partitions could be searched
# if needed to satisfy the limit.
maximum_nprobes: Optional[int] = None
# lower bound for distance search
lower_bound: Optional[float] = None
@@ -476,7 +486,8 @@ class Query(pydantic.BaseModel):
query.vector_column = req.column
query.vector = req.query_vector
query.distance_type = req.distance_type
query.nprobes = req.nprobes
query.minimum_nprobes = req.minimum_nprobes
query.maximum_nprobes = req.maximum_nprobes
query.lower_bound = req.lower_bound
query.upper_bound = req.upper_bound
query.ef = req.ef
@@ -1037,7 +1048,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
super().__init__(table)
self._query = query
self._distance_type = None
self._nprobes = None
self._minimum_nprobes = None
self._maximum_nprobes = None
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None
@@ -1100,6 +1112,10 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
This method sets both the minimum and maximum number of probes to the same
value. See `minimum_nprobes` and `maximum_nprobes` for more fine-grained
control.
Parameters
----------
nprobes: int
@@ -1110,7 +1126,36 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._nprobes = nprobes
self._minimum_nprobes = nprobes
self._maximum_nprobes = nprobes
return self
def minimum_nprobes(self, minimum_nprobes: int) -> LanceVectorQueryBuilder:
"""Set the minimum number of probes to use.
See `nprobes` for more details.
These partitions will be searched on every vector query and will increase recall
at the expense of latency.
"""
self._minimum_nprobes = minimum_nprobes
return self
def maximum_nprobes(self, maximum_nprobes: int) -> LanceVectorQueryBuilder:
"""Set the maximum number of probes to use.
See `nprobes` for more details.
If this value is greater than `minimum_nprobes` then the excess partitions
will be searched only if we have not found enough results.
This can be useful when there is a narrow filter to allow these queries to
spend more time searching and avoid potential false negatives.
If this value is 0 then no limit will be applied and all partitions could be
searched if needed to satisfy the limit.
"""
self._maximum_nprobes = maximum_nprobes
return self
def distance_range(
@@ -1214,7 +1259,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
limit=self._limit,
distance_type=self._distance_type,
columns=self._columns,
nprobes=self._nprobes,
minimum_nprobes=self._minimum_nprobes,
maximum_nprobes=self._maximum_nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor,
@@ -1578,7 +1624,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_columns = fts_columns
self._norm = None
self._reranker = None
self._nprobes = None
self._minimum_nprobes = None
self._maximum_nprobes = None
self._refine_factor = None
self._distance_type = None
self._phrase_query = None
@@ -1810,7 +1857,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._nprobes = nprobes
self._minimum_nprobes = nprobes
self._maximum_nprobes = nprobes
return self
def minimum_nprobes(self, minimum_nprobes: int) -> LanceHybridQueryBuilder:
"""Set the minimum number of probes to use.
See `nprobes` for more details.
"""
self._minimum_nprobes = minimum_nprobes
return self
def maximum_nprobes(self, maximum_nprobes: int) -> LanceHybridQueryBuilder:
"""Set the maximum number of probes to use.
See `nprobes` for more details.
"""
self._maximum_nprobes = maximum_nprobes
return self
def distance_range(
@@ -2039,8 +2103,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_query.phrase_query(True)
if self._distance_type:
self._vector_query.metric(self._distance_type)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._minimum_nprobes:
self._vector_query.minimum_nprobes(self._minimum_nprobes)
if self._maximum_nprobes is not None:
self._vector_query.maximum_nprobes(self._maximum_nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
@@ -2651,6 +2717,34 @@ class AsyncVectorQueryBase:
self._inner.nprobes(nprobes)
return self
def minimum_nprobes(self, minimum_nprobes: int) -> Self:
"""Set the minimum number of probes to use.
See `nprobes` for more details.
These partitions will be searched on every indexed vector query and will
increase recall at the expense of latency.
"""
self._inner.minimum_nprobes(minimum_nprobes)
return self
def maximum_nprobes(self, maximum_nprobes: int) -> Self:
"""Set the maximum number of probes to use.
See `nprobes` for more details.
If this value is greater than `minimum_nprobes` then the excess partitions
will be searched only if we have not found enough results.
This can be useful when there is a narrow filter to allow these queries to
spend more time searching and avoid potential false negatives.
If this value is 0 then no limit will be applied and all partitions could be
searched if needed to satisfy the limit.
"""
self._inner.maximum_nprobes(maximum_nprobes)
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> Self:

View File

@@ -3637,8 +3637,10 @@ class AsyncTable:
)
if query.distance_type is not None:
async_query = async_query.distance_type(query.distance_type)
if query.nprobes is not None:
async_query = async_query.nprobes(query.nprobes)
if query.minimum_nprobes is not None:
async_query = async_query.minimum_nprobes(query.minimum_nprobes)
if query.maximum_nprobes is not None:
async_query = async_query.maximum_nprobes(query.maximum_nprobes)
if query.refine_factor is not None:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:

View File

@@ -439,6 +439,33 @@ def test_query_builder_with_filter(table):
assert all(np.array(rs[0]["vector"]) == [3, 4])
def test_invalid_nprobes_sync(table):
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list()
with pytest.raises(
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
):
LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list()
with pytest.raises(
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
):
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list()
@pytest.mark.asyncio
async def test_invalid_nprobes_async(table_async: AsyncTable):
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list()
with pytest.raises(
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
):
await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list()
with pytest.raises(
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
):
await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list()
def test_query_builder_with_prefilter(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
@@ -585,6 +612,21 @@ async def test_query_async(table_async: AsyncTable):
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).minimum_nprobes(10),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).maximum_nprobes(30),
expected_num_rows=2,
)
await check_query(
table_async.query()
.nearest_to(pa.array([1, 2]))
.minimum_nprobes(10)
.maximum_nprobes(20),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
expected_num_rows=2,
@@ -911,7 +953,39 @@ def test_query_serialization_sync(table: lancedb.table.Table):
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
check_set_props(
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
q,
vector_column="vector",
vector=[5.0, 6.0],
minimum_nprobes=10,
maximum_nprobes=10,
refine_factor=5,
)
q = table.search([5.0, 6.0]).minimum_nprobes(10).to_query_object()
check_set_props(
q,
vector_column="vector",
vector=[5.0, 6.0],
minimum_nprobes=10,
maximum_nprobes=None,
)
q = table.search([5.0, 6.0]).nprobes(50).to_query_object()
check_set_props(
q,
vector_column="vector",
vector=[5.0, 6.0],
minimum_nprobes=50,
maximum_nprobes=50,
)
q = table.search([5.0, 6.0]).maximum_nprobes(10).to_query_object()
check_set_props(
q,
vector_column="vector",
vector=[5.0, 6.0],
maximum_nprobes=10,
minimum_nprobes=None,
)
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
@@ -963,7 +1037,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
limit=10,
vector=sample_vector,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
@@ -973,7 +1048,20 @@ async def test_query_serialization_async(table_async: AsyncTable):
q,
vector=sample_vector,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).nprobes(50).to_query_object()
check_set_props(
q,
vector=sample_vector,
postfilter=False,
minimum_nprobes=50,
maximum_nprobes=50,
with_row_id=False,
bypass_vector_index=False,
limit=10,
@@ -992,7 +1080,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
filter="id = 1",
postfilter=True,
vector=sample_vector,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
@@ -1006,7 +1095,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
check_set_props(
q,
vector=sample_vector,
nprobes=10,
minimum_nprobes=10,
maximum_nprobes=10,
refine_factor=5,
postfilter=False,
with_row_id=False,
@@ -1014,6 +1104,18 @@ async def test_query_serialization_async(table_async: AsyncTable):
limit=10,
)
q = (await table_async.search([5.0, 6.0])).minimum_nprobes(5).to_query_object()
check_set_props(
q,
vector=sample_vector,
minimum_nprobes=5,
maximum_nprobes=20,
postfilter=False,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (
(await table_async.search([5.0, 6.0]))
.distance_range(0.0, 1.0)
@@ -1025,7 +1127,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
lower_bound=0.0,
upper_bound=1.0,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
@@ -1037,7 +1140,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
distance_type="cosine",
vector=sample_vector,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
@@ -1049,7 +1153,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
ef=7,
vector=sample_vector,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
@@ -1061,7 +1166,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
bypass_vector_index=True,
vector=sample_vector,
postfilter=False,
nprobes=20,
minimum_nprobes=20,
maximum_nprobes=20,
with_row_id=False,
limit=10,
)

View File

@@ -496,6 +496,8 @@ def test_query_sync_minimal():
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
"minimum_nprobes": 20,
"maximum_nprobes": 20,
"version": None,
}
@@ -536,6 +538,8 @@ def test_query_sync_maximal():
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"minimum_nprobes": 5,
"maximum_nprobes": 5,
"lower_bound": None,
"upper_bound": None,
"ef": None,
@@ -564,6 +568,66 @@ def test_query_sync_maximal():
)
def test_query_sync_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
"vector_column": "vector2",
"refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"minimum_nprobes": 5,
"maximum_nprobes": 15,
"version": None,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.minimum_nprobes(5)
.maximum_nprobes(15)
.to_list()
)
def test_query_sync_no_max_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
"vector_column": "vector2",
"refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"minimum_nprobes": 5,
"maximum_nprobes": 0,
"version": None,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.minimum_nprobes(5)
.maximum_nprobes(0)
.to_list()
)
@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
def test_query_sync_batch_queries(server_version):
def handler(body):
@@ -666,6 +730,8 @@ def test_query_sync_hybrid():
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"minimum_nprobes": 20,
"maximum_nprobes": 20,
"lower_bound": None,
"upper_bound": None,
"ef": None,

View File

@@ -235,7 +235,10 @@ pub struct PyQueryRequest {
pub with_row_id: Option<bool>,
pub column: Option<String>,
pub query_vector: Option<PyQueryVectors>,
pub nprobes: Option<usize>,
pub minimum_nprobes: Option<usize>,
// None means user did not set it and default shoud be used (currenty 20)
// Some(0) means user set it to None and there is no limit
pub maximum_nprobes: Option<usize>,
pub lower_bound: Option<f32>,
pub upper_bound: Option<f32>,
pub ef: Option<usize>,
@@ -261,7 +264,8 @@ impl From<AnyQuery> for PyQueryRequest {
with_row_id: Some(query_request.with_row_id),
column: None,
query_vector: None,
nprobes: None,
minimum_nprobes: None,
maximum_nprobes: None,
lower_bound: None,
upper_bound: None,
ef: None,
@@ -281,7 +285,11 @@ impl From<AnyQuery> for PyQueryRequest {
with_row_id: Some(vector_query.base.with_row_id),
column: vector_query.column,
query_vector: Some(PyQueryVectors(vector_query.query_vector)),
nprobes: Some(vector_query.nprobes),
minimum_nprobes: Some(vector_query.minimum_nprobes),
maximum_nprobes: match vector_query.maximum_nprobes {
None => Some(0),
Some(value) => Some(value),
},
lower_bound: vector_query.lower_bound,
upper_bound: vector_query.upper_bound,
ef: vector_query.ef,
@@ -655,6 +663,29 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
pub fn minimum_nprobes(&mut self, minimum_nprobes: u32) -> PyResult<()> {
self.inner = self
.inner
.clone()
.minimum_nprobes(minimum_nprobes as usize)
.infer_error()?;
Ok(())
}
pub fn maximum_nprobes(&mut self, maximum_nprobes: u32) -> PyResult<()> {
let maximum_nprobes = if maximum_nprobes == 0 {
None
} else {
Some(maximum_nprobes as usize)
};
self.inner = self
.inner
.clone()
.maximum_nprobes(maximum_nprobes)
.infer_error()?;
Ok(())
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);