chore: choose appropriate args for concat_table based on pyarrow version & refactor reranker tests (#1455)

This commit is contained in:
Ayush Chaurasia
2024-07-18 21:04:59 +05:30
committed by GitHub
parent dc609a337d
commit ed7bd45c17
3 changed files with 82 additions and 222 deletions

View File

@@ -35,7 +35,7 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
def _compute_one_embedding(self, row): def _compute_one_embedding(self, row):
emb = np.array([float(hash(c)) for c in row[:10]]) emb = np.array([float(hash(c)) for c in row[:10]])
emb /= np.linalg.norm(emb) emb /= np.linalg.norm(emb)
return emb return emb if len(emb) == 10 else [0] * 10
def ndims(self): def ndims(self):
return 10 return 10

View File

@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from packaging.version import Version
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
ARROW_VERSION = Version(pa.__version__)
class Reranker(ABC): class Reranker(ABC):
def __init__(self, return_score: str = "relevance"): def __init__(self, return_score: str = "relevance"):
@@ -23,6 +26,11 @@ class Reranker(ABC):
if return_score not in ["relevance", "all"]: if return_score not in ["relevance", "all"]:
raise ValueError("score must be either 'relevance' or 'all'") raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score self.score = return_score
# Set the merge args based on the arrow version here to avoid checking it at
# each query
self._concat_tables_args = {"promote_options": "default"}
if ARROW_VERSION.major <= 13:
self._concat_tables_args = {"promote": True}
def rerank_vector( def rerank_vector(
self, self,
@@ -119,7 +127,9 @@ class Reranker(ABC):
fts_results : pa.Table fts_results : pa.Table
The results from the FTS search The results from the FTS search
""" """
combined = pa.concat_tables([vector_results, fts_results], promote=True) combined = pa.concat_tables(
[vector_results, fts_results], **self._concat_tables_args
)
row_id = combined.column("_rowid") row_id = combined.column("_rowid")
# deduplicate # deduplicate

View File

@@ -11,6 +11,7 @@ from lancedb.rerankers import (
ColbertReranker, ColbertReranker,
CrossEncoderReranker, CrossEncoderReranker,
OpenaiReranker, OpenaiReranker,
JinaReranker,
) )
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -82,6 +83,63 @@ def get_test_table(tmp_path):
return table, MyTable return table, MyTable
def _run_test_reranker(reranker, table, query, query_vector, schema):
# Hybrid search setting
result1 = (
table.search(query, query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search(query, query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_linear_combination(tmp_path): def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
# The default reranker # The default reranker
@@ -126,185 +184,21 @@ def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere") pytest.importorskip("cohere")
reranker = CohereReranker() reranker = CohereReranker()
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
# Hybrid search setting _run_test_reranker(reranker, table, "single player experience", None, schema)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
query = "Our father who art in heaven"
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_cross_encoder_reranker(tmp_path): def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers") pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker() reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
result1 = ( _run_test_reranker(reranker, table, "single player experience", None, schema)
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_colbert_reranker(tmp_path): def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers") pytest.importorskip("transformers")
reranker = ColbertReranker() reranker = ColbertReranker()
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
result1 = ( _run_test_reranker(reranker, table, "single player experience", None, schema)
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
@pytest.mark.skipif( @pytest.mark.skipif(
@@ -314,58 +208,14 @@ def test_openai_reranker(tmp_path):
pytest.importorskip("openai") pytest.importorskip("openai")
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
reranker = OpenaiReranker() reranker = OpenaiReranker()
result1 = ( _run_test_reranker(reranker, table, "single player experience", None, schema)
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30 @pytest.mark.skipif(
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
err = ( )
"The _relevance_score column of the results returned by the reranker " def test_jina_reranker(tmp_path):
"represents the relevance of the result to the query & should " pytest.importorskip("jina")
"be descending." table, schema = get_test_table(tmp_path)
) reranker = JinaReranker()
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err _run_test_reranker(reranker, table, "single player experience", None, schema)
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err