chore: choose appropriate args for concat_table based on pyarrow version & refactor reranker tests (#1455)

This commit is contained in:
Ayush Chaurasia
2024-07-18 21:04:59 +05:30
committed by GitHub
parent dc609a337d
commit ed7bd45c17
3 changed files with 82 additions and 222 deletions

View File

@@ -35,7 +35,7 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
def _compute_one_embedding(self, row):
emb = np.array([float(hash(c)) for c in row[:10]])
emb /= np.linalg.norm(emb)
return emb
return emb if len(emb) == 10 else [0] * 10
def ndims(self):
return 10

View File

@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
from packaging.version import Version
import numpy as np
import pyarrow as pa
ARROW_VERSION = Version(pa.__version__)
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
@@ -23,6 +26,11 @@ class Reranker(ABC):
if return_score not in ["relevance", "all"]:
raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score
# Set the merge args based on the arrow version here to avoid checking it at
# each query
self._concat_tables_args = {"promote_options": "default"}
if ARROW_VERSION.major <= 13:
self._concat_tables_args = {"promote": True}
def rerank_vector(
self,
@@ -119,7 +127,9 @@ class Reranker(ABC):
fts_results : pa.Table
The results from the FTS search
"""
combined = pa.concat_tables([vector_results, fts_results], promote=True)
combined = pa.concat_tables(
[vector_results, fts_results], **self._concat_tables_args
)
row_id = combined.column("_rowid")
# deduplicate

View File

@@ -11,6 +11,7 @@ from lancedb.rerankers import (
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
)
from lancedb.table import LanceTable
@@ -82,6 +83,63 @@ def get_test_table(tmp_path):
return table, MyTable
def _run_test_reranker(reranker, table, query, query_vector, schema):
# Hybrid search setting
result1 = (
table.search(query, query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search(query, query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
@@ -126,185 +184,21 @@ def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
# Hybrid search setting
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
query = "Our father who art in heaven"
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
@@ -314,58 +208,14 @@ def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
reranker = OpenaiReranker()
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
_run_test_reranker(reranker, table, "single player experience", None, schema)
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
@pytest.mark.skipif(
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
)
def test_jina_reranker(tmp_path):
pytest.importorskip("jina")
table, schema = get_test_table(tmp_path)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)