feat(python): hybrid search updates, examples, & latency benchmarks (#964)

- Rename safe_import -> attempt_import_or_raise (closes
https://github.com/lancedb/lancedb/pull/923)
- Update docs
- Add Notebook example (@changhiskhan you can use it for the talk. Comes
with "open in colab" button)
- Latency benchmark & results comparison, sanity check on real-world
data
- Updates the default openai model to gpt-4
This commit is contained in:
Ayush Chaurasia
2024-02-13 17:58:39 +05:30
committed by Weston Pace
parent 1045af6c09
commit 510e8378bc
20 changed files with 1209 additions and 80 deletions

View File

@@ -90,7 +90,9 @@ nav:
- Building an ANN index: ann_indexes.md
- Vector Search: search.md
- Full-text search: fts.md
- Hybrid search: hybrid_search.md
- Hybrid search:
- hybrid_search/hybrid_search.md
- AirBNB financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md
@@ -151,7 +153,9 @@ nav:
- Building an ANN index: ann_indexes.md
- Vector Search: search.md
- Full-text search: fts.md
- Hybrid search: hybrid_search.md
- Hybrid search:
- hybrid_search/hybrid_search.md
- AirBNB financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md

View File

@@ -17,6 +17,7 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
```python
from lancedb.embeddings import register
from lancedb.util import attempt_import_or_raise
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
@@ -81,7 +82,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
@@ -109,14 +110,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
@@ -175,7 +176,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
@@ -183,7 +184,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):

View File

@@ -9,6 +9,9 @@ Contains the text embedding functions registered by default.
### Sentence transformers
Allows you to set parameters when registering a `sentence-transformers` object.
!!! info
Sentence transformer embeddings are normalized by default. It is recommended to use normalized embeddings for similarity search.
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |

View File

@@ -69,7 +69,7 @@ reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vec
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
### Arguments
----------------
* `weight`: `float`, default `0.7`:
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
@@ -91,9 +91,9 @@ reranker = CohereReranker()
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
### Arguments
----------------
* `model_name`` : str, default `"rerank-english-v2.0"``
* `model_name` : str, default `"rerank-english-v2.0"`
The name of the cross encoder model to use. Available cohere models are:
- rerank-english-v2.0
- rerank-multilingual-v2.0
@@ -117,7 +117,7 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
```
Arguments
### Arguments
----------------
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
@@ -143,7 +143,7 @@ reranker = ColbertReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
### Arguments
----------------
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
The name of the cross encoder model to use.
@@ -162,7 +162,8 @@ This reranker uses the OpenAI API to combine the results of semantic and full-te
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
!!! Tip
You might run out of token limit so set the search `limits` based on your token limit.
- You might run out of token limit so set the search `limits` based on your token limit.
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
```python
from lancedb.rerankers import OpenaiReranker
@@ -172,15 +173,15 @@ reranker = OpenaiReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
### Arguments
----------------
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
The name of the cross encoder model to use.
`column` : `str`, default `"text"`
* `column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
`return_score` : `str`, default `"relevance"`
* `return_score` : `str`, default `"relevance"`
options are "relevance" or "all". Only "relevance" is supported for now.
`api_key` : `str`, default `None`
* `api_key` : `str`, default `None`
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
@@ -212,24 +213,30 @@ class MyReranker(Reranker):
```
You can also accept additional arguments like a filter along with fts and vector search results
### Example of a Custom Reranker
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
```python
from lancedb.rerankers import Reranker
import pyarrow as pa
from typing import List, Union
import pandas as pd
from lancedb.rerankers import CohereReranker
class MyReranker(Reranker):
...
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
# Do something with the combined results & filter
# ...
class MofidifiedCohereReranker(CohereReranker):
def __init__(self, filters: Union[str, List[str]], **kwargs):
super().__init__(**kwargs)
filters = filters if isinstance(filters, list) else [filters]
self.filters = filters
# Return the combined results
return combined_result
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
df = combined_result.to_pandas()
for filter in self.filters:
df = df.query("not text.str.contains(@filter)")
return pa.Table.from_pandas(df)
```
!!! tip
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ excluded_globs = [
"../src/concepts/*.md",
"../src/ann_indexes.md",
"../src/basic.md",
"../src/hybrid_search.md",
"../src/hybrid_search/hybrid_search.md",
]
python_prefix = "py"

View File

@@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from abc import ABC, abstractmethod
from typing import List, Union
@@ -91,25 +90,6 @@ class EmbeddingFunction(BaseModel, ABC):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION

View File

@@ -19,6 +19,7 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT
@@ -183,8 +184,8 @@ class BedRockText(TextEmbeddingFunction):
boto3.client
The boto3 client for Amazon Bedrock service
"""
botocore = self.safe_import("botocore")
boto3 = self.safe_import("boto3")
botocore = attempt_import_or_raise("botocore")
boto3 = attempt_import_or_raise("boto3")
session_kwargs = {"region_name": self.region}
client_kwargs = {**session_kwargs}

View File

@@ -16,6 +16,7 @@ from typing import ClassVar, List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
@@ -84,7 +85,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = self.safe_import("cohere")
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere")

View File

@@ -19,6 +19,7 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, api_key_not_found_help
@@ -134,7 +135,7 @@ class GeminiText(TextEmbeddingFunction):
@cached_property
def client(self):
genai = self.safe_import("google.generativeai", "google.generativeai")
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
if not os.environ.get("GOOGLE_API_KEY"):
api_key_not_found_help("google")

View File

@@ -14,6 +14,7 @@ from typing import List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@@ -122,7 +123,7 @@ class GteEmbeddings(TextEmbeddingFunction):
return Model()
else:
sentence_transformers = self.safe_import(
sentence_transformers = attempt_import_or_raise(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(

View File

@@ -14,6 +14,7 @@ from typing import List
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, weak_lru
@@ -131,10 +132,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
@weak_lru(maxsize=1)
def get_model(self):
instructor_embedding = self.safe_import(
instructor_embedding = attempt_import_or_raise(
"InstructorEmbedding", "InstructorEmbedding"
)
torch = self.safe_import("torch", "torch")
torch = attempt_import_or_raise("torch", "torch")
model = instructor_embedding.INSTRUCTOR(self.name)
if self.quantize:

View File

@@ -21,6 +21,7 @@ import pyarrow as pa
from pydantic import PrivateAttr
from tqdm import tqdm
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
@@ -50,7 +51,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
open_clip = attempt_import_or_raise("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
@@ -78,14 +79,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
@@ -144,7 +145,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
torch = attempt_import_or_raise("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
@@ -152,7 +153,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):

View File

@@ -16,6 +16,7 @@ from typing import List, Optional, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
@@ -68,7 +69,7 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
@cached_property
def _openai_client(self):
openai = self.safe_import("openai")
openai = attempt_import_or_raise("openai")
if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai")

View File

@@ -14,6 +14,7 @@ from typing import List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@@ -75,7 +76,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = self.safe_import(
sentence_transformers = attempt_import_or_raise(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(self.name, device=self.device)

View File

@@ -4,7 +4,7 @@ from typing import Union
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -41,7 +41,7 @@ class CohereReranker(Reranker):
@cached_property
def _client(self):
cohere = safe_import("cohere")
cohere = attempt_import_or_raise("cohere")
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \

View File

@@ -2,7 +2,7 @@ from functools import cached_property
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -29,7 +29,9 @@ class ColbertReranker(Reranker):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = safe_import("torch") # import here for faster ops later
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
def rerank_hybrid(
self,
@@ -80,7 +82,7 @@ class ColbertReranker(Reranker):
@cached_property
def _model(self):
transformers = safe_import("transformers")
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)

View File

@@ -3,7 +3,7 @@ from typing import Union
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -32,7 +32,7 @@ class CrossEncoderReranker(Reranker):
return_score="relevance",
):
super().__init__(return_score)
torch = safe_import("torch")
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
@@ -41,7 +41,7 @@ class CrossEncoderReranker(Reranker):
@cached_property
def model(self):
sbert = safe_import("sentence_transformers")
sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
return cross_encoder

View File

@@ -5,7 +5,7 @@ from typing import Optional
import pyarrow as pa
from ..util import safe_import
from ..util import attempt_import_or_raise
from .base import Reranker
@@ -17,7 +17,7 @@ class OpenaiReranker(Reranker):
Parameters
----------
model_name : str, default "gpt-3.5-turbo-1106 "
model_name : str, default "gpt-4-turbo-preview"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
@@ -29,7 +29,7 @@ class OpenaiReranker(Reranker):
def __init__(
self,
model_name: str = "gpt-3.5-turbo-1106",
model_name: str = "gpt-4-turbo-preview",
column: str = "text",
return_score="relevance",
api_key: Optional[str] = None,
@@ -93,7 +93,9 @@ class OpenaiReranker(Reranker):
@cached_property
def _client(self):
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
openai = attempt_import_or_raise(
"openai"
) # TODO: force version or handle versions < 1.0
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
raise ValueError(
"OPENAI_API_KEY not set. Either set it in your environment or \

View File

@@ -116,7 +116,7 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import(module: str, mitigation=None):
def attempt_import_or_raise(module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.