mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-02 09:52:57 +00:00
feat!: enforce all rerankers always return relevance score & deprecate linear combination fixes (#1687)
- Enforce all rerankers always return _relevance_score. This was already loosely done in tests before but based on user feedback its better to always have _relevance_score present in all reranked results - Deprecate LinearCombinationReranker in docs. And also fix a case where it would not return _relevance_score if one result set was missing
This commit is contained in:
@@ -36,6 +36,7 @@ from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
check_reranker_result(result_set)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
@@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
check_reranker_result(results)
|
||||
return results
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
@@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str = None,
|
||||
vector_column: str = None,
|
||||
query: Optional[str] = None,
|
||||
vector_column: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
@@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
check_reranker_result(results)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
@@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = RRFReranker(),
|
||||
normalize: str = "score",
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
@@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
normalize: str, default "score"
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
|
||||
@@ -105,7 +105,7 @@ class Reranker(ABC):
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Rerank function receives the individual results from the vector and FTS search
|
||||
results. You can choose to use any of the results to generate the final results,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from numpy import NaN
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
@@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker):
|
||||
def merge_results(
|
||||
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||
):
|
||||
# If both are empty then just return an empty table
|
||||
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||
return vector_results
|
||||
# If one is empty then return the other
|
||||
# If one is empty then return the other and add _relevance_score
|
||||
# column equal the existing vector or fts score
|
||||
if len(vector_results) == 0:
|
||||
return fts_results
|
||||
results = fts_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(fts_results["_score"], type=pa.float32()),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_distance",
|
||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
if len(fts_results) == 0:
|
||||
return vector_results
|
||||
# invert the distance to relevance score
|
||||
results = vector_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(
|
||||
[
|
||||
self._invert_score(distance)
|
||||
for distance in vector_results["_distance"].to_pylist()
|
||||
],
|
||||
type=pa.float32(),
|
||||
),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_score",
|
||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
|
||||
19
python/python/lancedb/rerankers/util.py
Normal file
19
python/python/lancedb/rerankers/util.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The Lance Authors
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def check_reranker_result(result):
|
||||
if not isinstance(result, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(result)}"
|
||||
)
|
||||
|
||||
# Enforce that `_relevance_score` column is present in the result of every
|
||||
# rerank_hybrid method
|
||||
if "_relevance_score" not in result.column_names:
|
||||
raise ValueError(
|
||||
"rerank_hybrid must return a pyarrow.Table with a column"
|
||||
"named `_relevance_score`"
|
||||
)
|
||||
@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Vector search setting
|
||||
result = (
|
||||
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
result_explicit = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
assert len(result_deduped) <= 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.rerank(reranker, normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Test with empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user