fix(python): preserve original distance and score in hybrid queries (#2061)

Fixes #2031

When we do hybrid search, we normalize the scores. We do this
calculation in-place, because the Rerankers expect the `_distance` and
`_score` columns to be the normalized ones. So I've changed the logic so
that we restore the original distance and scores by matching on row ids.
This commit is contained in:
Will Jones
2025-01-23 13:54:26 -08:00
committed by GitHub
parent 52b79d2b1e
commit 28e1b70e4b
3 changed files with 126 additions and 24 deletions

View File

@@ -3,7 +3,9 @@
import lancedb
from lancedb.query import LanceHybridQueryBuilder
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
@@ -110,3 +112,23 @@ async def test_explain_plan(table: AsyncTable):
assert "KNNVectorDistance" in plan
assert "FTS Search Plan" in plan
assert "LanceScan" in plan
def test_normalize_scores():
cases = [
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),
(pa.array([2.0, 10.0, 20.0]), pa.array([0.0, 8.0 / 18.0, 1.0])),
(pa.array([0.0, 0.0, 0.0]), pa.array([0.0, 0.0, 0.0])),
(pa.array([10.0, 9.9999999999999]), pa.array([0.0, 0.0])),
]
for input, expected in cases:
for invert in [True, False]:
result = LanceHybridQueryBuilder._normalize_scores(input, invert)
if invert:
expected = pc.subtract(1.0, expected)
assert pc.equal(
result, expected
), f"Expected {expected} but got {result} for invert={invert}"

View File

@@ -4,6 +4,7 @@ import random
import lancedb
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
@@ -316,6 +317,55 @@ def test_rrf_reranker(tmp_path, use_tantivy):
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
def test_rrf_reranker_distance():
data = pa.table(
{
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(32 * 1024).cast(pa.float32()), 32
),
"text": pa.array(["hello"] * 1024),
}
)
db = lancedb.connect("memory://")
table = db.create_table("test", data)
table.create_index(num_partitions=1, num_sub_vectors=2)
table.create_fts_index("text", use_tantivy=False)
reranker = RRFReranker(return_score="all")
hybrid_results = (
table.search(query_type="hybrid")
.vector([0.0] * 32)
.text("hello")
.with_row_id(True)
.rerank(reranker)
.to_list()
)
hybrid_distances = {row["_rowid"]: row["_distance"] for row in hybrid_results}
hybrid_scores = {row["_rowid"]: row["_score"] for row in hybrid_results}
vector_results = table.search([0.0] * 32).with_row_id(True).to_list()
vector_distances = {row["_rowid"]: row["_distance"] for row in vector_results}
fts_results = table.search("hello", query_type="fts").with_row_id(True).to_list()
fts_scores = {row["_rowid"]: row["_score"] for row in fts_results}
found_match = False
for rowid, distance in hybrid_distances.items():
if rowid in vector_distances:
found_match = True
assert distance == vector_distances[rowid], "Distance mismatch"
assert found_match, "No results matched between hybrid and vector search"
found_match = False
for rowid, score in hybrid_scores.items():
if rowid in fts_scores and fts_scores[rowid] is not None:
found_match = True
assert score == fts_scores[rowid], "Score mismatch"
assert found_match, "No results matched between hybrid and fts search"
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)