mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
Multi-task instructor model with quantization support & weak_lru cache for embedding function models (#612)
resolves #608
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
@@ -125,6 +137,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not hasattr(__value, "__dict__"):
|
||||
return False
|
||||
return vars(self) == vars(__value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
|
||||
137
python/lancedb/embeddings/instructor.py
Normal file
137
python/lancedb/embeddings/instructor.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, weak_lru
|
||||
|
||||
|
||||
@register("instructor")
|
||||
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||
"Represent the `domain` `text_type` for `task_objective`":
|
||||
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||
|
||||
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||
"Represent the document for retreival"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The default model is hkunlp/instructor-base
|
||||
batch_size: int, default 32
|
||||
The batch size to use when generating embeddings
|
||||
device: str, default "cpu"
|
||||
The device to use when generating embeddings
|
||||
show_progress_bar: bool, default True
|
||||
Whether to show a progress bar when generating embeddings
|
||||
normalize_embeddings: bool, default True
|
||||
Whether to normalize the embeddings
|
||||
quantize: bool, default False
|
||||
Whether to quantize the model
|
||||
source_instruction: str, default "represent the docuement for retreival"
|
||||
The instruction for the source column
|
||||
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||
The instruction for the query
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||
|
||||
instructor = get_registry().get("instructor").create(
|
||||
source_instruction="represent the docuement for retreival",
|
||||
query_instruction="represent the document for retreiving the most similar documents"
|
||||
)
|
||||
|
||||
class Schema(LanceModel):
|
||||
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||
text: str = instructor.SourceField()
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||
|
||||
tbl.add(texts)
|
||||
|
||||
"""
|
||||
|
||||
name: str = "hkunlp/instructor-base"
|
||||
batch_size: int = 32
|
||||
device: str = "cpu"
|
||||
show_progress_bar: bool = True
|
||||
normalize_embeddings: bool = True
|
||||
quantize: bool = False
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
model = self.get_model()
|
||||
return model.encode("foo").shape[0]
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.generate_embeddings([[self.query_instruction, query]])
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
texts_formatted = []
|
||||
for text in texts:
|
||||
texts_formatted.append([self.source_instruction, text])
|
||||
return self.generate_embeddings(texts_formatted)
|
||||
|
||||
def generate_embeddings(self, texts: List) -> List:
|
||||
model = self.get_model()
|
||||
res = model.encode(
|
||||
texts,
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
).tolist()
|
||||
return res
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_model(self):
|
||||
instructor_embedding = self.safe_import(
|
||||
"InstructorEmbedding", "InstructorEmbedding"
|
||||
)
|
||||
torch = self.safe_import("torch", "torch")
|
||||
|
||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||
if self.quantize:
|
||||
if (
|
||||
"qnnpack" in torch.backends.quantized.supported_engines
|
||||
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
||||
torch.backends.quantized.engine = "qnnpack"
|
||||
model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
return model
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -5,6 +17,7 @@ from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
@@ -30,7 +43,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
return self.get_embedding_model()
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
@@ -54,9 +67,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
@weak_lru(maxsize=1)
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
@@ -71,7 +83,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
sentence_transformers = self.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
@@ -11,12 +11,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import weakref
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -164,6 +166,50 @@ class FunctionWrapper:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def weak_lru(maxsize=128):
|
||||
"""
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||
is bounded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maxsize : int, default 128
|
||||
The maximum number of objects to cache.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator that can be applied to a method.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> class Foo:
|
||||
... @weak_lru()
|
||||
... def bar(self, x):
|
||||
... return x
|
||||
>>> foo = Foo()
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
>>> foo.bar(2)
|
||||
2
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
@functools.lru_cache(maxsize)
|
||||
def _func(_self, *args, **kwargs):
|
||||
return func(_self(), *args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *args, **kwargs):
|
||||
return _func(weakref.ref(self), *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
|
||||
@@ -53,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
|
||||
@@ -32,8 +32,8 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get(alias).create()
|
||||
func2 = registry.get(alias).create()
|
||||
func = registry.get(alias).create(max_retries=0)
|
||||
func2 = registry.get(alias).create(max_retries=0)
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
@@ -150,7 +150,11 @@ def test_openclip(tmp_path):
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
) # also skip if cohere not installed
|
||||
def test_cohere_embedding_function():
|
||||
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
cohere = (
|
||||
get_registry()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
@@ -162,3 +166,19 @@ def test_cohere_embedding_function():
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_instructor_embedding(tmp_path):
|
||||
model = get_registry().get("instructor").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
Reference in New Issue
Block a user