From 549ca51a8aa7040ff555dc67228f05f82ebc8202 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 26 Aug 2024 13:25:10 +0530 Subject: [PATCH] feat: add answerdotai rerankers support and minor improvements (#1560) This PR: - Adds missing license headers - Integrates with answerdotai Rerankers package - Updates ColbertReranker to subclass answerdotai package. This is done to keep backwards compatibility as some users might be used to importing ColbertReranker directly - Set `trust_remote_code` to ` True` by default in CrossEncoder and sentence-transformer based rerankers --- docs/src/reranking/answerdotai.md | 74 ++++++++++++++ .../embeddings/sentence_transformers.py | 13 ++- .../python/lancedb/embeddings/transformers.py | 4 + python/python/lancedb/rerankers/__init__.py | 2 + .../python/lancedb/rerankers/answerdotai.py | 99 +++++++++++++++++++ python/python/lancedb/rerankers/base.py | 13 +++ python/python/lancedb/rerankers/cohere.py | 13 +++ python/python/lancedb/rerankers/colbert.py | 93 ++++------------- .../python/lancedb/rerankers/cross_encoder.py | 26 ++++- python/python/lancedb/rerankers/jinaai.py | 13 +++ .../lancedb/rerankers/linear_combination.py | 13 +++ python/python/lancedb/rerankers/openai.py | 13 +++ python/python/lancedb/rerankers/rrf.py | 13 +++ python/python/tests/test_rerankers.py | 11 ++- 14 files changed, 324 insertions(+), 76 deletions(-) create mode 100644 docs/src/reranking/answerdotai.md create mode 100644 python/python/lancedb/rerankers/answerdotai.py diff --git a/docs/src/reranking/answerdotai.md b/docs/src/reranking/answerdotai.md new file mode 100644 index 00000000..b19f24ff --- /dev/null +++ b/docs/src/reranking/answerdotai.md @@ -0,0 +1,74 @@ +# AnswersDotAI Rerankers + +This integration allows using answersdotai's rerankers to rerank the search results. [Rerankers](https://github.com/AnswerDotAI/rerankers) +A lightweight, low-dependency, unified API to use all common reranking and cross-encoder models. + +!!! note + Supported Query Types: Hybrid, Vector, FTS + + +```python +import numpy +import lancedb +from lancedb.embeddings import get_registry +from lancedb.pydantic import LanceModel, Vector +from lancedb.rerankers import AnswerdotaiRerankers + +embedder = get_registry().get("sentence-transformers").create() +db = lancedb.connect("~/.lancedb") + +class Schema(LanceModel): + text: str = embedder.SourceField() + vector: Vector(embedder.ndims()) = embedder.VectorField() + +data = [ + {"text": "hello world"}, + {"text": "goodbye world"} + ] +tbl = db.create_table("test", schema=Schema, mode="overwrite") +tbl.add(data) +reranker = AnswerdotaiRerankers() + +# Run vector search with a reranker +result = tbl.search("hello").rerank(reranker=reranker).to_list() + +# Run FTS search with a reranker +result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list() + +# Run hybrid search with a reranker +tbl.create_fts_index("text", replace=True) +result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list() + +``` + +Accepted Arguments +---------------- +| Argument | Type | Default | Description | +| --- | --- | --- | --- | +| `model_type` | `str` | `"colbert"` | The type of model to use. Supported model types can be found here - https://github.com/AnswerDotAI/rerankers | +| `model_name` | `str` | `"answerdotai/answerai-colbert-small-v1"` | The name of the reranker model to use. | +| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. | +| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type | + + + +## Supported Scores for each query type +You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type: + +### Hybrid Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) | + +### Vector Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) | + +### FTS Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) | \ No newline at end of file diff --git a/python/python/lancedb/embeddings/sentence_transformers.py b/python/python/lancedb/embeddings/sentence_transformers.py index fe8e997d..b0ef1d50 100644 --- a/python/python/lancedb/embeddings/sentence_transformers.py +++ b/python/python/lancedb/embeddings/sentence_transformers.py @@ -26,12 +26,23 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): An embedding function that uses the sentence-transformers library https://huggingface.co/sentence-transformers + + Parameters + ---------- + name: str, default "all-MiniLM-L6-v2" + The name of the model to use. + device: str, default "cpu" + The device to use for the model + normalize: bool, default True + Whether to normalize the embeddings + trust_remote_code: bool, default True + Whether to trust the remote code """ name: str = "all-MiniLM-L6-v2" device: str = "cpu" normalize: bool = True - trust_remote_code: bool = False + trust_remote_code: bool = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/python/python/lancedb/embeddings/transformers.py b/python/python/lancedb/embeddings/transformers.py index dba5b161..f532f7c9 100644 --- a/python/python/lancedb/embeddings/transformers.py +++ b/python/python/lancedb/embeddings/transformers.py @@ -36,6 +36,10 @@ class TransformersEmbeddingFunction(EmbeddingFunction): The name of the model to use. This should be a model name that can be loaded by transformers.AutoModel.from_pretrained. For example, "bert-base-uncased". default: "colbert-ir/colbertv2.0"" + device : str + The device to use for the model. Default is "cpu". + show_progress_bar : bool + Whether to show a progress bar when loading the model. Default is True. to download package, run : `pip install transformers` diff --git a/python/python/lancedb/rerankers/__init__.py b/python/python/lancedb/rerankers/__init__.py index 0b767a67..93903a16 100644 --- a/python/python/lancedb/rerankers/__init__.py +++ b/python/python/lancedb/rerankers/__init__.py @@ -6,6 +6,7 @@ from .linear_combination import LinearCombinationReranker from .openai import OpenaiReranker from .jinaai import JinaReranker from .rrf import RRFReranker +from .answerdotai import AnswerdotaiRerankers __all__ = [ "Reranker", @@ -16,4 +17,5 @@ __all__ = [ "ColbertReranker", "JinaReranker", "RRFReranker", + "AnswerdotaiRerankers", ] diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py new file mode 100644 index 00000000..3c2fcb2d --- /dev/null +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pyarrow as pa +from .base import Reranker +from ..util import attempt_import_or_raise + + +class AnswerdotaiRerankers(Reranker): + """ + Reranks the results using the Answerdotai Rerank API. + All supported reranker model types can be found here: + - https://github.com/AnswerDotAI/rerankers + + + Parameters + ---------- + model_type : str, default "colbert" + The type of the model to use. + model_name : str, default "rerank-english-v2.0" + The name of the model to use from the given model type. + column : str, default "text" + The name of the column to use as input to the cross encoder model. + return_score : str, default "relevance" + options are "relevance" or "all". Only "relevance" is supported for now. + """ + + def __init__( + self, + model_type="colbert", + model_name: str = "answerdotai/answerai-colbert-small-v1", + column: str = "text", + return_score="relevance", + ): + super().__init__(return_score) + self.column = column + rerankers = attempt_import_or_raise( + "rerankers" + ) # import here for faster ops later + self.reranker = rerankers.Reranker(model_name, model_type) + + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() + doc_ids = list(range(len(docs))) + result = self.reranker.rank(query, docs, doc_ids=doc_ids) + + # get the scores of each document in the same order as the input + scores = [result.get_result_by_docid(i).score for i in doc_ids] + + # add the scores + result_set = result_set.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) + if self.score == "relevance": + combined_results = self._keep_relevance_score(combined_results) + elif self.score == "all": + raise NotImplementedError( + "Answerdotai Reranker does not support score='all' yet" + ) + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + return combined_results + + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + return vector_results + + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["_score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + + return fts_results diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 8667ca9c..65ed43e7 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod from packaging.version import Version from typing import Union, List, TYPE_CHECKING diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index e4a12dbf..5cf7e8f0 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from packaging.version import Version from functools import cached_property diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index 5e8701b3..cffdd0ba 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -1,10 +1,20 @@ -import pyarrow as pa +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from ..util import attempt_import_or_raise -from .base import Reranker +from .answerdotai import AnswerdotaiRerankers -class ColbertReranker(Reranker): +class ColbertReranker(AnswerdotaiRerankers): """ Reranks the results using the ColBERT model. @@ -20,76 +30,13 @@ class ColbertReranker(Reranker): def __init__( self, - model_name: str = "colbert", + model_name: str = "colbert-ir/colbertv2.0", column: str = "text", return_score="relevance", ): - super().__init__(return_score) - self.model_name = model_name - self.column = column - rerankers = attempt_import_or_raise( - "rerankers" - ) # import here for faster ops later - self.colbert = rerankers.Reranker(self.model_name, model_type="colbert") - - def _rerank(self, result_set: pa.Table, query: str): - docs = result_set[self.column].to_pylist() - doc_ids = list(range(len(docs))) - result = self.colbert.rank(query, docs, doc_ids=doc_ids) - - # get the scores of each document in the same order as the input - scores = [result.get_result_by_docid(i).score for i in doc_ids] - - # add the scores - result_set = result_set.append_column( - "_relevance_score", pa.array(scores, type=pa.float32()) + super().__init__( + model_type="colbert", + model_name=model_name, + column=column, + return_score=return_score, ) - - return result_set - - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - combined_results = self._rerank(combined_results, query) - if self.score == "relevance": - combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "OpenAI Reranker does not support score='all' yet" - ) - - combined_results = combined_results.sort_by( - [("_relevance_score", "descending")] - ) - - return combined_results - - def rerank_vector( - self, - query: str, - vector_results: pa.Table, - ): - result_set = self._rerank(vector_results, query) - if self.score == "relevance": - result_set = result_set.drop_columns(["_distance"]) - - result_set = result_set.sort_by([("_relevance_score", "descending")]) - - return result_set - - def rerank_fts( - self, - query: str, - fts_results: pa.Table, - ): - result_set = self._rerank(fts_results, query) - if self.score == "relevance": - result_set = result_set.drop_columns(["_score"]) - - result_set = result_set.sort_by([("_relevance_score", "descending")]) - - return result_set diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index 05673673..6a6cb2bd 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import cached_property from typing import Union @@ -22,6 +35,11 @@ class CrossEncoderReranker(Reranker): device : str, default None The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". + return_score : str, default "relevance" + options are "relevance" or "all". Only "relevance" is supported for now. + trust_remote_code : bool, default True + If True, will trust the remote code to be safe. If False, will not trust + the remote code and will not run it """ def __init__( @@ -30,12 +48,14 @@ class CrossEncoderReranker(Reranker): column: str = "text", device: Union[str, None] = None, return_score="relevance", + trust_remote_code: bool = True, ): super().__init__(return_score) torch = attempt_import_or_raise("torch") self.model_name = model_name self.column = column self.device = device + self.trust_remote_code = trust_remote_code if self.device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -43,7 +63,11 @@ class CrossEncoderReranker(Reranker): def model(self): sbert = attempt_import_or_raise("sentence_transformers") # Allows overriding the automatically selected device - cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device) + cross_encoder = sbert.CrossEncoder( + self.model_name, + device=self.device, + trust_remote_code=self.trust_remote_code, + ) return cross_encoder diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 4d4edcfb..6be646bd 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import requests from functools import cached_property diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 3d7dcc25..6ab18427 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pyarrow as pa from .base import Reranker diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 7e6c19b2..76fe8e4c 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from functools import cached_property diff --git a/python/python/lancedb/rerankers/rrf.py b/python/python/lancedb/rerankers/rrf.py index 23ed1dc1..e0c95b48 100644 --- a/python/python/lancedb/rerankers/rrf.py +++ b/python/python/lancedb/rerankers/rrf.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Union, List, TYPE_CHECKING import pyarrow as pa diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 442328d9..fca0850c 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -15,6 +15,7 @@ from lancedb.rerankers import ( CrossEncoderReranker, OpenaiReranker, JinaReranker, + AnswerdotaiRerankers, ) from lancedb.table import LanceTable @@ -254,12 +255,20 @@ def test_cross_encoder_reranker(tmp_path, use_tantivy): @pytest.mark.parametrize("use_tantivy", [True, False]) def test_colbert_reranker(tmp_path, use_tantivy): - pytest.importorskip("transformers") + pytest.importorskip("rerankers") reranker = ColbertReranker() table, schema = get_test_table(tmp_path, use_tantivy) _run_test_reranker(reranker, table, "single player experience", None, schema) +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_answerdotai_reranker(tmp_path, use_tantivy): + pytest.importorskip("rerankers") + reranker = AnswerdotaiRerankers() + table, schema = get_test_table(tmp_path, use_tantivy) + _run_test_reranker(reranker, table, "single player experience", None, schema) + + @pytest.mark.skipif( os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" )