Compare commits

...

5 Commits

Author SHA1 Message Date
Lance Release
d326146a40 [python] Bump version: 0.2.5 → 0.2.6 2023-10-01 17:48:59 +00:00
Chang She
693bca1eba feat(python): expose prefilter to lancedb (#522)
We have experimental support for prefiltering (without ANN) in pylance.
This means that we can now apply a filter BEFORE vector search is
performed. This can be done via the `.where(filter_string,
prefilter=True)` kwargs of the query.

Limitations:
- When connecting to LanceDB cloud, `prefilter=True` will raise
NotImplemented
- When an ANN index is present, `prefilter=True` will raise
NotImplemented
- This option is not available for full text search query
- This option is not available for empty search query (just
filter/project)

Additional changes in this PR:
- Bump pylance version to v0.8.0 which supports the experimental
prefiltering.

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-10-01 10:34:12 -07:00
Will Jones
343e274ea5 fix: define minimum dependency versions (#515)
Closes #513

For each of these, I found the minimum version that would allow the unit
tests to pass.
2023-09-29 09:04:49 -07:00
Rob Meng
a695fb8030 fix import attr to use import attrs (#510)
Thanks to #508, I used `attr` instead of the correct package `attrs`

s/attr/attrs
2023-09-23 00:30:56 -04:00
Hynek Schlawack
bc8670d7af [Python] Fix attrs dependency (#508)
The `attr` project is unrelated to `attrs` that also provides the `attr`
namespace (see also <https://hynek.me/articles/import-attrs/>).

It used to _usually_ work, because attrs is a dependency of aiohttp and
somehow took precedence over `attr`'s `attr`.

Yes, sorry, it's a mess.
2023-09-21 12:35:34 -04:00
10 changed files with 86 additions and 27 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.5 current_version = 0.2.6
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -26,6 +26,7 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
from cachetools import cached from cachetools import cached
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from tqdm import tqdm
class EmbeddingFunctionRegistry: class EmbeddingFunctionRegistry:
@@ -514,7 +515,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
executor.submit(self.generate_image_embedding, image) executor.submit(self.generate_image_embedding, image)
for image in images for image in images
] ]
return [future.result() for future in futures] return [future.result() for future in tqdm(futures)]
def generate_image_embedding( def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"] self, image: Union[str, bytes, "PIL.Image.Image"]
@@ -557,7 +558,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
""" """
encode a single image tensor and optionally normalize the output encode a single image tensor and optionally normalize the output
""" """
image_features = self._model.encode_image(image_tensor) image_features = self._model.encode_image(image_tensor.to(self.device))
if self.normalize: if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze() return image_features.cpu().numpy().squeeze()

View File

@@ -38,6 +38,9 @@ class Query(pydantic.BaseModel):
# sql filter to refine the query with # sql filter to refine the query with
filter: Optional[str] = None filter: Optional[str] = None
# if True then apply the filter before vector search
prefilter: bool = False
# top k results to return # top k results to return
k: int k: int
@@ -162,7 +165,7 @@ class LanceQueryBuilder(ABC):
for row in self.to_arrow().to_pylist() for row in self.to_arrow().to_pylist()
] ]
def limit(self, limit: int) -> LanceVectorQueryBuilder: def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return. """Set the maximum number of results to return.
Parameters Parameters
@@ -172,13 +175,13 @@ class LanceQueryBuilder(ABC):
Returns Returns
------- -------
LanceVectorQueryBuilder LanceQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._limit = limit self._limit = limit
return self return self
def select(self, columns: list) -> LanceVectorQueryBuilder: def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return. """Set the columns to return.
Parameters Parameters
@@ -188,13 +191,13 @@ class LanceQueryBuilder(ABC):
Returns Returns
------- -------
LanceVectorQueryBuilder LanceQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._columns = columns self._columns = columns
return self return self
def where(self, where: str) -> LanceVectorQueryBuilder: def where(self, where) -> LanceQueryBuilder:
"""Set the where clause. """Set the where clause.
Parameters Parameters
@@ -204,7 +207,7 @@ class LanceQueryBuilder(ABC):
Returns Returns
------- -------
LanceVectorQueryBuilder LanceQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._where = where self._where = where
@@ -246,6 +249,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = 20 self._nprobes = 20
self._refine_factor = None self._refine_factor = None
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
@@ -320,6 +324,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
query = Query( query = Query(
vector=vector, vector=vector,
filter=self._where, filter=self._where,
prefilter=self._prefilter,
k=self._limit, k=self._limit,
metric=self._metric, metric=self._metric,
columns=self._columns, columns=self._columns,
@@ -329,6 +334,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
) )
return self._table._execute_query(query) return self._table._execute_query(query)
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future. Currently this is only supported
in OSS and can only be used with a table that does not have an ANN
index.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
return self
class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str): def __init__(self, table: "lancedb.table.Table", query: str):

View File

@@ -14,7 +14,7 @@
import abc import abc
from typing import List, Optional from typing import List, Optional
import attr import attrs
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
@@ -44,7 +44,7 @@ class VectorQuery(BaseModel):
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
@attr.define @attrs.define
class VectorQueryResult: class VectorQueryResult:
# for now the response is directly seralized into a pandas dataframe # for now the response is directly seralized into a pandas dataframe
tbl: pa.Table tbl: pa.Table

View File

@@ -16,7 +16,7 @@ import functools
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union
import aiohttp import aiohttp
import attr import attrs
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
@@ -43,14 +43,14 @@ async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
return reader.read_all() return reader.read_all()
@attr.define(slots=False) @attrs.define(slots=False)
class RestfulLanceDBClient: class RestfulLanceDBClient:
db_name: str db_name: str
region: str region: str
api_key: Credential api_key: Credential
host_override: Optional[str] = attr.field(default=None) host_override: Optional[str] = attrs.field(default=None)
closed: bool = attr.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
@functools.cached_property @functools.cached_property
def session(self) -> aiohttp.ClientSession: def session(self) -> aiohttp.ClientSession:

View File

@@ -98,6 +98,8 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon")
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow() return self._conn._loop.run_until_complete(result).to_arrow()

View File

@@ -844,9 +844,16 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
if query.prefilter:
for idx in ds.list_indices():
if query.vector_column in idx["fields"]:
raise NotImplementedError(
"Prefiltering for indexed vector column is coming soon."
)
return ds.to_table( return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,
prefilter=query.prefilter,
nearest={ nearest={
"column": query.vector_column, "column": query.vector_column,
"q": query.vector, "q": query.vector,

View File

@@ -1,14 +1,14 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.2.5" version = "0.2.6"
dependencies = [ dependencies = [
"pylance==0.7.4", "pylance==0.8.0",
"ratelimiter", "ratelimiter~=1.0",
"retry", "retry>=0.9.2",
"tqdm", "tqdm>=4.1.0",
"aiohttp", "aiohttp",
"pydantic", "pydantic>=1.10",
"attr", "attrs>=21.3.0",
"semver>=3.0", "semver>=3.0",
"cachetools" "cachetools"
] ]
@@ -62,4 +62,4 @@ addopts = "--strict-markers"
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio" "asyncio"
] ]

View File

@@ -38,6 +38,7 @@ class MockTable:
return ds.to_table( return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,
prefilter=query.prefilter,
nearest={ nearest={
"column": query.vector_column, "column": query.vector_column,
"q": query.vector, "q": query.vector,
@@ -97,6 +98,25 @@ def test_query_builder_with_filter(table):
assert all(df["vector"].values[0] == [3, 4]) assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_prefilter(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2")
.limit(1)
.to_df()
)
assert len(df) == 0
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2", prefilter=True)
.limit(1)
.to_df()
)
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_metric(table): def test_query_builder_with_metric(table):
query = [4, 8] query = [4, 8]
vector_column_name = "vector" vector_column_name = "vector"

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import attr import attrs
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
@@ -21,10 +21,10 @@ from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attr.define @attrs.define
class MockLanceDBServer: class MockLanceDBServer:
runner: web.AppRunner = attr.field(init=False) runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attr.field(init=False) site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response: async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"] table_name = request.match_info["table_name"]