mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
feat(python): multi-vector reranking support (#1481)
Currently targeting the following usage:
```
from lancedb.rerankers import CrossEncoderReranker
reranker = CrossEncoderReranker()
query = "hello"
res1 = table.search(query, vector_column_name="vector").limit(3)
res2 = table.search(query, vector_column_name="text_vector").limit(3)
res3 = table.search(query, vector_column_name="meta_vector").limit(3)
reranked = reranker.rerank_multivector(
[res1, res2, res3],
deduplicate=True,
query=query # some reranker models need query
)
```
- This implements rerank_multivector function in the base reranker so
that all rerankers that implement rerank_vector will automatically have
multivector reranking support
- Special case for RRF reranker that just uses its existing
rerank_hybrid fcn to multi-vector reranking.
---------
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
@@ -25,10 +26,13 @@ def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
meta: str = meta_emb.SourceField()
|
||||
meta_vector: Vector(meta_emb.ndims()) = meta_emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
@@ -77,7 +81,12 @@ def get_test_table(tmp_path):
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
table.add(
|
||||
[
|
||||
{"text": p, "meta": phrases[random.randint(0, len(phrases) - 1)]}
|
||||
for p in phrases
|
||||
]
|
||||
)
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
@@ -88,12 +97,12 @@ def get_test_table(tmp_path):
|
||||
def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
# Hybrid search setting
|
||||
result1 = (
|
||||
table.search(query, query_type="hybrid")
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search(query, query_type="hybrid")
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
@@ -101,7 +110,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
@@ -116,11 +125,16 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
result = (
|
||||
table.search(query, vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
@@ -129,11 +143,13 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
table.search(query_vector, vector_column_name="vector").rerank(
|
||||
reranker=reranker
|
||||
).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
table.search(query, query_type="fts", vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
@@ -141,22 +157,48 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
rs2 = (
|
||||
table.search(query, vector_column_name="meta_vector")
|
||||
.limit(10)
|
||||
.with_row_id(True)
|
||||
)
|
||||
result = reranker.rerank_multivector([rs1, rs2], query)
|
||||
assert len(result) == 20
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
|
||||
def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
table.search(
|
||||
"Our father who art in heaven",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
)
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||
table.search(
|
||||
"Our father who art in heaven.",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
)
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven..", query_type="hybrid"
|
||||
"Our father who art in heaven..",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
).to_pydantic(schema)
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
@@ -164,7 +206,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user