mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 20:32:59 +00:00
feat(python): multi-vector reranking support (#1481)
Currently targeting the following usage:
```
from lancedb.rerankers import CrossEncoderReranker
reranker = CrossEncoderReranker()
query = "hello"
res1 = table.search(query, vector_column_name="vector").limit(3)
res2 = table.search(query, vector_column_name="text_vector").limit(3)
res3 = table.search(query, vector_column_name="meta_vector").limit(3)
reranked = reranker.rerank_multivector(
[res1, res2, res3],
deduplicate=True,
query=query # some reranker models need query
)
```
- This implements rerank_multivector function in the base reranker so
that all rerankers that implement rerank_vector will automatically have
multivector reranking support
- Special case for RRF reranker that just uses its existing
rerank_hybrid fcn to multi-vector reranking.
---------
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from packaging.version import Version
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..table import LanceVectorQueryBuilder
|
||||
|
||||
ARROW_VERSION = Version(pa.__version__)
|
||||
|
||||
|
||||
@@ -130,12 +134,94 @@ class Reranker(ABC):
|
||||
combined = pa.concat_tables(
|
||||
[vector_results, fts_results], **self._concat_tables_args
|
||||
)
|
||||
row_id = combined.column("_rowid")
|
||||
|
||||
# deduplicate
|
||||
mask = np.full((combined.shape[0]), False)
|
||||
_, mask_indices = np.unique(np.array(row_id), return_index=True)
|
||||
mask[mask_indices] = True
|
||||
combined = combined.filter(mask=mask)
|
||||
combined = self._deduplicate(combined)
|
||||
|
||||
return combined
|
||||
|
||||
def rerank_multivector(
|
||||
self,
|
||||
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
|
||||
query: Union[str, None], # Some rerankers might not need the query
|
||||
deduplicate: bool = False,
|
||||
):
|
||||
"""
|
||||
This is a rerank function that receives the results from multiple
|
||||
vector searches. For example, this can be used to combine the
|
||||
results of two vector searches with different embeddings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector_results : List[pa.Table] or List[LanceVectorQueryBuilder]
|
||||
The results from the vector search. Either accepts the query builder
|
||||
if the results haven't been executed yet or the results in arrow format.
|
||||
query : str or None,
|
||||
The input query. Some rerankers might not need the query to rerank.
|
||||
In that case, it can be set to None explicitly. This is inteded to
|
||||
be handled by the reranker implementations.
|
||||
deduplicate : bool, optional
|
||||
Whether to deduplicate the results based on the `_rowid` column,
|
||||
by default False. Requires `_rowid` to be present in the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The reranked results
|
||||
"""
|
||||
vector_results = (
|
||||
[vector_results] if not isinstance(vector_results, list) else vector_results
|
||||
)
|
||||
|
||||
# Make sure all elements are of the same type
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
"All elements in vector_results should be of the same type"
|
||||
)
|
||||
|
||||
# avoids circular import
|
||||
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
|
||||
vector_results = [result.to_arrow() for result in vector_results]
|
||||
elif not isinstance(vector_results[0], pa.Table):
|
||||
raise ValueError(
|
||||
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
|
||||
)
|
||||
|
||||
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
|
||||
|
||||
reranked = self.rerank_vector(query, combined)
|
||||
|
||||
# TODO: Allow custom deduplicators here.
|
||||
# currently, this'll just keep the first instance.
|
||||
if deduplicate:
|
||||
if "_rowid" not in combined.column_names:
|
||||
raise ValueError(
|
||||
"'_rowid' is required for deduplication. \
|
||||
add _rowid to search results like this: \
|
||||
`search().with_row_id(True)`"
|
||||
)
|
||||
reranked = self._deduplicate(reranked)
|
||||
|
||||
return reranked
|
||||
|
||||
def _deduplicate(self, table: pa.Table):
|
||||
"""
|
||||
Deduplicate the table based on the `_rowid` column.
|
||||
"""
|
||||
row_id = table.column("_rowid")
|
||||
|
||||
# deduplicate
|
||||
mask = np.full((table.shape[0]), False)
|
||||
_, mask_indices = np.unique(np.array(row_id), return_index=True)
|
||||
mask[mask_indices] = True
|
||||
deduped_table = table.filter(mask=mask)
|
||||
|
||||
return deduped_table
|
||||
|
||||
def _keep_relevance_score(self, combined_results: pa.Table):
|
||||
if self.score == "relevance":
|
||||
if "score" in combined_results.column_names:
|
||||
combined_results = combined_results.drop_columns(["score"])
|
||||
if "_distance" in combined_results.column_names:
|
||||
combined_results = combined_results.drop_columns(["_distance"])
|
||||
return combined_results
|
||||
|
||||
@@ -88,7 +88,7 @@ class CohereReranker(Reranker):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for cohere reranker"
|
||||
|
||||
@@ -73,7 +73,7 @@ class ColbertReranker(Reranker):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
|
||||
@@ -66,7 +66,7 @@ class CrossEncoderReranker(Reranker):
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
|
||||
@@ -92,7 +92,7 @@ class JinaReranker(Reranker):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for JinaReranker"
|
||||
|
||||
@@ -103,7 +103,7 @@ class LinearCombinationReranker(Reranker):
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
if self.score == "relevance":
|
||||
tbl = tbl.drop_columns(["score", "_distance"])
|
||||
tbl = self._keep_relevance_score(tbl)
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, score1, score2):
|
||||
|
||||
@@ -84,7 +84,7 @@ class OpenaiReranker(Reranker):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
import pyarrow as pa
|
||||
|
||||
from collections import defaultdict
|
||||
from .base import Reranker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..table import LanceVectorQueryBuilder
|
||||
|
||||
|
||||
class RRFReranker(Reranker):
|
||||
"""
|
||||
@@ -55,6 +59,46 @@ class RRFReranker(Reranker):
|
||||
)
|
||||
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_multivector(
|
||||
self,
|
||||
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
|
||||
query: str = None,
|
||||
deduplicate: bool = True, # noqa: F821 # TODO: automatically deduplicates
|
||||
):
|
||||
"""
|
||||
Overridden method to rerank the results from multiple vector searches.
|
||||
This leverages the RRF hybrid reranking algorithm to combine the
|
||||
results from multiple vector searches as it doesn't support reranking
|
||||
vector results individually.
|
||||
"""
|
||||
# Make sure all elements are of the same type
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
"All elements in vector_results should be of the same type"
|
||||
)
|
||||
|
||||
# avoid circular import
|
||||
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
|
||||
vector_results = [result.to_arrow() for result in vector_results]
|
||||
elif not isinstance(vector_results[0], pa.Table):
|
||||
raise ValueError(
|
||||
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
|
||||
)
|
||||
|
||||
# _rowid is required for RRF reranking
|
||||
if not all("_rowid" in result.column_names for result in vector_results):
|
||||
raise ValueError(
|
||||
"'_rowid' is required for deduplication. \
|
||||
add _rowid to search results like this: \
|
||||
`search().with_row_id(True)`"
|
||||
)
|
||||
|
||||
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
|
||||
empty_table = pa.Table.from_arrays([], names=[])
|
||||
reranked = self.rerank_hybrid(query, combined, empty_table)
|
||||
|
||||
return reranked
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
@@ -25,10 +26,13 @@ def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
meta: str = meta_emb.SourceField()
|
||||
meta_vector: Vector(meta_emb.ndims()) = meta_emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
@@ -77,7 +81,12 @@ def get_test_table(tmp_path):
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
table.add(
|
||||
[
|
||||
{"text": p, "meta": phrases[random.randint(0, len(phrases) - 1)]}
|
||||
for p in phrases
|
||||
]
|
||||
)
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
@@ -88,12 +97,12 @@ def get_test_table(tmp_path):
|
||||
def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
# Hybrid search setting
|
||||
result1 = (
|
||||
table.search(query, query_type="hybrid")
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search(query, query_type="hybrid")
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
@@ -101,7 +110,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
@@ -116,11 +125,16 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
result = (
|
||||
table.search(query, vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
@@ -129,11 +143,13 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
table.search(query_vector, vector_column_name="vector").rerank(
|
||||
reranker=reranker
|
||||
).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
table.search(query, query_type="fts", vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
@@ -141,22 +157,48 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
rs2 = (
|
||||
table.search(query, vector_column_name="meta_vector")
|
||||
.limit(10)
|
||||
.with_row_id(True)
|
||||
)
|
||||
result = reranker.rerank_multivector([rs1, rs2], query)
|
||||
assert len(result) == 20
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
|
||||
def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
table.search(
|
||||
"Our father who art in heaven",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
)
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||
table.search(
|
||||
"Our father who art in heaven.",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
)
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven..", query_type="hybrid"
|
||||
"Our father who art in heaven..",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
).to_pydantic(schema)
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
@@ -164,7 +206,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user