Multi-task instructor model with quantization support & weak_lru cache for embedding function models (#612)

resolves #608
This commit is contained in:
Ayush Chaurasia
2023-11-09 12:34:18 +05:30
committed by GitHub
parent 662968559d
commit 1e8678f11a
9 changed files with 270 additions and 10 deletions

View File

@@ -14,6 +14,7 @@
# ruff: noqa: F401 # ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .cohere import CohereEmbeddingFunction from .cohere import CohereEmbeddingFunction
from .instructor import InstructorEmbeddingFunction
from .open_clip import OpenClipEmbeddings from .open_clip import OpenClipEmbeddings
from .openai import OpenAIEmbeddings from .openai import OpenAIEmbeddings
from .registry import EmbeddingFunctionRegistry, get_registry from .registry import EmbeddingFunctionRegistry, get_registry

View File

@@ -1,3 +1,15 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union
@@ -125,6 +137,14 @@ class EmbeddingFunction(BaseModel, ABC):
""" """
return Field(json_schema_extra={"vector_column_for": self}, **kwargs) return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
def __eq__(self, __value: object) -> bool:
if not hasattr(__value, "__dict__"):
return False
return vars(self) == vars(__value)
def __hash__(self) -> int:
return hash(frozenset(vars(self).items()))
class EmbeddingFunctionConfig(BaseModel): class EmbeddingFunctionConfig(BaseModel):
""" """

View File

@@ -0,0 +1,137 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, weak_lru
@register("instructor")
class InstructorEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
variety of tasks, including text classification, sentence similarity, and document retrieval.
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
"Represent the `domain` `text_type` for `task_objective`":
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
"Represent the document for retreival"
Parameters
----------
name: str
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
The default model is hkunlp/instructor-base
batch_size: int, default 32
The batch size to use when generating embeddings
device: str, default "cpu"
The device to use when generating embeddings
show_progress_bar: bool, default True
Whether to show a progress bar when generating embeddings
normalize_embeddings: bool, default True
Whether to normalize the embeddings
quantize: bool, default False
Whether to quantize the model
source_instruction: str, default "represent the docuement for retreival"
The instruction for the source column
query_instruction: str, default "represent the document for retreiving the most similar documents"
The instruction for the query
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
instructor = get_registry().get("instructor").create(
source_instruction="represent the docuement for retreival",
query_instruction="represent the document for retreiving the most similar documents"
)
class Schema(LanceModel):
vector: Vector(instructor.ndims()) = instructor.VectorField()
text: str = instructor.SourceField()
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=Schema, mode="overwrite")
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
tbl.add(texts)
"""
name: str = "hkunlp/instructor-base"
batch_size: int = 32
device: str = "cpu"
show_progress_bar: bool = True
normalize_embeddings: bool = True
quantize: bool = False
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
source_instruction: str = "represent the document for retrieval"
query_instruction: str = (
"represent the document for retrieving the most similar documents"
)
@weak_lru(maxsize=1)
def ndims(self):
model = self.get_model()
return model.encode("foo").shape[0]
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.generate_embeddings([[self.query_instruction, query]])
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
texts_formatted = []
for text in texts:
texts_formatted.append([self.source_instruction, text])
return self.generate_embeddings(texts_formatted)
def generate_embeddings(self, texts: List) -> List:
model = self.get_model()
res = model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
normalize_embeddings=self.normalize_embeddings,
).tolist()
return res
@weak_lru(maxsize=1)
def get_model(self):
instructor_embedding = self.safe_import(
"InstructorEmbedding", "InstructorEmbedding"
)
torch = self.safe_import("torch", "torch")
model = instructor_embedding.INSTRUCTOR(self.name)
if self.quantize:
if (
"qnnpack" in torch.backends.quantized.supported_engines
): # fix for https://github.com/pytorch/pytorch/issues/29327
torch.backends.quantized.engine = "qnnpack"
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return model

View File

@@ -1,3 +1,15 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures import concurrent.futures
import io import io
import os import os

View File

@@ -1,3 +1,15 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np

View File

@@ -1,3 +1,15 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
@@ -5,6 +17,7 @@ from cachetools import cached
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import weak_lru
@register("sentence-transformers") @register("sentence-transformers")
@@ -30,7 +43,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
name and device. This is cached so that the model is only loaded name and device. This is cached so that the model is only loaded
once per process. once per process.
""" """
return self.__class__.get_embedding_model(self.name, self.device) return self.get_embedding_model()
def ndims(self): def ndims(self):
if self._ndims is None: if self._ndims is None:
@@ -54,9 +67,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
normalize_embeddings=self.normalize, normalize_embeddings=self.normalize,
).tolist() ).tolist()
@classmethod @weak_lru(maxsize=1)
@cached(cache={}) def get_embedding_model(self):
def get_embedding_model(cls, name, device):
""" """
Get the sentence-transformers embedding model specified by the Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded name and device. This is cached so that the model is only loaded
@@ -71,7 +83,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
TODO: use lru_cache instead with a reasonable/configurable maxsize TODO: use lru_cache instead with a reasonable/configurable maxsize
""" """
sentence_transformers = cls.safe_import( sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers" "sentence_transformers", "sentence-transformers"
) )
return sentence_transformers.SentenceTransformer(name, device=device) return sentence_transformers.SentenceTransformer(self.name, device=self.device)

View File

@@ -11,12 +11,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import math import math
import random import random
import socket import socket
import sys import sys
import time import time
import urllib.error import urllib.error
import weakref
from typing import Callable, List, Union from typing import Callable, List, Union
import numpy as np import numpy as np
@@ -164,6 +166,50 @@ class FunctionWrapper:
yield from _chunker(arr) yield from _chunker(arr)
def weak_lru(maxsize=128):
"""
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
is bounded.
Parameters
----------
maxsize : int, default 128
The maximum number of objects to cache.
Returns
-------
Callable
A decorator that can be applied to a method.
Examples
--------
>>> class Foo:
... @weak_lru()
... def bar(self, x):
... return x
>>> foo = Foo()
>>> foo.bar(1)
1
>>> foo.bar(2)
2
>>> foo.bar(1)
1
"""
def wrapper(func):
@functools.lru_cache(maxsize)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def retry_with_exponential_backoff( def retry_with_exponential_backoff(
func, func,
initial_delay: float = 1, initial_delay: float = 1,

View File

@@ -53,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
dev = ["ruff", "pre-commit", "black"] dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"] embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
[project.scripts] [project.scripts]
lancedb = "lancedb.cli.cli:cli" lancedb = "lancedb.cli.cli:cli"

View File

@@ -32,8 +32,8 @@ from lancedb.pydantic import LanceModel, Vector
def test_sentence_transformer(alias, tmp_path): def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
registry = get_registry() registry = get_registry()
func = registry.get(alias).create() func = registry.get(alias).create(max_retries=0)
func2 = registry.get(alias).create() func2 = registry.get(alias).create(max_retries=0)
class Words(LanceModel): class Words(LanceModel):
text: str = func.SourceField() text: str = func.SourceField()
@@ -150,7 +150,11 @@ def test_openclip(tmp_path):
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed ) # also skip if cohere not installed
def test_cohere_embedding_function(): def test_cohere_embedding_function():
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0") cohere = (
get_registry()
.get("cohere")
.create(name="embed-multilingual-v2.0", max_retries=0)
)
class TextModel(LanceModel): class TextModel(LanceModel):
text: str = cohere.SourceField() text: str = cohere.SourceField()
@@ -162,3 +166,19 @@ def test_cohere_embedding_function():
tbl.add(df) tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims() assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
@pytest.mark.slow
def test_instructor_embedding(tmp_path):
model = get_registry().get("instructor").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()