From 59b57e30edcbce0c31b700c5faf9ba3c193002d0 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 13 Jun 2025 15:18:29 -0700 Subject: [PATCH] feat: add maximum and minimum nprobes properties (#2430) This exposes the maximum_nprobes and minimum_nprobes feature that was added in https://github.com/lancedb/lance/pull/3903 ## Summary by CodeRabbit - **New Features** - Added support for specifying minimum and maximum probe counts in vector search queries, allowing finer control over search behavior. - Users can now independently set minimum and maximum probes for vector and hybrid queries via new methods and parameters in Python, Node.js, and Rust APIs. - **Bug Fixes** - Improved parameter validation to ensure correct usage of minimum and maximum probe values. - **Tests** - Expanded test coverage to validate correct handling, serialization, and error cases for the new probe parameters. --- nodejs/__test__/table.test.ts | 26 ++++++ nodejs/lancedb/query.ts | 31 +++++++ nodejs/src/query.rs | 25 ++++++ python/python/lancedb/_lancedb.pyi | 7 +- python/python/lancedb/query.py | 114 ++++++++++++++++++++--- python/python/lancedb/table.py | 6 +- python/python/tests/test_query.py | 124 ++++++++++++++++++++++++-- python/python/tests/test_remote_db.py | 66 ++++++++++++++ python/src/query.rs | 37 +++++++- rust/lancedb/src/query.rs | 78 ++++++++++++++-- rust/lancedb/src/remote/table.rs | 18 +++- rust/lancedb/src/table.rs | 5 +- 12 files changed, 505 insertions(+), 32 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 5b4f9d86..23fe67dd 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -559,6 +559,32 @@ describe("When creating an index", () => { rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow(); expect(rst.numRows).toBe(1); + // test nprobes + rst = await tbl.query().nearestTo(queryVec).limit(2).nprobes(50).toArrow(); + expect(rst.numRows).toBe(2); + rst = await tbl + .query() + .nearestTo(queryVec) + .limit(2) + .minimumNprobes(15) + .toArrow(); + expect(rst.numRows).toBe(2); + rst = await tbl + .query() + .nearestTo(queryVec) + .limit(2) + .minimumNprobes(10) + .maximumNprobes(20) + .toArrow(); + expect(rst.numRows).toBe(2); + + expect(() => tbl.query().nearestTo(queryVec).minimumNprobes(0)).toThrow( + "Invalid input, minimum_nprobes must be greater than 0", + ); + expect(() => tbl.query().nearestTo(queryVec).maximumNprobes(5)).toThrow( + "Invalid input, maximum_nprobes must be greater than minimum_nprobes", + ); + await tbl.dropIndex("vec_idx"); const indices2 = await tbl.listIndices(); expect(indices2.length).toBe(0); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index b544faa1..c9fc97d9 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -448,6 +448,10 @@ export class VectorQuery extends QueryBase { * For best results we recommend tuning this parameter with a benchmark against * your actual data to find the smallest possible value that will still give * you the desired recall. + * + * For more fine grained control over behavior when you have a very narrow filter + * you can use `minimumNprobes` and `maximumNprobes`. This method sets both + * the minimum and maximum to the same value. */ nprobes(nprobes: number): VectorQuery { super.doCall((inner) => inner.nprobes(nprobes)); @@ -455,6 +459,33 @@ export class VectorQuery extends QueryBase { return this; } + /** + * Set the minimum number of probes used. + * + * This controls the minimum number of partitions that will be searched. This + * parameter will impact every query against a vector index, regardless of the + * filter. See `nprobes` for more details. Higher values will increase recall + * but will also increase latency. + */ + minimumNprobes(minimumNprobes: number): VectorQuery { + super.doCall((inner) => inner.minimumNprobes(minimumNprobes)); + return this; + } + + /** + * Set the maximum number of probes used. + * + * This controls the maximum number of partitions that will be searched. If this + * number is greater than minimumNprobes then the excess partitions will _only_ be + * searched if we have not found enough results. This can be useful when there is + * a narrow filter to allow these queries to spend more time searching and avoid + * potential false negatives. + */ + maximumNprobes(maximumNprobes: number): VectorQuery { + super.doCall((inner) => inner.maximumNprobes(maximumNprobes)); + return this; + } + /* * Set the distance range to use * diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 3e3208cf..00630c3c 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -178,6 +178,31 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + #[napi] + pub fn minimum_nprobes(&mut self, minimum_nprobe: u32) -> napi::Result<()> { + self.inner = self + .inner + .clone() + .minimum_nprobes(minimum_nprobe as usize) + .default_error()?; + Ok(()) + } + + #[napi] + pub fn maximum_nprobes(&mut self, maximum_nprobes: u32) -> napi::Result<()> { + let maximum_nprobes = if maximum_nprobes == 0 { + None + } else { + Some(maximum_nprobes as usize) + }; + self.inner = self + .inner + .clone() + .maximum_nprobes(maximum_nprobes) + .default_error()?; + Ok(()) + } + #[napi] pub fn distance_range(&mut self, lower_bound: Option, upper_bound: Option) { // napi doesn't support f32, so we have to convert to f32 diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index c0fa4712..8f4c5a06 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -143,6 +143,8 @@ class VectorQuery: def postfilter(self): ... def refine_factor(self, refine_factor: int): ... def nprobes(self, nprobes: int): ... + def minimum_nprobes(self, minimum_nprobes: int): ... + def maximum_nprobes(self, maximum_nprobes: int): ... def bypass_vector_index(self): ... def nearest_to_text(self, query: dict) -> HybridQuery: ... def to_query_request(self) -> PyQueryRequest: ... @@ -158,6 +160,8 @@ class HybridQuery: def distance_type(self, distance_type: str): ... def refine_factor(self, refine_factor: int): ... def nprobes(self, nprobes: int): ... + def minimum_nprobes(self, minimum_nprobes: int): ... + def maximum_nprobes(self, maximum_nprobes: int): ... def bypass_vector_index(self): ... def to_vector_query(self) -> VectorQuery: ... def to_fts_query(self) -> FTSQuery: ... @@ -178,7 +182,8 @@ class PyQueryRequest: with_row_id: Optional[bool] column: Optional[str] query_vector: Optional[List[pa.Array]] - nprobes: Optional[int] + minimum_nprobes: Optional[int] + maximum_nprobes: Optional[int] lower_bound: Optional[float] upper_bound: Optional[float] ef: Optional[int] diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 2c248699..4025b1a2 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -437,8 +437,18 @@ class Query(pydantic.BaseModel): # which columns to return in the results columns: Optional[Union[List[str], Dict[str, str]]] = None - # number of IVF partitions to search - nprobes: Optional[int] = None + # minimum number of IVF partitions to search + # + # If None then a default value (20) will be used. + minimum_nprobes: Optional[int] = None + + # maximum number of IVF partitions to search + # + # If None then a default value (20) will be used. + # + # If 0 then no limit will be applied and all partitions could be searched + # if needed to satisfy the limit. + maximum_nprobes: Optional[int] = None # lower bound for distance search lower_bound: Optional[float] = None @@ -476,7 +486,8 @@ class Query(pydantic.BaseModel): query.vector_column = req.column query.vector = req.query_vector query.distance_type = req.distance_type - query.nprobes = req.nprobes + query.minimum_nprobes = req.minimum_nprobes + query.maximum_nprobes = req.maximum_nprobes query.lower_bound = req.lower_bound query.upper_bound = req.upper_bound query.ef = req.ef @@ -1037,7 +1048,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): super().__init__(table) self._query = query self._distance_type = None - self._nprobes = None + self._minimum_nprobes = None + self._maximum_nprobes = None self._lower_bound = None self._upper_bound = None self._refine_factor = None @@ -1100,6 +1112,10 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. + This method sets both the minimum and maximum number of probes to the same + value. See `minimum_nprobes` and `maximum_nprobes` for more fine-grained + control. + Parameters ---------- nprobes: int @@ -1110,7 +1126,36 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceQueryBuilder object. """ - self._nprobes = nprobes + self._minimum_nprobes = nprobes + self._maximum_nprobes = nprobes + return self + + def minimum_nprobes(self, minimum_nprobes: int) -> LanceVectorQueryBuilder: + """Set the minimum number of probes to use. + + See `nprobes` for more details. + + These partitions will be searched on every vector query and will increase recall + at the expense of latency. + """ + self._minimum_nprobes = minimum_nprobes + return self + + def maximum_nprobes(self, maximum_nprobes: int) -> LanceVectorQueryBuilder: + """Set the maximum number of probes to use. + + See `nprobes` for more details. + + If this value is greater than `minimum_nprobes` then the excess partitions + will be searched only if we have not found enough results. + + This can be useful when there is a narrow filter to allow these queries to + spend more time searching and avoid potential false negatives. + + If this value is 0 then no limit will be applied and all partitions could be + searched if needed to satisfy the limit. + """ + self._maximum_nprobes = maximum_nprobes return self def distance_range( @@ -1214,7 +1259,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): limit=self._limit, distance_type=self._distance_type, columns=self._columns, - nprobes=self._nprobes, + minimum_nprobes=self._minimum_nprobes, + maximum_nprobes=self._maximum_nprobes, lower_bound=self._lower_bound, upper_bound=self._upper_bound, refine_factor=self._refine_factor, @@ -1578,7 +1624,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._fts_columns = fts_columns self._norm = None self._reranker = None - self._nprobes = None + self._minimum_nprobes = None + self._maximum_nprobes = None self._refine_factor = None self._distance_type = None self._phrase_query = None @@ -1810,7 +1857,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceHybridQueryBuilder The LanceHybridQueryBuilder object. """ - self._nprobes = nprobes + self._minimum_nprobes = nprobes + self._maximum_nprobes = nprobes + return self + + def minimum_nprobes(self, minimum_nprobes: int) -> LanceHybridQueryBuilder: + """Set the minimum number of probes to use. + + See `nprobes` for more details. + """ + self._minimum_nprobes = minimum_nprobes + return self + + def maximum_nprobes(self, maximum_nprobes: int) -> LanceHybridQueryBuilder: + """Set the maximum number of probes to use. + + See `nprobes` for more details. + """ + self._maximum_nprobes = maximum_nprobes return self def distance_range( @@ -2039,8 +2103,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._fts_query.phrase_query(True) if self._distance_type: self._vector_query.metric(self._distance_type) - if self._nprobes: - self._vector_query.nprobes(self._nprobes) + if self._minimum_nprobes: + self._vector_query.minimum_nprobes(self._minimum_nprobes) + if self._maximum_nprobes is not None: + self._vector_query.maximum_nprobes(self._maximum_nprobes) if self._refine_factor: self._vector_query.refine_factor(self._refine_factor) if self._ef: @@ -2651,6 +2717,34 @@ class AsyncVectorQueryBase: self._inner.nprobes(nprobes) return self + def minimum_nprobes(self, minimum_nprobes: int) -> Self: + """Set the minimum number of probes to use. + + See `nprobes` for more details. + + These partitions will be searched on every indexed vector query and will + increase recall at the expense of latency. + """ + self._inner.minimum_nprobes(minimum_nprobes) + return self + + def maximum_nprobes(self, maximum_nprobes: int) -> Self: + """Set the maximum number of probes to use. + + See `nprobes` for more details. + + If this value is greater than `minimum_nprobes` then the excess partitions + will be searched only if we have not found enough results. + + This can be useful when there is a narrow filter to allow these queries to + spend more time searching and avoid potential false negatives. + + If this value is 0 then no limit will be applied and all partitions could be + searched if needed to satisfy the limit. + """ + self._inner.maximum_nprobes(maximum_nprobes) + return self + def distance_range( self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None ) -> Self: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index f4d52663..5bc78f9f 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3637,8 +3637,10 @@ class AsyncTable: ) if query.distance_type is not None: async_query = async_query.distance_type(query.distance_type) - if query.nprobes is not None: - async_query = async_query.nprobes(query.nprobes) + if query.minimum_nprobes is not None: + async_query = async_query.minimum_nprobes(query.minimum_nprobes) + if query.maximum_nprobes is not None: + async_query = async_query.maximum_nprobes(query.maximum_nprobes) if query.refine_factor is not None: async_query = async_query.refine_factor(query.refine_factor) if query.vector_column: diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index c50642fb..3abdadc6 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -439,6 +439,33 @@ def test_query_builder_with_filter(table): assert all(np.array(rs[0]["vector"]) == [3, 4]) +def test_invalid_nprobes_sync(table): + with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"): + LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list() + with pytest.raises( + ValueError, match="maximum_nprobes must be greater than minimum_nprobes" + ): + LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list() + with pytest.raises( + ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes" + ): + LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list() + + +@pytest.mark.asyncio +async def test_invalid_nprobes_async(table_async: AsyncTable): + with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"): + await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list() + with pytest.raises( + ValueError, match="maximum_nprobes must be greater than minimum_nprobes" + ): + await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list() + with pytest.raises( + ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes" + ): + await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list() + + def test_query_builder_with_prefilter(table): df = ( LanceVectorQueryBuilder(table, [0, 0], "vector") @@ -585,6 +612,21 @@ async def test_query_async(table_async: AsyncTable): table_async.query().nearest_to(pa.array([1, 2])).nprobes(10), expected_num_rows=2, ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).minimum_nprobes(10), + expected_num_rows=2, + ) + await check_query( + table_async.query().nearest_to(pa.array([1, 2])).maximum_nprobes(30), + expected_num_rows=2, + ) + await check_query( + table_async.query() + .nearest_to(pa.array([1, 2])) + .minimum_nprobes(10) + .maximum_nprobes(20), + expected_num_rows=2, + ) await check_query( table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(), expected_num_rows=2, @@ -911,7 +953,39 @@ def test_query_serialization_sync(table: lancedb.table.Table): q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object() check_set_props( - q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5 + q, + vector_column="vector", + vector=[5.0, 6.0], + minimum_nprobes=10, + maximum_nprobes=10, + refine_factor=5, + ) + + q = table.search([5.0, 6.0]).minimum_nprobes(10).to_query_object() + check_set_props( + q, + vector_column="vector", + vector=[5.0, 6.0], + minimum_nprobes=10, + maximum_nprobes=None, + ) + + q = table.search([5.0, 6.0]).nprobes(50).to_query_object() + check_set_props( + q, + vector_column="vector", + vector=[5.0, 6.0], + minimum_nprobes=50, + maximum_nprobes=50, + ) + + q = table.search([5.0, 6.0]).maximum_nprobes(10).to_query_object() + check_set_props( + q, + vector_column="vector", + vector=[5.0, 6.0], + maximum_nprobes=10, + minimum_nprobes=None, ) q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object() @@ -963,7 +1037,8 @@ async def test_query_serialization_async(table_async: AsyncTable): limit=10, vector=sample_vector, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, bypass_vector_index=False, ) @@ -973,7 +1048,20 @@ async def test_query_serialization_async(table_async: AsyncTable): q, vector=sample_vector, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).nprobes(50).to_query_object() + check_set_props( + q, + vector=sample_vector, + postfilter=False, + minimum_nprobes=50, + maximum_nprobes=50, with_row_id=False, bypass_vector_index=False, limit=10, @@ -992,7 +1080,8 @@ async def test_query_serialization_async(table_async: AsyncTable): filter="id = 1", postfilter=True, vector=sample_vector, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, bypass_vector_index=False, ) @@ -1006,7 +1095,8 @@ async def test_query_serialization_async(table_async: AsyncTable): check_set_props( q, vector=sample_vector, - nprobes=10, + minimum_nprobes=10, + maximum_nprobes=10, refine_factor=5, postfilter=False, with_row_id=False, @@ -1014,6 +1104,18 @@ async def test_query_serialization_async(table_async: AsyncTable): limit=10, ) + q = (await table_async.search([5.0, 6.0])).minimum_nprobes(5).to_query_object() + check_set_props( + q, + vector=sample_vector, + minimum_nprobes=5, + maximum_nprobes=20, + postfilter=False, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + q = ( (await table_async.search([5.0, 6.0])) .distance_range(0.0, 1.0) @@ -1025,7 +1127,8 @@ async def test_query_serialization_async(table_async: AsyncTable): lower_bound=0.0, upper_bound=1.0, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, bypass_vector_index=False, limit=10, @@ -1037,7 +1140,8 @@ async def test_query_serialization_async(table_async: AsyncTable): distance_type="cosine", vector=sample_vector, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, bypass_vector_index=False, limit=10, @@ -1049,7 +1153,8 @@ async def test_query_serialization_async(table_async: AsyncTable): ef=7, vector=sample_vector, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, bypass_vector_index=False, limit=10, @@ -1061,7 +1166,8 @@ async def test_query_serialization_async(table_async: AsyncTable): bypass_vector_index=True, vector=sample_vector, postfilter=False, - nprobes=20, + minimum_nprobes=20, + maximum_nprobes=20, with_row_id=False, limit=10, ) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index a5f3feda..aa7c0c72 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -496,6 +496,8 @@ def test_query_sync_minimal(): "ef": None, "vector": [1.0, 2.0, 3.0], "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, "version": None, } @@ -536,6 +538,8 @@ def test_query_sync_maximal(): "refine_factor": 10, "vector": [1.0, 2.0, 3.0], "nprobes": 5, + "minimum_nprobes": 5, + "maximum_nprobes": 5, "lower_bound": None, "upper_bound": None, "ef": None, @@ -564,6 +568,66 @@ def test_query_sync_maximal(): ) +def test_query_sync_nprobes(): + def handler(body): + assert body == { + "distance_type": "l2", + "k": 10, + "prefilter": True, + "fast_search": True, + "vector_column": "vector2", + "refine_factor": None, + "lower_bound": None, + "upper_bound": None, + "ef": None, + "vector": [1.0, 2.0, 3.0], + "nprobes": 5, + "minimum_nprobes": 5, + "maximum_nprobes": 15, + "version": None, + } + + return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + + with query_test_table(handler) as table: + ( + table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) + .minimum_nprobes(5) + .maximum_nprobes(15) + .to_list() + ) + + +def test_query_sync_no_max_nprobes(): + def handler(body): + assert body == { + "distance_type": "l2", + "k": 10, + "prefilter": True, + "fast_search": True, + "vector_column": "vector2", + "refine_factor": None, + "lower_bound": None, + "upper_bound": None, + "ef": None, + "vector": [1.0, 2.0, 3.0], + "nprobes": 5, + "minimum_nprobes": 5, + "maximum_nprobes": 0, + "version": None, + } + + return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + + with query_test_table(handler) as table: + ( + table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) + .minimum_nprobes(5) + .maximum_nprobes(0) + .to_list() + ) + + @pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")]) def test_query_sync_batch_queries(server_version): def handler(body): @@ -666,6 +730,8 @@ def test_query_sync_hybrid(): "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, "lower_bound": None, "upper_bound": None, "ef": None, diff --git a/python/src/query.rs b/python/src/query.rs index 49c2b392..c68149e3 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -235,7 +235,10 @@ pub struct PyQueryRequest { pub with_row_id: Option, pub column: Option, pub query_vector: Option, - pub nprobes: Option, + pub minimum_nprobes: Option, + // None means user did not set it and default shoud be used (currenty 20) + // Some(0) means user set it to None and there is no limit + pub maximum_nprobes: Option, pub lower_bound: Option, pub upper_bound: Option, pub ef: Option, @@ -261,7 +264,8 @@ impl From for PyQueryRequest { with_row_id: Some(query_request.with_row_id), column: None, query_vector: None, - nprobes: None, + minimum_nprobes: None, + maximum_nprobes: None, lower_bound: None, upper_bound: None, ef: None, @@ -281,7 +285,11 @@ impl From for PyQueryRequest { with_row_id: Some(vector_query.base.with_row_id), column: vector_query.column, query_vector: Some(PyQueryVectors(vector_query.query_vector)), - nprobes: Some(vector_query.nprobes), + minimum_nprobes: Some(vector_query.minimum_nprobes), + maximum_nprobes: match vector_query.maximum_nprobes { + None => Some(0), + Some(value) => Some(value), + }, lower_bound: vector_query.lower_bound, upper_bound: vector_query.upper_bound, ef: vector_query.ef, @@ -655,6 +663,29 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + pub fn minimum_nprobes(&mut self, minimum_nprobes: u32) -> PyResult<()> { + self.inner = self + .inner + .clone() + .minimum_nprobes(minimum_nprobes as usize) + .infer_error()?; + Ok(()) + } + + pub fn maximum_nprobes(&mut self, maximum_nprobes: u32) -> PyResult<()> { + let maximum_nprobes = if maximum_nprobes == 0 { + None + } else { + Some(maximum_nprobes as usize) + }; + self.inner = self + .inner + .clone() + .maximum_nprobes(maximum_nprobes) + .infer_error()?; + Ok(()) + } + #[pyo3(signature = (lower_bound=None, upper_bound=None))] pub fn distance_range(&mut self, lower_bound: Option, upper_bound: Option) { self.inner = self.inner.clone().distance_range(lower_bound, upper_bound); diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 056c3aa6..614fdcca 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -796,8 +796,10 @@ pub struct VectorQueryRequest { pub column: Option, /// The vector(s) to search for pub query_vector: Vec>, - /// The number of partitions to search - pub nprobes: usize, + /// The minimum number of partitions to search + pub minimum_nprobes: usize, + /// The maximum number of partitions to search + pub maximum_nprobes: Option, /// The lower bound (inclusive) of the distance to search for. pub lower_bound: Option, /// The upper bound (exclusive) of the distance to search for. @@ -819,7 +821,8 @@ impl Default for VectorQueryRequest { base: QueryRequest::default(), column: None, query_vector: Vec::new(), - nprobes: 20, + minimum_nprobes: 20, + maximum_nprobes: Some(20), lower_bound: None, upper_bound: None, ef: None, @@ -925,11 +928,75 @@ impl VectorQuery { /// For best results we recommend tuning this parameter with a benchmark against /// your actual data to find the smallest possible value that will still give /// you the desired recall. + /// + /// This method sets both the minimum and maximum number of partitions to search. + /// For more fine-grained control see [`VectorQuery::minimum_nprobes`] and + /// [`VectorQuery::maximum_nprobes`]. pub fn nprobes(mut self, nprobes: usize) -> Self { - self.request.nprobes = nprobes; + self.request.minimum_nprobes = nprobes; + self.request.maximum_nprobes = Some(nprobes); self } + /// Set the minimum number of partitions to search + /// + /// This argument is only used when the vector column has an IVF PQ index. + /// If there is no index then this value is ignored. + /// + /// See [`VectorQuery::nprobes`] for more details. + /// + /// These partitions will be searched on every indexed vector query. + /// + /// Will return an error if the value is not greater than 0 or if maximum_nprobes + /// has been set and is less than the minimum_nprobes. + pub fn minimum_nprobes(mut self, minimum_nprobes: usize) -> Result { + if minimum_nprobes == 0 { + return Err(Error::InvalidInput { + message: "minimum_nprobes must be greater than 0".to_string(), + }); + } + if let Some(maximum_nprobes) = self.request.maximum_nprobes { + if minimum_nprobes > maximum_nprobes { + return Err(Error::InvalidInput { + message: "minimum_nprobes must be less or equal to maximum_nprobes".to_string(), + }); + } + } + self.request.minimum_nprobes = minimum_nprobes; + Ok(self) + } + + /// Set the maximum number of partitions to search + /// + /// This argument is only used when the vector column has an IVF PQ index. + /// If there is no index then this value is ignored. + /// + /// See [`VectorQuery::nprobes`] for more details. + /// + /// If this value is greater than minimum_nprobes then the excess partitions will + /// only be searched if the initial search does not return enough results. + /// + /// This can be useful when there is a narrow filter to allow these queries to + /// spend more time searching and avoid potential false negatives. + /// + /// Set to None to search all partitions, if needed, to satsify the limit + pub fn maximum_nprobes(mut self, maximum_nprobes: Option) -> Result { + if let Some(maximum_nprobes) = maximum_nprobes { + if maximum_nprobes == 0 { + return Err(Error::InvalidInput { + message: "maximum_nprobes must be greater than 0".to_string(), + }); + } + if maximum_nprobes < self.request.minimum_nprobes { + return Err(Error::InvalidInput { + message: "maximum_nprobes must be greater than minimum_nprobes".to_string(), + }); + } + } + self.request.maximum_nprobes = maximum_nprobes; + Ok(self) + } + /// Set the distance range for vector search, /// only rows with distances in the range [lower_bound, upper_bound) will be returned pub fn distance_range(mut self, lower_bound: Option, upper_bound: Option) -> Self { @@ -1208,7 +1275,8 @@ mod tests { ); assert_eq!(query.request.base.limit.unwrap(), 100); assert_eq!(query.request.base.offset.unwrap(), 1); - assert_eq!(query.request.nprobes, 1000); + assert_eq!(query.request.minimum_nprobes, 1000); + assert_eq!(query.request.maximum_nprobes, Some(1000)); assert!(query.request.use_index); assert_eq!(query.request.distance_type, Some(DistanceType::Cosine)); assert_eq!(query.request.refine_factor, Some(999)); diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 2f69fc27..df4f5acd 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -32,6 +32,7 @@ use lance::dataset::{ColumnAlteration, NewColumnTransform, Version}; use lance_datafusion::exec::{execute_plan, OneShotExec}; use reqwest::{RequestBuilder, Response}; use serde::{Deserialize, Serialize}; +use serde_json::Number; use std::collections::HashMap; use std::io::Cursor; use std::pin::Pin; @@ -438,7 +439,18 @@ impl RemoteTable { // Apply general parameters, before we dispatch based on number of query vectors. body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); - body["nprobes"] = query.nprobes.into(); + // In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`. + // Old client / new server: since minimum_nprobes is missing, fallback to nprobes + // New client / old server: old server will only see nprobes, make sure to set both + // nprobes and minimum_nprobes + // New client / new server: since minimum_nprobes is present, server can ignore nprobes + body["nprobes"] = query.minimum_nprobes.into(); + body["minimum_nprobes"] = query.minimum_nprobes.into(); + if let Some(maximum_nprobes) = query.maximum_nprobes { + body["maximum_nprobes"] = maximum_nprobes.into(); + } else { + body["maximum_nprobes"] = serde_json::Value::Number(Number::from_u128(0).unwrap()) + } body["lower_bound"] = query.lower_bound.into(); body["upper_bound"] = query.upper_bound.into(); body["ef"] = query.ef.into(); @@ -2075,6 +2087,8 @@ mod tests { "prefilter": true, "distance_type": "l2", "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, "lower_bound": Option::::None, "upper_bound": Option::::None, "k": 10, @@ -2175,6 +2189,8 @@ mod tests { "bypass_vector_index": true, "columns": ["a", "b"], "nprobes": 12, + "minimum_nprobes": 12, + "maximum_nprobes": 12, "lower_bound": Option::::None, "upper_bound": Option::::None, "ef": Option::::None, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a534be4e..d21d9bda 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -2354,12 +2354,15 @@ impl BaseTable for NativeTable { query.base.limit.unwrap_or(DEFAULT_TOP_K), )?; } + scanner.minimum_nprobes(query.minimum_nprobes); + if let Some(maximum_nprobes) = query.maximum_nprobes { + scanner.maximum_nprobes(maximum_nprobes); + } } scanner.limit( query.base.limit.map(|limit| limit as i64), query.base.offset.map(|offset| offset as i64), )?; - scanner.nprobs(query.nprobes); if let Some(ef) = query.ef { scanner.ef(ef); }