feat!: enforce all rerankers always return relevance score & deprecate linear combination fixes (#1687)

- Enforce all rerankers always return _relevance_score. This was already
loosely done in tests before but based on user feedback its better to
always have _relevance_score present in all reranked results
- Deprecate LinearCombinationReranker in docs. And also fix a case where
it would not return _relevance_score if one result set was missing
This commit is contained in:
Ayush Chaurasia
2024-09-23 12:12:02 +05:30
committed by GitHub
parent 7c314d61cc
commit 86978e7588
7 changed files with 118 additions and 26 deletions

View File

@@ -36,6 +36,7 @@ from . import __version__
from .arrow import AsyncRecordBatchReader
from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import safe_import_pandas
if TYPE_CHECKING:
@@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
check_reranker_result(result_set)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
@@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
check_reranker_result(results)
return results
def tantivy_to_arrow(self) -> pa.Table:
@@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "Table",
query: str = None,
vector_column: str = None,
query: Optional[str] = None,
vector_column: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
@@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_query._query, vector_results, fts_results
)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
check_reranker_result(results)
# apply limit after reranking
results = results.slice(length=self._limit)
@@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def rerank(
self,
normalize="score",
reranker: Reranker = RRFReranker(),
normalize: str = "score",
) -> LanceHybridQueryBuilder:
"""
Rerank the hybrid search results using the specified reranker. The reranker
@@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
Parameters
----------
reranker: Reranker, default RRFReranker()
The reranker to use. Must be an instance of Reranker class.
normalize: str, default "score"
The method to normalize the scores. Can be "rank" or "score". If "rank",
the scores are converted to ranks and then normalized. If "score", the
scores are normalized directly.
reranker: Reranker, default RRFReranker()
The reranker to use. Must be an instance of Reranker class.
Returns
-------
LanceHybridQueryBuilder

View File

@@ -105,7 +105,7 @@ class Reranker(ABC):
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
) -> pa.Table:
"""
Rerank function receives the individual results from the vector and FTS search
results. You can choose to use any of the results to generate the final results,

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from numpy import NaN
import pyarrow as pa
from .base import Reranker
@@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker):
def merge_results(
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
):
# If both are empty then just return an empty table
if len(vector_results) == 0 and len(fts_results) == 0:
return vector_results
# If one is empty then return the other
# If one is empty then return the other and add _relevance_score
# column equal the existing vector or fts score
if len(vector_results) == 0:
return fts_results
results = fts_results.append_column(
"_relevance_score",
pa.array(fts_results["_score"], type=pa.float32()),
)
if self.score == "relevance":
results = self._keep_relevance_score(results)
elif self.score == "all":
results = results.append_column(
"_distance",
pa.array([NaN] * len(fts_results), type=pa.float32()),
)
return results
if len(fts_results) == 0:
return vector_results
# invert the distance to relevance score
results = vector_results.append_column(
"_relevance_score",
pa.array(
[
self._invert_score(distance)
for distance in vector_results["_distance"].to_pylist()
],
type=pa.float32(),
),
)
if self.score == "relevance":
results = self._keep_relevance_score(results)
elif self.score == "all":
results = results.append_column(
"_score",
pa.array([NaN] * len(vector_results), type=pa.float32()),
)
return results
# sort both input tables on _rowid
combined_list = []

View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
import pyarrow as pa
def check_reranker_result(result):
if not isinstance(result, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(result)}"
)
# Enforce that `_relevance_score` column is present in the result of every
# rerank_hybrid method
if "_relevance_score" not in result.column_names:
raise ValueError(
"rerank_hybrid must return a pyarrow.Table with a column"
"named `_relevance_score`"
)

View File

@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
)
assert len(result) == 30
err = (
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Vector search setting
result = (
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
result_explicit = (
table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query)
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
result_deduped = reranker.rerank_multivector(
[rs1, rs2, rs1], query, deduplicate=True
)
assert len(result_deduped) < 20
assert len(result_deduped) <= 20
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
assert len(result) == 20 and result == result_arrow
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
.vector(query_vector)
.text(query)
.limit(30)
.rerank(normalize="score")
.rerank(reranker, normalize="score")
.to_arrow()
)
assert len(result) == 30
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
table.search(query, query_type="hybrid", vector_column_name="vector").text(
query
).to_arrow()
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Test with empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
@pytest.mark.parametrize("use_tantivy", [True, False])