mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat: add maximum and minimum nprobes properties (#2430)
This exposes the maximum_nprobes and minimum_nprobes feature that was added in https://github.com/lancedb/lance/pull/3903 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for specifying minimum and maximum probe counts in vector search queries, allowing finer control over search behavior. - Users can now independently set minimum and maximum probes for vector and hybrid queries via new methods and parameters in Python, Node.js, and Rust APIs. - **Bug Fixes** - Improved parameter validation to ensure correct usage of minimum and maximum probe values. - **Tests** - Expanded test coverage to validate correct handling, serialization, and error cases for the new probe parameters. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -143,6 +143,8 @@ class VectorQuery:
|
||||
def postfilter(self): ...
|
||||
def refine_factor(self, refine_factor: int): ...
|
||||
def nprobes(self, nprobes: int): ...
|
||||
def minimum_nprobes(self, minimum_nprobes: int): ...
|
||||
def maximum_nprobes(self, maximum_nprobes: int): ...
|
||||
def bypass_vector_index(self): ...
|
||||
def nearest_to_text(self, query: dict) -> HybridQuery: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
@@ -158,6 +160,8 @@ class HybridQuery:
|
||||
def distance_type(self, distance_type: str): ...
|
||||
def refine_factor(self, refine_factor: int): ...
|
||||
def nprobes(self, nprobes: int): ...
|
||||
def minimum_nprobes(self, minimum_nprobes: int): ...
|
||||
def maximum_nprobes(self, maximum_nprobes: int): ...
|
||||
def bypass_vector_index(self): ...
|
||||
def to_vector_query(self) -> VectorQuery: ...
|
||||
def to_fts_query(self) -> FTSQuery: ...
|
||||
@@ -178,7 +182,8 @@ class PyQueryRequest:
|
||||
with_row_id: Optional[bool]
|
||||
column: Optional[str]
|
||||
query_vector: Optional[List[pa.Array]]
|
||||
nprobes: Optional[int]
|
||||
minimum_nprobes: Optional[int]
|
||||
maximum_nprobes: Optional[int]
|
||||
lower_bound: Optional[float]
|
||||
upper_bound: Optional[float]
|
||||
ef: Optional[int]
|
||||
|
||||
@@ -437,8 +437,18 @@ class Query(pydantic.BaseModel):
|
||||
# which columns to return in the results
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
|
||||
# number of IVF partitions to search
|
||||
nprobes: Optional[int] = None
|
||||
# minimum number of IVF partitions to search
|
||||
#
|
||||
# If None then a default value (20) will be used.
|
||||
minimum_nprobes: Optional[int] = None
|
||||
|
||||
# maximum number of IVF partitions to search
|
||||
#
|
||||
# If None then a default value (20) will be used.
|
||||
#
|
||||
# If 0 then no limit will be applied and all partitions could be searched
|
||||
# if needed to satisfy the limit.
|
||||
maximum_nprobes: Optional[int] = None
|
||||
|
||||
# lower bound for distance search
|
||||
lower_bound: Optional[float] = None
|
||||
@@ -476,7 +486,8 @@ class Query(pydantic.BaseModel):
|
||||
query.vector_column = req.column
|
||||
query.vector = req.query_vector
|
||||
query.distance_type = req.distance_type
|
||||
query.nprobes = req.nprobes
|
||||
query.minimum_nprobes = req.minimum_nprobes
|
||||
query.maximum_nprobes = req.maximum_nprobes
|
||||
query.lower_bound = req.lower_bound
|
||||
query.upper_bound = req.upper_bound
|
||||
query.ef = req.ef
|
||||
@@ -1037,7 +1048,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._distance_type = None
|
||||
self._nprobes = None
|
||||
self._minimum_nprobes = None
|
||||
self._maximum_nprobes = None
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
@@ -1100,6 +1112,10 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
|
||||
This method sets both the minimum and maximum number of probes to the same
|
||||
value. See `minimum_nprobes` and `maximum_nprobes` for more fine-grained
|
||||
control.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
@@ -1110,7 +1126,36 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
self._minimum_nprobes = nprobes
|
||||
self._maximum_nprobes = nprobes
|
||||
return self
|
||||
|
||||
def minimum_nprobes(self, minimum_nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the minimum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
|
||||
These partitions will be searched on every vector query and will increase recall
|
||||
at the expense of latency.
|
||||
"""
|
||||
self._minimum_nprobes = minimum_nprobes
|
||||
return self
|
||||
|
||||
def maximum_nprobes(self, maximum_nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the maximum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
|
||||
If this value is greater than `minimum_nprobes` then the excess partitions
|
||||
will be searched only if we have not found enough results.
|
||||
|
||||
This can be useful when there is a narrow filter to allow these queries to
|
||||
spend more time searching and avoid potential false negatives.
|
||||
|
||||
If this value is 0 then no limit will be applied and all partitions could be
|
||||
searched if needed to satisfy the limit.
|
||||
"""
|
||||
self._maximum_nprobes = maximum_nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
@@ -1214,7 +1259,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
limit=self._limit,
|
||||
distance_type=self._distance_type,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
minimum_nprobes=self._minimum_nprobes,
|
||||
maximum_nprobes=self._maximum_nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
@@ -1578,7 +1624,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_columns = fts_columns
|
||||
self._norm = None
|
||||
self._reranker = None
|
||||
self._nprobes = None
|
||||
self._minimum_nprobes = None
|
||||
self._maximum_nprobes = None
|
||||
self._refine_factor = None
|
||||
self._distance_type = None
|
||||
self._phrase_query = None
|
||||
@@ -1810,7 +1857,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
self._minimum_nprobes = nprobes
|
||||
self._maximum_nprobes = nprobes
|
||||
return self
|
||||
|
||||
def minimum_nprobes(self, minimum_nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""Set the minimum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
"""
|
||||
self._minimum_nprobes = minimum_nprobes
|
||||
return self
|
||||
|
||||
def maximum_nprobes(self, maximum_nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""Set the maximum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
"""
|
||||
self._maximum_nprobes = maximum_nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
@@ -2039,8 +2103,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query.phrase_query(True)
|
||||
if self._distance_type:
|
||||
self._vector_query.metric(self._distance_type)
|
||||
if self._nprobes:
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._minimum_nprobes:
|
||||
self._vector_query.minimum_nprobes(self._minimum_nprobes)
|
||||
if self._maximum_nprobes is not None:
|
||||
self._vector_query.maximum_nprobes(self._maximum_nprobes)
|
||||
if self._refine_factor:
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
@@ -2651,6 +2717,34 @@ class AsyncVectorQueryBase:
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def minimum_nprobes(self, minimum_nprobes: int) -> Self:
|
||||
"""Set the minimum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
|
||||
These partitions will be searched on every indexed vector query and will
|
||||
increase recall at the expense of latency.
|
||||
"""
|
||||
self._inner.minimum_nprobes(minimum_nprobes)
|
||||
return self
|
||||
|
||||
def maximum_nprobes(self, maximum_nprobes: int) -> Self:
|
||||
"""Set the maximum number of probes to use.
|
||||
|
||||
See `nprobes` for more details.
|
||||
|
||||
If this value is greater than `minimum_nprobes` then the excess partitions
|
||||
will be searched only if we have not found enough results.
|
||||
|
||||
This can be useful when there is a narrow filter to allow these queries to
|
||||
spend more time searching and avoid potential false negatives.
|
||||
|
||||
If this value is 0 then no limit will be applied and all partitions could be
|
||||
searched if needed to satisfy the limit.
|
||||
"""
|
||||
self._inner.maximum_nprobes(maximum_nprobes)
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> Self:
|
||||
|
||||
@@ -3637,8 +3637,10 @@ class AsyncTable:
|
||||
)
|
||||
if query.distance_type is not None:
|
||||
async_query = async_query.distance_type(query.distance_type)
|
||||
if query.nprobes is not None:
|
||||
async_query = async_query.nprobes(query.nprobes)
|
||||
if query.minimum_nprobes is not None:
|
||||
async_query = async_query.minimum_nprobes(query.minimum_nprobes)
|
||||
if query.maximum_nprobes is not None:
|
||||
async_query = async_query.maximum_nprobes(query.maximum_nprobes)
|
||||
if query.refine_factor is not None:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
|
||||
@@ -439,6 +439,33 @@ def test_query_builder_with_filter(table):
|
||||
assert all(np.array(rs[0]["vector"]) == [3, 4])
|
||||
|
||||
|
||||
def test_invalid_nprobes_sync(table):
|
||||
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
|
||||
):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
|
||||
):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_nprobes_async(table_async: AsyncTable):
|
||||
with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"):
|
||||
await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="maximum_nprobes must be greater than minimum_nprobes"
|
||||
):
|
||||
await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list()
|
||||
with pytest.raises(
|
||||
ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes"
|
||||
):
|
||||
await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list()
|
||||
|
||||
|
||||
def test_query_builder_with_prefilter(table):
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
@@ -585,6 +612,21 @@ async def test_query_async(table_async: AsyncTable):
|
||||
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).minimum_nprobes(10),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).maximum_nprobes(30),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query()
|
||||
.nearest_to(pa.array([1, 2]))
|
||||
.minimum_nprobes(10)
|
||||
.maximum_nprobes(20),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
|
||||
expected_num_rows=2,
|
||||
@@ -911,7 +953,39 @@ def test_query_serialization_sync(table: lancedb.table.Table):
|
||||
|
||||
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
|
||||
check_set_props(
|
||||
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
|
||||
q,
|
||||
vector_column="vector",
|
||||
vector=[5.0, 6.0],
|
||||
minimum_nprobes=10,
|
||||
maximum_nprobes=10,
|
||||
refine_factor=5,
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).minimum_nprobes(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector_column="vector",
|
||||
vector=[5.0, 6.0],
|
||||
minimum_nprobes=10,
|
||||
maximum_nprobes=None,
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).nprobes(50).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector_column="vector",
|
||||
vector=[5.0, 6.0],
|
||||
minimum_nprobes=50,
|
||||
maximum_nprobes=50,
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).maximum_nprobes(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector_column="vector",
|
||||
vector=[5.0, 6.0],
|
||||
maximum_nprobes=10,
|
||||
minimum_nprobes=None,
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
|
||||
@@ -963,7 +1037,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
limit=10,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
@@ -973,7 +1048,20 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
q,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).nprobes(50).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
minimum_nprobes=50,
|
||||
maximum_nprobes=50,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
@@ -992,7 +1080,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
filter="id = 1",
|
||||
postfilter=True,
|
||||
vector=sample_vector,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
@@ -1006,7 +1095,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
nprobes=10,
|
||||
minimum_nprobes=10,
|
||||
maximum_nprobes=10,
|
||||
refine_factor=5,
|
||||
postfilter=False,
|
||||
with_row_id=False,
|
||||
@@ -1014,6 +1104,18 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).minimum_nprobes(5).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
minimum_nprobes=5,
|
||||
maximum_nprobes=20,
|
||||
postfilter=False,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.distance_range(0.0, 1.0)
|
||||
@@ -1025,7 +1127,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
lower_bound=0.0,
|
||||
upper_bound=1.0,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
@@ -1037,7 +1140,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
distance_type="cosine",
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
@@ -1049,7 +1153,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
ef=7,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
@@ -1061,7 +1166,8 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
bypass_vector_index=True,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
minimum_nprobes=20,
|
||||
maximum_nprobes=20,
|
||||
with_row_id=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
@@ -496,6 +496,8 @@ def test_query_sync_minimal():
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
"minimum_nprobes": 20,
|
||||
"maximum_nprobes": 20,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -536,6 +538,8 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"minimum_nprobes": 5,
|
||||
"maximum_nprobes": 5,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
@@ -564,6 +568,66 @@ def test_query_sync_maximal():
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_nprobes():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": True,
|
||||
"fast_search": True,
|
||||
"vector_column": "vector2",
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"minimum_nprobes": 5,
|
||||
"maximum_nprobes": 15,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||
.minimum_nprobes(5)
|
||||
.maximum_nprobes(15)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_no_max_nprobes():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": True,
|
||||
"fast_search": True,
|
||||
"vector_column": "vector2",
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"minimum_nprobes": 5,
|
||||
"maximum_nprobes": 0,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||
.minimum_nprobes(5)
|
||||
.maximum_nprobes(0)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
|
||||
def test_query_sync_batch_queries(server_version):
|
||||
def handler(body):
|
||||
@@ -666,6 +730,8 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"minimum_nprobes": 20,
|
||||
"maximum_nprobes": 20,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
|
||||
@@ -235,7 +235,10 @@ pub struct PyQueryRequest {
|
||||
pub with_row_id: Option<bool>,
|
||||
pub column: Option<String>,
|
||||
pub query_vector: Option<PyQueryVectors>,
|
||||
pub nprobes: Option<usize>,
|
||||
pub minimum_nprobes: Option<usize>,
|
||||
// None means user did not set it and default shoud be used (currenty 20)
|
||||
// Some(0) means user set it to None and there is no limit
|
||||
pub maximum_nprobes: Option<usize>,
|
||||
pub lower_bound: Option<f32>,
|
||||
pub upper_bound: Option<f32>,
|
||||
pub ef: Option<usize>,
|
||||
@@ -261,7 +264,8 @@ impl From<AnyQuery> for PyQueryRequest {
|
||||
with_row_id: Some(query_request.with_row_id),
|
||||
column: None,
|
||||
query_vector: None,
|
||||
nprobes: None,
|
||||
minimum_nprobes: None,
|
||||
maximum_nprobes: None,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
@@ -281,7 +285,11 @@ impl From<AnyQuery> for PyQueryRequest {
|
||||
with_row_id: Some(vector_query.base.with_row_id),
|
||||
column: vector_query.column,
|
||||
query_vector: Some(PyQueryVectors(vector_query.query_vector)),
|
||||
nprobes: Some(vector_query.nprobes),
|
||||
minimum_nprobes: Some(vector_query.minimum_nprobes),
|
||||
maximum_nprobes: match vector_query.maximum_nprobes {
|
||||
None => Some(0),
|
||||
Some(value) => Some(value),
|
||||
},
|
||||
lower_bound: vector_query.lower_bound,
|
||||
upper_bound: vector_query.upper_bound,
|
||||
ef: vector_query.ef,
|
||||
@@ -655,6 +663,29 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
pub fn minimum_nprobes(&mut self, minimum_nprobes: u32) -> PyResult<()> {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.minimum_nprobes(minimum_nprobes as usize)
|
||||
.infer_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn maximum_nprobes(&mut self, maximum_nprobes: u32) -> PyResult<()> {
|
||||
let maximum_nprobes = if maximum_nprobes == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(maximum_nprobes as usize)
|
||||
};
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.maximum_nprobes(maximum_nprobes)
|
||||
.infer_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||
|
||||
Reference in New Issue
Block a user