diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index d26a04f4..37f56e69 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -582,7 +582,7 @@ describe("When creating an index", () => { "Invalid input, minimum_nprobes must be greater than 0", ); expect(() => tbl.query().nearestTo(queryVec).maximumNprobes(5)).toThrow( - "Invalid input, maximum_nprobes must be greater than minimum_nprobes", + "Invalid input, maximum_nprobes must be greater than or equal to minimum_nprobes", ); await tbl.dropIndex("vec_idx"); diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index c60582cb..e6db7515 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3673,9 +3673,14 @@ class AsyncTable: ) if query.distance_type is not None: async_query = async_query.distance_type(query.distance_type) - if query.minimum_nprobes is not None: + if query.minimum_nprobes is not None and query.maximum_nprobes is not None: + # Set both to the minimum first to avoid min > max error. + async_query = async_query.nprobes( + query.minimum_nprobes + ).maximum_nprobes(query.maximum_nprobes) + elif query.minimum_nprobes is not None: async_query = async_query.minimum_nprobes(query.minimum_nprobes) - if query.maximum_nprobes is not None: + elif query.maximum_nprobes is not None: async_query = async_query.maximum_nprobes(query.maximum_nprobes) if query.refine_factor is not None: async_query = async_query.refine_factor(query.refine_factor) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 56e987da..28f3d69e 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -445,25 +445,45 @@ def test_invalid_nprobes_sync(table): with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"): LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(0).to_list() with pytest.raises( - ValueError, match="maximum_nprobes must be greater than minimum_nprobes" + ValueError, + match="maximum_nprobes must be greater than or equal to minimum_nprobes", ): LanceVectorQueryBuilder(table, [0, 0], "vector").maximum_nprobes(5).to_list() with pytest.raises( - ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes" + ValueError, + match="minimum_nprobes must be less than or equal to maximum_nprobes", ): LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(100).to_list() +def test_nprobes_works_sync(table): + LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).to_list() + + +def test_nprobes_min_max_works_sync(table): + LanceVectorQueryBuilder(table, [0, 0], "vector").minimum_nprobes(2).maximum_nprobes( + 4 + ).to_list() + + +def test_multiple_nprobes_calls_works_sync(table): + LanceVectorQueryBuilder(table, [0, 0], "vector").nprobes(30).maximum_nprobes( + 20 + ).minimum_nprobes(20).to_list() + + @pytest.mark.asyncio async def test_invalid_nprobes_async(table_async: AsyncTable): with pytest.raises(ValueError, match="minimum_nprobes must be greater than 0"): await table_async.vector_search([0, 0]).minimum_nprobes(0).to_list() with pytest.raises( - ValueError, match="maximum_nprobes must be greater than minimum_nprobes" + ValueError, + match="maximum_nprobes must be greater than or equal to minimum_nprobes", ): await table_async.vector_search([0, 0]).maximum_nprobes(5).to_list() with pytest.raises( - ValueError, match="minimum_nprobes must be less or equal to maximum_nprobes" + ValueError, + match="minimum_nprobes must be less than or equal to maximum_nprobes", ): await table_async.vector_search([0, 0]).minimum_nprobes(100).to_list() diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 614fdcca..ba4424e3 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -958,7 +958,8 @@ impl VectorQuery { if let Some(maximum_nprobes) = self.request.maximum_nprobes { if minimum_nprobes > maximum_nprobes { return Err(Error::InvalidInput { - message: "minimum_nprobes must be less or equal to maximum_nprobes".to_string(), + message: "minimum_nprobes must be less than or equal to maximum_nprobes" + .to_string(), }); } } @@ -989,7 +990,8 @@ impl VectorQuery { } if maximum_nprobes < self.request.minimum_nprobes { return Err(Error::InvalidInput { - message: "maximum_nprobes must be greater than minimum_nprobes".to_string(), + message: "maximum_nprobes must be greater than or equal to minimum_nprobes" + .to_string(), }); } }