mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
Embeddings: HF model hub support added via transformers (#1154)
This commit is contained in:
committed by
Weston Pace
parent
ac63d4066b
commit
1c41a00d87
@@ -10,7 +10,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .bedrock import BedRockText
|
||||
@@ -21,4 +20,7 @@ from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .gte import GteEmbeddings
|
||||
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
|
||||
from .imagebind import ImageBindEmbeddings
|
||||
from .utils import with_embeddings
|
||||
|
||||
@@ -38,6 +38,9 @@ class ImageBindEmbeddings(EmbeddingFunction):
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = 1024
|
||||
|
||||
100
python/python/lancedb/embeddings/transformers.py
Normal file
100
python/python/lancedb/embeddings/transformers.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import List, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT
|
||||
|
||||
|
||||
@register("huggingface")
|
||||
class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that can use any model from the transformers library.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
name : str
|
||||
The name of the model to use. This should be a model name that can be loaded
|
||||
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
|
||||
default: "colbert-ir/colbertv2.0""
|
||||
|
||||
to download package, run :
|
||||
`pip install transformers`
|
||||
you may need to install pytorch as well - `https://pytorch.org/get-started/locally/`
|
||||
|
||||
"""
|
||||
|
||||
name: str = "colbert-ir/colbertv2.0"
|
||||
_tokenizer: Any = PrivateAttr()
|
||||
_model: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = None
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
self._ndims = self._model.config.hidden_size
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
embedding = []
|
||||
for text in texts:
|
||||
encoding = self._tokenizer(
|
||||
text, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
|
||||
embedding.append(emb.detach().numpy())
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
@register("colbert")
|
||||
class ColbertEmbeddings(TransformersEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the colbert model from the huggingface library.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
name : str
|
||||
The name of the model to use. This should be a model name that can be loaded
|
||||
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
|
||||
default: "colbert-ir/colbertv2.0""
|
||||
|
||||
to download package, run :
|
||||
`pip install transformers`
|
||||
you may need to install pytorch as well - `https://pytorch.org/get-started/locally/`
|
||||
|
||||
"""
|
||||
|
||||
name: str = "colbert-ir/colbertv2.0"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -45,7 +45,7 @@ except Exception:
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"])
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
@@ -84,7 +84,7 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
)
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
query = "greeting"
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
@@ -184,9 +184,9 @@ def test_imagebind(tmp_path):
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -321,8 +321,6 @@ def test_gemini_embedding(tmp_path):
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_gte_embedding(tmp_path):
|
||||
import lancedb.embeddings.gte
|
||||
|
||||
model = get_registry().get("gte-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
|
||||
Reference in New Issue
Block a user