Embeddings: HF model hub support added via transformers (#1154)

This commit is contained in:
Raghav Dixit
2024-04-05 00:26:27 -04:00
committed by Weston Pace
parent ac63d4066b
commit 1c41a00d87
6 changed files with 136 additions and 7 deletions

View File

@@ -90,7 +90,7 @@ requires = ["maturin>=1.4"]
build-backend = "maturin"
[tool.ruff.lint]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
select = ["F", "E", "W", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"

View File

@@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .bedrock import BedRockText
@@ -21,4 +20,7 @@ from .open_clip import OpenClipEmbeddings
from .openai import OpenAIEmbeddings
from .registry import EmbeddingFunctionRegistry, get_registry
from .sentence_transformers import SentenceTransformerEmbeddings
from .gte import GteEmbeddings
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings

View File

@@ -38,6 +38,9 @@ class ImageBindEmbeddings(EmbeddingFunction):
device: str = "cpu"
normalize: bool = False
class Config:
keep_untouched = (cached_property,)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ndims = 1024

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Any
import numpy as np
from pydantic import PrivateAttr
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import TEXT
@register("huggingface")
class TransformersEmbeddingFunction(EmbeddingFunction):
"""
An embedding function that can use any model from the transformers library.
Parameters:
----------
name : str
The name of the model to use. This should be a model name that can be loaded
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
default: "colbert-ir/colbertv2.0""
to download package, run :
`pip install transformers`
you may need to install pytorch as well - `https://pytorch.org/get-started/locally/`
"""
name: str = "colbert-ir/colbertv2.0"
_tokenizer: Any = PrivateAttr()
_model: Any = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ndims = None
transformers = attempt_import_or_raise("transformers")
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name)
class Config:
keep_untouched = (cached_property,)
def ndims(self):
self._ndims = self._model.config.hidden_size
return self._ndims
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
embedding = []
for text in texts:
encoding = self._tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
embedding.append(emb.detach().numpy())
return embedding
@register("colbert")
class ColbertEmbeddings(TransformersEmbeddingFunction):
"""
An embedding function that uses the colbert model from the huggingface library.
Parameters:
----------
name : str
The name of the model to use. This should be a model name that can be loaded
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
default: "colbert-ir/colbertv2.0""
to download package, run :
`pip install transformers`
you may need to install pytorch as well - `https://pytorch.org/get-started/locally/`
"""
name: str = "colbert-ir/colbertv2.0"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -45,7 +45,7 @@ except Exception:
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"])
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
@@ -84,7 +84,7 @@ def test_basic_text_embeddings(alias, tmp_path):
)
)
query = "greetings"
query = "greeting"
actual = (
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
@@ -184,9 +184,9 @@ def test_imagebind(tmp_path):
import shutil
import tempfile
import lancedb.embeddings.imagebind
import pandas as pd
import requests
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -321,8 +321,6 @@ def test_gemini_embedding(tmp_path):
)
@pytest.mark.slow
def test_gte_embedding(tmp_path):
import lancedb.embeddings.gte
model = get_registry().get("gte-text").create()
class TextModel(LanceModel):