mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
[python] Use pydantic for embedding function persistence (#467)
1. Support persistent embedding function so users can just search using query string 2. Add fixed size list conversion for multiple vector columns 3. Add support for empty query (just apply select/where/limit). 4. Refactor and simplify some of the data prep code --------- Co-authored-by: Chang She <chang@lancedb.com> Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
|
||||
|
||||
@@ -14,3 +17,22 @@ def doctest_setup(monkeypatch, tmpdir):
|
||||
monkeypatch.setitem(os.environ, "COLUMNS", "80")
|
||||
# Work in a temporary directory
|
||||
monkeypatch.chdir(tmpdir)
|
||||
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@registry.register()
|
||||
class MockEmbeddingFunction(EmbeddingFunctionModel):
|
||||
def __call__(self, data):
|
||||
if isinstance(data, str):
|
||||
data = [data]
|
||||
elif isinstance(data, pa.ChunkedArray):
|
||||
data = data.combine_chunks().to_pylist()
|
||||
elif isinstance(data, pa.Array):
|
||||
data = data.to_pylist()
|
||||
|
||||
return [self.embed(row) for row in data]
|
||||
|
||||
def embed(self, row):
|
||||
return [float(hash(c)) for c in row[:10]]
|
||||
|
||||
@@ -16,12 +16,13 @@ from __future__ import annotations
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionModel
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
@@ -289,6 +290,7 @@ class LanceDBConnection(DBConnection):
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
@@ -307,6 +309,7 @@ class LanceDBConnection(DBConnection):
|
||||
mode=mode,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
)
|
||||
return tbl
|
||||
|
||||
|
||||
21
python/lancedb/embeddings/__init__.py
Normal file
21
python/lancedb/embeddings/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .functions import (
|
||||
REGISTRY,
|
||||
EmbeddingFunctionModel,
|
||||
EmbeddingFunctionRegistry,
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from .utils import with_embeddings
|
||||
224
python/lancedb/embeddings/functions.py
Normal file
224
python/lancedb/embeddings/functions.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return REGISTRY
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunctionModel.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunctionModel):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
self._functions[cls.__name__] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(self, metadata: Optional[dict]) -> dict:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[dict]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
functions = {}
|
||||
for obj in raw_list:
|
||||
model = self.load(obj["schema"]["title"])
|
||||
functions[obj["model"]["vector_column"]] = model(**obj["model"])
|
||||
return functions
|
||||
|
||||
def function_to_metadata(self, func):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
schema = func.model_json_schema()
|
||||
json_data = func.model_dump()
|
||||
return {
|
||||
"schema": schema,
|
||||
"model": json_data,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
REGISTRY = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
class EmbeddingFunctionModel(BaseModel, ABC):
|
||||
"""
|
||||
A callable ABC for embedding functions
|
||||
"""
|
||||
|
||||
source_column: Optional[str]
|
||||
vector_column: str
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> List[np.array]:
|
||||
pass
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
|
||||
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function. This is called
|
||||
before generate_embeddings() and is useful for stripping
|
||||
whitespace, lowercasing, etc.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@REGISTRY.register()
|
||||
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
return SentenceTransformer(name, device=device)
|
||||
except ImportError:
|
||||
raise ValueError("Please install sentence_transformers")
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from .util import safe_import_pandas
|
||||
from ..util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
@@ -58,7 +58,7 @@ def with_embeddings(
|
||||
pa.Table
|
||||
The input table with a new column called "vector" containing the embeddings.
|
||||
"""
|
||||
func = EmbeddingFunction(func)
|
||||
func = FunctionWrapper(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit()
|
||||
func = func.batch_size(batch_size)
|
||||
@@ -71,7 +71,11 @@ def with_embeddings(
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class EmbeddingFunction:
|
||||
class FunctionWrapper:
|
||||
"""
|
||||
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
self.func = func
|
||||
self.rate_limiter_kwargs = {}
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -54,7 +55,164 @@ class Query(pydantic.BaseModel):
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder:
|
||||
class LanceQueryBuilder(ABC):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
query: Optional[Union[np.ndarray, str]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query, dtype=np.float32)
|
||||
elif isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
# If query_type is fts, then query must be a string.
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(
|
||||
f"Query type is 'fts' but query is not a string: {type(query)}"
|
||||
)
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
# If query_type is vector, then query must be a list or np.ndarray.
|
||||
# otherwise raise TypeError
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
|
||||
)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
elif isinstance(query, str):
|
||||
func = table.embedding_functions.get(vector_column_name, None)
|
||||
if func is not None:
|
||||
query = func(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise TypeError("Query must be a list, np.ndarray, or str")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def limit(self, limit: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceVectorQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
@@ -80,68 +238,17 @@ class LanceQueryBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, str],
|
||||
query: Union[np.ndarray, list],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._table = table
|
||||
self._query = query
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column = vector_column
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
@@ -151,13 +258,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
@@ -173,13 +280,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
As an example, a refine factor of 2 will sample 2x as many vectors as
|
||||
@@ -195,22 +302,12 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -233,25 +330,12 @@ class LanceQueryBuilder:
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
import tantivy
|
||||
@@ -275,3 +359,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
return output_tbl
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
|
||||
from ..query import LanceQueryBuilder
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
@@ -73,7 +73,11 @@ class RemoteTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
@@ -89,9 +93,9 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
return LanceQueryBuilder(self, query, vector_column)
|
||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceVectorQueryBuilder:
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -28,46 +28,78 @@ from lance.dataset import ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||
def _sanitize_data(
|
||||
data,
|
||||
schema: Optional[pa.Schema],
|
||||
metadata: Optional[dict],
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [dict(d) for d in data]
|
||||
data = pa.Table.from_pylist(data)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
elif isinstance(data, dict):
|
||||
data = vec_to_table(data)
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
elif pd is not None and isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
if metadata:
|
||||
data = _append_vector_col(data, metadata, schema)
|
||||
metadata.update(data.schema.metadata or {})
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
# Do not serialize Pandas metadata
|
||||
metadata = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
metadata = {k: v for k, v in metadata.items() if k != b"pandas"}
|
||||
schema = data.schema.with_metadata(metadata)
|
||||
data = pa.Table.from_arrays(data.columns, schema=schema)
|
||||
if isinstance(data, Iterable):
|
||||
data = _to_record_batch_generator(data, schema, on_bad_vectors, fill_value)
|
||||
if not isinstance(data, (pa.Table, Iterable)):
|
||||
elif isinstance(data, Iterable):
|
||||
data = _to_record_batch_generator(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||
return data
|
||||
|
||||
|
||||
def _to_record_batch_generator(data: Iterable, schema, on_bad_vectors, fill_value):
|
||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
||||
"""
|
||||
Use the embedding function to automatically embed the source column and add the
|
||||
vector column to the table.
|
||||
"""
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_col, func in functions.items():
|
||||
if vector_col not in data.column_names:
|
||||
col_data = func(data[func.source_column])
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_col).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _to_record_batch_generator(
|
||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||
):
|
||||
for batch in data:
|
||||
if not isinstance(batch, pa.RecordBatch):
|
||||
table = _sanitize_data(batch, schema, on_bad_vectors, fill_value)
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
yield batch
|
||||
yield batch
|
||||
@@ -196,17 +228,27 @@ class Table(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
vector_column: str, default "vector"
|
||||
query: str, list, np.ndarray, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -325,14 +367,14 @@ class LanceTable(Table):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
1
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.checkout(1)
|
||||
3
|
||||
>>> table.checkout(2)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
@@ -361,19 +403,19 @@ class LanceTable(Table):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
1
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.restore(1)
|
||||
3
|
||||
>>> table.restore(2)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> len(table.list_versions())
|
||||
3
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
if version is None:
|
||||
@@ -501,7 +543,11 @@ class LanceTable(Table):
|
||||
"""
|
||||
# TODO: manage table listing and metadata separately
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
@@ -569,18 +615,50 @@ class LanceTable(Table):
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
def _get_embedding_function_for_source_col(self, column_name: str):
|
||||
for k, v in self.embedding_functions.items():
|
||||
if v.source_column == column_name:
|
||||
return v
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
Returns
|
||||
-------
|
||||
funcs: dict
|
||||
A mapping of the vector column to the embedding function
|
||||
or empty dict if not configured.
|
||||
"""
|
||||
return EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||
self.schema.metadata
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
query: str, list, np.ndarray, or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If the query is a list/np.ndarray then the query type is "vector";
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -590,17 +668,9 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query)
|
||||
if isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
return LanceQueryBuilder(self, query, vector_column_name)
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -612,6 +682,7 @@ class LanceTable(Table):
|
||||
mode="create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionModel] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -649,20 +720,52 @@ class LanceTable(Table):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
embedding_functions: list of EmbeddingFunctionModel, default None
|
||||
The embedding functions to use when creating the table.
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
else:
|
||||
if schema is None:
|
||||
|
||||
if schema is None:
|
||||
if data is None:
|
||||
raise ValueError("Either data or schema must be provided")
|
||||
data = pa.Table.from_pylist([], schema=schema)
|
||||
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode)
|
||||
return LanceTable(db, name)
|
||||
elif hasattr(data, "schema"):
|
||||
schema = data.schema
|
||||
elif isinstance(data, Iterable):
|
||||
if metadata:
|
||||
raise TypeError(
|
||||
(
|
||||
"Persistent embedding functions not yet "
|
||||
"supported for generator data input"
|
||||
)
|
||||
)
|
||||
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
|
||||
empty = pa.Table.from_pylist([], schema=schema)
|
||||
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
|
||||
table = LanceTable(db, name)
|
||||
|
||||
if data is not None:
|
||||
table.add(data)
|
||||
|
||||
return table
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
@@ -770,22 +873,38 @@ def _sanitize_schema(
|
||||
return data
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
data = _sanitize_vector_column(
|
||||
for field in schema:
|
||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
if field.name in data.column_names and (
|
||||
likely_vector_col or is_default_vector_col
|
||||
):
|
||||
data = _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
|
||||
# just check the vector column
|
||||
if VECTOR_COLUMN_NAME in data.column_names:
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
# just check the vector column
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_vector_column(
|
||||
@@ -809,8 +928,6 @@ def _sanitize_vector_column(
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if vector_column_name not in data.column_names:
|
||||
raise ValueError(f"Missing vector column: {vector_column_name}")
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
|
||||
@@ -9,7 +9,8 @@ dependencies = [
|
||||
"aiohttp",
|
||||
"pydantic",
|
||||
"attr",
|
||||
"semver>=3.0"
|
||||
"semver>=3.0",
|
||||
"cachetools"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
|
||||
@@ -144,7 +144,7 @@ def test_ingest_iterator(tmp_path):
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.embeddings import with_embeddings
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -40,3 +42,37 @@ def test_with_embeddings():
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
|
||||
|
||||
def test_embedding_function(tmp_path):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# let's create a table
|
||||
table = pa.table(
|
||||
{
|
||||
"text": pa.array(["hello world", "goodbye world"]),
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
metadata = registry.get_table_metadata([func])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
lance.write_dataset(table, tmp_path / "test.lance")
|
||||
|
||||
# Load this back
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
functions = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
func = functions["vector"]
|
||||
actual = func("hello world")
|
||||
|
||||
# We create an instance
|
||||
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
# And we make sure we can call it
|
||||
expected = expected_func("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.query import LanceQueryBuilder, Query
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def test_cast(table):
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
results = q.to_pydantic(TestModel)
|
||||
assert len(results) == 1
|
||||
r0 = results[0]
|
||||
@@ -84,13 +84,15 @@ def test_cast(table):
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
)
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
|
||||
@@ -22,6 +22,7 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.table import LanceTable
|
||||
@@ -178,16 +179,16 @@ def test_versioning(db):
|
||||
],
|
||||
)
|
||||
|
||||
assert len(table.list_versions()) == 1
|
||||
assert table.version == 1
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 3
|
||||
|
||||
table.checkout(1)
|
||||
assert table.version == 1
|
||||
table.checkout(2)
|
||||
assert table.version == 2
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
@@ -278,21 +279,21 @@ def test_restore(db):
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(1)
|
||||
assert len(table.list_versions()) == 3
|
||||
table.restore(2)
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table) == 1
|
||||
|
||||
expected = table.to_arrow()
|
||||
table.checkout(1)
|
||||
table.checkout(2)
|
||||
table.restore()
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table.list_versions()) == 5
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
table.restore(4) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 4
|
||||
table.restore(5) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 5
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(5)
|
||||
table.restore(6)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(0)
|
||||
@@ -306,7 +307,7 @@ def test_merge(db, tmp_path):
|
||||
)
|
||||
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
|
||||
table.merge(other_table, left_on="id")
|
||||
assert len(table.list_versions()) == 2
|
||||
assert len(table.list_versions()) == 3
|
||||
expected = pa.table(
|
||||
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
|
||||
schema=table.schema,
|
||||
@@ -325,10 +326,10 @@ def test_delete(db):
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
|
||||
@@ -340,11 +341,103 @@ def test_update(db):
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table.list_versions()) == 4
|
||||
assert table.version == 4
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: vector(10)
|
||||
vector2: vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_empty_query(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
|
||||
val = df.id.iloc[0]
|
||||
assert val == 1
|
||||
|
||||
Reference in New Issue
Block a user