mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
feat(python): hybrid search updates, examples, & latency benchmarks (#964)
- Rename safe_import -> attempt_import_or_raise (closes https://github.com/lancedb/lancedb/pull/923) - Update docs - Add Notebook example (@changhiskhan you can use it for the talk. Comes with "open in colab" button) - Latency benchmark & results comparison, sanity check on real-world data - Updates the default openai model to gpt-4
This commit is contained in:
committed by
Weston Pace
parent
1045af6c09
commit
510e8378bc
@@ -90,7 +90,9 @@ nav:
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search: hybrid_search.md
|
||||
- Hybrid search:
|
||||
- hybrid_search/hybrid_search.md
|
||||
- AirBNB financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -151,7 +153,9 @@ nav:
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search: hybrid_search.md
|
||||
- Hybrid search:
|
||||
- hybrid_search/hybrid_search.md
|
||||
- AirBNB financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
|
||||
@@ -17,6 +17,7 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import register
|
||||
from lancedb.util import attempt_import_or_raise
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
@@ -81,7 +82,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||
open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
@@ -109,14 +110,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
@@ -175,7 +176,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
@@ -183,7 +184,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -9,6 +9,9 @@ Contains the text embedding functions registered by default.
|
||||
### Sentence transformers
|
||||
Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
|
||||
!!! info
|
||||
Sentence transformer embeddings are normalized by default. It is recommended to use normalized embeddings for similarity search.
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
||||
|
||||
@@ -69,7 +69,7 @@ reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vec
|
||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `weight`: `float`, default `0.7`:
|
||||
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
||||
@@ -91,9 +91,9 @@ reranker = CohereReranker()
|
||||
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name`` : str, default `"rerank-english-v2.0"``
|
||||
* `model_name` : str, default `"rerank-english-v2.0"`
|
||||
The name of the cross encoder model to use. Available cohere models are:
|
||||
- rerank-english-v2.0
|
||||
- rerank-multilingual-v2.0
|
||||
@@ -117,7 +117,7 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
|
||||
```
|
||||
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
||||
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
||||
@@ -143,7 +143,7 @@ reranker = ColbertReranker()
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||
The name of the cross encoder model to use.
|
||||
@@ -162,7 +162,8 @@ This reranker uses the OpenAI API to combine the results of semantic and full-te
|
||||
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||
|
||||
!!! Tip
|
||||
You might run out of token limit so set the search `limits` based on your token limit.
|
||||
- You might run out of token limit so set the search `limits` based on your token limit.
|
||||
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
@@ -172,15 +173,15 @@ reranker = OpenaiReranker()
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
|
||||
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
|
||||
The name of the cross encoder model to use.
|
||||
`column` : `str`, default `"text"`
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
`return_score` : `str`, default `"relevance"`
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
`api_key` : `str`, default `None`
|
||||
* `api_key` : `str`, default `None`
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
@@ -212,24 +213,30 @@ class MyReranker(Reranker):
|
||||
|
||||
```
|
||||
|
||||
You can also accept additional arguments like a filter along with fts and vector search results
|
||||
### Example of a Custom Reranker
|
||||
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||
|
||||
```python
|
||||
|
||||
from lancedb.rerankers import Reranker
|
||||
import pyarrow as pa
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
class MyReranker(Reranker):
|
||||
...
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
# Do something with the combined results & filter
|
||||
# ...
|
||||
class MofidifiedCohereReranker(CohereReranker):
|
||||
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
self.filters = filters
|
||||
|
||||
# Return the combined results
|
||||
return combined_result
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||
df = combined_result.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ excluded_globs = [
|
||||
"../src/concepts/*.md",
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/hybrid_search.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
@@ -91,25 +90,6 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT
|
||||
@@ -183,8 +184,8 @@ class BedRockText(TextEmbeddingFunction):
|
||||
boto3.client
|
||||
The boto3 client for Amazon Bedrock service
|
||||
"""
|
||||
botocore = self.safe_import("botocore")
|
||||
boto3 = self.safe_import("boto3")
|
||||
botocore = attempt_import_or_raise("botocore")
|
||||
boto3 = attempt_import_or_raise("boto3")
|
||||
|
||||
session_kwargs = {"region_name": self.region}
|
||||
client_kwargs = {**session_kwargs}
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
@@ -84,7 +85,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
cohere = self.safe_import("cohere")
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if CohereEmbeddingFunction.client is None:
|
||||
if os.environ.get("COHERE_API_KEY") is None:
|
||||
api_key_not_found_help("cohere")
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, api_key_not_found_help
|
||||
@@ -134,7 +135,7 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
genai = self.safe_import("google.generativeai", "google.generativeai")
|
||||
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||
|
||||
if not os.environ.get("GOOGLE_API_KEY"):
|
||||
api_key_not_found_help("google")
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
@@ -122,7 +123,7 @@ class GteEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
return Model()
|
||||
else:
|
||||
sentence_transformers = self.safe_import(
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, weak_lru
|
||||
@@ -131,10 +132,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_model(self):
|
||||
instructor_embedding = self.safe_import(
|
||||
instructor_embedding = attempt_import_or_raise(
|
||||
"InstructorEmbedding", "InstructorEmbedding"
|
||||
)
|
||||
torch = self.safe_import("torch", "torch")
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||
if self.quantize:
|
||||
|
||||
@@ -21,6 +21,7 @@ import pyarrow as pa
|
||||
from pydantic import PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
@@ -50,7 +51,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
open_clip = attempt_import_or_raise("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
@@ -78,14 +79,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
@@ -144,7 +145,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
@@ -152,7 +153,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
@@ -68,7 +69,7 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
@cached_property
|
||||
def _openai_client(self):
|
||||
openai = self.safe_import("openai")
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
@@ -75,7 +76,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = self.safe_import(
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class CohereReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = safe_import("cohere")
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
|
||||
@@ -2,7 +2,7 @@ from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
@@ -29,7 +29,9 @@ class ColbertReranker(Reranker):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = safe_import("torch") # import here for faster ops later
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
@@ -80,7 +82,7 @@ class ColbertReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = safe_import("transformers")
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class CrossEncoderReranker(Reranker):
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
@@ -41,7 +41,7 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = safe_import("sentence_transformers")
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class OpenaiReranker(Reranker):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||
model_name : str, default "gpt-4-turbo-preview"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
@@ -29,7 +29,7 @@ class OpenaiReranker(Reranker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-3.5-turbo-1106",
|
||||
model_name: str = "gpt-4-turbo-preview",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
@@ -93,7 +93,9 @@ class OpenaiReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||
openai = attempt_import_or_raise(
|
||||
"openai"
|
||||
) # TODO: force version or handle versions < 1.0
|
||||
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||
|
||||
@@ -116,7 +116,7 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||
|
||||
|
||||
def safe_import(module: str, mitigation=None):
|
||||
def attempt_import_or_raise(module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Reference in New Issue
Block a user