mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
143 lines
5.4 KiB
Python
143 lines
5.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from ..util import attempt_import_or_raise
|
|
from .base import TextEmbeddingFunction
|
|
from .registry import register
|
|
from .utils import TEXT, weak_lru
|
|
|
|
|
|
@register("instructor")
|
|
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
|
"""
|
|
An embedding function that uses the InstructorEmbedding library. Instructor models
|
|
support multi-task learning, and can be used for a variety of tasks, including
|
|
text classification, sentence similarity, and document retrieval. If you want to
|
|
calculate customized embeddings for specific sentences, you may follow the unified
|
|
template to write instructions:
|
|
"Represent the `domain` `text_type` for `task_objective`":
|
|
|
|
* domain is optional, and it specifies the domain of the text, e.g., science,
|
|
finance, medicine, etc.
|
|
* text_type is required, and it specifies the encoding unit, e.g., sentence,
|
|
document, paragraph, etc.
|
|
* task_objective is optional, and it specifies the objective of embedding,
|
|
e.g., retrieve a document, classify the sentence, etc.
|
|
|
|
For example, if you want to calculate embeddings for a document, you may write the
|
|
instruction as follows:
|
|
"Represent the document for retrieval"
|
|
|
|
Parameters
|
|
----------
|
|
name: str
|
|
The name of the model to use. Available models are listed at
|
|
https://github.com/xlang-ai/instructor-embedding#model-list;
|
|
The default model is hkunlp/instructor-base
|
|
batch_size: int, default 32
|
|
The batch size to use when generating embeddings
|
|
device: str, default "cpu"
|
|
The device to use when generating embeddings
|
|
show_progress_bar: bool, default True
|
|
Whether to show a progress bar when generating embeddings
|
|
normalize_embeddings: bool, default True
|
|
Whether to normalize the embeddings
|
|
quantize: bool, default False
|
|
Whether to quantize the model
|
|
source_instruction: str, default "represent the document for retrieval"
|
|
The instruction for the source column
|
|
query_instruction: str, default "represent the document for retrieving the most
|
|
similar documents"
|
|
The instruction for the query
|
|
|
|
Examples
|
|
--------
|
|
|
|
import lancedb
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
|
|
|
instructor = get_registry().get("instructor").create(
|
|
source_instruction="represent the document for retrieval",
|
|
query_instruction="represent the document for retrieving the most "
|
|
"similar documents"
|
|
)
|
|
|
|
class Schema(LanceModel):
|
|
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
|
text: str = instructor.SourceField()
|
|
|
|
db = lancedb.connect("~/.lancedb")
|
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
|
|
|
texts = [{"text": "Capitalism has been dominant in the Western world since the "
|
|
"end of feudalism, but most feel[who?] that..."},
|
|
{"text": "The disparate impact theory is especially controversial under "
|
|
"the Fair Housing Act because the Act..."},
|
|
{"text": "Disparate impact in United States labor law refers to practices "
|
|
"in employment, housing, and other areas that.."}]
|
|
|
|
tbl.add(texts)
|
|
|
|
"""
|
|
|
|
name: str = "hkunlp/instructor-base"
|
|
batch_size: int = 32
|
|
device: str = "cpu"
|
|
show_progress_bar: bool = True
|
|
normalize_embeddings: bool = True
|
|
quantize: bool = False
|
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
|
|
|
source_instruction: str = "represent the document for retrieval"
|
|
query_instruction: str = (
|
|
"represent the document for retrieving the most similar documents"
|
|
)
|
|
|
|
@weak_lru(maxsize=1)
|
|
def ndims(self):
|
|
model = self.get_model()
|
|
return model.encode("foo").shape[0]
|
|
|
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
|
return self.generate_embeddings([[self.query_instruction, query]])
|
|
|
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
|
texts = self.sanitize_input(texts)
|
|
texts_formatted = [[self.source_instruction, text] for text in texts]
|
|
return self.generate_embeddings(texts_formatted)
|
|
|
|
def generate_embeddings(self, texts: List) -> List:
|
|
model = self.get_model()
|
|
res = model.encode(
|
|
texts,
|
|
batch_size=self.batch_size,
|
|
show_progress_bar=self.show_progress_bar,
|
|
normalize_embeddings=self.normalize_embeddings,
|
|
device=self.device,
|
|
).tolist()
|
|
return res
|
|
|
|
@weak_lru(maxsize=1)
|
|
def get_model(self):
|
|
instructor_embedding = attempt_import_or_raise(
|
|
"InstructorEmbedding", "InstructorEmbedding"
|
|
)
|
|
torch = attempt_import_or_raise("torch", "torch")
|
|
|
|
model = instructor_embedding.INSTRUCTOR(self.name)
|
|
if self.quantize:
|
|
if (
|
|
"qnnpack" in torch.backends.quantized.supported_engines
|
|
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
|
torch.backends.quantized.engine = "qnnpack"
|
|
model = torch.quantization.quantize_dynamic(
|
|
model, {torch.nn.Linear}, dtype=torch.qint8
|
|
)
|
|
return model
|