feat(python): multi-vector reranking support (#1481)

Currently targeting the following usage:
```
from lancedb.rerankers import CrossEncoderReranker

reranker = CrossEncoderReranker()

query = "hello"

res1 = table.search(query, vector_column_name="vector").limit(3)
res2 = table.search(query, vector_column_name="text_vector").limit(3)
res3 = table.search(query, vector_column_name="meta_vector").limit(3)

reranked = reranker.rerank_multivector(
               [res1, res2, res3],  
              deduplicate=True,
              query=query # some reranker models need query
)
```
- This implements rerank_multivector function in the base reranker so
that all rerankers that implement rerank_vector will automatically have
multivector reranking support
- Special case for RRF reranker that just uses its existing
rerank_hybrid fcn to multi-vector reranking.

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Ayush Chaurasia
2024-08-07 01:45:46 +05:30
committed by GitHub
parent d07d7a5980
commit 4769d8eb76
9 changed files with 196 additions and 24 deletions

View File

@@ -1,9 +1,13 @@
from abc import ABC, abstractmethod
from packaging.version import Version
from typing import Union, List, TYPE_CHECKING
import numpy as np
import pyarrow as pa
if TYPE_CHECKING:
from ..table import LanceVectorQueryBuilder
ARROW_VERSION = Version(pa.__version__)
@@ -130,12 +134,94 @@ class Reranker(ABC):
combined = pa.concat_tables(
[vector_results, fts_results], **self._concat_tables_args
)
row_id = combined.column("_rowid")
# deduplicate
mask = np.full((combined.shape[0]), False)
_, mask_indices = np.unique(np.array(row_id), return_index=True)
mask[mask_indices] = True
combined = combined.filter(mask=mask)
combined = self._deduplicate(combined)
return combined
def rerank_multivector(
self,
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
query: Union[str, None], # Some rerankers might not need the query
deduplicate: bool = False,
):
"""
This is a rerank function that receives the results from multiple
vector searches. For example, this can be used to combine the
results of two vector searches with different embeddings.
Parameters
----------
vector_results : List[pa.Table] or List[LanceVectorQueryBuilder]
The results from the vector search. Either accepts the query builder
if the results haven't been executed yet or the results in arrow format.
query : str or None,
The input query. Some rerankers might not need the query to rerank.
In that case, it can be set to None explicitly. This is inteded to
be handled by the reranker implementations.
deduplicate : bool, optional
Whether to deduplicate the results based on the `_rowid` column,
by default False. Requires `_rowid` to be present in the results.
Returns
-------
pa.Table
The reranked results
"""
vector_results = (
[vector_results] if not isinstance(vector_results, list) else vector_results
)
# Make sure all elements are of the same type
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
raise ValueError(
"All elements in vector_results should be of the same type"
)
# avoids circular import
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
vector_results = [result.to_arrow() for result in vector_results]
elif not isinstance(vector_results[0], pa.Table):
raise ValueError(
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
)
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
reranked = self.rerank_vector(query, combined)
# TODO: Allow custom deduplicators here.
# currently, this'll just keep the first instance.
if deduplicate:
if "_rowid" not in combined.column_names:
raise ValueError(
"'_rowid' is required for deduplication. \
add _rowid to search results like this: \
`search().with_row_id(True)`"
)
reranked = self._deduplicate(reranked)
return reranked
def _deduplicate(self, table: pa.Table):
"""
Deduplicate the table based on the `_rowid` column.
"""
row_id = table.column("_rowid")
# deduplicate
mask = np.full((table.shape[0]), False)
_, mask_indices = np.unique(np.array(row_id), return_index=True)
mask[mask_indices] = True
deduped_table = table.filter(mask=mask)
return deduped_table
def _keep_relevance_score(self, combined_results: pa.Table):
if self.score == "relevance":
if "score" in combined_results.column_names:
combined_results = combined_results.drop_columns(["score"])
if "_distance" in combined_results.column_names:
combined_results = combined_results.drop_columns(["_distance"])
return combined_results

View File

@@ -88,7 +88,7 @@ class CohereReranker(Reranker):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for cohere reranker"

View File

@@ -73,7 +73,7 @@ class ColbertReranker(Reranker):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"

View File

@@ -66,7 +66,7 @@ class CrossEncoderReranker(Reranker):
combined_results = self._rerank(combined_results, query)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for CrossEncoderReranker"

View File

@@ -92,7 +92,7 @@ class JinaReranker(Reranker):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for JinaReranker"

View File

@@ -103,7 +103,7 @@ class LinearCombinationReranker(Reranker):
[("_relevance_score", "descending")]
)
if self.score == "relevance":
tbl = tbl.drop_columns(["score", "_distance"])
tbl = self._keep_relevance_score(tbl)
return tbl
def _combine_score(self, score1, score2):

View File

@@ -84,7 +84,7 @@ class OpenaiReranker(Reranker):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"

View File

@@ -1,8 +1,12 @@
from typing import Union, List, TYPE_CHECKING
import pyarrow as pa
from collections import defaultdict
from .base import Reranker
if TYPE_CHECKING:
from ..table import LanceVectorQueryBuilder
class RRFReranker(Reranker):
"""
@@ -55,6 +59,46 @@ class RRFReranker(Reranker):
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
combined_results = self._keep_relevance_score(combined_results)
return combined_results
def rerank_multivector(
self,
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
query: str = None,
deduplicate: bool = True, # noqa: F821 # TODO: automatically deduplicates
):
"""
Overridden method to rerank the results from multiple vector searches.
This leverages the RRF hybrid reranking algorithm to combine the
results from multiple vector searches as it doesn't support reranking
vector results individually.
"""
# Make sure all elements are of the same type
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
raise ValueError(
"All elements in vector_results should be of the same type"
)
# avoid circular import
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
vector_results = [result.to_arrow() for result in vector_results]
elif not isinstance(vector_results[0], pa.Table):
raise ValueError(
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
)
# _rowid is required for RRF reranking
if not all("_rowid" in result.column_names for result in vector_results):
raise ValueError(
"'_rowid' is required for deduplication. \
add _rowid to search results like this: \
`search().with_row_id(True)`"
)
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
empty_table = pa.Table.from_arrays([], names=[])
reranked = self.rerank_hybrid(query, combined, empty_table)
return reranked

View File

@@ -1,4 +1,5 @@
import os
import random
import lancedb
import numpy as np
@@ -25,10 +26,13 @@ def get_test_table(tmp_path):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
meta: str = meta_emb.SourceField()
meta_vector: Vector(meta_emb.ndims()) = meta_emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
@@ -77,7 +81,12 @@ def get_test_table(tmp_path):
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
table.add(
[
{"text": p, "meta": phrases[random.randint(0, len(phrases) - 1)]}
for p in phrases
]
)
# Create a fts index
table.create_fts_index("text")
@@ -88,12 +97,12 @@ def get_test_table(tmp_path):
def _run_test_reranker(reranker, table, query, query_vector, schema):
# Hybrid search setting
result1 = (
table.search(query, query_type="hybrid")
table.search(query, query_type="hybrid", vector_column_name="vector")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search(query, query_type="hybrid")
table.search(query, query_type="hybrid", vector_column_name="vector")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
@@ -101,7 +110,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search((query_vector, query), vector_column_name="vector")
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
@@ -116,11 +125,16 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
result = (
table.search(query, vector_column_name="vector")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
@@ -129,11 +143,13 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
table.search(query_vector, vector_column_name="vector").rerank(
reranker=reranker
).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
table.search(query, query_type="fts", vector_column_name="vector")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
@@ -141,22 +157,48 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
rs2 = (
table.search(query, vector_column_name="meta_vector")
.limit(10)
.with_row_id(True)
)
result = reranker.rerank_multivector([rs1, rs2], query)
assert len(result) == 20
result_deduped = reranker.rerank_multivector(
[rs1, rs2, rs1], query, deduplicate=True
)
assert len(result_deduped) < 20
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
assert len(result) == 20 and result == result_arrow
def _run_test_hybrid_reranker(reranker, tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
table.search(
"Our father who art in heaven",
query_type="hybrid",
vector_column_name="vector",
)
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search("Our father who art in heaven.", query_type="hybrid")
table.search(
"Our father who art in heaven.",
query_type="hybrid",
vector_column_name="vector",
)
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..", query_type="hybrid"
"Our father who art in heaven..",
query_type="hybrid",
vector_column_name="vector",
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
@@ -164,7 +206,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search((query_vector, query), vector_column_name="vector")
.limit(30)
.rerank(normalize="score")
.to_arrow()