mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 11:52:57 +00:00
chore: choose appropriate args for concat_table based on pyarrow version & refactor reranker tests (#1455)
This commit is contained in:
@@ -35,7 +35,7 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
def _compute_one_embedding(self, row):
|
def _compute_one_embedding(self, row):
|
||||||
emb = np.array([float(hash(c)) for c in row[:10]])
|
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||||
emb /= np.linalg.norm(emb)
|
emb /= np.linalg.norm(emb)
|
||||||
return emb
|
return emb if len(emb) == 10 else [0] * 10
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
return 10
|
return 10
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
|
ARROW_VERSION = Version(pa.__version__)
|
||||||
|
|
||||||
|
|
||||||
class Reranker(ABC):
|
class Reranker(ABC):
|
||||||
def __init__(self, return_score: str = "relevance"):
|
def __init__(self, return_score: str = "relevance"):
|
||||||
@@ -23,6 +26,11 @@ class Reranker(ABC):
|
|||||||
if return_score not in ["relevance", "all"]:
|
if return_score not in ["relevance", "all"]:
|
||||||
raise ValueError("score must be either 'relevance' or 'all'")
|
raise ValueError("score must be either 'relevance' or 'all'")
|
||||||
self.score = return_score
|
self.score = return_score
|
||||||
|
# Set the merge args based on the arrow version here to avoid checking it at
|
||||||
|
# each query
|
||||||
|
self._concat_tables_args = {"promote_options": "default"}
|
||||||
|
if ARROW_VERSION.major <= 13:
|
||||||
|
self._concat_tables_args = {"promote": True}
|
||||||
|
|
||||||
def rerank_vector(
|
def rerank_vector(
|
||||||
self,
|
self,
|
||||||
@@ -119,7 +127,9 @@ class Reranker(ABC):
|
|||||||
fts_results : pa.Table
|
fts_results : pa.Table
|
||||||
The results from the FTS search
|
The results from the FTS search
|
||||||
"""
|
"""
|
||||||
combined = pa.concat_tables([vector_results, fts_results], promote=True)
|
combined = pa.concat_tables(
|
||||||
|
[vector_results, fts_results], **self._concat_tables_args
|
||||||
|
)
|
||||||
row_id = combined.column("_rowid")
|
row_id = combined.column("_rowid")
|
||||||
|
|
||||||
# deduplicate
|
# deduplicate
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from lancedb.rerankers import (
|
|||||||
ColbertReranker,
|
ColbertReranker,
|
||||||
CrossEncoderReranker,
|
CrossEncoderReranker,
|
||||||
OpenaiReranker,
|
OpenaiReranker,
|
||||||
|
JinaReranker,
|
||||||
)
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
@@ -82,6 +83,63 @@ def get_test_table(tmp_path):
|
|||||||
return table, MyTable
|
return table, MyTable
|
||||||
|
|
||||||
|
|
||||||
|
def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||||
|
# Hybrid search setting
|
||||||
|
result1 = (
|
||||||
|
table.search(query, query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=reranker)
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search(query, query_type="hybrid")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
err = (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
# Vector search setting
|
||||||
|
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
assert len(result) == 30
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
result_explicit = (
|
||||||
|
table.search(query_vector)
|
||||||
|
.rerank(reranker=reranker, query_string=query)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result_explicit) == 30
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError
|
||||||
|
): # This raises an error because vector query is provided without reanking query
|
||||||
|
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
|
||||||
|
# FTS search setting
|
||||||
|
result = (
|
||||||
|
table.search(query, query_type="fts")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
def test_linear_combination(tmp_path):
|
def test_linear_combination(tmp_path):
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
# The default reranker
|
# The default reranker
|
||||||
@@ -126,185 +184,21 @@ def test_cohere_reranker(tmp_path):
|
|||||||
pytest.importorskip("cohere")
|
pytest.importorskip("cohere")
|
||||||
reranker = CohereReranker()
|
reranker = CohereReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
# Hybrid search setting
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=CohereReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_encoder_reranker(tmp_path):
|
def test_cross_encoder_reranker(tmp_path):
|
||||||
pytest.importorskip("sentence_transformers")
|
pytest.importorskip("sentence_transformers")
|
||||||
reranker = CrossEncoderReranker()
|
reranker = CrossEncoderReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query), query_type="hybrid")
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_colbert_reranker(tmp_path):
|
def test_colbert_reranker(tmp_path):
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
reranker = ColbertReranker()
|
reranker = ColbertReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
# test explicit hybrid query
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@@ -314,58 +208,14 @@ def test_openai_reranker(tmp_path):
|
|||||||
pytest.importorskip("openai")
|
pytest.importorskip("openai")
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
reranker = OpenaiReranker()
|
reranker = OpenaiReranker()
|
||||||
result1 = (
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=OpenaiReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
# test explicit hybrid query
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
|
||||||
err = (
|
)
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
def test_jina_reranker(tmp_path):
|
||||||
"represents the relevance of the result to the query & should "
|
pytest.importorskip("jina")
|
||||||
"be descending."
|
table, schema = get_test_table(tmp_path)
|
||||||
)
|
reranker = JinaReranker()
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|||||||
Reference in New Issue
Block a user