diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 1aa8d6a1..f8d45eed 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from numpy import nan import pyarrow as pa @@ -95,43 +96,22 @@ class LinearCombinationReranker(Reranker): pa.array([nan] * len(vector_results), type=pa.float32()), ) return results - - # sort both input tables on _rowid - combined_list = [] - vector_list = vector_results.sort_by("_rowid").to_pylist() - fts_list = fts_results.sort_by("_rowid").to_pylist() - i, j = 0, 0 - while i < len(vector_list): - if j >= len(fts_list): - for vi in vector_list[i:]: - vi["_relevance_score"] = self._combine_score(vi["_distance"], fill) - combined_list.append(vi) - break - - vi = vector_list[i] - fj = fts_list[j] - # invert the fts score from relevance to distance - inverted_fts_score = self._invert_score(fj["_score"]) - if vi["_rowid"] == fj["_rowid"]: - vi["_relevance_score"] = self._combine_score( - vi["_distance"], inverted_fts_score - ) - vi["_score"] = fj["_score"] # keep the original score - combined_list.append(vi) - i += 1 - j += 1 - elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]: - vi["_relevance_score"] = self._combine_score(vi["_distance"], fill) - combined_list.append(vi) - i += 1 + results = defaultdict() + for vector_result in vector_results.to_pylist(): + results[vector_result["_rowid"]] = vector_result + for fts_result in fts_results.to_pylist(): + row_id = fts_result["_rowid"] + if row_id in results: + results[row_id]["_score"] = fts_result["_score"] else: - fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill) - combined_list.append(fj) - j += 1 - if j < len(fts_list) - 1: - for fj in fts_list[j:]: - fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill) - combined_list.append(fj) + results[row_id] = fts_result + + combined_list = [] + for row_id, result in results.items(): + vector_score = self._invert_score(result.get("_distance", fill)) + fts_score = result.get("_score", fill) + result["_relevance_score"] = self._combine_score(vector_score, fts_score) + combined_list.append(result) relevance_score_schema = pa.schema( [ @@ -148,10 +128,10 @@ class LinearCombinationReranker(Reranker): tbl = self._keep_relevance_score(tbl) return tbl - def _combine_score(self, score1, score2): + def _combine_score(self, vector_score, fts_score): # these scores represent distance - return 1 - (self.weight * score1 + (1 - self.weight) * score2) + return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score) - def _invert_score(self, score: float): + def _invert_score(self, dist: float): # Invert the score between relevance and distance - return 1 - score + return 1 - dist diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 4e1c6898..12689658 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -3,6 +3,7 @@ import random import lancedb import numpy as np +import pyarrow as pa import pytest from lancedb.conftest import MockTextEmbeddingFunction # noqa from lancedb.embeddings import EmbeddingFunctionRegistry @@ -281,6 +282,31 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): @pytest.mark.parametrize("use_tantivy", [True, False]) def test_linear_combination(tmp_path, use_tantivy): reranker = LinearCombinationReranker() + + vector_results = pa.Table.from_pydict( + { + "_rowid": [0, 1, 2, 3, 4], + "_distance": [0.1, 0.2, 0.3, 0.4, 0.5], + "_text": ["a", "b", "c", "d", "e"], + } + ) + + fts_results = pa.Table.from_pydict( + { + "_rowid": [1, 2, 3, 4, 5], + "_score": [0.1, 0.2, 0.3, 0.4, 0.5], + "_text": ["b", "c", "d", "e", "f"], + } + ) + + combined_results = reranker.merge_results(vector_results, fts_results, 1.0) + assert len(combined_results) == 6 + assert "_rowid" in combined_results.column_names + assert "_text" in combined_results.column_names + assert "_distance" not in combined_results.column_names + assert "_score" not in combined_results.column_names + assert "_relevance_score" in combined_results.column_names + _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)