add jina reranker

This commit is contained in:
ayush chaurasia
2024-07-02 15:01:37 +05:30
parent 020a437230
commit 2bb7b2d2e7
3 changed files with 117 additions and 195 deletions

View File

@@ -4,6 +4,7 @@ from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jina import JinaReranker
__all__ = [
"Reranker",
@@ -12,4 +13,5 @@ __all__ = [
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
"JinaReranker",
]

View File

@@ -0,0 +1,103 @@
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class JinaReranker(Reranker):
"""
Reranks the results using Jina reranker model.
Parameters
----------
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
The name of the reranker to use. For all models, see
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
column : str, default "text"
The name of the column to use as input to the cross encoder model.
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
"""
def __init__(
self,
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
):
super().__init__(return_score)
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@cached_property
def model(self):
transformers = attempt_import_or_raise("transformers")
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self.model_name, num_labels=1, trust_remote_code=True
)
return model
def _rerank(self, result_set: pa.Table, query: str):
passages = result_set[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.compute_score(cross_inp)
result_set = result_set.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for CrossEncoderReranker"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -1,5 +1,3 @@
import os
import lancedb
import numpy as np
import pytest
@@ -11,6 +9,7 @@ from lancedb.rerankers import (
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
)
from lancedb.table import LanceTable
@@ -119,136 +118,18 @@ def test_linear_combination(tmp_path):
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@pytest.mark.slow
@pytest.mark.parametrize(
"reranker",
[
ColbertReranker(),
OpenaiReranker(),
CohereReranker(),
CrossEncoderReranker(),
JinaReranker(),
],
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
# Hybrid search setting
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
query = "Our father who art in heaven"
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
reranker = ColbertReranker()
def test_colbert_reranker(tmp_path, reranker):
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
@@ -305,67 +186,3 @@ def test_colbert_reranker(tmp_path):
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
reranker = OpenaiReranker()
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err