From ebac960571a5703c31bcda2a5103006ac74b7bff Mon Sep 17 00:00:00 2001 From: Hezi Zisman Date: Tue, 24 Dec 2024 20:33:26 +0200 Subject: [PATCH] feat(python): add `bypass_vector_index` to sync api (#1947) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hi lancedb team, This PR adds the `bypass_vector_index` logic to the sync API, as described in [Issue #535](https://github.com/lancedb/lancedb/issues/535). (Closes #535). Iv'e implemented it only for the regular vector search. If you think it should also be supported for FTS, Hybrid, or Empty queries and for the cloud solution, please let me know, and I’ll be happy to extend it. Since there’s no `CONTRIBUTING.md` or contribution guidelines, I opted for the simplest implementation to get this started. Looking forward to your feedback! Thanks! --------- Co-authored-by: Will Jones --- python/python/lancedb/query.py | 44 ++++++++++++++++++++++++++++++++++ python/python/lancedb/table.py | 2 ++ python/python/tests/test_db.py | 22 +++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 2012b765..1e8468bc 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -126,6 +126,9 @@ class Query(pydantic.BaseModel): ef: Optional[int] = None + # Default is true. Set to false to enforce a brute force search. + use_index: bool = True + class LanceQueryBuilder(ABC): """An abstract query builder. Subclasses are defined for vector search, @@ -253,6 +256,7 @@ class LanceQueryBuilder(ABC): self._vector = None self._text = None self._ef = None + self._use_index = True @deprecation.deprecated( deprecated_in="0.3.1", @@ -511,6 +515,7 @@ class LanceQueryBuilder(ABC): "metric": self._metric, "nprobes": self._nprobes, "refine_factor": self._refine_factor, + "use_index": self._use_index, }, prefilter=self._prefilter, filter=self._str_query, @@ -729,6 +734,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): offset=self._offset, fast_search=self._fast_search, ef=self._ef, + use_index=self._use_index, ) result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: @@ -802,6 +808,24 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._str_query = query_string if query_string is not None else self._str_query return self + def bypass_vector_index(self) -> LanceVectorQueryBuilder: + """ + If this is called then any vector index is skipped + + An exhaustive (flat) search will be performed. The query vector will + be compared to every vector in the table. At high scales this can be + expensive. However, this is often still useful. For example, skipping + the vector index can give you ground truth results which you can use to + calculate your recall to select an appropriate value for nprobes. + + Returns + ------- + LanceVectorQueryBuilder + The LanceVectorQueryBuilder object. + """ + self._use_index = False + return self + class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" @@ -1108,6 +1132,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.refine_factor(self._refine_factor) if self._ef: self._vector_query.ef(self._ef) + if not self._use_index: + self._vector_query.bypass_vector_index() with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) @@ -1323,6 +1349,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._text = text return self + def bypass_vector_index(self) -> LanceHybridQueryBuilder: + """ + If this is called then any vector index is skipped + + An exhaustive (flat) search will be performed. The query vector will + be compared to every vector in the table. At high scales this can be + expensive. However, this is often still useful. For example, skipping + the vector index can give you ground truth results which you can use to + calculate your recall to select an appropriate value for nprobes. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._use_index = False + return self + class AsyncQueryBase(object): def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]): diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 34175844..ebc04320 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2812,6 +2812,8 @@ class AsyncTable: async_query = async_query.column(query.vector_column) if query.ef: async_query = async_query.ef(query.ef) + if not query.use_index: + async_query = async_query.bypass_vector_index() if not query.prefilter: async_query = async_query.postfilter() diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 893ae467..394956c8 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -681,3 +681,25 @@ def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection): with pytest.raises(ValueError): tmp_db.create_table("foo$$bar", data) tmp_db.create_table("foo.bar", data) + + +def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection): + data = [{"vector": np.random.rand(32)} for _ in range(512)] + sample_key = data[100]["vector"] + table = tmp_db.create_table( + "test", + data, + ) + + table.create_index( + num_partitions=2, + num_sub_vectors=2, + ) + + plan_with_index = table.search(sample_key).explain_plan(verbose=True) + assert "ANN" in plan_with_index + + plan_without_index = ( + table.search(sample_key).bypass_vector_index().explain_plan(verbose=True) + ) + assert "KNN" in plan_without_index