mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
fix(python): preserve original distance and score in hybrid queries (#2061)
Fixes #2031 When we do hybrid search, we normalize the scores. We do this calculation in-place, because the Rerankers expect the `_distance` and `_score` columns to be the normalized ones. So I've changed the logic so that we restore the original distance and scores by matching on row ids.
This commit is contained in:
@@ -20,6 +20,7 @@ import asyncio
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
@@ -1189,18 +1190,52 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score")
|
||||
|
||||
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||
vector_results = LanceHybridQueryBuilder._normalize_scores(
|
||||
vector_results, "_distance"
|
||||
)
|
||||
# We check whether the results (vector and FTS) are empty, because when
|
||||
# they are, they often are missing the _rowid column, which causes an error
|
||||
if vector_results.num_rows > 0:
|
||||
distance_i = vector_results.column_names.index("_distance")
|
||||
original_distances = vector_results.column(distance_i)
|
||||
original_distance_row_ids = vector_results.column("_rowid")
|
||||
vector_results = vector_results.set_column(
|
||||
distance_i,
|
||||
vector_results.field(distance_i),
|
||||
LanceHybridQueryBuilder._normalize_scores(original_distances),
|
||||
)
|
||||
|
||||
# In fts higher scores represent relevance. Not inverting them here as
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = LanceHybridQueryBuilder._normalize_scores(fts_results, "_score")
|
||||
if fts_results.num_rows > 0:
|
||||
score_i = fts_results.column_names.index("_score")
|
||||
original_scores = fts_results.column(score_i)
|
||||
original_score_row_ids = fts_results.column("_rowid")
|
||||
fts_results = fts_results.set_column(
|
||||
score_i,
|
||||
fts_results.field(score_i),
|
||||
LanceHybridQueryBuilder._normalize_scores(original_scores),
|
||||
)
|
||||
|
||||
results = reranker.rerank_hybrid(fts_query, vector_results, fts_results)
|
||||
|
||||
check_reranker_result(results)
|
||||
|
||||
if "_distance" in results.column_names:
|
||||
# restore the original distances
|
||||
indices = pc.index_in(
|
||||
results["_rowid"], original_distance_row_ids, skip_nulls=True
|
||||
)
|
||||
original_distances = pc.take(original_distances, indices)
|
||||
distance_i = results.column_names.index("_distance")
|
||||
results = results.set_column(distance_i, "_distance", original_distances)
|
||||
|
||||
if "_score" in results.column_names:
|
||||
# restore the original scores
|
||||
indices = pc.index_in(
|
||||
results["_rowid"], original_score_row_ids, skip_nulls=True
|
||||
)
|
||||
original_scores = pc.take(original_scores, indices)
|
||||
score_i = results.column_names.index("_score")
|
||||
results = results.set_column(score_i, "_score", original_scores)
|
||||
|
||||
results = results.slice(length=limit)
|
||||
|
||||
if not with_row_ids:
|
||||
@@ -1230,28 +1265,23 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(results: pa.Table, column: str, invert=False):
|
||||
if len(results) == 0:
|
||||
return results
|
||||
# Get the _score column from results
|
||||
scores = results.column(column).to_numpy()
|
||||
def _normalize_scores(scores: pa.Array, invert=False) -> pa.Array:
|
||||
if len(scores) == 0:
|
||||
return scores
|
||||
# normalize the scores by subtracting the min and dividing by the max
|
||||
max, min = np.max(scores), np.min(scores)
|
||||
if np.isclose(max, min):
|
||||
rng = max
|
||||
else:
|
||||
rng = max - min
|
||||
# If rng is 0 then min and max are both 0 and so we can leave the scores as is
|
||||
if rng != 0:
|
||||
scores = (scores - min) / rng
|
||||
min, max = pc.min_max(scores).values()
|
||||
rng = pc.subtract(max, min)
|
||||
|
||||
if not pc.equal(rng, pa.scalar(0.0)).as_py():
|
||||
scores = pc.divide(pc.subtract(scores, min), rng)
|
||||
elif not pc.equal(max, pa.scalar(0.0)).as_py():
|
||||
# If rng is 0, then we at least want the scores to be 0
|
||||
scores = pc.subtract(scores, min)
|
||||
|
||||
if invert:
|
||||
scores = 1 - scores
|
||||
# replace the _score column with the ranks
|
||||
_score_idx = results.column_names.index(column)
|
||||
results = results.set_column(
|
||||
_score_idx, column, pa.array(scores, type=pa.float32())
|
||||
)
|
||||
return results
|
||||
scores = pc.subtract(1, scores)
|
||||
|
||||
return scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
import lancedb
|
||||
|
||||
from lancedb.query import LanceHybridQueryBuilder
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
@@ -110,3 +112,23 @@ async def test_explain_plan(table: AsyncTable):
|
||||
assert "KNNVectorDistance" in plan
|
||||
assert "FTS Search Plan" in plan
|
||||
assert "LanceScan" in plan
|
||||
|
||||
|
||||
def test_normalize_scores():
|
||||
cases = [
|
||||
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),
|
||||
(pa.array([2.0, 10.0, 20.0]), pa.array([0.0, 8.0 / 18.0, 1.0])),
|
||||
(pa.array([0.0, 0.0, 0.0]), pa.array([0.0, 0.0, 0.0])),
|
||||
(pa.array([10.0, 9.9999999999999]), pa.array([0.0, 0.0])),
|
||||
]
|
||||
|
||||
for input, expected in cases:
|
||||
for invert in [True, False]:
|
||||
result = LanceHybridQueryBuilder._normalize_scores(input, invert)
|
||||
|
||||
if invert:
|
||||
expected = pc.subtract(1.0, expected)
|
||||
|
||||
assert pc.equal(
|
||||
result, expected
|
||||
), f"Expected {expected} but got {result} for invert={invert}"
|
||||
|
||||
@@ -4,6 +4,7 @@ import random
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
@@ -316,6 +317,55 @@ def test_rrf_reranker(tmp_path, use_tantivy):
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
|
||||
|
||||
def test_rrf_reranker_distance():
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(32 * 1024).cast(pa.float32()), 32
|
||||
),
|
||||
"text": pa.array(["hello"] * 1024),
|
||||
}
|
||||
)
|
||||
db = lancedb.connect("memory://")
|
||||
table = db.create_table("test", data)
|
||||
|
||||
table.create_index(num_partitions=1, num_sub_vectors=2)
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
|
||||
reranker = RRFReranker(return_score="all")
|
||||
|
||||
hybrid_results = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector([0.0] * 32)
|
||||
.text("hello")
|
||||
.with_row_id(True)
|
||||
.rerank(reranker)
|
||||
.to_list()
|
||||
)
|
||||
hybrid_distances = {row["_rowid"]: row["_distance"] for row in hybrid_results}
|
||||
hybrid_scores = {row["_rowid"]: row["_score"] for row in hybrid_results}
|
||||
|
||||
vector_results = table.search([0.0] * 32).with_row_id(True).to_list()
|
||||
vector_distances = {row["_rowid"]: row["_distance"] for row in vector_results}
|
||||
|
||||
fts_results = table.search("hello", query_type="fts").with_row_id(True).to_list()
|
||||
fts_scores = {row["_rowid"]: row["_score"] for row in fts_results}
|
||||
|
||||
found_match = False
|
||||
for rowid, distance in hybrid_distances.items():
|
||||
if rowid in vector_distances:
|
||||
found_match = True
|
||||
assert distance == vector_distances[rowid], "Distance mismatch"
|
||||
assert found_match, "No results matched between hybrid and vector search"
|
||||
|
||||
found_match = False
|
||||
for rowid, score in hybrid_scores.items():
|
||||
if rowid in fts_scores and fts_scores[rowid] is not None:
|
||||
found_match = True
|
||||
assert score == fts_scores[rowid], "Score mismatch"
|
||||
assert found_match, "No results matched between hybrid and fts search"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user