feat: expose use_index in LanceDB OSS

This commit is contained in:
Will Jones
2023-10-13 11:17:10 -07:00
parent 683824f1e9
commit f9ccefb032
8 changed files with 71 additions and 3 deletions

View File

@@ -59,6 +59,8 @@ class Query(pydantic.BaseModel):
# Refine factor.
refine_factor: Optional[int] = None
use_index: bool = True
class LanceQueryBuilder(ABC):
@classmethod
@@ -279,6 +281,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
self._use_index = True
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
@@ -340,6 +343,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def use_index(self, use_index: bool) -> LanceVectorQueryBuilder:
"""
Choose whether to use an ANN index or not. Default is True.
Setting this to False is not yet supported on LanceDB Cloud.
Parameters
----------
use_index: bool
If True, use an ANN index if one exists, otherwise perform exact KNN
on a full table scan.
"""
self._use_index = use_index
return self
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
@@ -360,6 +378,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
nprobes=self._nprobes,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
use_index=self._use_index,
)
return self._table._execute_query(query)

View File

@@ -101,6 +101,10 @@ class RemoteTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon")
if not query.use_index:
raise NotImplementedError(
"Cloud does not support non-indexed queries if the table has indices"
)
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()

View File

@@ -559,7 +559,8 @@ class LanceTable(Table):
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
"append", which inserts new rows, and "overwrite", which replaces
the entire content of the table with the new rows.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -887,6 +888,7 @@ class LanceTable(Table):
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"use_index": query.use_index,
},
)

View File

@@ -46,6 +46,7 @@ class MockTable:
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"use_index": query.use_index,
},
)
@@ -84,10 +85,12 @@ def test_cast(table):
assert r0.float_field == 1.0
def test_query_builder(table):
@pytest.mark.parametrize("use_index", [True, False])
def test_query_builder(table, use_index: bool):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.use_index(use_index)
.select(["id"])
.to_list()
)