From 31dad71c9405be1b6da0f361adf8f0b9ea8ce9f8 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Sat, 16 Sep 2023 21:23:51 -0400 Subject: [PATCH] multi-modal embedding-function (#484) --- .github/workflows/python.yml | 4 +- python/lancedb/conftest.py | 31 +- python/lancedb/db.py | 4 +- python/lancedb/embeddings/__init__.py | 10 +- python/lancedb/embeddings/functions.py | 415 +++++++++++++++++++++---- python/lancedb/pydantic.py | 44 ++- python/lancedb/query.py | 31 +- python/lancedb/remote/db.py | 7 +- python/lancedb/table.py | 47 +-- python/pyproject.toml | 11 +- python/tests/test_embeddings.py | 27 +- python/tests/test_embeddings_slow.py | 125 ++++++++ python/tests/test_table.py | 32 +- 13 files changed, 645 insertions(+), 143 deletions(-) create mode 100644 python/tests/test_embeddings_slow.py diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 1465875d..6b0b6b6e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -38,7 +38,7 @@ jobs: - name: isort run: isort --check --diff --quiet . - name: Run tests - run: pytest -x -v --durations=30 tests + run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest run: pytest --doctest-modules lancedb mac: @@ -65,4 +65,4 @@ jobs: - name: Black run: black --check --diff --no-color --quiet . - name: Run tests - run: pytest -x -v --durations=30 tests + run: pytest -m "not slow" -x -v --durations=30 tests diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index f91b1b75..716c6de8 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -1,9 +1,9 @@ import os -import pyarrow as pa +import numpy as np import pytest -from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry +from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # import lancedb so we don't have to in every example @@ -22,17 +22,20 @@ def doctest_setup(monkeypatch, tmpdir): registry = EmbeddingFunctionRegistry.get_instance() -@registry.register() -class MockEmbeddingFunction(EmbeddingFunctionModel): - def __call__(self, data): - if isinstance(data, str): - data = [data] - elif isinstance(data, pa.ChunkedArray): - data = data.combine_chunks().to_pylist() - elif isinstance(data, pa.Array): - data = data.to_pylist() +@registry.register("test") +class MockTextEmbeddingFunction(TextEmbeddingFunction): + """ + Return the hash of the first 10 characters + """ - return [self.embed(row) for row in data] + def generate_embeddings(self, texts): + return [self._compute_one_embedding(row) for row in texts] - def embed(self, row): - return [float(hash(c)) for c in row[:10]] + def _compute_one_embedding(self, row): + emb = np.array([float(hash(c)) for c in row[:10]]) + emb /= np.linalg.norm(emb) + return emb + + @property + def ndims(self): + return 10 diff --git a/python/lancedb/db.py b/python/lancedb/db.py index df163efa..6cdbce33 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -22,7 +22,7 @@ import pyarrow as pa from pyarrow import fs from .common import DATA, URI -from .embeddings import EmbeddingFunctionModel +from .embeddings import EmbeddingFunctionConfig from .pydantic import LanceModel from .table import LanceTable, Table from .util import fs_from_uri, get_uri_location, get_uri_scheme @@ -290,7 +290,7 @@ class LanceDBConnection(DBConnection): mode: str = "create", on_bad_vectors: str = "error", fill_value: float = 0.0, - embedding_functions: Optional[List[EmbeddingFunctionModel]] = None, + embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, ) -> LanceTable: """Create a table in the database. diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index 68abdd3e..454157cb 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -13,10 +13,12 @@ from .functions import ( - REGISTRY, - EmbeddingFunctionModel, + EmbeddingFunction, + EmbeddingFunctionConfig, EmbeddingFunctionRegistry, - SentenceTransformerEmbeddingFunction, - TextEmbeddingFunctionModel, + OpenAIEmbeddings, + OpenClipEmbeddings, + SentenceTransformerEmbeddings, + TextEmbeddingFunction, ) from .utils import with_embeddings diff --git a/python/lancedb/embeddings/functions.py b/python/lancedb/embeddings/functions.py index d1ae2cf7..e8683695 100644 --- a/python/lancedb/embeddings/functions.py +++ b/python/lancedb/embeddings/functions.py @@ -10,14 +10,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures +import importlib +import io import json +import os +import socket +import urllib.error +import urllib.parse as urlparse +import urllib.request from abc import ABC, abstractmethod -from typing import List, Optional, Union +from functools import cached_property +from typing import Dict, List, Optional, Union import numpy as np import pyarrow as pa from cachetools import cached -from pydantic import BaseModel +from pydantic import BaseModel, Field class EmbeddingFunctionRegistry: @@ -28,25 +37,33 @@ class EmbeddingFunctionRegistry: @classmethod def get_instance(cls): - return REGISTRY + return __REGISTRY__ def __init__(self): self._functions = {} - def register(self): + def register(self, alias: str = None): """ This creates a decorator that can be used to register - an EmbeddingFunctionModel. + an EmbeddingFunction. + + Parameters + ---------- + alias : Optional[str] + a human friendly name for the embedding function. If not + provided, the class name will be used. """ # This is a decorator for a class that inherits from BaseModel # It adds the class to the registry def decorator(cls): - if not issubclass(cls, EmbeddingFunctionModel): - raise TypeError("Must be a subclass of EmbeddingFunctionModel") + if not issubclass(cls, EmbeddingFunction): + raise TypeError("Must be a subclass of EmbeddingFunction") if cls.__name__ in self._functions: raise KeyError(f"{cls.__name__} was already registered") - self._functions[cls.__name__] = cls + key = alias or cls.__name__ + self._functions[key] = cls + cls.__embedding_function_registry_alias__ = alias return cls return decorator @@ -57,13 +74,22 @@ class EmbeddingFunctionRegistry: """ self._functions = {} - def load(self, name: str): + def get(self, name: str): """ Fetch an embedding function class by name + + Parameters + ---------- + name : str + The name of the embedding function to fetch + Either the alias or the class name if no alias was provided + during registration """ return self._functions[name] - def parse_functions(self, metadata: Optional[dict]) -> dict: + def parse_functions( + self, metadata: Optional[Dict[bytes, bytes]] + ) -> Dict[str, "EmbeddingFunctionConfig"]: """ Parse the metadata from an arrow table and return a mapping of the vector column to the @@ -71,9 +97,9 @@ class EmbeddingFunctionRegistry: Parameters ---------- - metadata : Optional[dict] + metadata : Optional[Dict[bytes, bytes]] The metadata from an arrow table. Note that - the keys and values are bytes. + the keys and values are bytes (pyarrow api) Returns ------- @@ -86,68 +112,91 @@ class EmbeddingFunctionRegistry: return {} serialized = metadata[b"embedding_functions"] raw_list = json.loads(serialized.decode("utf-8")) - functions = {} - for obj in raw_list: - model = self.load(obj["schema"]["title"]) - functions[obj["model"]["vector_column"]] = model(**obj["model"]) - return functions + return { + obj["vector_column"]: EmbeddingFunctionConfig( + vector_column=obj["vector_column"], + source_column=obj["source_column"], + function=self.get(obj["name"])(**obj["model"]), + ) + for obj in raw_list + } - def function_to_metadata(self, func): + def function_to_metadata(self, conf: "EmbeddingFunctionConfig"): """ Convert the given embedding function and source / vector column configs into a config dictionary that can be serialized into arrow metadata """ - schema = func.model_json_schema() + func = conf.function + name = getattr( + func, "__embedding_function_registry_alias__", func.__class__.__name__ + ) json_data = func.model_dump() return { - "schema": schema, + "name": name, "model": json_data, + "source_column": conf.source_column, + "vector_column": conf.vector_column, } def get_table_metadata(self, func_list): """ - Convert a list of embedding functions and source / vector column configs + Convert a list of embedding functions and source / vector configs into a config dictionary that can be serialized into arrow metadata """ + if func_list is None or len(func_list) == 0: + return None json_data = [self.function_to_metadata(func) for func in func_list] - # Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode + # Note that metadata dictionary values must be bytes + # so we need to json dump then utf8 encode metadata = json.dumps(json_data, indent=2).encode("utf-8") return {"embedding_functions": metadata} -REGISTRY = EmbeddingFunctionRegistry() - - -class EmbeddingFunctionModel(BaseModel, ABC): - """ - A callable ABC for embedding functions - """ - - source_column: Optional[str] - vector_column: str - - @abstractmethod - def __call__(self, *args, **kwargs) -> List[np.array]: - pass +# Global instance +__REGISTRY__ = EmbeddingFunctionRegistry() TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] +IMAGES = Union[ + str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray +] -class TextEmbeddingFunctionModel(EmbeddingFunctionModel): +class EmbeddingFunction(BaseModel, ABC): """ - A callable ABC for embedding functions that take text as input + An ABC for embedding functions. + + The API has two methods: + 1. compute_query_embeddings() which takes a query and returns a list of embeddings + 2. get_source_embeddings() which returns a list of embeddings for the source column + For text data, the two will be the same. For multi-modal data, the source column + might be images and the vector column might be text. """ - def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]: - texts = self.sanitize_input(texts) - return self.generate_embeddings(texts) + @classmethod + def create(cls, **kwargs): + """ + Create an instance of the embedding function + """ + return cls(**kwargs) + + @abstractmethod + def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for a given user query + """ + pass + + @abstractmethod + def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]: + """ + Compute the embeddings for the source column in the database + """ + pass def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: """ - Sanitize the input to the embedding function. This is called - before generate_embeddings() and is useful for stripping - whitespace, lowercasing, etc. + Sanitize the input to the embedding function. """ if isinstance(texts, str): texts = [texts] @@ -157,6 +206,71 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel): texts = texts.combine_chunks().to_pylist() return texts + @classmethod + def safe_import(cls, module: str, mitigation=None): + """ + Import the specified module. If the module is not installed, + raise an ImportError with a helpful message. + + Parameters + ---------- + module : str + The name of the module to import + mitigation : Optional[str] + The package(s) to install to mitigate the error. + If not provided then the module name will be used. + """ + try: + return importlib.import_module(module) + except ImportError: + raise ImportError(f"Please install {mitigation or module}") + + @property + @abstractmethod + def ndims(self): + """ + Return the dimensions of the vector column + """ + pass + + def SourceField(self, **kwargs): + """ + Return a pydantic Field that can automatically indicate + the source column for this embedding function + """ + return Field(json_schema_extra={"source_column_for": self}, **kwargs) + + def VectorField(self, **kwargs): + """ + Return a pydantic Field that can automatically indicate + the target vector column for this embedding function + """ + return Field(json_schema_extra={"vector_column_for": self}, **kwargs) + + +class EmbeddingFunctionConfig(BaseModel): + """ + This is a dataclass that holds the embedding function + and source column for a vector column + """ + + vector_column: str + source_column: str + function: EmbeddingFunction + + +class TextEmbeddingFunction(EmbeddingFunction): + """ + A callable ABC for embedding functions that take text as input + """ + + def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + return self.compute_source_embeddings(query, *args, **kwargs) + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + return self.generate_embeddings(texts) + @abstractmethod def generate_embeddings( self, texts: Union[List[str], np.ndarray] @@ -167,15 +281,20 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel): pass -@REGISTRY.register() -class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel): +register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) + + +@register("sentence-transformers") +class SentenceTransformerEmbeddings(TextEmbeddingFunction): """ An embedding function that uses the sentence-transformers library + + https://huggingface.co/sentence-transformers """ name: str = "all-MiniLM-L6-v2" device: str = "cpu" - normalize: bool = False + normalize: bool = True @property def embedding_model(self): @@ -186,6 +305,10 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel): """ return self.__class__.get_embedding_model(self.name, self.device) + @cached_property + def ndims(self): + return len(self.generate_embeddings(["foo"])[0]) + def generate_embeddings( self, texts: Union[List[str], np.ndarray] ) -> List[np.array]: @@ -220,9 +343,197 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel): TODO: use lru_cache instead with a reasonable/configurable maxsize """ - try: - from sentence_transformers import SentenceTransformer + sentence_transformers = cls.safe_import( + "sentence_transformers", "sentence-transformers" + ) + return sentence_transformers.SentenceTransformer(name, device=device) - return SentenceTransformer(name, device=device) - except ImportError: - raise ValueError("Please install sentence_transformers") + +@register("openai") +class OpenAIEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses the OpenAI API + + https://platform.openai.com/docs/guides/embeddings + """ + + name: str = "text-embedding-ada-002" + + @property + def ndims(self): + # TODO don't hardcode this + return 1536 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + # TODO retry, rate limit, token limit + openai = self.safe_import("openai") + rs = openai.Embedding.create(input=texts, model=self.name)["data"] + return [v["embedding"] for v in rs] + + +@register("open-clip") +class OpenClipEmbeddings(EmbeddingFunction): + """ + An embedding function that uses the OpenClip API + For multi-modal text-to-image search + + https://github.com/mlfoundations/open_clip + """ + + name: str = "ViT-B-32" + pretrained: str = "laion2b_s34b_b79k" + device: str = "cpu" + batch_size: int = 64 + normalize: bool = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + open_clip = self.safe_import("open_clip", "open-clip") + model, _, preprocess = open_clip.create_model_and_transforms( + self.name, pretrained=self.pretrained + ) + model.to(self.device) + self._model, self._preprocess = model, preprocess + self._tokenizer = open_clip.get_tokenizer(self.name) + + @cached_property + def ndims(self): + return self.generate_text_embeddings("foo").shape[0] + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + return [self.generate_text_embeddings(query)] + else: + PIL = self.safe_import("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError("OpenClip supports str or PIL Image as query") + + def generate_text_embeddings(self, text: str) -> np.ndarray: + torch = self.safe_import("torch") + text = self.sanitize_input(text) + text = self._tokenizer(text) + text.to(self.device) + with torch.no_grad(): + text_features = self._model.encode_text(text.to(self.device)) + if self.normalize: + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().numpy().squeeze() + + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.array]: + """ + Get the embeddings for the given images + """ + images = self.sanitize_input(images) + embeddings = [] + for i in range(0, len(images), self.batch_size): + j = min(i + self.batch_size, len(images)) + batch = images[i:j] + embeddings.extend(self._parallel_get(batch)) + return embeddings + + def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: + """ + Issue concurrent requests to retrieve the image data + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_embedding, image) + for image in images + ] + return [future.result() for future in futures] + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + """ + Generate the embedding for a single image + + Parameters + ---------- + image : Union[str, bytes, PIL.Image.Image] + The image to embed. If the image is a str, it is treated as a uri. + If the image is bytes, it is treated as the raw image bytes. + """ + torch = self.safe_import("torch") + # TODO handle retry and errors for https + image = self._to_pil(image) + image = self._preprocess(image).unsqueeze(0) + with torch.no_grad(): + return self._encode_and_normalize_image(image) + + def _to_pil(self, image: Union[str, bytes]): + PIL = self.safe_import("PIL", "pillow") + if isinstance(image, bytes): + return PIL.Image.open(io.BytesIO(image)) + if isinstance(image, PIL.Image.Image): + return image + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + # TODO handle drive letter on windows. + if parsed.scheme == "file": + return PIL.Image.open(parsed.path) + elif parsed.scheme == "": + return PIL.Image.open(image if os.name == "nt" else parsed.path) + elif parsed.scheme.startswith("http"): + return PIL.Image.open(io.BytesIO(url_retrieve(image))) + else: + raise NotImplementedError("Only local and http(s) urls are supported") + + def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"): + """ + encode a single image tensor and optionally normalize the output + """ + image_features = self._model.encode_image(image_tensor) + if self.normalize: + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features.cpu().numpy().squeeze() + + +def url_retrieve(url: str): + """ + Parameters + ---------- + url: str + URL to download from + """ + try: + with urllib.request.urlopen(url) as conn: + return conn.read() + except (socket.gaierror, urllib.error.URLError) as err: + raise ConnectionError("could not download {} due to {}".format(url, err)) diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 6b4c9dc1..46b882e7 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -26,6 +26,8 @@ import pyarrow as pa import pydantic import semver +from .embeddings import EmbeddingFunctionRegistry + PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: from pydantic_core import CoreSchema, core_schema @@ -290,13 +292,49 @@ class LanceModel(pydantic.BaseModel): """ Get the Arrow Schema for this model. """ - return pydantic_to_schema(cls) + schema = pydantic_to_schema(cls) + functions = cls.parse_embedding_functions() + if len(functions) > 0: + metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata( + functions + ) + schema = schema.with_metadata(metadata) + return schema @classmethod def field_names(cls) -> List[str]: """ Get the field names of this model. """ + return list(cls.safe_get_fields().keys()) + + @classmethod + def safe_get_fields(cls): if PYDANTIC_VERSION.major < 2: - return list(cls.__fields__.keys()) - return list(cls.model_fields.keys()) + return cls.__fields__ + return cls.model_fields + + @classmethod + def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]: + """ + Parse the embedding functions from this model. + """ + from .embeddings import EmbeddingFunctionConfig + + vec_and_function = [] + for name, field_info in cls.safe_get_fields().items(): + func = (field_info.json_schema_extra or {}).get("vector_column_for") + if func is not None: + vec_and_function.append([name, func]) + + configs = [] + for vec, func in vec_and_function: + for source, field_info in cls.safe_get_fields().items(): + src_func = (field_info.json_schema_extra or {}).get("source_column_for") + if src_func == func: + configs.append( + EmbeddingFunctionConfig( + source_column=source, vector_column=vec, function=func + ) + ) + return configs diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 87a856b4..10925ac7 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -60,13 +60,15 @@ class LanceQueryBuilder(ABC): def create( cls, table: "lancedb.table.Table", - query: Optional[Union[np.ndarray, str]], + query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]], query_type: str, vector_column_name: str, ) -> LanceQueryBuilder: if query is None: return LanceEmptyQueryBuilder(table) + # convert "auto" query_type to "vector" or "fts" + # and convert the query to vector if needed query, query_type = cls._resolve_query( table, query, query_type, vector_column_name ) @@ -90,30 +92,27 @@ class LanceQueryBuilder(ABC): # otherwise raise TypeError if query_type == "fts": if not isinstance(query, str): - raise TypeError( - f"Query type is 'fts' but query is not a string: {type(query)}" - ) + raise TypeError(f"'fts' queries must be a string: {type(query)}") return query, query_type elif query_type == "vector": - # If query_type is vector, then query must be a list or np.ndarray. - # otherwise raise TypeError if not isinstance(query, (list, np.ndarray)): - raise TypeError( - f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}" - ) + conf = table.embedding_functions.get(vector_column_name) + if conf is not None: + query = conf.function.compute_query_embeddings(query)[0] + else: + msg = f"No embedding function for {vector_column_name}" + raise ValueError(msg) return query, query_type elif query_type == "auto": if isinstance(query, (list, np.ndarray)): return query, "vector" - elif isinstance(query, str): - func = table.embedding_functions.get(vector_column_name, None) - if func is not None: - query = func(query)[0] + else: + conf = table.embedding_functions.get(vector_column_name) + if conf is not None: + query = conf.function.compute_query_embeddings(query)[0] return query, "vector" else: return query, "fts" - else: - raise TypeError("Query must be a list, np.ndarray, or str") else: raise ValueError( f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}" @@ -238,7 +237,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): def __init__( self, table: "lancedb.table.Table", - query: Union[np.ndarray, list], + query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str = VECTOR_COLUMN_NAME, ): super().__init__(table) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 3568c64d..89d7a3f5 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -18,10 +18,9 @@ from urllib.parse import urlparse import pyarrow as pa -from lancedb.common import DATA -from lancedb.db import DBConnection -from lancedb.table import Table, _sanitize_data - +from ..common import DATA +from ..db import DBConnection +from ..table import Table, _sanitize_data from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 2fe3ce9c..ae19d83c 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -28,7 +28,8 @@ from lance.dataset import ReaderLike from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME -from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry +from .embeddings import EmbeddingFunctionRegistry +from .embeddings.functions import EmbeddingFunctionConfig from .pydantic import LanceModel from .query import LanceQueryBuilder, Query from .util import fs_from_uri, safe_import_pandas @@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem vector column to the table. """ functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) - for vector_col, func in functions.items(): - if vector_col not in data.column_names: - col_data = func(data[func.source_column]) + for vector_column, conf in functions.items(): + func = conf.function + if vector_column not in data.column_names: + col_data = func.compute_source_embeddings(data[conf.source_column]) if schema is not None: - dtype = schema.field(vector_col).type + dtype = schema.field(vector_column).type else: dtype = pa.list_(pa.float32(), len(col_data[0])) data = data.append_column( - pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype) + pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) ) return data @@ -230,7 +232,7 @@ class Table(ABC): @abstractmethod def search( self, - query: Optional[Union[VEC, str]] = None, + query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, vector_column_name: str = VECTOR_COLUMN_NAME, query_type: str = "auto", ) -> LanceQueryBuilder: @@ -239,7 +241,7 @@ class Table(ABC): Parameters ---------- - query: str, list, np.ndarray, default None + query: str, list, np.ndarray, PIL.Image.Image, default None The query to search for. If None then the select/where/limit clauses are applied to filter the table @@ -249,6 +251,8 @@ class Table(ABC): "vector", "fts", or "auto" If "auto" then the query type is inferred from the query; If `query` is a list/np.ndarray then the query type is "vector"; + If `query` is a PIL.Image.Image then either do vector search + or raise an error if no corresponding embedding function is found. If `query` is a string, then the query type is "vector" if the table has embedding functions else the query type is "fts" @@ -524,6 +528,9 @@ class LanceTable(Table): fill_value: float = 0.0, ): """Add data to the table. + If vector columns are missing and the table + has embedding functions, then the vector columns + are automatically computed and added. Parameters ---------- @@ -617,12 +624,6 @@ class LanceTable(Table): ) self._reset_dataset() - def _get_embedding_function_for_source_col(self, column_name: str): - for k, v in self.embedding_functions.items(): - if v.source_column == column_name: - return v - return None - @cached_property def embedding_functions(self) -> dict: """ @@ -640,7 +641,7 @@ class LanceTable(Table): def search( self, - query: Optional[Union[VEC, str]] = None, + query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, vector_column_name: str = VECTOR_COLUMN_NAME, query_type: str = "auto", ) -> LanceQueryBuilder: @@ -649,7 +650,7 @@ class LanceTable(Table): Parameters ---------- - query: str, list, np.ndarray, or None + query: str, list, np.ndarray, a PIL Image or None The query to search for. If None then the select/where/limit clauses are applied to filter the table @@ -658,9 +659,11 @@ class LanceTable(Table): query_type: str, default "auto" "vector", "fts", or "auto" If "auto" then the query type is inferred from the query; - If the query is a list/np.ndarray then the query type is "vector"; + If `query` is a list/np.ndarray then the query type is "vector"; + If `query` is a PIL.Image.Image then either do vector search + or raise an error if no corresponding embedding function is found. If the query is a string, then the query type is "vector" if the - table has embedding functions else the query type is "fts" + table has embedding functions, else the query type is "fts" Returns ------- @@ -684,7 +687,7 @@ class LanceTable(Table): mode="create", on_bad_vectors: str = "error", fill_value: float = 0.0, - embedding_functions: List[EmbeddingFunctionModel] = None, + embedding_functions: List[EmbeddingFunctionConfig] = None, ): """ Create a new table. @@ -727,10 +730,16 @@ class LanceTable(Table): """ tbl = LanceTable(db, name) if inspect.isclass(schema) and issubclass(schema, LanceModel): + # convert LanceModel to pyarrow schema + # note that it's possible this contains + # embedding function metadata already schema = schema.to_arrow_schema() metadata = None if embedding_functions is not None: + # If we passed in embedding functions explicitly + # then we'll override any schema metadata that + # may was implicitly specified by the LanceModel schema registry = EmbeddingFunctionRegistry.get_instance() metadata = registry.get_table_metadata(embedding_functions) diff --git a/python/pyproject.toml b/python/pyproject.toml index d29d476c..f4955daa 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -44,9 +44,11 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"] +tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +clip = ["torch", "pillow", "open-clip"] +embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"] [build-system] requires = ["setuptools", "wheel"] @@ -54,3 +56,10 @@ build-backend = "setuptools.build_meta" [tool.isort] profile = "black" + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "asyncio" +] \ No newline at end of file diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index 4b7c0c49..3eb4839b 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -16,8 +16,12 @@ import lance import numpy as np import pyarrow as pa -from lancedb.conftest import MockEmbeddingFunction -from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings +from lancedb.conftest import MockTextEmbeddingFunction +from lancedb.embeddings import ( + EmbeddingFunctionConfig, + EmbeddingFunctionRegistry, + with_embeddings, +) def mock_embed_func(input_data): @@ -54,8 +58,12 @@ def test_embedding_function(tmp_path): "vector": [np.random.randn(10), np.random.randn(10)], } ) - func = MockEmbeddingFunction(source_column="text", vector_column="vector") - metadata = registry.get_table_metadata([func]) + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="vector", + function=MockTextEmbeddingFunction(), + ) + metadata = registry.get_table_metadata([conf]) table = table.replace_schema_metadata(metadata) # Write it to disk @@ -65,14 +73,13 @@ def test_embedding_function(tmp_path): ds = lance.dataset(tmp_path / "test.lance") # can we get the serialized version back out? - functions = registry.parse_functions(ds.schema.metadata) + configs = registry.parse_functions(ds.schema.metadata) - func = functions["vector"] - actual = func("hello world") + conf = configs["vector"] + func = conf.function + actual = func.compute_query_embeddings("hello world") - # We create an instance - expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector") # And we make sure we can call it - expected = expected_func("hello world") + expected = func.compute_query_embeddings("hello world") assert np.allclose(actual, expected) diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py new file mode 100644 index 00000000..8960f587 --- /dev/null +++ b/python/tests/test_embeddings_slow.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io + +import numpy as np +import pandas as pd +import pytest +import requests + +import lancedb +from lancedb.embeddings import EmbeddingFunctionRegistry +from lancedb.pydantic import LanceModel, Vector + +# These are integration tests for embedding functions. +# They are slow because they require downloading models +# or connection to external api + + +@pytest.mark.slow +@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"]) +def test_sentence_transformer(alias, tmp_path): + db = lancedb.connect(tmp_path) + registry = EmbeddingFunctionRegistry.get_instance() + func = registry.get(alias).create() + + class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims) = func.VectorField() + + table = db.create_table("words", schema=Words) + table.add( + pd.DataFrame( + { + "text": [ + "hello world", + "goodbye world", + "fizz", + "buzz", + "foo", + "bar", + "baz", + ] + } + ) + ) + + query = "greetings" + actual = table.search(query).limit(1).to_pydantic(Words)[0] + + vec = func.compute_query_embeddings(query)[0] + expected = table.search(vec).limit(1).to_pydantic(Words)[0] + assert actual.text == expected.text + assert actual.text == "hello world" + + +@pytest.mark.slow +def test_openclip(tmp_path): + from PIL import Image + + db = lancedb.connect(tmp_path) + registry = EmbeddingFunctionRegistry.get_instance() + func = registry.get("open-clip").create() + + class Images(LanceModel): + label: str + image_uri: str = func.SourceField() + image_bytes: bytes = func.SourceField() + vector: Vector(func.ndims) = func.VectorField() + vec_from_bytes: Vector(func.ndims) = func.VectorField() + + table = db.create_table("images", schema=Images) + labels = ["cat", "cat", "dog", "dog", "horse", "horse"] + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + # get each uri as bytes + image_bytes = [requests.get(uri).content for uri in uris] + table.add( + pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) + ) + + # text search + actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0] + assert actual.label == "dog" + frombytes = ( + table.search("man's best friend", vector_column_name="vec_from_bytes") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == frombytes.label + assert np.allclose(actual.vector, frombytes.vector) + + # image search + query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg" + image_bytes = requests.get(query_image_uri).content + query_image = Image.open(io.BytesIO(image_bytes)) + actual = table.search(query_image).limit(1).to_pydantic(Images)[0] + assert actual.label == "dog" + other = ( + table.search(query_image, vector_column_name="vec_from_bytes") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == other.label + + arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow() + assert np.allclose( + arrow_table["vector"].combine_chunks().values.to_numpy(), + arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(), + ) diff --git a/python/tests/test_table.py b/python/tests/test_table.py index b1b000eb..92129e9c 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -22,8 +22,9 @@ import pandas as pd import pyarrow as pa import pytest -from lancedb.conftest import MockEmbeddingFunction +from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import LanceDBConnection +from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable @@ -356,20 +357,23 @@ def test_create_with_embedding_function(db): text: str vector: Vector(10) - func = MockEmbeddingFunction(source_column="text", vector_column="vector") + func = MockTextEmbeddingFunction() texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] - df = pd.DataFrame({"text": texts, "vector": func(texts)}) + df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)}) + conf = EmbeddingFunctionConfig( + source_column="text", vector_column="vector", function=func + ) table = LanceTable.create( db, "my_table", schema=MyTable, - embedding_functions=[func], + embedding_functions=[conf], ) table.add(df) query_str = "hi how are you?" - query_vector = func(query_str)[0] + query_vector = func.compute_query_embeddings(query_str)[0] expected = table.search(query_vector).limit(2).to_arrow() actual = table.search(query_str).limit(2).to_arrow() @@ -377,17 +381,13 @@ def test_create_with_embedding_function(db): def test_add_with_embedding_function(db): - class MyTable(LanceModel): - text: str - vector: Vector(10) + emb = EmbeddingFunctionRegistry.get_instance().get("test")() - func = MockEmbeddingFunction(source_column="text", vector_column="vector") - table = LanceTable.create( - db, - "my_table", - schema=MyTable, - embedding_functions=[func], - ) + class MyTable(LanceModel): + text: str = emb.SourceField() + vector: Vector(emb.ndims) = emb.VectorField() + + table = LanceTable.create(db, "my_table", schema=MyTable) texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pd.DataFrame({"text": texts}) @@ -397,7 +397,7 @@ def test_add_with_embedding_function(db): table.add([{"text": t} for t in texts]) query_str = "hi how are you?" - query_vector = func(query_str)[0] + query_vector = emb.compute_query_embeddings(query_str)[0] expected = table.search(query_vector).limit(2).to_arrow() actual = table.search(query_str).limit(2).to_arrow()