mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
Multi-task instructor model with quantization support & weak_lru cache for embedding function models (#612)
resolves #608
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# ruff: noqa: F401
|
# ruff: noqa: F401
|
||||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||||
from .cohere import CohereEmbeddingFunction
|
from .cohere import CohereEmbeddingFunction
|
||||||
|
from .instructor import InstructorEmbeddingFunction
|
||||||
from .open_clip import OpenClipEmbeddings
|
from .open_clip import OpenClipEmbeddings
|
||||||
from .openai import OpenAIEmbeddings
|
from .openai import OpenAIEmbeddings
|
||||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||||
|
|||||||
@@ -1,3 +1,15 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@@ -125,6 +137,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not hasattr(__value, "__dict__"):
|
||||||
|
return False
|
||||||
|
return vars(self) == vars(__value)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(frozenset(vars(self).items()))
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionConfig(BaseModel):
|
class EmbeddingFunctionConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
137
python/lancedb/embeddings/instructor.py
Normal file
137
python/lancedb/embeddings/instructor.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import TEXT, weak_lru
|
||||||
|
|
||||||
|
|
||||||
|
@register("instructor")
|
||||||
|
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||||
|
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||||
|
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||||
|
"Represent the `domain` `text_type` for `task_objective`":
|
||||||
|
|
||||||
|
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||||
|
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||||
|
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||||
|
|
||||||
|
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||||
|
"Represent the document for retreival"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||||
|
The default model is hkunlp/instructor-base
|
||||||
|
batch_size: int, default 32
|
||||||
|
The batch size to use when generating embeddings
|
||||||
|
device: str, default "cpu"
|
||||||
|
The device to use when generating embeddings
|
||||||
|
show_progress_bar: bool, default True
|
||||||
|
Whether to show a progress bar when generating embeddings
|
||||||
|
normalize_embeddings: bool, default True
|
||||||
|
Whether to normalize the embeddings
|
||||||
|
quantize: bool, default False
|
||||||
|
Whether to quantize the model
|
||||||
|
source_instruction: str, default "represent the docuement for retreival"
|
||||||
|
The instruction for the source column
|
||||||
|
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||||
|
The instruction for the query
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||||
|
|
||||||
|
instructor = get_registry().get("instructor").create(
|
||||||
|
source_instruction="represent the docuement for retreival",
|
||||||
|
query_instruction="represent the document for retreiving the most similar documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||||
|
text: str = instructor.SourceField()
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
|
|
||||||
|
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||||
|
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||||
|
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||||
|
|
||||||
|
tbl.add(texts)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "hkunlp/instructor-base"
|
||||||
|
batch_size: int = 32
|
||||||
|
device: str = "cpu"
|
||||||
|
show_progress_bar: bool = True
|
||||||
|
normalize_embeddings: bool = True
|
||||||
|
quantize: bool = False
|
||||||
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||||
|
|
||||||
|
source_instruction: str = "represent the document for retrieval"
|
||||||
|
query_instruction: str = (
|
||||||
|
"represent the document for retrieving the most similar documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def ndims(self):
|
||||||
|
model = self.get_model()
|
||||||
|
return model.encode("foo").shape[0]
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.generate_embeddings([[self.query_instruction, query]])
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
texts_formatted = []
|
||||||
|
for text in texts:
|
||||||
|
texts_formatted.append([self.source_instruction, text])
|
||||||
|
return self.generate_embeddings(texts_formatted)
|
||||||
|
|
||||||
|
def generate_embeddings(self, texts: List) -> List:
|
||||||
|
model = self.get_model()
|
||||||
|
res = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
show_progress_bar=self.show_progress_bar,
|
||||||
|
normalize_embeddings=self.normalize_embeddings,
|
||||||
|
).tolist()
|
||||||
|
return res
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def get_model(self):
|
||||||
|
instructor_embedding = self.safe_import(
|
||||||
|
"InstructorEmbedding", "InstructorEmbedding"
|
||||||
|
)
|
||||||
|
torch = self.safe_import("torch", "torch")
|
||||||
|
|
||||||
|
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||||
|
if self.quantize:
|
||||||
|
if (
|
||||||
|
"qnnpack" in torch.backends.quantized.supported_engines
|
||||||
|
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
||||||
|
torch.backends.quantized.engine = "qnnpack"
|
||||||
|
model = torch.quantization.quantize_dynamic(
|
||||||
|
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
return model
|
||||||
@@ -1,3 +1,15 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,3 +1,15 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -1,3 +1,15 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -5,6 +17,7 @@ from cachetools import cached
|
|||||||
|
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
|
from .utils import weak_lru
|
||||||
|
|
||||||
|
|
||||||
@register("sentence-transformers")
|
@register("sentence-transformers")
|
||||||
@@ -30,7 +43,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
name and device. This is cached so that the model is only loaded
|
name and device. This is cached so that the model is only loaded
|
||||||
once per process.
|
once per process.
|
||||||
"""
|
"""
|
||||||
return self.__class__.get_embedding_model(self.name, self.device)
|
return self.get_embedding_model()
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
if self._ndims is None:
|
if self._ndims is None:
|
||||||
@@ -54,9 +67,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
normalize_embeddings=self.normalize,
|
normalize_embeddings=self.normalize,
|
||||||
).tolist()
|
).tolist()
|
||||||
|
|
||||||
@classmethod
|
@weak_lru(maxsize=1)
|
||||||
@cached(cache={})
|
def get_embedding_model(self):
|
||||||
def get_embedding_model(cls, name, device):
|
|
||||||
"""
|
"""
|
||||||
Get the sentence-transformers embedding model specified by the
|
Get the sentence-transformers embedding model specified by the
|
||||||
name and device. This is cached so that the model is only loaded
|
name and device. This is cached so that the model is only loaded
|
||||||
@@ -71,7 +83,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
|
|
||||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||||
"""
|
"""
|
||||||
sentence_transformers = cls.safe_import(
|
sentence_transformers = self.safe_import(
|
||||||
"sentence_transformers", "sentence-transformers"
|
"sentence_transformers", "sentence-transformers"
|
||||||
)
|
)
|
||||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||||
|
|||||||
@@ -11,12 +11,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import urllib.error
|
import urllib.error
|
||||||
|
import weakref
|
||||||
from typing import Callable, List, Union
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -164,6 +166,50 @@ class FunctionWrapper:
|
|||||||
yield from _chunker(arr)
|
yield from _chunker(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_lru(maxsize=128):
|
||||||
|
"""
|
||||||
|
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||||
|
is bounded.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
maxsize : int, default 128
|
||||||
|
The maximum number of objects to cache.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Callable
|
||||||
|
A decorator that can be applied to a method.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> class Foo:
|
||||||
|
... @weak_lru()
|
||||||
|
... def bar(self, x):
|
||||||
|
... return x
|
||||||
|
>>> foo = Foo()
|
||||||
|
>>> foo.bar(1)
|
||||||
|
1
|
||||||
|
>>> foo.bar(2)
|
||||||
|
2
|
||||||
|
>>> foo.bar(1)
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
@functools.lru_cache(maxsize)
|
||||||
|
def _func(_self, *args, **kwargs):
|
||||||
|
return func(_self(), *args, **kwargs)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
return _func(weakref.ref(self), *args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def retry_with_exponential_backoff(
|
def retry_with_exponential_backoff(
|
||||||
func,
|
func,
|
||||||
initial_delay: float = 1,
|
initial_delay: float = 1,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
|||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
lancedb = "lancedb.cli.cli:cli"
|
lancedb = "lancedb.cli.cli:cli"
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ from lancedb.pydantic import LanceModel, Vector
|
|||||||
def test_sentence_transformer(alias, tmp_path):
|
def test_sentence_transformer(alias, tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
func = registry.get(alias).create()
|
func = registry.get(alias).create(max_retries=0)
|
||||||
func2 = registry.get(alias).create()
|
func2 = registry.get(alias).create(max_retries=0)
|
||||||
|
|
||||||
class Words(LanceModel):
|
class Words(LanceModel):
|
||||||
text: str = func.SourceField()
|
text: str = func.SourceField()
|
||||||
@@ -150,7 +150,11 @@ def test_openclip(tmp_path):
|
|||||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||||
) # also skip if cohere not installed
|
) # also skip if cohere not installed
|
||||||
def test_cohere_embedding_function():
|
def test_cohere_embedding_function():
|
||||||
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
cohere = (
|
||||||
|
get_registry()
|
||||||
|
.get("cohere")
|
||||||
|
.create(name="embed-multilingual-v2.0", max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
class TextModel(LanceModel):
|
class TextModel(LanceModel):
|
||||||
text: str = cohere.SourceField()
|
text: str = cohere.SourceField()
|
||||||
@@ -162,3 +166,19 @@ def test_cohere_embedding_function():
|
|||||||
|
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_instructor_embedding(tmp_path):
|
||||||
|
model = get_registry().get("instructor").create()
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
|||||||
Reference in New Issue
Block a user