diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index 99e6d314..d1944106 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -14,6 +14,7 @@ # ruff: noqa: F401 from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .cohere import CohereEmbeddingFunction +from .instructor import InstructorEmbeddingFunction from .open_clip import OpenClipEmbeddings from .openai import OpenAIEmbeddings from .registry import EmbeddingFunctionRegistry, get_registry diff --git a/python/lancedb/embeddings/base.py b/python/lancedb/embeddings/base.py index e3e34608..04d1360d 100644 --- a/python/lancedb/embeddings/base.py +++ b/python/lancedb/embeddings/base.py @@ -1,3 +1,15 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib from abc import ABC, abstractmethod from typing import List, Union @@ -125,6 +137,14 @@ class EmbeddingFunction(BaseModel, ABC): """ return Field(json_schema_extra={"vector_column_for": self}, **kwargs) + def __eq__(self, __value: object) -> bool: + if not hasattr(__value, "__dict__"): + return False + return vars(self) == vars(__value) + + def __hash__(self) -> int: + return hash(frozenset(vars(self).items())) + class EmbeddingFunctionConfig(BaseModel): """ diff --git a/python/lancedb/embeddings/instructor.py b/python/lancedb/embeddings/instructor.py new file mode 100644 index 00000000..53be8ccb --- /dev/null +++ b/python/lancedb/embeddings/instructor.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import numpy as np + +from .base import TextEmbeddingFunction +from .registry import register +from .utils import TEXT, weak_lru + + +@register("instructor") +class InstructorEmbeddingFunction(TextEmbeddingFunction): + """ + An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a + variety of tasks, including text classification, sentence similarity, and document retrieval. + If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions: + "Represent the `domain` `text_type` for `task_objective`": + + * domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc. + * text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc. + * task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc. + + For example, if you want to calculate embeddings for a document, you may write the instruction as follows: + "Represent the document for retreival" + + Parameters + ---------- + name: str + The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list; + The default model is hkunlp/instructor-base + batch_size: int, default 32 + The batch size to use when generating embeddings + device: str, default "cpu" + The device to use when generating embeddings + show_progress_bar: bool, default True + Whether to show a progress bar when generating embeddings + normalize_embeddings: bool, default True + Whether to normalize the embeddings + quantize: bool, default False + Whether to quantize the model + source_instruction: str, default "represent the docuement for retreival" + The instruction for the source column + query_instruction: str, default "represent the document for retreiving the most similar documents" + The instruction for the query + + Examples + -------- + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction + + instructor = get_registry().get("instructor").create( + source_instruction="represent the docuement for retreival", + query_instruction="represent the document for retreiving the most similar documents" + ) + + class Schema(LanceModel): + vector: Vector(instructor.ndims()) = instructor.VectorField() + text: str = instructor.SourceField() + + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=Schema, mode="overwrite") + + texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."}, + {"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."}, + {"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}] + + tbl.add(texts) + + """ + + name: str = "hkunlp/instructor-base" + batch_size: int = 32 + device: str = "cpu" + show_progress_bar: bool = True + normalize_embeddings: bool = True + quantize: bool = False + # convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly + + source_instruction: str = "represent the document for retrieval" + query_instruction: str = ( + "represent the document for retrieving the most similar documents" + ) + + @weak_lru(maxsize=1) + def ndims(self): + model = self.get_model() + return model.encode("foo").shape[0] + + def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + return self.generate_embeddings([[self.query_instruction, query]]) + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + texts_formatted = [] + for text in texts: + texts_formatted.append([self.source_instruction, text]) + return self.generate_embeddings(texts_formatted) + + def generate_embeddings(self, texts: List) -> List: + model = self.get_model() + res = model.encode( + texts, + batch_size=self.batch_size, + show_progress_bar=self.show_progress_bar, + normalize_embeddings=self.normalize_embeddings, + ).tolist() + return res + + @weak_lru(maxsize=1) + def get_model(self): + instructor_embedding = self.safe_import( + "InstructorEmbedding", "InstructorEmbedding" + ) + torch = self.safe_import("torch", "torch") + + model = instructor_embedding.INSTRUCTOR(self.name) + if self.quantize: + if ( + "qnnpack" in torch.backends.quantized.supported_engines + ): # fix for https://github.com/pytorch/pytorch/issues/29327 + torch.backends.quantized.engine = "qnnpack" + model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + return model diff --git a/python/lancedb/embeddings/open_clip.py b/python/lancedb/embeddings/open_clip.py index 37023377..1e9aeb6e 100644 --- a/python/lancedb/embeddings/open_clip.py +++ b/python/lancedb/embeddings/open_clip.py @@ -1,3 +1,15 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import concurrent.futures import io import os diff --git a/python/lancedb/embeddings/openai.py b/python/lancedb/embeddings/openai.py index 25459743..406ed40f 100644 --- a/python/lancedb/embeddings/openai.py +++ b/python/lancedb/embeddings/openai.py @@ -1,3 +1,15 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Union import numpy as np diff --git a/python/lancedb/embeddings/sentence_transformers.py b/python/lancedb/embeddings/sentence_transformers.py index 5e40a51d..995ab7bd 100644 --- a/python/lancedb/embeddings/sentence_transformers.py +++ b/python/lancedb/embeddings/sentence_transformers.py @@ -1,3 +1,15 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Union import numpy as np @@ -5,6 +17,7 @@ from cachetools import cached from .base import TextEmbeddingFunction from .registry import register +from .utils import weak_lru @register("sentence-transformers") @@ -30,7 +43,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): name and device. This is cached so that the model is only loaded once per process. """ - return self.__class__.get_embedding_model(self.name, self.device) + return self.get_embedding_model() def ndims(self): if self._ndims is None: @@ -54,9 +67,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): normalize_embeddings=self.normalize, ).tolist() - @classmethod - @cached(cache={}) - def get_embedding_model(cls, name, device): + @weak_lru(maxsize=1) + def get_embedding_model(self): """ Get the sentence-transformers embedding model specified by the name and device. This is cached so that the model is only loaded @@ -71,7 +83,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): TODO: use lru_cache instead with a reasonable/configurable maxsize """ - sentence_transformers = cls.safe_import( + sentence_transformers = self.safe_import( "sentence_transformers", "sentence-transformers" ) - return sentence_transformers.SentenceTransformer(name, device=device) + return sentence_transformers.SentenceTransformer(self.name, device=self.device) diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 1308b358..59ed0460 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -11,12 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import math import random import socket import sys import time import urllib.error +import weakref from typing import Callable, List, Union import numpy as np @@ -164,6 +166,50 @@ class FunctionWrapper: yield from _chunker(arr) +def weak_lru(maxsize=128): + """ + LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage + is bounded. + + Parameters + ---------- + maxsize : int, default 128 + The maximum number of objects to cache. + + Returns + ------- + Callable + A decorator that can be applied to a method. + + Examples + -------- + >>> class Foo: + ... @weak_lru() + ... def bar(self, x): + ... return x + >>> foo = Foo() + >>> foo.bar(1) + 1 + >>> foo.bar(2) + 2 + >>> foo.bar(1) + 1 + """ + + def wrapper(func): + @functools.lru_cache(maxsize) + def _func(_self, *args, **kwargs): + return func(_self(), *args, **kwargs) + + @functools.wraps(func) + def inner(self, *args, **kwargs): + return _func(weakref.ref(self), *args, **kwargs) + + return inner + + return wrapper + + def retry_with_exponential_backoff( func, initial_delay: float = 1, diff --git a/python/pyproject.toml b/python/pyproject.toml index fcec9887..ad03960b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,7 +50,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"] +embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"] [build-system] requires = ["setuptools", "wheel"] diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index b0078397..2e116827 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -32,8 +32,8 @@ from lancedb.pydantic import LanceModel, Vector def test_sentence_transformer(alias, tmp_path): db = lancedb.connect(tmp_path) registry = get_registry() - func = registry.get(alias).create() - func2 = registry.get(alias).create() + func = registry.get(alias).create(max_retries=0) + func2 = registry.get(alias).create(max_retries=0) class Words(LanceModel): text: str = func.SourceField() @@ -150,7 +150,11 @@ def test_openclip(tmp_path): os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" ) # also skip if cohere not installed def test_cohere_embedding_function(): - cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0") + cohere = ( + get_registry() + .get("cohere") + .create(name="embed-multilingual-v2.0", max_retries=0) + ) class TextModel(LanceModel): text: str = cohere.SourceField() @@ -162,3 +166,19 @@ def test_cohere_embedding_function(): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims() + + +@pytest.mark.slow +def test_instructor_embedding(tmp_path): + model = get_registry().get("instructor").create() + + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims()