mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: add answerdotai rerankers support and minor improvements (#1560)
This PR: - Adds missing license headers - Integrates with answerdotai Rerankers package - Updates ColbertReranker to subclass answerdotai package. This is done to keep backwards compatibility as some users might be used to importing ColbertReranker directly - Set `trust_remote_code` to ` True` by default in CrossEncoder and sentence-transformer based rerankers
This commit is contained in:
@@ -26,12 +26,23 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "all-MiniLM-L6-v2"
|
||||
The name of the model to use.
|
||||
device: str, default "cpu"
|
||||
The device to use for the model
|
||||
normalize: bool, default True
|
||||
Whether to normalize the embeddings
|
||||
trust_remote_code: bool, default True
|
||||
Whether to trust the remote code
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
trust_remote_code: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -36,6 +36,10 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
The name of the model to use. This should be a model name that can be loaded
|
||||
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
|
||||
default: "colbert-ir/colbertv2.0""
|
||||
device : str
|
||||
The device to use for the model. Default is "cpu".
|
||||
show_progress_bar : bool
|
||||
Whether to show a progress bar when loading the model. Default is True.
|
||||
|
||||
to download package, run :
|
||||
`pip install transformers`
|
||||
|
||||
@@ -6,6 +6,7 @@ from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"ColbertReranker",
|
||||
"JinaReranker",
|
||||
"RRFReranker",
|
||||
"AnswerdotaiRerankers",
|
||||
]
|
||||
|
||||
99
python/python/lancedb/rerankers/answerdotai.py
Normal file
99
python/python/lancedb/rerankers/answerdotai.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pyarrow as pa
|
||||
from .base import Reranker
|
||||
from ..util import attempt_import_or_raise
|
||||
|
||||
|
||||
class AnswerdotaiRerankers(Reranker):
|
||||
"""
|
||||
Reranks the results using the Answerdotai Rerank API.
|
||||
All supported reranker model types can be found here:
|
||||
- https://github.com/AnswerDotAI/rerankers
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_type : str, default "colbert"
|
||||
The type of the model to use.
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the model to use from the given model type.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="colbert",
|
||||
model_name: str = "answerdotai/answerai-colbert-small-v1",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.column = column
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.reranker = rerankers.Reranker(model_name, model_type)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
doc_ids = list(range(len(docs)))
|
||||
result = self.reranker.rank(query, docs, doc_ids=doc_ids)
|
||||
|
||||
# get the scores of each document in the same order as the input
|
||||
scores = [result.get_result_by_docid(i).score for i in doc_ids]
|
||||
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"Answerdotai Reranker does not support score='all' yet"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return fts_results
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from packaging.version import Version
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from packaging.version import Version
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import pyarrow as pa
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
class ColbertReranker(AnswerdotaiRerankers):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
@@ -20,76 +30,13 @@ class ColbertReranker(Reranker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert",
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.colbert = rerankers.Reranker(self.model_name, model_type="colbert")
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
doc_ids = list(range(len(docs)))
|
||||
result = self.colbert.rank(query, docs, doc_ids=doc_ids)
|
||||
|
||||
# get the scores of each document in the same order as the input
|
||||
scores = [result.get_result_by_docid(i).score for i in doc_ids]
|
||||
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
super().__init__(
|
||||
model_type="colbert",
|
||||
model_name=model_name,
|
||||
column=column,
|
||||
return_score=return_score,
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -22,6 +35,11 @@ class CrossEncoderReranker(Reranker):
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
trust_remote_code : bool, default True
|
||||
If True, will trust the remote code to be safe. If False, will not trust
|
||||
the remote code and will not run it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -30,12 +48,14 @@ class CrossEncoderReranker(Reranker):
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
trust_remote_code: bool = True,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
self.trust_remote_code = trust_remote_code
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -43,7 +63,11 @@ class CrossEncoderReranker(Reranker):
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
# Allows overriding the automatically selected device
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device)
|
||||
cross_encoder = sbert.CrossEncoder(
|
||||
self.model_name,
|
||||
device=self.device,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from lancedb.rerankers import (
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
AnswerdotaiRerankers,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -254,12 +255,20 @@ def test_cross_encoder_reranker(tmp_path, use_tantivy):
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_colbert_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("transformers")
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = AnswerdotaiRerankers()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user