mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
Compare commits
1 Commits
python-v0.
...
ayush/jina
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bb7b2d2e7 |
@@ -4,6 +4,7 @@ from .colbert import ColbertReranker
|
|||||||
from .cross_encoder import CrossEncoderReranker
|
from .cross_encoder import CrossEncoderReranker
|
||||||
from .linear_combination import LinearCombinationReranker
|
from .linear_combination import LinearCombinationReranker
|
||||||
from .openai import OpenaiReranker
|
from .openai import OpenaiReranker
|
||||||
|
from .jina import JinaReranker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reranker",
|
"Reranker",
|
||||||
@@ -12,4 +13,5 @@ __all__ = [
|
|||||||
"LinearCombinationReranker",
|
"LinearCombinationReranker",
|
||||||
"OpenaiReranker",
|
"OpenaiReranker",
|
||||||
"ColbertReranker",
|
"ColbertReranker",
|
||||||
|
"JinaReranker",
|
||||||
]
|
]
|
||||||
|
|||||||
103
python/python/lancedb/rerankers/jina.py
Normal file
103
python/python/lancedb/rerankers/jina.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class JinaReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using Jina reranker model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
|
||||||
|
The name of the reranker to use. For all models, see
|
||||||
|
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
device : str, default None
|
||||||
|
The device to use for the cross encoder model. If None, will use "cuda"
|
||||||
|
if available, otherwise "cpu".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
|
||||||
|
column: str = "text",
|
||||||
|
device: Union[str, None] = None,
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
torch = attempt_import_or_raise("torch")
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.device = device
|
||||||
|
if self.device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
transformers = attempt_import_or_raise("transformers")
|
||||||
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
self.model_name, num_labels=1, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
|
passages = result_set[self.column].to_pylist()
|
||||||
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
|
cross_scores = self.model.compute_score(cross_inp)
|
||||||
|
result_set = result_set.append_column(
|
||||||
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
|
# sort the results by _score
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return_score='all' not implemented for CrossEncoderReranker"
|
||||||
|
)
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
vector_results = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
vector_results = vector_results.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return vector_results
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
fts_results = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
fts_results = fts_results.drop_columns(["score"])
|
||||||
|
|
||||||
|
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return fts_results
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,6 +9,7 @@ from lancedb.rerankers import (
|
|||||||
ColbertReranker,
|
ColbertReranker,
|
||||||
CrossEncoderReranker,
|
CrossEncoderReranker,
|
||||||
OpenaiReranker,
|
OpenaiReranker,
|
||||||
|
JinaReranker,
|
||||||
)
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
@@ -119,136 +118,18 @@ def test_linear_combination(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.slow
|
||||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
@pytest.mark.parametrize(
|
||||||
|
"reranker",
|
||||||
|
[
|
||||||
|
ColbertReranker(),
|
||||||
|
OpenaiReranker(),
|
||||||
|
CohereReranker(),
|
||||||
|
CrossEncoderReranker(),
|
||||||
|
JinaReranker(),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_cohere_reranker(tmp_path):
|
def test_colbert_reranker(tmp_path, reranker):
|
||||||
pytest.importorskip("cohere")
|
|
||||||
reranker = CohereReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
# Hybrid search setting
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=CohereReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_encoder_reranker(tmp_path):
|
|
||||||
pytest.importorskip("sentence_transformers")
|
|
||||||
reranker = CrossEncoderReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query), query_type="hybrid")
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_colbert_reranker(tmp_path):
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
reranker = ColbertReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
@@ -305,67 +186,3 @@ def test_colbert_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
|
||||||
)
|
|
||||||
def test_openai_reranker(tmp_path):
|
|
||||||
pytest.importorskip("openai")
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
reranker = OpenaiReranker()
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=OpenaiReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
# test explicit hybrid query
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|||||||
Reference in New Issue
Block a user