mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 21:39:57 +00:00
Compare commits
5 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d326146a40 | ||
|
|
693bca1eba | ||
|
|
343e274ea5 | ||
|
|
a695fb8030 | ||
|
|
bc8670d7af |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.2.5
|
current_version = 0.2.6
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import numpy as np
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from cachetools import cached
|
from cachetools import cached
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionRegistry:
|
class EmbeddingFunctionRegistry:
|
||||||
@@ -514,7 +515,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
executor.submit(self.generate_image_embedding, image)
|
executor.submit(self.generate_image_embedding, image)
|
||||||
for image in images
|
for image in images
|
||||||
]
|
]
|
||||||
return [future.result() for future in futures]
|
return [future.result() for future in tqdm(futures)]
|
||||||
|
|
||||||
def generate_image_embedding(
|
def generate_image_embedding(
|
||||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||||
@@ -557,7 +558,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
encode a single image tensor and optionally normalize the output
|
encode a single image tensor and optionally normalize the output
|
||||||
"""
|
"""
|
||||||
image_features = self._model.encode_image(image_tensor)
|
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
return image_features.cpu().numpy().squeeze()
|
return image_features.cpu().numpy().squeeze()
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ class Query(pydantic.BaseModel):
|
|||||||
# sql filter to refine the query with
|
# sql filter to refine the query with
|
||||||
filter: Optional[str] = None
|
filter: Optional[str] = None
|
||||||
|
|
||||||
|
# if True then apply the filter before vector search
|
||||||
|
prefilter: bool = False
|
||||||
|
|
||||||
# top k results to return
|
# top k results to return
|
||||||
k: int
|
k: int
|
||||||
|
|
||||||
@@ -162,7 +165,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
for row in self.to_arrow().to_pylist()
|
for row in self.to_arrow().to_pylist()
|
||||||
]
|
]
|
||||||
|
|
||||||
def limit(self, limit: int) -> LanceVectorQueryBuilder:
|
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||||
"""Set the maximum number of results to return.
|
"""Set the maximum number of results to return.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -172,13 +175,13 @@ class LanceQueryBuilder(ABC):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceVectorQueryBuilder
|
LanceQueryBuilder
|
||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
self._limit = limit
|
self._limit = limit
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def select(self, columns: list) -> LanceVectorQueryBuilder:
|
def select(self, columns: list) -> LanceQueryBuilder:
|
||||||
"""Set the columns to return.
|
"""Set the columns to return.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -188,13 +191,13 @@ class LanceQueryBuilder(ABC):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceVectorQueryBuilder
|
LanceQueryBuilder
|
||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
self._columns = columns
|
self._columns = columns
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def where(self, where: str) -> LanceVectorQueryBuilder:
|
def where(self, where) -> LanceQueryBuilder:
|
||||||
"""Set the where clause.
|
"""Set the where clause.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -204,7 +207,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceVectorQueryBuilder
|
LanceQueryBuilder
|
||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
self._where = where
|
self._where = where
|
||||||
@@ -246,6 +249,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._nprobes = 20
|
self._nprobes = 20
|
||||||
self._refine_factor = None
|
self._refine_factor = None
|
||||||
self._vector_column = vector_column
|
self._vector_column = vector_column
|
||||||
|
self._prefilter = False
|
||||||
|
|
||||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||||
"""Set the distance metric to use.
|
"""Set the distance metric to use.
|
||||||
@@ -320,6 +324,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
query = Query(
|
query = Query(
|
||||||
vector=vector,
|
vector=vector,
|
||||||
filter=self._where,
|
filter=self._where,
|
||||||
|
prefilter=self._prefilter,
|
||||||
k=self._limit,
|
k=self._limit,
|
||||||
metric=self._metric,
|
metric=self._metric,
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
@@ -329,6 +334,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
)
|
)
|
||||||
return self._table._execute_query(query)
|
return self._table._execute_query(query)
|
||||||
|
|
||||||
|
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||||
|
"""Set the where clause.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str
|
||||||
|
The where clause.
|
||||||
|
prefilter: bool, default False
|
||||||
|
If True, apply the filter before vector search, otherwise the
|
||||||
|
filter is applied on the result of vector search.
|
||||||
|
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||||
|
without warning in the future. Currently this is only supported
|
||||||
|
in OSS and can only be used with a table that does not have an ANN
|
||||||
|
index.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
The LanceQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._where = where
|
||||||
|
self._prefilter = prefilter
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import attr
|
import attrs
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ class VectorQuery(BaseModel):
|
|||||||
refine_factor: Optional[int] = None
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@attr.define
|
@attrs.define
|
||||||
class VectorQueryResult:
|
class VectorQueryResult:
|
||||||
# for now the response is directly seralized into a pandas dataframe
|
# for now the response is directly seralized into a pandas dataframe
|
||||||
tbl: pa.Table
|
tbl: pa.Table
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import functools
|
|||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import attr
|
import attrs
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -43,14 +43,14 @@ async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
|||||||
return reader.read_all()
|
return reader.read_all()
|
||||||
|
|
||||||
|
|
||||||
@attr.define(slots=False)
|
@attrs.define(slots=False)
|
||||||
class RestfulLanceDBClient:
|
class RestfulLanceDBClient:
|
||||||
db_name: str
|
db_name: str
|
||||||
region: str
|
region: str
|
||||||
api_key: Credential
|
api_key: Credential
|
||||||
host_override: Optional[str] = attr.field(default=None)
|
host_override: Optional[str] = attrs.field(default=None)
|
||||||
|
|
||||||
closed: bool = attr.field(default=False, init=False)
|
closed: bool = attrs.field(default=False, init=False)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> aiohttp.ClientSession:
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
|||||||
@@ -98,6 +98,8 @@ class RemoteTable(Table):
|
|||||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
if query.prefilter:
|
||||||
|
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||||
|
|
||||||
|
|||||||
@@ -844,9 +844,16 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
|
if query.prefilter:
|
||||||
|
for idx in ds.list_indices():
|
||||||
|
if query.vector_column in idx["fields"]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Prefiltering for indexed vector column is coming soon."
|
||||||
|
)
|
||||||
return ds.to_table(
|
return ds.to_table(
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
|
prefilter=query.prefilter,
|
||||||
nearest={
|
nearest={
|
||||||
"column": query.vector_column,
|
"column": query.vector_column,
|
||||||
"q": query.vector,
|
"q": query.vector,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.2.5"
|
version = "0.2.6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"pylance==0.7.4",
|
"pylance==0.8.0",
|
||||||
"ratelimiter",
|
"ratelimiter~=1.0",
|
||||||
"retry",
|
"retry>=0.9.2",
|
||||||
"tqdm",
|
"tqdm>=4.1.0",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"pydantic",
|
"pydantic>=1.10",
|
||||||
"attr",
|
"attrs>=21.3.0",
|
||||||
"semver>=3.0",
|
"semver>=3.0",
|
||||||
"cachetools"
|
"cachetools"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class MockTable:
|
|||||||
return ds.to_table(
|
return ds.to_table(
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
|
prefilter=query.prefilter,
|
||||||
nearest={
|
nearest={
|
||||||
"column": query.vector_column,
|
"column": query.vector_column,
|
||||||
"q": query.vector,
|
"q": query.vector,
|
||||||
@@ -97,6 +98,25 @@ def test_query_builder_with_filter(table):
|
|||||||
assert all(df["vector"].values[0] == [3, 4])
|
assert all(df["vector"].values[0] == [3, 4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_builder_with_prefilter(table):
|
||||||
|
df = (
|
||||||
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
|
.where("id = 2")
|
||||||
|
.limit(1)
|
||||||
|
.to_df()
|
||||||
|
)
|
||||||
|
assert len(df) == 0
|
||||||
|
|
||||||
|
df = (
|
||||||
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
|
.where("id = 2", prefilter=True)
|
||||||
|
.limit(1)
|
||||||
|
.to_df()
|
||||||
|
)
|
||||||
|
assert df["id"].values[0] == 2
|
||||||
|
assert all(df["vector"].values[0] == [3, 4])
|
||||||
|
|
||||||
|
|
||||||
def test_query_builder_with_metric(table):
|
def test_query_builder_with_metric(table):
|
||||||
query = [4, 8]
|
query = [4, 8]
|
||||||
vector_column_name = "vector"
|
vector_column_name = "vector"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import attr
|
import attrs
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -21,10 +21,10 @@ from aiohttp import web
|
|||||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||||
|
|
||||||
|
|
||||||
@attr.define
|
@attrs.define
|
||||||
class MockLanceDBServer:
|
class MockLanceDBServer:
|
||||||
runner: web.AppRunner = attr.field(init=False)
|
runner: web.AppRunner = attrs.field(init=False)
|
||||||
site: web.TCPSite = attr.field(init=False)
|
site: web.TCPSite = attrs.field(init=False)
|
||||||
|
|
||||||
async def query_handler(self, request: web.Request) -> web.Response:
|
async def query_handler(self, request: web.Request) -> web.Response:
|
||||||
table_name = request.match_info["table_name"]
|
table_name = request.match_info["table_name"]
|
||||||
|
|||||||
Reference in New Issue
Block a user