feat(python): expose prefilter to lancedb (#522)

We have experimental support for prefiltering (without ANN) in pylance.
This means that we can now apply a filter BEFORE vector search is
performed. This can be done via the `.where(filter_string,
prefilter=True)` kwargs of the query.

Limitations:
- When connecting to LanceDB cloud, `prefilter=True` will raise
NotImplemented
- When an ANN index is present, `prefilter=True` will raise
NotImplemented
- This option is not available for full text search query
- This option is not available for empty search query (just
filter/project)

Additional changes in this PR:
- Bump pylance version to v0.8.0 which supports the experimental
prefiltering.

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-10-01 10:34:12 -07:00
committed by GitHub
parent 343e274ea5
commit 693bca1eba
6 changed files with 68 additions and 9 deletions

View File

@@ -26,6 +26,7 @@ import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel, Field, PrivateAttr
from tqdm import tqdm
class EmbeddingFunctionRegistry:
@@ -514,7 +515,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in futures]
return [future.result() for future in tqdm(futures)]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
@@ -557,7 +558,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor)
image_features = self._model.encode_image(image_tensor.to(self.device))
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()

View File

@@ -38,6 +38,9 @@ class Query(pydantic.BaseModel):
# sql filter to refine the query with
filter: Optional[str] = None
# if True then apply the filter before vector search
prefilter: bool = False
# top k results to return
k: int
@@ -162,7 +165,7 @@ class LanceQueryBuilder(ABC):
for row in self.to_arrow().to_pylist()
]
def limit(self, limit: int) -> LanceVectorQueryBuilder:
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
@@ -172,13 +175,13 @@ class LanceQueryBuilder(ABC):
Returns
-------
LanceVectorQueryBuilder
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceVectorQueryBuilder:
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
@@ -188,13 +191,13 @@ class LanceQueryBuilder(ABC):
Returns
-------
LanceVectorQueryBuilder
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceVectorQueryBuilder:
def where(self, where) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
@@ -204,7 +207,7 @@ class LanceQueryBuilder(ABC):
Returns
-------
LanceVectorQueryBuilder
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
@@ -246,6 +249,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = 20
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
@@ -320,6 +324,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
query = Query(
vector=vector,
filter=self._where,
prefilter=self._prefilter,
k=self._limit,
metric=self._metric,
columns=self._columns,
@@ -329,6 +334,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
)
return self._table._execute_query(query)
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future. Currently this is only supported
in OSS and can only be used with a table that does not have an ANN
index.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str):

View File

@@ -98,6 +98,8 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon")
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()

View File

@@ -844,9 +844,16 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
if query.prefilter:
for idx in ds.list_indices():
if query.vector_column in idx["fields"]:
raise NotImplementedError(
"Prefiltering for indexed vector column is coming soon."
)
return ds.to_table(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
nearest={
"column": query.vector_column,
"q": query.vector,

View File

@@ -2,7 +2,7 @@
name = "lancedb"
version = "0.2.5"
dependencies = [
"pylance==0.7.4",
"pylance==0.8.0",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.1.0",

View File

@@ -38,6 +38,7 @@ class MockTable:
return ds.to_table(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
nearest={
"column": query.vector_column,
"q": query.vector,
@@ -97,6 +98,25 @@ def test_query_builder_with_filter(table):
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_prefilter(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2")
.limit(1)
.to_df()
)
assert len(df) == 0
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2", prefilter=True)
.limit(1)
.to_df()
)
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"