feat: add answerdotai rerankers support and minor improvements (#1560)

This PR:
- Adds missing license headers
- Integrates with answerdotai Rerankers package
- Updates ColbertReranker to subclass answerdotai package. This is done
to keep backwards compatibility as some users might be used to importing
ColbertReranker directly
- Set `trust_remote_code` to ` True` by default in CrossEncoder and
sentence-transformer based rerankers
This commit is contained in:
Ayush Chaurasia
2024-08-26 13:25:10 +05:30
committed by GitHub
parent 632007d0e2
commit 549ca51a8a
14 changed files with 324 additions and 76 deletions

View File

@@ -26,12 +26,23 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
Parameters
----------
name: str, default "all-MiniLM-L6-v2"
The name of the model to use.
device: str, default "cpu"
The device to use for the model
normalize: bool, default True
Whether to normalize the embeddings
trust_remote_code: bool, default True
Whether to trust the remote code
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
trust_remote_code: bool = False
trust_remote_code: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -36,6 +36,10 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
The name of the model to use. This should be a model name that can be loaded
by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased".
default: "colbert-ir/colbertv2.0""
device : str
The device to use for the model. Default is "cpu".
show_progress_bar : bool
Whether to show a progress bar when loading the model. Default is True.
to download package, run :
`pip install transformers`

View File

@@ -6,6 +6,7 @@ from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jinaai import JinaReranker
from .rrf import RRFReranker
from .answerdotai import AnswerdotaiRerankers
__all__ = [
"Reranker",
@@ -16,4 +17,5 @@ __all__ = [
"ColbertReranker",
"JinaReranker",
"RRFReranker",
"AnswerdotaiRerankers",
]

View File

@@ -0,0 +1,99 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
from .base import Reranker
from ..util import attempt_import_or_raise
class AnswerdotaiRerankers(Reranker):
"""
Reranks the results using the Answerdotai Rerank API.
All supported reranker model types can be found here:
- https://github.com/AnswerDotAI/rerankers
Parameters
----------
model_type : str, default "colbert"
The type of the model to use.
model_name : str, default "rerank-english-v2.0"
The name of the model to use from the given model type.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_type="colbert",
model_name: str = "answerdotai/answerai-colbert-small-v1",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.column = column
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.reranker = rerankers.Reranker(model_name, model_type)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
doc_ids = list(range(len(docs)))
result = self.reranker.rank(query, docs, doc_ids=doc_ids)
# get the scores of each document in the same order as the input
scores = [result.get_result_by_docid(i).score for i in doc_ids]
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"Answerdotai Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from packaging.version import Version
from typing import Union, List, TYPE_CHECKING

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from packaging.version import Version
from functools import cached_property

View File

@@ -1,10 +1,20 @@
import pyarrow as pa
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..util import attempt_import_or_raise
from .base import Reranker
from .answerdotai import AnswerdotaiRerankers
class ColbertReranker(Reranker):
class ColbertReranker(AnswerdotaiRerankers):
"""
Reranks the results using the ColBERT model.
@@ -20,76 +30,13 @@ class ColbertReranker(Reranker):
def __init__(
self,
model_name: str = "colbert",
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.colbert = rerankers.Reranker(self.model_name, model_type="colbert")
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
doc_ids = list(range(len(docs)))
result = self.colbert.rank(query, docs, doc_ids=doc_ids)
# get the scores of each document in the same order as the input
scores = [result.get_result_by_docid(i).score for i in doc_ids]
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
super().__init__(
model_type="colbert",
model_name=model_name,
column=column,
return_score=return_score,
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Union
@@ -22,6 +35,11 @@ class CrossEncoderReranker(Reranker):
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
trust_remote_code : bool, default True
If True, will trust the remote code to be safe. If False, will not trust
the remote code and will not run it
"""
def __init__(
@@ -30,12 +48,14 @@ class CrossEncoderReranker(Reranker):
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
trust_remote_code: bool = True,
):
super().__init__(return_score)
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
self.trust_remote_code = trust_remote_code
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -43,7 +63,11 @@ class CrossEncoderReranker(Reranker):
def model(self):
sbert = attempt_import_or_raise("sentence_transformers")
# Allows overriding the automatically selected device
cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device)
cross_encoder = sbert.CrossEncoder(
self.model_name,
device=self.device,
trust_remote_code=self.trust_remote_code,
)
return cross_encoder

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
from functools import cached_property

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
from .base import Reranker

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from functools import cached_property

View File

@@ -1,3 +1,16 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, TYPE_CHECKING
import pyarrow as pa

View File

@@ -15,6 +15,7 @@ from lancedb.rerankers import (
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
AnswerdotaiRerankers,
)
from lancedb.table import LanceTable
@@ -254,12 +255,20 @@ def test_cross_encoder_reranker(tmp_path, use_tantivy):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_colbert_reranker(tmp_path, use_tantivy):
pytest.importorskip("transformers")
pytest.importorskip("rerankers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_answerdotai_reranker(tmp_path, use_tantivy):
pytest.importorskip("rerankers")
reranker = AnswerdotaiRerankers()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)