[python] Use pydantic for embedding function persistence (#467)

1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Chang She
2023-09-05 21:30:45 -07:00
committed by GitHub
parent 52fa7f5577
commit 9a9a73a65d
13 changed files with 815 additions and 192 deletions

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functions import (
REGISTRY,
EmbeddingFunctionModel,
EmbeddingFunctionRegistry,
SentenceTransformerEmbeddingFunction,
)
from .utils import with_embeddings

View File

@@ -0,0 +1,224 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing
"""
@classmethod
def get_instance(cls):
return REGISTRY
def __init__(self):
self._functions = {}
def register(self):
"""
This creates a decorator that can be used to register
an EmbeddingFunctionModel.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunctionModel):
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
self._functions[cls.__name__] = cls
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def load(self, name: str):
"""
Fetch an embedding function class by name
"""
return self._functions[name]
def parse_functions(self, metadata: Optional[dict]) -> dict:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[dict]
The metadata from an arrow table. Note that
the keys and values are bytes.
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
functions = {}
for obj in raw_list:
model = self.load(obj["schema"]["title"])
functions[obj["model"]["vector_column"]] = model(**obj["model"])
return functions
def function_to_metadata(self, func):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
schema = func.model_json_schema()
json_data = func.model_dump()
return {
"schema": schema,
"model": json_data,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
REGISTRY = EmbeddingFunctionRegistry()
class EmbeddingFunctionModel(BaseModel, ABC):
"""
A callable ABC for embedding functions
"""
source_column: Optional[str]
vector_column: str
@abstractmethod
def __call__(self, *args, **kwargs) -> List[np.array]:
pass
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
"""
A callable ABC for embedding functions that take text as input
"""
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function. This is called
before generate_embeddings() and is useful for stripping
whitespace, lowercasing, etc.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass
@REGISTRY.register()
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = False
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.__class__.get_embedding_model(self.name, self.device)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@classmethod
@cached(cache={})
def get_embedding_model(cls, name, device):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
try:
from sentence_transformers import SentenceTransformer
return SentenceTransformer(name, device=device)
except ImportError:
raise ValueError("Please install sentence_transformers")

View File

@@ -0,0 +1,154 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
from typing import Callable, Union
import numpy as np
import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from ..util import safe_import_pandas
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
def with_embeddings(
func: Callable,
data: DATA,
column: str = "text",
wrap_api: bool = True,
show_progress: bool = False,
batch_size: int = 1000,
) -> pa.Table:
"""Add a vector column to a table using the given embedding function.
The new columns will be called "vector".
Parameters
----------
func : Callable
A function that takes a list of strings and returns a list of vectors.
data : pa.Table or pd.DataFrame
The data to add an embedding column to.
column : str, default "text"
The name of the column to use as input to the embedding function.
wrap_api : bool, default True
Whether to wrap the embedding function in a retry and rate limiter.
show_progress : bool, default False
Whether to show a progress bar.
batch_size : int, default 1000
The number of row values to pass to each call of the embedding function.
Returns
-------
pa.Table
The input table with a new column called "vector" containing the embeddings.
"""
func = FunctionWrapper(func)
if wrap_api:
func = func.retry().rate_limit()
func = func.batch_size(batch_size)
if show_progress:
func = func.show_progress()
if pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))
return data.append_column("vector", table["vector"])
class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable):
self.func = func
self.rate_limiter_kwargs = {}
self.retry_kwargs = {}
self._batch_size = None
self._progress = False
def __call__(self, text):
# Get the embedding with retry
if len(self.retry_kwargs) > 0:
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
else:
def embed_func(c):
return self.func(c.tolist())
if len(self.rate_limiter_kwargs) > 0:
v = int(sys.version_info.minor)
if v >= 11:
print(
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter"
)
else:
import ratelimiter
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]
)
embed_func = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in embed_func(c)]
return embeds
def __repr__(self):
return f"EmbeddingFunction(func={self.func})"
def rate_limit(self, max_calls=0.9, period=1.0):
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
return self
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
self.retry_kwargs = dict(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
)
return self
def batch_size(self, batch_size):
self._batch_size = batch_size
return self
def show_progress(self):
self._progress = True
return self
def to_batches(self, arr):
length = len(arr)
def _chunker(arr):
for start_i in range(0, len(arr), self._batch_size):
yield arr[start_i : start_i + self._batch_size]
if self._progress:
from tqdm.auto import tqdm
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
yield from _chunker(arr)