[python] Use pydantic for embedding function persistence (#467)

1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Chang She
2023-09-05 21:30:45 -07:00
committed by GitHub
parent 52fa7f5577
commit 9a9a73a65d
13 changed files with 815 additions and 192 deletions

View File

@@ -1,7 +1,10 @@
import os
import pyarrow as pa
import pytest
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
# import lancedb so we don't have to in every example
@@ -14,3 +17,22 @@ def doctest_setup(monkeypatch, tmpdir):
monkeypatch.setitem(os.environ, "COLUMNS", "80")
# Work in a temporary directory
monkeypatch.chdir(tmpdir)
registry = EmbeddingFunctionRegistry.get_instance()
@registry.register()
class MockEmbeddingFunction(EmbeddingFunctionModel):
def __call__(self, data):
if isinstance(data, str):
data = [data]
elif isinstance(data, pa.ChunkedArray):
data = data.combine_chunks().to_pylist()
elif isinstance(data, pa.Array):
data = data.to_pylist()
return [self.embed(row) for row in data]
def embed(self, row):
return [float(hash(c)) for c in row[:10]]

View File

@@ -16,12 +16,13 @@ from __future__ import annotations
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union
import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .embeddings import EmbeddingFunctionModel
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
@@ -289,6 +290,7 @@ class LanceDBConnection(DBConnection):
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
) -> LanceTable:
"""Create a table in the database.
@@ -307,6 +309,7 @@ class LanceDBConnection(DBConnection):
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
embedding_functions=embedding_functions,
)
return tbl

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functions import (
REGISTRY,
EmbeddingFunctionModel,
EmbeddingFunctionRegistry,
SentenceTransformerEmbeddingFunction,
)
from .utils import with_embeddings

View File

@@ -0,0 +1,224 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing
"""
@classmethod
def get_instance(cls):
return REGISTRY
def __init__(self):
self._functions = {}
def register(self):
"""
This creates a decorator that can be used to register
an EmbeddingFunctionModel.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunctionModel):
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
self._functions[cls.__name__] = cls
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def load(self, name: str):
"""
Fetch an embedding function class by name
"""
return self._functions[name]
def parse_functions(self, metadata: Optional[dict]) -> dict:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[dict]
The metadata from an arrow table. Note that
the keys and values are bytes.
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
functions = {}
for obj in raw_list:
model = self.load(obj["schema"]["title"])
functions[obj["model"]["vector_column"]] = model(**obj["model"])
return functions
def function_to_metadata(self, func):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
schema = func.model_json_schema()
json_data = func.model_dump()
return {
"schema": schema,
"model": json_data,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
REGISTRY = EmbeddingFunctionRegistry()
class EmbeddingFunctionModel(BaseModel, ABC):
"""
A callable ABC for embedding functions
"""
source_column: Optional[str]
vector_column: str
@abstractmethod
def __call__(self, *args, **kwargs) -> List[np.array]:
pass
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
"""
A callable ABC for embedding functions that take text as input
"""
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function. This is called
before generate_embeddings() and is useful for stripping
whitespace, lowercasing, etc.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass
@REGISTRY.register()
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = False
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.__class__.get_embedding_model(self.name, self.device)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@classmethod
@cached(cache={})
def get_embedding_model(cls, name, device):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
try:
from sentence_transformers import SentenceTransformer
return SentenceTransformer(name, device=device)
except ImportError:
raise ValueError("Please install sentence_transformers")

View File

@@ -1,4 +1,4 @@
# Copyright 2023 LanceDB Developers
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from .util import safe_import_pandas
from ..util import safe_import_pandas
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
@@ -58,7 +58,7 @@ def with_embeddings(
pa.Table
The input table with a new column called "vector" containing the embeddings.
"""
func = EmbeddingFunction(func)
func = FunctionWrapper(func)
if wrap_api:
func = func.retry().rate_limit()
func = func.batch_size(batch_size)
@@ -71,7 +71,11 @@ def with_embeddings(
return data.append_column("vector", table["vector"])
class EmbeddingFunction:
class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable):
self.func = func
self.rate_limiter_kwargs = {}

View File

@@ -13,6 +13,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
import numpy as np
@@ -54,7 +55,164 @@ class Query(pydantic.BaseModel):
refine_factor: Optional[int] = None
class LanceQueryBuilder:
class LanceQueryBuilder(ABC):
@classmethod
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
elif isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
# If query_type is fts, then query must be a string.
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(
f"Query type is 'fts' but query is not a string: {type(query)}"
)
return query, query_type
elif query_type == "vector":
# If query_type is vector, then query must be a list or np.ndarray.
# otherwise raise TypeError
if not isinstance(query, (list, np.ndarray)):
raise TypeError(
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
elif isinstance(query, str):
func = table.embedding_functions.get(vector_column_name, None)
if func is not None:
query = func(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise TypeError("Query must be a list, np.ndarray, or str")
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
def limit(self, limit: int) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceVectorQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
A builder for nearest neighbor queries for LanceDB.
@@ -80,68 +238,17 @@ class LanceQueryBuilder:
def __init__(
self,
table: "lancedb.table.Table",
query: Union[np.ndarray, str],
query: Union[np.ndarray, list],
vector_column: str = VECTOR_COLUMN_NAME,
):
super().__init__(table)
self._query = query
self._metric = "L2"
self._nprobes = 20
self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
Parameters
@@ -151,13 +258,13 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric
return self
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
"""Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if
@@ -173,13 +280,13 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as
@@ -195,22 +302,12 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._refine_factor = refine_factor
return self
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
@@ -233,25 +330,12 @@ class LanceQueryBuilder:
)
return self._table._execute_query(query)
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query
def to_arrow(self) -> pa.Table:
try:
import tantivy
@@ -275,3 +359,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
return output_tbl
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
)

View File

@@ -20,7 +20,7 @@ from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder
from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
@@ -73,7 +73,11 @@ class RemoteTable(Table):
fill_value: float = 0.0,
) -> int:
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
@@ -89,9 +93,9 @@ class RemoteTable(Table):
)
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
return LanceQueryBuilder(self, query, vector_column)
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
) -> LanceVectorQueryBuilder:
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query)

View File

@@ -17,7 +17,7 @@ import inspect
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Iterable, List, Optional, Union
from typing import Any, Iterable, List, Optional, Union
import lance
import numpy as np
@@ -28,46 +28,78 @@ from lance.dataset import ReaderLike
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
from .pydantic import LanceModel
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
pd = safe_import_pandas()
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
def _sanitize_data(
data,
schema: Optional[pa.Schema],
metadata: Optional[dict],
on_bad_vectors: str,
fill_value: Any,
):
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
data = [dict(d) for d in data]
data = pa.Table.from_pylist(data)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if isinstance(data, dict):
elif isinstance(data, dict):
data = vec_to_table(data)
if pd is not None and isinstance(data, pd.DataFrame):
elif pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
if isinstance(data, pa.Table):
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
# Do not serialize Pandas metadata
metadata = data.schema.metadata if data.schema.metadata is not None else {}
metadata = {k: v for k, v in metadata.items() if k != b"pandas"}
schema = data.schema.with_metadata(metadata)
data = pa.Table.from_arrays(data.columns, schema=schema)
if isinstance(data, Iterable):
data = _to_record_batch_generator(data, schema, on_bad_vectors, fill_value)
if not isinstance(data, (pa.Table, Iterable)):
elif isinstance(data, Iterable):
data = _to_record_batch_generator(
data, schema, metadata, on_bad_vectors, fill_value
)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
return data
def _to_record_batch_generator(data: Iterable, schema, on_bad_vectors, fill_value):
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
"""
Use the embedding function to automatically embed the source column and add the
vector column to the table.
"""
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_col, func in functions.items():
if vector_col not in data.column_names:
col_data = func(data[func.source_column])
if schema is not None:
dtype = schema.field(vector_col).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
)
return data
def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
):
for batch in data:
if not isinstance(batch, pa.RecordBatch):
table = _sanitize_data(batch, schema, on_bad_vectors, fill_value)
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches():
yield batch
yield batch
@@ -196,17 +228,27 @@ class Table(ABC):
@abstractmethod
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
self,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
vector_column: str, default "vector"
query: str, list, np.ndarray, default None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector"
The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
@@ -325,14 +367,14 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
1
2
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version
2
>>> table.checkout(1)
3
>>> table.checkout(2)
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
@@ -361,19 +403,19 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
1
2
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version
2
>>> table.restore(1)
3
>>> table.restore(2)
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> len(table.list_versions())
3
4
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
if version is None:
@@ -501,7 +543,11 @@ class LanceTable(Table):
"""
# TODO: manage table listing and metadata separately
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset()
@@ -569,18 +615,50 @@ class LanceTable(Table):
)
self._reset_dataset()
def _get_embedding_function_for_source_col(self, column_name: str):
for k, v in self.embedding_functions.items():
if v.source_column == column_name:
return v
return None
@cached_property
def embedding_functions(self) -> dict:
"""
Get the embedding functions for the table
Returns
-------
funcs: dict
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
return EmbeddingFunctionRegistry.get_instance().parse_functions(
self.schema.metadata
)
def search(
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
self,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
query: str, list, np.ndarray, or None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector"
The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If the query is a list/np.ndarray then the query type is "vector";
If the query is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
@@ -590,17 +668,9 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(self, query, vector_column_name)
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query, vector_column_name)
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name
)
@classmethod
def create(
@@ -612,6 +682,7 @@ class LanceTable(Table):
mode="create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionModel] = None,
):
"""
Create a new table.
@@ -649,20 +720,52 @@ class LanceTable(Table):
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
embedding_functions: list of EmbeddingFunctionModel, default None
The embedding functions to use when creating the table.
"""
tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
if data is not None:
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
else:
if schema is None:
if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode)
return LanceTable(db, name)
elif hasattr(data, "schema"):
schema = data.schema
elif isinstance(data, Iterable):
if metadata:
raise TypeError(
(
"Persistent embedding functions not yet "
"supported for generator data input"
)
)
if metadata:
schema = schema.with_metadata(metadata)
empty = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
table = LanceTable(db, name)
if data is not None:
table.add(data)
return table
@classmethod
def open(cls, db, name):
@@ -770,22 +873,38 @@ def _sanitize_schema(
return data
# cast the columns to the expected types
data = data.combine_chunks()
data = _sanitize_vector_column(
for field in schema:
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and field.type.list_size >= 10
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
if field.name in data.column_names and (
likely_vector_col or is_default_vector_col
):
data = _sanitize_vector_column(
data,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
if VECTOR_COLUMN_NAME in data.column_names:
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return data
def _sanitize_vector_column(
@@ -809,8 +928,6 @@ def _sanitize_vector_column(
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type):

View File

@@ -9,7 +9,8 @@ dependencies = [
"aiohttp",
"pydantic",
"attr",
"semver>=3.0"
"semver>=3.0",
"cachetools"
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]

View File

@@ -144,7 +144,7 @@ def test_ingest_iterator(tmp_path):
tbl_len = len(tbl)
tbl.add(make_batches())
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 2
assert len(tbl.list_versions()) == 3
db.drop_database()
run_tests(arrow_schema)

View File

@@ -12,10 +12,12 @@
# limitations under the License.
import sys
import lance
import numpy as np
import pyarrow as pa
from lancedb.embeddings import with_embeddings
from lancedb.conftest import MockEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
def mock_embed_func(input_data):
@@ -40,3 +42,37 @@ def test_with_embeddings():
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()
# let's create a table
table = pa.table(
{
"text": pa.array(["hello world", "goodbye world"]),
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
metadata = registry.get_table_metadata([func])
table = table.replace_schema_metadata(metadata)
# Write it to disk
lance.write_dataset(table, tmp_path / "test.lance")
# Load this back
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
functions = registry.parse_functions(ds.schema.metadata)
func = functions["vector"]
actual = func("hello world")
# We create an instance
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
# And we make sure we can call it
expected = expected_func("hello world")
assert np.allclose(actual, expected)

View File

@@ -21,7 +21,7 @@ import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.query import LanceQueryBuilder, Query
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -72,7 +72,7 @@ def test_cast(table):
str_field: str
float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel)
assert len(results) == 1
r0 = results[0]
@@ -84,13 +84,15 @@ def test_cast(table):
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
)
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
)
tm.assert_frame_equal(df_default, df_l2)
df_cosine = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])

View File

@@ -22,6 +22,7 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.table import LanceTable
@@ -178,16 +179,16 @@ def test_versioning(db):
],
)
assert len(table.list_versions()) == 1
assert table.version == 1
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 2
assert table.version == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 3
table.checkout(1)
assert table.version == 1
table.checkout(2)
assert table.version == 2
assert len(table) == 2
@@ -278,21 +279,21 @@ def test_restore(db):
data=[{"vector": [1.1, 0.9], "type": "vector"}],
)
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(1)
assert len(table.list_versions()) == 3
table.restore(2)
assert len(table.list_versions()) == 4
assert len(table) == 1
expected = table.to_arrow()
table.checkout(1)
table.checkout(2)
table.restore()
assert len(table.list_versions()) == 4
assert len(table.list_versions()) == 5
assert table.to_arrow() == expected
table.restore(4) # latest version should be no-op
assert len(table.list_versions()) == 4
table.restore(5) # latest version should be no-op
assert len(table.list_versions()) == 5
with pytest.raises(ValueError):
table.restore(5)
table.restore(6)
with pytest.raises(ValueError):
table.restore(0)
@@ -306,7 +307,7 @@ def test_merge(db, tmp_path):
)
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
table.merge(other_table, left_on="id")
assert len(table.list_versions()) == 2
assert len(table.list_versions()) == 3
expected = pa.table(
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
schema=table.schema,
@@ -325,10 +326,10 @@ def test_delete(db):
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
table.delete("id=0")
assert len(table.list_versions()) == 2
assert table.version == 2
table.delete("id=0")
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
@@ -340,11 +341,103 @@ def test_update(db):
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table.list_versions()) == 4
assert table.version == 4
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func(texts)})
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: vector(10)
vector2: vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_empty_query(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
val = df.id.iloc[0]
assert val == 1