fix: linear reranker applies wrong score to combine (#2035)

related to #2014 
this fixes:
- linear reranker may lost some results if the merging consumes all
vector results earlier than fts results
- linear reranker inverts the fts score but only vector distance can be
inverted

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-01-17 11:33:48 +08:00
committed by GitHub
parent 4703cc6894
commit d0501f65f1
2 changed files with 46 additions and 40 deletions

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from numpy import nan
import pyarrow as pa
@@ -95,43 +96,22 @@ class LinearCombinationReranker(Reranker):
pa.array([nan] * len(vector_results), type=pa.float32()),
)
return results
# sort both input tables on _rowid
combined_list = []
vector_list = vector_results.sort_by("_rowid").to_pylist()
fts_list = fts_results.sort_by("_rowid").to_pylist()
i, j = 0, 0
while i < len(vector_list):
if j >= len(fts_list):
for vi in vector_list[i:]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
break
vi = vector_list[i]
fj = fts_list[j]
# invert the fts score from relevance to distance
inverted_fts_score = self._invert_score(fj["_score"])
if vi["_rowid"] == fj["_rowid"]:
vi["_relevance_score"] = self._combine_score(
vi["_distance"], inverted_fts_score
)
vi["_score"] = fj["_score"] # keep the original score
combined_list.append(vi)
i += 1
j += 1
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
i += 1
results = defaultdict()
for vector_result in vector_results.to_pylist():
results[vector_result["_rowid"]] = vector_result
for fts_result in fts_results.to_pylist():
row_id = fts_result["_rowid"]
if row_id in results:
results[row_id]["_score"] = fts_result["_score"]
else:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
j += 1
if j < len(fts_list) - 1:
for fj in fts_list[j:]:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
results[row_id] = fts_result
combined_list = []
for row_id, result in results.items():
vector_score = self._invert_score(result.get("_distance", fill))
fts_score = result.get("_score", fill)
result["_relevance_score"] = self._combine_score(vector_score, fts_score)
combined_list.append(result)
relevance_score_schema = pa.schema(
[
@@ -148,10 +128,10 @@ class LinearCombinationReranker(Reranker):
tbl = self._keep_relevance_score(tbl)
return tbl
def _combine_score(self, score1, score2):
def _combine_score(self, vector_score, fts_score):
# these scores represent distance
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score)
def _invert_score(self, score: float):
def _invert_score(self, dist: float):
# Invert the score between relevance and distance
return 1 - score
return 1 - dist

View File

@@ -3,6 +3,7 @@ import random
import lancedb
import numpy as np
import pyarrow as pa
import pytest
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
@@ -281,6 +282,31 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_linear_combination(tmp_path, use_tantivy):
reranker = LinearCombinationReranker()
vector_results = pa.Table.from_pydict(
{
"_rowid": [0, 1, 2, 3, 4],
"_distance": [0.1, 0.2, 0.3, 0.4, 0.5],
"_text": ["a", "b", "c", "d", "e"],
}
)
fts_results = pa.Table.from_pydict(
{
"_rowid": [1, 2, 3, 4, 5],
"_score": [0.1, 0.2, 0.3, 0.4, 0.5],
"_text": ["b", "c", "d", "e", "f"],
}
)
combined_results = reranker.merge_results(vector_results, fts_results, 1.0)
assert len(combined_results) == 6
assert "_rowid" in combined_results.column_names
assert "_text" in combined_results.column_names
assert "_distance" not in combined_results.column_names
assert "_score" not in combined_results.column_names
assert "_relevance_score" in combined_results.column_names
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)