mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat(python): Reranker DX improvements (#904)
- Most users might not know how to use `QueryBuilder` object. Instead we should just pass the string query. - Add new rerankers: Colbert, openai
This commit is contained in:
@@ -7,7 +7,12 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.rerankers import (
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user