diff --git a/python/lancedb/embeddings/functions.py b/python/lancedb/embeddings/functions.py index 4fb763c5..e2a70898 100644 --- a/python/lancedb/embeddings/functions.py +++ b/python/lancedb/embeddings/functions.py @@ -26,6 +26,7 @@ import numpy as np import pyarrow as pa from cachetools import cached from pydantic import BaseModel, Field, PrivateAttr +from tqdm import tqdm class EmbeddingFunctionRegistry: @@ -514,7 +515,7 @@ class OpenClipEmbeddings(EmbeddingFunction): executor.submit(self.generate_image_embedding, image) for image in images ] - return [future.result() for future in futures] + return [future.result() for future in tqdm(futures)] def generate_image_embedding( self, image: Union[str, bytes, "PIL.Image.Image"] @@ -557,7 +558,7 @@ class OpenClipEmbeddings(EmbeddingFunction): """ encode a single image tensor and optionally normalize the output """ - image_features = self._model.encode_image(image_tensor) + image_features = self._model.encode_image(image_tensor.to(self.device)) if self.normalize: image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().squeeze() diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 10925ac7..eedf8079 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -38,6 +38,9 @@ class Query(pydantic.BaseModel): # sql filter to refine the query with filter: Optional[str] = None + # if True then apply the filter before vector search + prefilter: bool = False + # top k results to return k: int @@ -162,7 +165,7 @@ class LanceQueryBuilder(ABC): for row in self.to_arrow().to_pylist() ] - def limit(self, limit: int) -> LanceVectorQueryBuilder: + def limit(self, limit: int) -> LanceQueryBuilder: """Set the maximum number of results to return. Parameters @@ -172,13 +175,13 @@ class LanceQueryBuilder(ABC): Returns ------- - LanceVectorQueryBuilder + LanceQueryBuilder The LanceQueryBuilder object. """ self._limit = limit return self - def select(self, columns: list) -> LanceVectorQueryBuilder: + def select(self, columns: list) -> LanceQueryBuilder: """Set the columns to return. Parameters @@ -188,13 +191,13 @@ class LanceQueryBuilder(ABC): Returns ------- - LanceVectorQueryBuilder + LanceQueryBuilder The LanceQueryBuilder object. """ self._columns = columns return self - def where(self, where: str) -> LanceVectorQueryBuilder: + def where(self, where) -> LanceQueryBuilder: """Set the where clause. Parameters @@ -204,7 +207,7 @@ class LanceQueryBuilder(ABC): Returns ------- - LanceVectorQueryBuilder + LanceQueryBuilder The LanceQueryBuilder object. """ self._where = where @@ -246,6 +249,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._nprobes = 20 self._refine_factor = None self._vector_column = vector_column + self._prefilter = False def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. @@ -320,6 +324,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): query = Query( vector=vector, filter=self._where, + prefilter=self._prefilter, k=self._limit, metric=self._metric, columns=self._columns, @@ -329,6 +334,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): ) return self._table._execute_query(query) + def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder: + """Set the where clause. + + Parameters + ---------- + where: str + The where clause. + prefilter: bool, default False + If True, apply the filter before vector search, otherwise the + filter is applied on the result of vector search. + This feature is **EXPERIMENTAL** and may be removed and modified + without warning in the future. Currently this is only supported + in OSS and can only be used with a table that does not have an ANN + index. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + self._where = where + self._prefilter = prefilter + return self + class LanceFtsQueryBuilder(LanceQueryBuilder): def __init__(self, table: "lancedb.table.Table", query: str): diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index e80868a6..981a1696 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -98,6 +98,8 @@ class RemoteTable(Table): return LanceVectorQueryBuilder(self, query, vector_column_name) def _execute_query(self, query: Query) -> pa.Table: + if query.prefilter: + raise NotImplementedError("Cloud support for prefiltering is coming soon") result = self._conn._client.query(self._name, query) return self._conn._loop.run_until_complete(result).to_arrow() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index ae19d83c..60aba685 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -844,9 +844,16 @@ class LanceTable(Table): def _execute_query(self, query: Query) -> pa.Table: ds = self.to_lance() + if query.prefilter: + for idx in ds.list_indices(): + if query.vector_column in idx["fields"]: + raise NotImplementedError( + "Prefiltering for indexed vector column is coming soon." + ) return ds.to_table( columns=query.columns, filter=query.filter, + prefilter=query.prefilter, nearest={ "column": query.vector_column, "q": query.vector, diff --git a/python/pyproject.toml b/python/pyproject.toml index e785322f..befc5b53 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -2,7 +2,7 @@ name = "lancedb" version = "0.2.5" dependencies = [ - "pylance==0.7.4", + "pylance==0.8.0", "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.1.0", diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 56e039c9..6784f439 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -38,6 +38,7 @@ class MockTable: return ds.to_table( columns=query.columns, filter=query.filter, + prefilter=query.prefilter, nearest={ "column": query.vector_column, "q": query.vector, @@ -97,6 +98,25 @@ def test_query_builder_with_filter(table): assert all(df["vector"].values[0] == [3, 4]) +def test_query_builder_with_prefilter(table): + df = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .where("id = 2") + .limit(1) + .to_df() + ) + assert len(df) == 0 + + df = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .where("id = 2", prefilter=True) + .limit(1) + .to_df() + ) + assert df["id"].values[0] == 2 + assert all(df["vector"].values[0] == [3, 4]) + + def test_query_builder_with_metric(table): query = [4, 8] vector_column_name = "vector"