diff --git a/python/python/lancedb/rerankers/__init__.py b/python/python/lancedb/rerankers/__init__.py index af833fd7..bbc5bbcf 100644 --- a/python/python/lancedb/rerankers/__init__.py +++ b/python/python/lancedb/rerankers/__init__.py @@ -4,6 +4,7 @@ from .colbert import ColbertReranker from .cross_encoder import CrossEncoderReranker from .linear_combination import LinearCombinationReranker from .openai import OpenaiReranker +from .jina import JinaReranker __all__ = [ "Reranker", @@ -12,4 +13,5 @@ __all__ = [ "LinearCombinationReranker", "OpenaiReranker", "ColbertReranker", + "JinaReranker", ] diff --git a/python/python/lancedb/rerankers/jina.py b/python/python/lancedb/rerankers/jina.py new file mode 100644 index 00000000..3af17a98 --- /dev/null +++ b/python/python/lancedb/rerankers/jina.py @@ -0,0 +1,103 @@ +from functools import cached_property +from typing import Union + +import pyarrow as pa + +from ..util import attempt_import_or_raise +from .base import Reranker + + +class JinaReranker(Reranker): + """ + Reranks the results using Jina reranker model. + + Parameters + ---------- + model_name : str, default "jinaai/jina-reranker-v1-turbo-en" + The name of the reranker to use. For all models, see + https://huggingface.co/jinaai/jina-reranker-v1-turbo-en + column : str, default "text" + The name of the column to use as input to the cross encoder model. + device : str, default None + The device to use for the cross encoder model. If None, will use "cuda" + if available, otherwise "cpu". + """ + + def __init__( + self, + model_name: str = "jinaai/jina-reranker-v1-turbo-en", + column: str = "text", + device: Union[str, None] = None, + return_score="relevance", + ): + super().__init__(return_score) + torch = attempt_import_or_raise("torch") + self.model_name = model_name + self.column = column + self.device = device + if self.device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + @cached_property + def model(self): + transformers = attempt_import_or_raise("transformers") + model = transformers.AutoModelForSequenceClassification.from_pretrained( + self.model_name, num_labels=1, trust_remote_code=True + ) + + return model + + def _rerank(self, result_set: pa.Table, query: str): + passages = result_set[self.column].to_pylist() + cross_inp = [[query, passage] for passage in passages] + cross_scores = self.model.compute_score(cross_inp) + result_set = result_set.append_column( + "_relevance_score", pa.array(cross_scores, type=pa.float32()) + ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) + # sort the results by _score + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "return_score='all' not implemented for CrossEncoderReranker" + ) + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + + return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + return vector_results + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + return fts_results diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 7775d598..645b15ec 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -1,5 +1,3 @@ -import os - import lancedb import numpy as np import pytest @@ -11,6 +9,7 @@ from lancedb.rerankers import ( ColbertReranker, CrossEncoderReranker, OpenaiReranker, + JinaReranker, ) from lancedb.table import LanceTable @@ -119,136 +118,18 @@ def test_linear_combination(tmp_path): ) -@pytest.mark.skipif( - os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" +@pytest.mark.slow +@pytest.mark.parametrize( + "reranker", + [ + ColbertReranker(), + OpenaiReranker(), + CohereReranker(), + CrossEncoderReranker(), + JinaReranker(), + ], ) -def test_cohere_reranker(tmp_path): - pytest.importorskip("cohere") - reranker = CohereReranker() - table, schema = get_test_table(tmp_path) - # Hybrid search setting - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=CohereReranker()) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=reranker) - .to_pydantic(schema) - ) - assert result1 == result2 - - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query)) - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - query = "Our father who art in heaven" - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - -def test_cross_encoder_reranker(tmp_path): - pytest.importorskip("sentence_transformers") - reranker = CrossEncoderReranker() - table, schema = get_test_table(tmp_path) - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=reranker) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=reranker) - .to_pydantic(schema) - ) - assert result1 == result2 - - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query), query_type="hybrid") - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - -def test_colbert_reranker(tmp_path): - pytest.importorskip("transformers") - reranker = ColbertReranker() +def test_colbert_reranker(tmp_path, reranker): table, schema = get_test_table(tmp_path) result1 = ( table.search("Our father who art in heaven", query_type="hybrid") @@ -305,67 +186,3 @@ def test_colbert_reranker(tmp_path): ) assert len(result) > 0 assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - -@pytest.mark.skipif( - os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" -) -def test_openai_reranker(tmp_path): - pytest.importorskip("openai") - table, schema = get_test_table(tmp_path) - reranker = OpenaiReranker() - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=reranker) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=OpenaiReranker()) - .to_pydantic(schema) - ) - assert result1 == result2 - - # test explicit hybrid query - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query)) - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err