From 0293bbe1428fd7b10c44fca25a623232b61f4dff Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 18 Oct 2023 11:02:19 +0530 Subject: [PATCH] [Python]Embeddings API refactor (#580) Sets things up for this -> https://github.com/lancedb/lancedb/issues/579 - Just separates out the registry/ingestion code from the function implementation code - adds a `get_registry` util - package name "open-clip" -> "open-clip-torch" --- python/lancedb/embeddings/__init__.py | 16 +- python/lancedb/embeddings/base.py | 138 +++++ python/lancedb/embeddings/cohere.py | 3 +- python/lancedb/embeddings/functions.py | 578 ------------------ python/lancedb/embeddings/open_clip.py | 163 +++++ python/lancedb/embeddings/openai.py | 37 ++ python/lancedb/embeddings/registry.py | 186 ++++++ .../embeddings/sentence_transformers.py | 77 +++ python/lancedb/embeddings/utils.py | 23 +- python/lancedb/table.py | 3 +- python/pyproject.toml | 2 +- python/tests/test_embeddings_slow.py | 12 +- 12 files changed, 636 insertions(+), 602 deletions(-) create mode 100644 python/lancedb/embeddings/base.py delete mode 100644 python/lancedb/embeddings/functions.py create mode 100644 python/lancedb/embeddings/open_clip.py create mode 100644 python/lancedb/embeddings/openai.py create mode 100644 python/lancedb/embeddings/registry.py create mode 100644 python/lancedb/embeddings/sentence_transformers.py diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index 55f2fdb5..2977f0b4 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -11,16 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .cohere import CohereEmbeddingFunction -from .functions import ( - EmbeddingFunction, - EmbeddingFunctionConfig, - EmbeddingFunctionRegistry, - OpenAIEmbeddings, - OpenClipEmbeddings, - SentenceTransformerEmbeddings, - TextEmbeddingFunction, - register, -) +from .open_clip import OpenClipEmbeddings +from .openai import OpenAIEmbeddings +from .registry import EmbeddingFunctionRegistry, get_registry +from .sentence_transformers import SentenceTransformerEmbeddings from .utils import with_embeddings diff --git a/python/lancedb/embeddings/base.py b/python/lancedb/embeddings/base.py new file mode 100644 index 00000000..a1d1aa05 --- /dev/null +++ b/python/lancedb/embeddings/base.py @@ -0,0 +1,138 @@ +import importlib +from abc import ABC, abstractmethod +from typing import List, Union + +import numpy as np +import pyarrow as pa +from pydantic import BaseModel, Field, PrivateAttr + +from .utils import TEXT + + +class EmbeddingFunction(BaseModel, ABC): + """ + An ABC for embedding functions. + + All concrete embedding functions must implement the following: + 1. compute_query_embeddings() which takes a query and returns a list of embeddings + 2. get_source_embeddings() which returns a list of embeddings for the source column + For text data, the two will be the same. For multi-modal data, the source column + might be images and the vector column might be text. + 3. ndims method which returns the number of dimensions of the vector column + """ + + _ndims: int = PrivateAttr() + + @classmethod + def create(cls, **kwargs): + """ + Create an instance of the embedding function + """ + return cls(**kwargs) + + @abstractmethod + def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for a given user query + """ + pass + + @abstractmethod + def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for the source column in the database + """ + pass + + def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(texts, str): + texts = [texts] + elif isinstance(texts, pa.Array): + texts = texts.to_pylist() + elif isinstance(texts, pa.ChunkedArray): + texts = texts.combine_chunks().to_pylist() + return texts + + @classmethod + def safe_import(cls, module: str, mitigation=None): + """ + Import the specified module. If the module is not installed, + raise an ImportError with a helpful message. + + Parameters + ---------- + module : str + The name of the module to import + mitigation : Optional[str] + The package(s) to install to mitigate the error. + If not provided then the module name will be used. + """ + try: + return importlib.import_module(module) + except ImportError: + raise ImportError(f"Please install {mitigation or module}") + + def safe_model_dump(self): + from ..pydantic import PYDANTIC_VERSION + + if PYDANTIC_VERSION.major < 2: + return dict(self) + return self.model_dump() + + @abstractmethod + def ndims(self): + """ + Return the dimensions of the vector column + """ + pass + + def SourceField(self, **kwargs): + """ + Creates a pydantic Field that can automatically annotate + the source column for this embedding function + """ + return Field(json_schema_extra={"source_column_for": self}, **kwargs) + + def VectorField(self, **kwargs): + """ + Creates a pydantic Field that can automatically annotate + the target vector column for this embedding function + """ + return Field(json_schema_extra={"vector_column_for": self}, **kwargs) + + +class EmbeddingFunctionConfig(BaseModel): + """ + This model encapsulates the configuration for a embedding function + in a lancedb table. It holds the embedding function, the source column, + and the vector column + """ + + vector_column: str + source_column: str + function: EmbeddingFunction + + +class TextEmbeddingFunction(EmbeddingFunction): + """ + A callable ABC for embedding functions that take text as input + """ + + def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + return self.compute_source_embeddings(query, *args, **kwargs) + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + return self.generate_embeddings(texts) + + @abstractmethod + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Generate the embeddings for the given texts + """ + pass diff --git a/python/lancedb/embeddings/cohere.py b/python/lancedb/embeddings/cohere.py index d9733dbc..07881f69 100644 --- a/python/lancedb/embeddings/cohere.py +++ b/python/lancedb/embeddings/cohere.py @@ -16,7 +16,8 @@ from typing import ClassVar, List, Union import numpy as np -from .functions import TextEmbeddingFunction, register +from .base import TextEmbeddingFunction +from .registry import register from .utils import api_key_not_found_help diff --git a/python/lancedb/embeddings/functions.py b/python/lancedb/embeddings/functions.py deleted file mode 100644 index e2a70898..00000000 --- a/python/lancedb/embeddings/functions.py +++ /dev/null @@ -1,578 +0,0 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import concurrent.futures -import importlib -import io -import json -import os -import socket -import urllib.error -import urllib.parse as urlparse -import urllib.request -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Union - -import numpy as np -import pyarrow as pa -from cachetools import cached -from pydantic import BaseModel, Field, PrivateAttr -from tqdm import tqdm - - -class EmbeddingFunctionRegistry: - """ - This is a singleton class used to register embedding functions - and fetch them by name. It also handles serializing and deserializing. - You can implement your own embedding function by subclassing EmbeddingFunction - or TextEmbeddingFunction and registering it with the registry. - - Examples - -------- - >>> registry = EmbeddingFunctionRegistry.get_instance() - >>> @registry.register("my-embedding-function") - ... class MyEmbeddingFunction(EmbeddingFunction): - ... def ndims(self) -> int: - ... return 128 - ... - ... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: - ... return self.compute_source_embeddings(query, *args, **kwargs) - ... - ... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: - ... return [np.random.rand(self.ndims()) for _ in range(len(texts))] - ... - >>> registry.get("my-embedding-function") - - """ - - @classmethod - def get_instance(cls): - return __REGISTRY__ - - def __init__(self): - self._functions = {} - - def register(self, alias: str = None): - """ - This creates a decorator that can be used to register - an EmbeddingFunction. - - Parameters - ---------- - alias : Optional[str] - a human friendly name for the embedding function. If not - provided, the class name will be used. - """ - - # This is a decorator for a class that inherits from BaseModel - # It adds the class to the registry - def decorator(cls): - if not issubclass(cls, EmbeddingFunction): - raise TypeError("Must be a subclass of EmbeddingFunction") - if cls.__name__ in self._functions: - raise KeyError(f"{cls.__name__} was already registered") - key = alias or cls.__name__ - self._functions[key] = cls - cls.__embedding_function_registry_alias__ = alias - return cls - - return decorator - - def reset(self): - """ - Reset the registry to its initial state - """ - self._functions = {} - - def get(self, name: str): - """ - Fetch an embedding function class by name - - Parameters - ---------- - name : str - The name of the embedding function to fetch - Either the alias or the class name if no alias was provided - during registration - """ - return self._functions[name] - - def parse_functions( - self, metadata: Optional[Dict[bytes, bytes]] - ) -> Dict[str, "EmbeddingFunctionConfig"]: - """ - Parse the metadata from an arrow table and - return a mapping of the vector column to the - embedding function and source column - - Parameters - ---------- - metadata : Optional[Dict[bytes, bytes]] - The metadata from an arrow table. Note that - the keys and values are bytes (pyarrow api) - - Returns - ------- - functions : dict - A mapping of vector column name to embedding function. - An empty dict is returned if input is None or does not - contain b"embedding_functions". - """ - if metadata is None or b"embedding_functions" not in metadata: - return {} - serialized = metadata[b"embedding_functions"] - raw_list = json.loads(serialized.decode("utf-8")) - return { - obj["vector_column"]: EmbeddingFunctionConfig( - vector_column=obj["vector_column"], - source_column=obj["source_column"], - function=self.get(obj["name"])(**obj["model"]), - ) - for obj in raw_list - } - - def function_to_metadata(self, conf: "EmbeddingFunctionConfig"): - """ - Convert the given embedding function and source / vector column configs - into a config dictionary that can be serialized into arrow metadata - """ - func = conf.function - name = getattr( - func, "__embedding_function_registry_alias__", func.__class__.__name__ - ) - json_data = func.safe_model_dump() - return { - "name": name, - "model": json_data, - "source_column": conf.source_column, - "vector_column": conf.vector_column, - } - - def get_table_metadata(self, func_list): - """ - Convert a list of embedding functions and source / vector configs - into a config dictionary that can be serialized into arrow metadata - """ - if func_list is None or len(func_list) == 0: - return None - json_data = [self.function_to_metadata(func) for func in func_list] - # Note that metadata dictionary values must be bytes - # so we need to json dump then utf8 encode - metadata = json.dumps(json_data, indent=2).encode("utf-8") - return {"embedding_functions": metadata} - - -# Global instance -__REGISTRY__ = EmbeddingFunctionRegistry() - - -TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] -IMAGES = Union[ - str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray -] - - -class EmbeddingFunction(BaseModel, ABC): - """ - An ABC for embedding functions. - - All concrete embedding functions must implement the following: - 1. compute_query_embeddings() which takes a query and returns a list of embeddings - 2. get_source_embeddings() which returns a list of embeddings for the source column - For text data, the two will be the same. For multi-modal data, the source column - might be images and the vector column might be text. - 3. ndims method which returns the number of dimensions of the vector column - """ - - _ndims: int = PrivateAttr() - - @classmethod - def create(cls, **kwargs): - """ - Create an instance of the embedding function - """ - return cls(**kwargs) - - @abstractmethod - def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]: - """ - Compute the embeddings for a given user query - """ - pass - - @abstractmethod - def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]: - """ - Compute the embeddings for the source column in the database - """ - pass - - def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: - """ - Sanitize the input to the embedding function. - """ - if isinstance(texts, str): - texts = [texts] - elif isinstance(texts, pa.Array): - texts = texts.to_pylist() - elif isinstance(texts, pa.ChunkedArray): - texts = texts.combine_chunks().to_pylist() - return texts - - @classmethod - def safe_import(cls, module: str, mitigation=None): - """ - Import the specified module. If the module is not installed, - raise an ImportError with a helpful message. - - Parameters - ---------- - module : str - The name of the module to import - mitigation : Optional[str] - The package(s) to install to mitigate the error. - If not provided then the module name will be used. - """ - try: - return importlib.import_module(module) - except ImportError: - raise ImportError(f"Please install {mitigation or module}") - - def safe_model_dump(self): - from ..pydantic import PYDANTIC_VERSION - - if PYDANTIC_VERSION.major < 2: - return dict(self) - return self.model_dump() - - @abstractmethod - def ndims(self): - """ - Return the dimensions of the vector column - """ - pass - - def SourceField(self, **kwargs): - """ - Creates a pydantic Field that can automatically annotate - the source column for this embedding function - """ - return Field(json_schema_extra={"source_column_for": self}, **kwargs) - - def VectorField(self, **kwargs): - """ - Creates a pydantic Field that can automatically annotate - the target vector column for this embedding function - """ - return Field(json_schema_extra={"vector_column_for": self}, **kwargs) - - -class EmbeddingFunctionConfig(BaseModel): - """ - This model encapsulates the configuration for a embedding function - in a lancedb table. It holds the embedding function, the source column, - and the vector column - """ - - vector_column: str - source_column: str - function: EmbeddingFunction - - -class TextEmbeddingFunction(EmbeddingFunction): - """ - A callable ABC for embedding functions that take text as input - """ - - def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: - return self.compute_source_embeddings(query, *args, **kwargs) - - def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: - texts = self.sanitize_input(texts) - return self.generate_embeddings(texts) - - @abstractmethod - def generate_embeddings( - self, texts: Union[List[str], np.ndarray] - ) -> List[np.array]: - """ - Generate the embeddings for the given texts - """ - pass - - -# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8 -register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) - - -@register("sentence-transformers") -class SentenceTransformerEmbeddings(TextEmbeddingFunction): - """ - An embedding function that uses the sentence-transformers library - - https://huggingface.co/sentence-transformers - """ - - name: str = "all-MiniLM-L6-v2" - device: str = "cpu" - normalize: bool = True - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._ndims = None - - @property - def embedding_model(self): - """ - Get the sentence-transformers embedding model specified by the - name and device. This is cached so that the model is only loaded - once per process. - """ - return self.__class__.get_embedding_model(self.name, self.device) - - def ndims(self): - if self._ndims is None: - self._ndims = len(self.generate_embeddings("foo")[0]) - return self._ndims - - def generate_embeddings( - self, texts: Union[List[str], np.ndarray] - ) -> List[np.array]: - """ - Get the embeddings for the given texts - - Parameters - ---------- - texts: list[str] or np.ndarray (of str) - The texts to embed - """ - return self.embedding_model.encode( - list(texts), - convert_to_numpy=True, - normalize_embeddings=self.normalize, - ).tolist() - - @classmethod - @cached(cache={}) - def get_embedding_model(cls, name, device): - """ - Get the sentence-transformers embedding model specified by the - name and device. This is cached so that the model is only loaded - once per process. - - Parameters - ---------- - name : str - The name of the model to load - device : str - The device to load the model on - - TODO: use lru_cache instead with a reasonable/configurable maxsize - """ - sentence_transformers = cls.safe_import( - "sentence_transformers", "sentence-transformers" - ) - return sentence_transformers.SentenceTransformer(name, device=device) - - -@register("openai") -class OpenAIEmbeddings(TextEmbeddingFunction): - """ - An embedding function that uses the OpenAI API - - https://platform.openai.com/docs/guides/embeddings - """ - - name: str = "text-embedding-ada-002" - - def ndims(self): - # TODO don't hardcode this - return 1536 - - def generate_embeddings( - self, texts: Union[List[str], np.ndarray] - ) -> List[np.array]: - """ - Get the embeddings for the given texts - - Parameters - ---------- - texts: list[str] or np.ndarray (of str) - The texts to embed - """ - # TODO retry, rate limit, token limit - openai = self.safe_import("openai") - rs = openai.Embedding.create(input=texts, model=self.name)["data"] - return [v["embedding"] for v in rs] - - -@register("open-clip") -class OpenClipEmbeddings(EmbeddingFunction): - """ - An embedding function that uses the OpenClip API - For multi-modal text-to-image search - - https://github.com/mlfoundations/open_clip - """ - - name: str = "ViT-B-32" - pretrained: str = "laion2b_s34b_b79k" - device: str = "cpu" - batch_size: int = 64 - normalize: bool = True - _model = PrivateAttr() - _preprocess = PrivateAttr() - _tokenizer = PrivateAttr() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - open_clip = self.safe_import("open_clip", "open-clip") - model, _, preprocess = open_clip.create_model_and_transforms( - self.name, pretrained=self.pretrained - ) - model.to(self.device) - self._model, self._preprocess = model, preprocess - self._tokenizer = open_clip.get_tokenizer(self.name) - self._ndims = None - - def ndims(self): - if self._ndims is None: - self._ndims = self.generate_text_embeddings("foo").shape[0] - return self._ndims - - def compute_query_embeddings( - self, query: Union[str, "PIL.Image.Image"], *args, **kwargs - ) -> List[np.ndarray]: - """ - Compute the embeddings for a given user query - - Parameters - ---------- - query : Union[str, PIL.Image.Image] - The query to embed. A query can be either text or an image. - """ - if isinstance(query, str): - return [self.generate_text_embeddings(query)] - else: - PIL = self.safe_import("PIL", "pillow") - if isinstance(query, PIL.Image.Image): - return [self.generate_image_embedding(query)] - else: - raise TypeError("OpenClip supports str or PIL Image as query") - - def generate_text_embeddings(self, text: str) -> np.ndarray: - torch = self.safe_import("torch") - text = self.sanitize_input(text) - text = self._tokenizer(text) - text.to(self.device) - with torch.no_grad(): - text_features = self._model.encode_text(text.to(self.device)) - if self.normalize: - text_features /= text_features.norm(dim=-1, keepdim=True) - return text_features.cpu().numpy().squeeze() - - def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: - """ - Sanitize the input to the embedding function. - """ - if isinstance(images, (str, bytes)): - images = [images] - elif isinstance(images, pa.Array): - images = images.to_pylist() - elif isinstance(images, pa.ChunkedArray): - images = images.combine_chunks().to_pylist() - return images - - def compute_source_embeddings( - self, images: IMAGES, *args, **kwargs - ) -> List[np.array]: - """ - Get the embeddings for the given images - """ - images = self.sanitize_input(images) - embeddings = [] - for i in range(0, len(images), self.batch_size): - j = min(i + self.batch_size, len(images)) - batch = images[i:j] - embeddings.extend(self._parallel_get(batch)) - return embeddings - - def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: - """ - Issue concurrent requests to retrieve the image data - """ - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self.generate_image_embedding, image) - for image in images - ] - return [future.result() for future in tqdm(futures)] - - def generate_image_embedding( - self, image: Union[str, bytes, "PIL.Image.Image"] - ) -> np.ndarray: - """ - Generate the embedding for a single image - - Parameters - ---------- - image : Union[str, bytes, PIL.Image.Image] - The image to embed. If the image is a str, it is treated as a uri. - If the image is bytes, it is treated as the raw image bytes. - """ - torch = self.safe_import("torch") - # TODO handle retry and errors for https - image = self._to_pil(image) - image = self._preprocess(image).unsqueeze(0) - with torch.no_grad(): - return self._encode_and_normalize_image(image) - - def _to_pil(self, image: Union[str, bytes]): - PIL = self.safe_import("PIL", "pillow") - if isinstance(image, bytes): - return PIL.Image.open(io.BytesIO(image)) - if isinstance(image, PIL.Image.Image): - return image - elif isinstance(image, str): - parsed = urlparse.urlparse(image) - # TODO handle drive letter on windows. - if parsed.scheme == "file": - return PIL.Image.open(parsed.path) - elif parsed.scheme == "": - return PIL.Image.open(image if os.name == "nt" else parsed.path) - elif parsed.scheme.startswith("http"): - return PIL.Image.open(io.BytesIO(url_retrieve(image))) - else: - raise NotImplementedError("Only local and http(s) urls are supported") - - def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"): - """ - encode a single image tensor and optionally normalize the output - """ - image_features = self._model.encode_image(image_tensor.to(self.device)) - if self.normalize: - image_features /= image_features.norm(dim=-1, keepdim=True) - return image_features.cpu().numpy().squeeze() - - -def url_retrieve(url: str): - """ - Parameters - ---------- - url: str - URL to download from - """ - try: - with urllib.request.urlopen(url) as conn: - return conn.read() - except (socket.gaierror, urllib.error.URLError) as err: - raise ConnectionError("could not download {} due to {}".format(url, err)) diff --git a/python/lancedb/embeddings/open_clip.py b/python/lancedb/embeddings/open_clip.py new file mode 100644 index 00000000..37023377 --- /dev/null +++ b/python/lancedb/embeddings/open_clip.py @@ -0,0 +1,163 @@ +import concurrent.futures +import io +import os +import urllib.parse as urlparse +from typing import List, Union + +import numpy as np +import pyarrow as pa +from pydantic import PrivateAttr +from tqdm import tqdm + +from .base import EmbeddingFunction +from .registry import register +from .utils import IMAGES, url_retrieve + + +@register("open-clip") +class OpenClipEmbeddings(EmbeddingFunction): + """ + An embedding function that uses the OpenClip API + For multi-modal text-to-image search + + https://github.com/mlfoundations/open_clip + """ + + name: str = "ViT-B-32" + pretrained: str = "laion2b_s34b_b79k" + device: str = "cpu" + batch_size: int = 64 + normalize: bool = True + _model = PrivateAttr() + _preprocess = PrivateAttr() + _tokenizer = PrivateAttr() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + open_clip = self.safe_import("open_clip", "open-clip") + model, _, preprocess = open_clip.create_model_and_transforms( + self.name, pretrained=self.pretrained + ) + model.to(self.device) + self._model, self._preprocess = model, preprocess + self._tokenizer = open_clip.get_tokenizer(self.name) + self._ndims = None + + def ndims(self): + if self._ndims is None: + self._ndims = self.generate_text_embeddings("foo").shape[0] + return self._ndims + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + return [self.generate_text_embeddings(query)] + else: + PIL = self.safe_import("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError("OpenClip supports str or PIL Image as query") + + def generate_text_embeddings(self, text: str) -> np.ndarray: + torch = self.safe_import("torch") + text = self.sanitize_input(text) + text = self._tokenizer(text) + text.to(self.device) + with torch.no_grad(): + text_features = self._model.encode_text(text.to(self.device)) + if self.normalize: + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().numpy().squeeze() + + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.array]: + """ + Get the embeddings for the given images + """ + images = self.sanitize_input(images) + embeddings = [] + for i in range(0, len(images), self.batch_size): + j = min(i + self.batch_size, len(images)) + batch = images[i:j] + embeddings.extend(self._parallel_get(batch)) + return embeddings + + def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: + """ + Issue concurrent requests to retrieve the image data + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_embedding, image) + for image in images + ] + return [future.result() for future in tqdm(futures)] + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + """ + Generate the embedding for a single image + + Parameters + ---------- + image : Union[str, bytes, PIL.Image.Image] + The image to embed. If the image is a str, it is treated as a uri. + If the image is bytes, it is treated as the raw image bytes. + """ + torch = self.safe_import("torch") + # TODO handle retry and errors for https + image = self._to_pil(image) + image = self._preprocess(image).unsqueeze(0) + with torch.no_grad(): + return self._encode_and_normalize_image(image) + + def _to_pil(self, image: Union[str, bytes]): + PIL = self.safe_import("PIL", "pillow") + if isinstance(image, bytes): + return PIL.Image.open(io.BytesIO(image)) + if isinstance(image, PIL.Image.Image): + return image + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + # TODO handle drive letter on windows. + if parsed.scheme == "file": + return PIL.Image.open(parsed.path) + elif parsed.scheme == "": + return PIL.Image.open(image if os.name == "nt" else parsed.path) + elif parsed.scheme.startswith("http"): + return PIL.Image.open(io.BytesIO(url_retrieve(image))) + else: + raise NotImplementedError("Only local and http(s) urls are supported") + + def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"): + """ + encode a single image tensor and optionally normalize the output + """ + image_features = self._model.encode_image(image_tensor.to(self.device)) + if self.normalize: + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features.cpu().numpy().squeeze() diff --git a/python/lancedb/embeddings/openai.py b/python/lancedb/embeddings/openai.py new file mode 100644 index 00000000..25459743 --- /dev/null +++ b/python/lancedb/embeddings/openai.py @@ -0,0 +1,37 @@ +from typing import List, Union + +import numpy as np + +from .base import TextEmbeddingFunction +from .registry import register + + +@register("openai") +class OpenAIEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses the OpenAI API + + https://platform.openai.com/docs/guides/embeddings + """ + + name: str = "text-embedding-ada-002" + + def ndims(self): + # TODO don't hardcode this + return 1536 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + # TODO retry, rate limit, token limit + openai = self.safe_import("openai") + rs = openai.Embedding.create(input=texts, model=self.name)["data"] + return [v["embedding"] for v in rs] diff --git a/python/lancedb/embeddings/registry.py b/python/lancedb/embeddings/registry.py new file mode 100644 index 00000000..af7600dc --- /dev/null +++ b/python/lancedb/embeddings/registry.py @@ -0,0 +1,186 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Dict, Optional + +from .base import EmbeddingFunction, EmbeddingFunctionConfig + + +class EmbeddingFunctionRegistry: + """ + This is a singleton class used to register embedding functions + and fetch them by name. It also handles serializing and deserializing. + You can implement your own embedding function by subclassing EmbeddingFunction + or TextEmbeddingFunction and registering it with the registry. + + NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] + Examples + -------- + >>> registry = EmbeddingFunctionRegistry.get_instance() + >>> @registry.register("my-embedding-function") + ... class MyEmbeddingFunction(EmbeddingFunction): + ... def ndims(self) -> int: + ... return 128 + ... + ... def compute_query_embeddings(self, query: str, *args, **kwargs): + ... return self.compute_source_embeddings(query, *args, **kwargs) + ... + ... def compute_source_embeddings(self, texts, *args, **kwargs): + ... return [np.random.rand(self.ndims()) for _ in range(len(texts))] + ... + >>> registry.get("my-embedding-function") + + """ + + @classmethod + def get_instance(cls): + return __REGISTRY__ + + def __init__(self): + self._functions = {} + + def register(self, alias: str = None): + """ + This creates a decorator that can be used to register + an EmbeddingFunction. + + Parameters + ---------- + alias : Optional[str] + a human friendly name for the embedding function. If not + provided, the class name will be used. + """ + + # This is a decorator for a class that inherits from BaseModel + # It adds the class to the registry + def decorator(cls): + if not issubclass(cls, EmbeddingFunction): + raise TypeError("Must be a subclass of EmbeddingFunction") + if cls.__name__ in self._functions: + raise KeyError(f"{cls.__name__} was already registered") + key = alias or cls.__name__ + self._functions[key] = cls + cls.__embedding_function_registry_alias__ = alias + return cls + + return decorator + + def reset(self): + """ + Reset the registry to its initial state + """ + self._functions = {} + + def get(self, name: str): + """ + Fetch an embedding function class by name + + Parameters + ---------- + name : str + The name of the embedding function to fetch + Either the alias or the class name if no alias was provided + during registration + """ + return self._functions[name] + + def parse_functions( + self, metadata: Optional[Dict[bytes, bytes]] + ) -> Dict[str, "EmbeddingFunctionConfig"]: + """ + Parse the metadata from an arrow table and + return a mapping of the vector column to the + embedding function and source column + + Parameters + ---------- + metadata : Optional[Dict[bytes, bytes]] + The metadata from an arrow table. Note that + the keys and values are bytes (pyarrow api) + + Returns + ------- + functions : dict + A mapping of vector column name to embedding function. + An empty dict is returned if input is None or does not + contain b"embedding_functions". + """ + if metadata is None or b"embedding_functions" not in metadata: + return {} + serialized = metadata[b"embedding_functions"] + raw_list = json.loads(serialized.decode("utf-8")) + return { + obj["vector_column"]: EmbeddingFunctionConfig( + vector_column=obj["vector_column"], + source_column=obj["source_column"], + function=self.get(obj["name"])(**obj["model"]), + ) + for obj in raw_list + } + + def function_to_metadata(self, conf: "EmbeddingFunctionConfig"): + """ + Convert the given embedding function and source / vector column configs + into a config dictionary that can be serialized into arrow metadata + """ + func = conf.function + name = getattr( + func, "__embedding_function_registry_alias__", func.__class__.__name__ + ) + json_data = func.safe_model_dump() + return { + "name": name, + "model": json_data, + "source_column": conf.source_column, + "vector_column": conf.vector_column, + } + + def get_table_metadata(self, func_list): + """ + Convert a list of embedding functions and source / vector configs + into a config dictionary that can be serialized into arrow metadata + """ + if func_list is None or len(func_list) == 0: + return None + json_data = [self.function_to_metadata(func) for func in func_list] + # Note that metadata dictionary values must be bytes + # so we need to json dump then utf8 encode + metadata = json.dumps(json_data, indent=2).encode("utf-8") + return {"embedding_functions": metadata} + + +# Global instance +__REGISTRY__ = EmbeddingFunctionRegistry() + + +# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8 +register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) + + +def get_registry(): + """ + Utility function to get the global instance of the registry + + Returns + ------- + EmbeddingFunctionRegistry + The global registry instance + + Examples + -------- + from lancedb.embeddings import get_registry + + registry = get_registry() + openai = registry.get("openai").create() + """ + return __REGISTRY__.get_instance() diff --git a/python/lancedb/embeddings/sentence_transformers.py b/python/lancedb/embeddings/sentence_transformers.py new file mode 100644 index 00000000..5e40a51d --- /dev/null +++ b/python/lancedb/embeddings/sentence_transformers.py @@ -0,0 +1,77 @@ +from typing import List, Union + +import numpy as np +from cachetools import cached + +from .base import TextEmbeddingFunction +from .registry import register + + +@register("sentence-transformers") +class SentenceTransformerEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses the sentence-transformers library + + https://huggingface.co/sentence-transformers + """ + + name: str = "all-MiniLM-L6-v2" + device: str = "cpu" + normalize: bool = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._ndims = None + + @property + def embedding_model(self): + """ + Get the sentence-transformers embedding model specified by the + name and device. This is cached so that the model is only loaded + once per process. + """ + return self.__class__.get_embedding_model(self.name, self.device) + + def ndims(self): + if self._ndims is None: + self._ndims = len(self.generate_embeddings("foo")[0]) + return self._ndims + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + return self.embedding_model.encode( + list(texts), + convert_to_numpy=True, + normalize_embeddings=self.normalize, + ).tolist() + + @classmethod + @cached(cache={}) + def get_embedding_model(cls, name, device): + """ + Get the sentence-transformers embedding model specified by the + name and device. This is cached so that the model is only loaded + once per process. + + Parameters + ---------- + name : str + The name of the model to load + device : str + The device to load the model on + + TODO: use lru_cache instead with a reasonable/configurable maxsize + """ + sentence_transformers = cls.safe_import( + "sentence_transformers", "sentence-transformers" + ) + return sentence_transformers.SentenceTransformer(name, device=device) diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index c70e6b18..e33bf4d3 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -12,8 +12,10 @@ # limitations under the License. import math +import socket import sys -from typing import Callable, Union +import urllib.error +from typing import Callable, List, Union import numpy as np import pyarrow as pa @@ -24,7 +26,12 @@ from ..util import safe_import_pandas from ..utils.general import LOGGER pd = safe_import_pandas() + DATA = Union[pa.Table, "pd.DataFrame"] +TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] +IMAGES = Union[ + str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray +] def with_embeddings( @@ -155,6 +162,20 @@ class FunctionWrapper: yield from _chunker(arr) +def url_retrieve(url: str): + """ + Parameters + ---------- + url: str + URL to download from + """ + try: + with urllib.request.urlopen(url) as conn: + return conn.read() + except (socket.gaierror, urllib.error.URLError) as err: + raise ConnectionError("could not download {} due to {}".format(url, err)) + + def api_key_not_found_help(provider): LOGGER.error(f"Could not find API key for {provider}.") raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.") diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 909c70e5..b4f881a1 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -29,8 +29,7 @@ from lance.dataset import CleanupStats, ReaderLike from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME -from .embeddings import EmbeddingFunctionRegistry -from .embeddings.functions import EmbeddingFunctionConfig +from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .pydantic import LanceModel from .query import LanceQueryBuilder, Query from .util import fs_from_uri, safe_import_pandas diff --git a/python/pyproject.toml b/python/pyproject.toml index 34950110..e0045b64 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"] +embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"] [project.scripts] lancedb = "lancedb.cli.cli:cli" diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 1ca0b78d..607a346d 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -19,7 +19,7 @@ import pytest import requests import lancedb -from lancedb.embeddings import EmbeddingFunctionRegistry +from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector # These are integration tests for embedding functions. @@ -31,7 +31,7 @@ from lancedb.pydantic import LanceModel, Vector @pytest.mark.parametrize("alias", ["sentence-transformers", "openai"]) def test_sentence_transformer(alias, tmp_path): db = lancedb.connect(tmp_path) - registry = EmbeddingFunctionRegistry.get_instance() + registry = get_registry() func = registry.get(alias).create() class Words(LanceModel): @@ -69,7 +69,7 @@ def test_openclip(tmp_path): from PIL import Image db = lancedb.connect(tmp_path) - registry = EmbeddingFunctionRegistry.get_instance() + registry = get_registry() func = registry.get("open-clip").create() class Images(LanceModel): @@ -131,11 +131,7 @@ def test_openclip(tmp_path): os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" ) # also skip if cohere not installed def test_cohere_embedding_function(): - cohere = ( - EmbeddingFunctionRegistry.get_instance() - .get("cohere") - .create(name="embed-multilingual-v2.0") - ) + cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0") class TextModel(LanceModel): text: str = cohere.SourceField()