mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat(python): expose prefilter to lancedb (#522)
We have experimental support for prefiltering (without ANN) in pylance. This means that we can now apply a filter BEFORE vector search is performed. This can be done via the `.where(filter_string, prefilter=True)` kwargs of the query. Limitations: - When connecting to LanceDB cloud, `prefilter=True` will raise NotImplemented - When an ANN index is present, `prefilter=True` will raise NotImplemented - This option is not available for full text search query - This option is not available for empty search query (just filter/project) Additional changes in this PR: - Bump pylance version to v0.8.0 which supports the experimental prefiltering. --------- Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -26,6 +26,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
@@ -514,7 +515,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in futures]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
@@ -557,7 +558,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor)
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
@@ -38,6 +38,9 @@ class Query(pydantic.BaseModel):
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# if True then apply the filter before vector search
|
||||
prefilter: bool = False
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
@@ -162,7 +165,7 @@ class LanceQueryBuilder(ABC):
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def limit(self, limit: int) -> LanceVectorQueryBuilder:
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
@@ -172,13 +175,13 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceVectorQueryBuilder:
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
@@ -188,13 +191,13 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceVectorQueryBuilder:
|
||||
def where(self, where) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
@@ -204,7 +207,7 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
@@ -246,6 +249,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
@@ -320,6 +324,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
query = Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
prefilter=self._prefilter,
|
||||
k=self._limit,
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
@@ -329,6 +334,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future. Currently this is only supported
|
||||
in OSS and can only be used with a table that does not have an ANN
|
||||
index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
|
||||
@@ -98,6 +98,8 @@ class RemoteTable(Table):
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
if query.prefilter:
|
||||
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
|
||||
@@ -844,9 +844,16 @@ class LanceTable(Table):
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
if query.prefilter:
|
||||
for idx in ds.list_indices():
|
||||
if query.vector_column in idx["fields"]:
|
||||
raise NotImplementedError(
|
||||
"Prefiltering for indexed vector column is coming soon."
|
||||
)
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
name = "lancedb"
|
||||
version = "0.2.5"
|
||||
dependencies = [
|
||||
"pylance==0.7.4",
|
||||
"pylance==0.8.0",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
|
||||
@@ -38,6 +38,7 @@ class MockTable:
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
@@ -97,6 +98,25 @@ def test_query_builder_with_filter(table):
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_prefilter(table):
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert len(df) == 0
|
||||
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2", prefilter=True)
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
|
||||
Reference in New Issue
Block a user