mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-30 19:50:40 +00:00
feat!: enforce all rerankers always return relevance score & deprecate linear combination fixes (#1687)
- Enforce all rerankers always return _relevance_score. This was already loosely done in tests before but based on user feedback its better to always have _relevance_score present in all reranked results - Deprecate LinearCombinationReranker in docs. And also fix a case where it would not return _relevance_score if one result set was missing
This commit is contained in:
@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Vector search setting
|
||||
result = (
|
||||
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
result_explicit = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
assert len(result_deduped) <= 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.rerank(reranker, normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Test with empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user