feat(python): hybrid search updates, examples, & latency benchmarks (#964)

- Rename safe_import -> attempt_import_or_raise (closes
https://github.com/lancedb/lancedb/pull/923)
- Update docs
- Add Notebook example (@changhiskhan you can use it for the talk. Comes
with "open in colab" button)
- Latency benchmark & results comparison, sanity check on real-world
data
- Updates the default openai model to gpt-4
This commit is contained in:
Ayush Chaurasia
2024-02-13 17:58:39 +05:30
committed by GitHub
parent 3169c36525
commit eb31d95fef
20 changed files with 1209 additions and 80 deletions

View File

@@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from abc import ABC, abstractmethod
from typing import List, Union
@@ -91,25 +90,6 @@ class EmbeddingFunction(BaseModel, ABC):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION

View File

@@ -19,6 +19,7 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT
@@ -183,8 +184,8 @@ class BedRockText(TextEmbeddingFunction):
boto3.client
The boto3 client for Amazon Bedrock service
"""
botocore = self.safe_import("botocore")
boto3 = self.safe_import("boto3")
botocore = attempt_import_or_raise("botocore")
boto3 = attempt_import_or_raise("boto3")
session_kwargs = {"region_name": self.region}
client_kwargs = {**session_kwargs}

View File

@@ -16,6 +16,7 @@ from typing import ClassVar, List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
@@ -84,7 +85,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = self.safe_import("cohere")
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere")

View File

@@ -19,6 +19,7 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, api_key_not_found_help
@@ -134,7 +135,7 @@ class GeminiText(TextEmbeddingFunction):
@cached_property
def client(self):
genai = self.safe_import("google.generativeai", "google.generativeai")
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
if not os.environ.get("GOOGLE_API_KEY"):
api_key_not_found_help("google")

View File

@@ -14,6 +14,7 @@ from typing import List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@@ -122,7 +123,7 @@ class GteEmbeddings(TextEmbeddingFunction):
return Model()
else:
sentence_transformers = self.safe_import(
sentence_transformers = attempt_import_or_raise(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(

View File

@@ -14,6 +14,7 @@ from typing import List
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, weak_lru
@@ -131,10 +132,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
@weak_lru(maxsize=1)
def get_model(self):
instructor_embedding = self.safe_import(
instructor_embedding = attempt_import_or_raise(
"InstructorEmbedding", "InstructorEmbedding"
)
torch = self.safe_import("torch", "torch")
torch = attempt_import_or_raise("torch", "torch")
model = instructor_embedding.INSTRUCTOR(self.name)
if self.quantize:

View File

@@ -21,6 +21,7 @@ import pyarrow as pa
from pydantic import PrivateAttr
from tqdm import tqdm
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
@@ -50,7 +51,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
open_clip = attempt_import_or_raise("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
@@ -78,14 +79,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
@@ -144,7 +145,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
@@ -152,7 +153,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):

View File

@@ -16,6 +16,7 @@ from typing import List, Optional, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
@@ -68,7 +69,7 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
@cached_property
def _openai_client(self):
openai = self.safe_import("openai")
openai = attempt_import_or_raise("openai")
if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai")

View File

@@ -14,6 +14,7 @@ from typing import List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@@ -75,7 +76,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = self.safe_import(
sentence_transformers = attempt_import_or_raise(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(self.name, device=self.device)

View File

@@ -4,7 +4,7 @@ from typing import Union
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -41,7 +41,7 @@ class CohereReranker(Reranker):
@cached_property
def _client(self):
cohere = safe_import("cohere")
cohere = attempt_import_or_raise("cohere")
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \

View File

@@ -2,7 +2,7 @@ from functools import cached_property
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -29,7 +29,9 @@ class ColbertReranker(Reranker):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = safe_import("torch") # import here for faster ops later
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
def rerank_hybrid(
self,
@@ -80,7 +82,7 @@ class ColbertReranker(Reranker):
@cached_property
def _model(self):
transformers = safe_import("transformers")
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)

View File

@@ -3,7 +3,7 @@ from typing import Union
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -32,7 +32,7 @@ class CrossEncoderReranker(Reranker):
return_score="relevance",
):
super().__init__(return_score)
torch = safe_import("torch")
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
@@ -41,7 +41,7 @@ class CrossEncoderReranker(Reranker):
@cached_property
def model(self):
sbert = safe_import("sentence_transformers")
sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
return cross_encoder

View File

@@ -5,7 +5,7 @@ from typing import Optional
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -17,7 +17,7 @@ class OpenaiReranker(Reranker):
Parameters
----------
model_name : str, default "gpt-3.5-turbo-1106 "
model_name : str, default "gpt-4-turbo-preview"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
@@ -29,7 +29,7 @@ class OpenaiReranker(Reranker):
def __init__(
self,
model_name: str = "gpt-3.5-turbo-1106",
model_name: str = "gpt-4-turbo-preview",
column: str = "text",
return_score="relevance",
api_key: Optional[str] = None,
@@ -93,7 +93,9 @@ class OpenaiReranker(Reranker):
@cached_property
def _client(self):
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
openai = attempt_import_or_raise(
"openai"
) # TODO: force version or handle versions < 1.0
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
raise ValueError(
"OPENAI_API_KEY not set. Either set it in your environment or \

View File

@@ -116,7 +116,7 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import(module: str, mitigation=None):
def attempt_import_or_raise(module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.