diff --git a/python/python/lancedb/embeddings/gte.py b/python/python/lancedb/embeddings/gte.py new file mode 100644 index 000000000..b4e7e16e2 --- /dev/null +++ b/python/python/lancedb/embeddings/gte.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + + +from typing import List, Union + +import numpy as np + +from ..util import attempt_import_or_raise +from .base import TextEmbeddingFunction +from .registry import register +from .utils import weak_lru + + +@register("gte-text") +class GteEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only) + as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large. + + For Apple users, you will need the mlx package insalled, which can be done with: + pip install mlx + + Parameters + ---------- + name: str, default "thenlper/gte-large" + The name of the model to use. + device: str, default "cpu" + Sets the device type for the model. + normalize: str, default "True" + Controls normalize param in encode function for the transformer. + mlx: bool, default False + Controls which model to use. False for gte-large,True for the mlx version. + + Examples + -------- + import lancedb + import lancedb.embeddings.gte + from lancedb.embeddings import get_registry + from lancedb.pydantic import LanceModel, Vector + import pandas as pd + + model = get_registry().get("gte-text").create() # mlx=True for Apple silicon + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]}) + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + rs = tbl.search("hello").limit(1).to_pandas() + + """ + + name: str = "thenlper/gte-large" + device: str = "cpu" + normalize: bool = True + mlx: bool = False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._ndims = None + if kwargs: + self.mlx = kwargs.get("mlx", False) + if self.mlx is True: + self.name = "gte-mlx" + + @property + def embedding_model(self): + """ + Get the embedding model specified by the flag, + name and device. This is cached so that the model is only loaded + once per process. + """ + return self.get_embedding_model() + + def ndims(self): + if self.mlx is True: + self._ndims = self.embedding_model.dims + if self._ndims is None: + self._ndims = len(self.generate_embeddings("foo")[0]) + return self._ndims + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts. + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + if self.mlx is True: + return self.embedding_model.run(list(texts)).tolist() + + return self.embedding_model.encode( + list(texts), + convert_to_numpy=True, + normalize_embeddings=self.normalize, + ).tolist() + + @weak_lru(maxsize=1) + def get_embedding_model(self): + """ + Get the embedding model specified by the flag, + name and device. This is cached so that the model is only loaded + once per process. + """ + if self.mlx is True: + from .gte_mlx_model import Model + + return Model() + else: + sentence_transformers = attempt_import_or_raise( + "sentence_transformers", "sentence-transformers" + ) + return sentence_transformers.SentenceTransformer( + self.name, device=self.device + ) diff --git a/python/python/lancedb/embeddings/siglip.py b/python/python/lancedb/embeddings/siglip.py new file mode 100644 index 000000000..41228bbe0 --- /dev/null +++ b/python/python/lancedb/embeddings/siglip.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import concurrent.futures +import io +import os +from typing import TYPE_CHECKING, List, Union +import urllib.parse as urlparse + +import numpy as np +import pyarrow as pa +from tqdm import tqdm +from pydantic import PrivateAttr + +from ..util import attempt_import_or_raise +from .base import EmbeddingFunction +from .registry import register +from .utils import IMAGES, url_retrieve + +if TYPE_CHECKING: + import PIL + import torch + + +@register("siglip") +class SigLipEmbeddings(EmbeddingFunction): + model_name: str = "google/siglip-base-patch16-224" + device: str = "cpu" + batch_size: int = 64 + normalize: bool = True + + _model = PrivateAttr() + _processor = PrivateAttr() + _tokenizer = PrivateAttr() + _torch = PrivateAttr() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + transformers = attempt_import_or_raise("transformers") + self._torch = attempt_import_or_raise("torch") + + self._processor = transformers.AutoProcessor.from_pretrained(self.model_name) + self._model = transformers.SiglipModel.from_pretrained(self.model_name) + self._model.to(self.device) + self._model.eval() + self._ndims = None + + def ndims(self): + if self._ndims is None: + self._ndims = self.generate_text_embeddings("foo").shape[0] + return self._ndims + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + if isinstance(query, str): + return [self.generate_text_embeddings(query)] + else: + PIL_Image = attempt_import_or_raise("PIL.Image", "pillow") + if isinstance(query, PIL_Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError("SigLIP supports str or PIL Image as query") + + def generate_text_embeddings(self, text: str) -> np.ndarray: + torch = self._torch + text_inputs = self._processor( + text=text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=64, + ).to(self.device) + + with torch.no_grad(): + text_features = self._model.get_text_features(**text_inputs) + if self.normalize: + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().detach().numpy().squeeze() + + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.ndarray]: + images = self.sanitize_input(images) + embeddings = [] + + for i in range(0, len(images), self.batch_size): + j = min(i + self.batch_size, len(images)) + batch = images[i:j] + embeddings.extend(self._parallel_get(batch)) + return embeddings + + def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_embedding, image) + for image in images + ] + return [f.result() for f in tqdm(futures, desc="SigLIP Embedding")] + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + image = self._to_pil(image) + image = self._processor(images=image, return_tensors="pt")["pixel_values"] + return self._encode_and_normalize_image(image) + + def _encode_and_normalize_image(self, image_tensor: "torch.Tensor") -> np.ndarray: + torch = self._torch + with torch.no_grad(): + image_features = self._model.get_image_features( + image_tensor.to(self.device) + ) + if self.normalize: + image_features = image_features / image_features.norm( + dim=-1, keepdim=True + ) + return image_features.cpu().detach().numpy().squeeze() + + def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]): + PIL_Image = attempt_import_or_raise("PIL.Image", "pillow") + if isinstance(image, PIL_Image.Image): + return image.convert("RGB") if image.mode != "RGB" else image + elif isinstance(image, bytes): + return PIL_Image.open(io.BytesIO(image)).convert("RGB") + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + if parsed.scheme == "file": + return PIL_Image.open(parsed.path).convert("RGB") + elif parsed.scheme == "": + path = image if os.name == "nt" else parsed.path + return PIL_Image.open(path).convert("RGB") + elif parsed.scheme.startswith("http"): + image_bytes = url_retrieve(image) + return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB") + else: + raise NotImplementedError("Only local and http(s) urls are supported") + else: + raise ValueError(f"Unsupported image type: {type(image)}")