mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
fix: linear reranker applies wrong score to combine (#2035)
related to #2014 this fixes: - linear reranker may lost some results if the merging consumes all vector results earlier than fts results - linear reranker inverts the fts score but only vector distance can be inverted --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from numpy import nan
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -95,43 +96,22 @@ class LinearCombinationReranker(Reranker):
|
||||
pa.array([nan] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
vector_list = vector_results.sort_by("_rowid").to_pylist()
|
||||
fts_list = fts_results.sort_by("_rowid").to_pylist()
|
||||
i, j = 0, 0
|
||||
while i < len(vector_list):
|
||||
if j >= len(fts_list):
|
||||
for vi in vector_list[i:]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
break
|
||||
|
||||
vi = vector_list[i]
|
||||
fj = fts_list[j]
|
||||
# invert the fts score from relevance to distance
|
||||
inverted_fts_score = self._invert_score(fj["_score"])
|
||||
if vi["_rowid"] == fj["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(
|
||||
vi["_distance"], inverted_fts_score
|
||||
)
|
||||
vi["_score"] = fj["_score"] # keep the original score
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
results = defaultdict()
|
||||
for vector_result in vector_results.to_pylist():
|
||||
results[vector_result["_rowid"]] = vector_result
|
||||
for fts_result in fts_results.to_pylist():
|
||||
row_id = fts_result["_rowid"]
|
||||
if row_id in results:
|
||||
results[row_id]["_score"] = fts_result["_score"]
|
||||
else:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
j += 1
|
||||
if j < len(fts_list) - 1:
|
||||
for fj in fts_list[j:]:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
results[row_id] = fts_result
|
||||
|
||||
combined_list = []
|
||||
for row_id, result in results.items():
|
||||
vector_score = self._invert_score(result.get("_distance", fill))
|
||||
fts_score = result.get("_score", fill)
|
||||
result["_relevance_score"] = self._combine_score(vector_score, fts_score)
|
||||
combined_list.append(result)
|
||||
|
||||
relevance_score_schema = pa.schema(
|
||||
[
|
||||
@@ -148,10 +128,10 @@ class LinearCombinationReranker(Reranker):
|
||||
tbl = self._keep_relevance_score(tbl)
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, score1, score2):
|
||||
def _combine_score(self, vector_score, fts_score):
|
||||
# these scores represent distance
|
||||
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
|
||||
return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score)
|
||||
|
||||
def _invert_score(self, score: float):
|
||||
def _invert_score(self, dist: float):
|
||||
# Invert the score between relevance and distance
|
||||
return 1 - score
|
||||
return 1 - dist
|
||||
|
||||
@@ -3,6 +3,7 @@ import random
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
@@ -281,6 +282,31 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_linear_combination(tmp_path, use_tantivy):
|
||||
reranker = LinearCombinationReranker()
|
||||
|
||||
vector_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0, 1, 2, 3, 4],
|
||||
"_distance": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"_text": ["a", "b", "c", "d", "e"],
|
||||
}
|
||||
)
|
||||
|
||||
fts_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [1, 2, 3, 4, 5],
|
||||
"_score": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"_text": ["b", "c", "d", "e", "f"],
|
||||
}
|
||||
)
|
||||
|
||||
combined_results = reranker.merge_results(vector_results, fts_results, 1.0)
|
||||
assert len(combined_results) == 6
|
||||
assert "_rowid" in combined_results.column_names
|
||||
assert "_text" in combined_results.column_names
|
||||
assert "_distance" not in combined_results.column_names
|
||||
assert "_score" not in combined_results.column_names
|
||||
assert "_relevance_score" in combined_results.column_names
|
||||
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user