From 4769d8eb76db8468feb2846dacb21812e4b4bdea Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 7 Aug 2024 01:45:46 +0530 Subject: [PATCH] feat(python): multi-vector reranking support (#1481) Currently targeting the following usage: ``` from lancedb.rerankers import CrossEncoderReranker reranker = CrossEncoderReranker() query = "hello" res1 = table.search(query, vector_column_name="vector").limit(3) res2 = table.search(query, vector_column_name="text_vector").limit(3) res3 = table.search(query, vector_column_name="meta_vector").limit(3) reranked = reranker.rerank_multivector( [res1, res2, res3], deduplicate=True, query=query # some reranker models need query ) ``` - This implements rerank_multivector function in the base reranker so that all rerankers that implement rerank_vector will automatically have multivector reranking support - Special case for RRF reranker that just uses its existing rerank_hybrid fcn to multi-vector reranking. --------- Co-authored-by: Weston Pace --- python/python/lancedb/rerankers/base.py | 96 ++++++++++++++++++- python/python/lancedb/rerankers/cohere.py | 2 +- python/python/lancedb/rerankers/colbert.py | 2 +- .../python/lancedb/rerankers/cross_encoder.py | 2 +- python/python/lancedb/rerankers/jinaai.py | 2 +- .../lancedb/rerankers/linear_combination.py | 2 +- python/python/lancedb/rerankers/openai.py | 2 +- python/python/lancedb/rerankers/rrf.py | 46 ++++++++- python/python/tests/test_rerankers.py | 66 ++++++++++--- 9 files changed, 196 insertions(+), 24 deletions(-) diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 75a767f4..85536e1c 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod from packaging.version import Version +from typing import Union, List, TYPE_CHECKING import numpy as np import pyarrow as pa +if TYPE_CHECKING: + from ..table import LanceVectorQueryBuilder + ARROW_VERSION = Version(pa.__version__) @@ -130,12 +134,94 @@ class Reranker(ABC): combined = pa.concat_tables( [vector_results, fts_results], **self._concat_tables_args ) - row_id = combined.column("_rowid") # deduplicate - mask = np.full((combined.shape[0]), False) - _, mask_indices = np.unique(np.array(row_id), return_index=True) - mask[mask_indices] = True - combined = combined.filter(mask=mask) + combined = self._deduplicate(combined) return combined + + def rerank_multivector( + self, + vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]], + query: Union[str, None], # Some rerankers might not need the query + deduplicate: bool = False, + ): + """ + This is a rerank function that receives the results from multiple + vector searches. For example, this can be used to combine the + results of two vector searches with different embeddings. + + Parameters + ---------- + vector_results : List[pa.Table] or List[LanceVectorQueryBuilder] + The results from the vector search. Either accepts the query builder + if the results haven't been executed yet or the results in arrow format. + query : str or None, + The input query. Some rerankers might not need the query to rerank. + In that case, it can be set to None explicitly. This is inteded to + be handled by the reranker implementations. + deduplicate : bool, optional + Whether to deduplicate the results based on the `_rowid` column, + by default False. Requires `_rowid` to be present in the results. + + Returns + ------- + pa.Table + The reranked results + """ + vector_results = ( + [vector_results] if not isinstance(vector_results, list) else vector_results + ) + + # Make sure all elements are of the same type + if not all(isinstance(v, type(vector_results[0])) for v in vector_results): + raise ValueError( + "All elements in vector_results should be of the same type" + ) + + # avoids circular import + if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder": + vector_results = [result.to_arrow() for result in vector_results] + elif not isinstance(vector_results[0], pa.Table): + raise ValueError( + "vector_results should be a list of pa.Table or LanceVectorQueryBuilder" + ) + + combined = pa.concat_tables(vector_results, **self._concat_tables_args) + + reranked = self.rerank_vector(query, combined) + + # TODO: Allow custom deduplicators here. + # currently, this'll just keep the first instance. + if deduplicate: + if "_rowid" not in combined.column_names: + raise ValueError( + "'_rowid' is required for deduplication. \ + add _rowid to search results like this: \ + `search().with_row_id(True)`" + ) + reranked = self._deduplicate(reranked) + + return reranked + + def _deduplicate(self, table: pa.Table): + """ + Deduplicate the table based on the `_rowid` column. + """ + row_id = table.column("_rowid") + + # deduplicate + mask = np.full((table.shape[0]), False) + _, mask_indices = np.unique(np.array(row_id), return_index=True) + mask[mask_indices] = True + deduped_table = table.filter(mask=mask) + + return deduped_table + + def _keep_relevance_score(self, combined_results: pa.Table): + if self.score == "relevance": + if "score" in combined_results.column_names: + combined_results = combined_results.drop_columns(["score"]) + if "_distance" in combined_results.column_names: + combined_results = combined_results.drop_columns(["_distance"]) + return combined_results diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 4018d44c..c925f54f 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -88,7 +88,7 @@ class CohereReranker(Reranker): combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": raise NotImplementedError( "return_score='all' not implemented for cohere reranker" diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index 87a8e690..e09f2029 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -73,7 +73,7 @@ class ColbertReranker(Reranker): combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": raise NotImplementedError( "OpenAI Reranker does not support score='all' yet" diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index c88b354a..daf02f75 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -66,7 +66,7 @@ class CrossEncoderReranker(Reranker): combined_results = self._rerank(combined_results, query) # sort the results by _score if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": raise NotImplementedError( "return_score='all' not implemented for CrossEncoderReranker" diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 3d3f13c0..d8f22b02 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -92,7 +92,7 @@ class JinaReranker(Reranker): combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": raise NotImplementedError( "return_score='all' not implemented for JinaReranker" diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 3eb19b1d..983fa901 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -103,7 +103,7 @@ class LinearCombinationReranker(Reranker): [("_relevance_score", "descending")] ) if self.score == "relevance": - tbl = tbl.drop_columns(["score", "_distance"]) + tbl = self._keep_relevance_score(tbl) return tbl def _combine_score(self, score1, score2): diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 04d9f0d2..d24a4bcc 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -84,7 +84,7 @@ class OpenaiReranker(Reranker): combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": raise NotImplementedError( "OpenAI Reranker does not support score='all' yet" diff --git a/python/python/lancedb/rerankers/rrf.py b/python/python/lancedb/rerankers/rrf.py index a4d95a39..23ed1dc1 100644 --- a/python/python/lancedb/rerankers/rrf.py +++ b/python/python/lancedb/rerankers/rrf.py @@ -1,8 +1,12 @@ +from typing import Union, List, TYPE_CHECKING import pyarrow as pa from collections import defaultdict from .base import Reranker +if TYPE_CHECKING: + from ..table import LanceVectorQueryBuilder + class RRFReranker(Reranker): """ @@ -55,6 +59,46 @@ class RRFReranker(Reranker): ) if self.score == "relevance": - combined_results = combined_results.drop_columns(["score", "_distance"]) + combined_results = self._keep_relevance_score(combined_results) return combined_results + + def rerank_multivector( + self, + vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]], + query: str = None, + deduplicate: bool = True, # noqa: F821 # TODO: automatically deduplicates + ): + """ + Overridden method to rerank the results from multiple vector searches. + This leverages the RRF hybrid reranking algorithm to combine the + results from multiple vector searches as it doesn't support reranking + vector results individually. + """ + # Make sure all elements are of the same type + if not all(isinstance(v, type(vector_results[0])) for v in vector_results): + raise ValueError( + "All elements in vector_results should be of the same type" + ) + + # avoid circular import + if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder": + vector_results = [result.to_arrow() for result in vector_results] + elif not isinstance(vector_results[0], pa.Table): + raise ValueError( + "vector_results should be a list of pa.Table or LanceVectorQueryBuilder" + ) + + # _rowid is required for RRF reranking + if not all("_rowid" in result.column_names for result in vector_results): + raise ValueError( + "'_rowid' is required for deduplication. \ + add _rowid to search results like this: \ + `search().with_row_id(True)`" + ) + + combined = pa.concat_tables(vector_results, **self._concat_tables_args) + empty_table = pa.Table.from_arrays([], names=[]) + reranked = self.rerank_hybrid(query, combined, empty_table) + + return reranked diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index b7ffc10c..d2d90e42 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -1,4 +1,5 @@ import os +import random import lancedb import numpy as np @@ -25,10 +26,13 @@ def get_test_table(tmp_path): db = lancedb.connect(tmp_path) # Create a LanceDB table schema with a vector and a text column emb = EmbeddingFunctionRegistry.get_instance().get("test")() + meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")() class MyTable(LanceModel): text: str = emb.SourceField() vector: Vector(emb.ndims()) = emb.VectorField() + meta: str = meta_emb.SourceField() + meta_vector: Vector(meta_emb.ndims()) = meta_emb.VectorField() # Initialize the table using the schema table = LanceTable.create( @@ -77,7 +81,12 @@ def get_test_table(tmp_path): ] # Add the phrases and vectors to the table - table.add([{"text": p} for p in phrases]) + table.add( + [ + {"text": p, "meta": phrases[random.randint(0, len(phrases) - 1)]} + for p in phrases + ] + ) # Create a fts index table.create_fts_index("text") @@ -88,12 +97,12 @@ def get_test_table(tmp_path): def _run_test_reranker(reranker, table, query, query_vector, schema): # Hybrid search setting result1 = ( - table.search(query, query_type="hybrid") + table.search(query, query_type="hybrid", vector_column_name="vector") .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( - table.search(query, query_type="hybrid") + table.search(query, query_type="hybrid", vector_column_name="vector") .rerank(reranker=reranker) .to_pydantic(schema) ) @@ -101,7 +110,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search((query_vector, query), vector_column_name="vector") .limit(30) .rerank(reranker=reranker) .to_arrow() @@ -116,11 +125,16 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + result = ( + table.search(query, vector_column_name="vector") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) assert len(result) == 30 assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err result_explicit = ( - table.search(query_vector) + table.search(query_vector, vector_column_name="vector") .rerank(reranker=reranker, query_string=query) .limit(30) .to_arrow() @@ -129,11 +143,13 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): with pytest.raises( ValueError ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + table.search(query_vector, vector_column_name="vector").rerank( + reranker=reranker + ).limit(30).to_arrow() # FTS search setting result = ( - table.search(query, query_type="fts") + table.search(query, query_type="fts", vector_column_name="vector") .rerank(reranker=reranker) .limit(30) .to_arrow() @@ -141,22 +157,48 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): assert len(result) > 0 assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + # Multi-vector search setting + rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True) + rs2 = ( + table.search(query, vector_column_name="meta_vector") + .limit(10) + .with_row_id(True) + ) + result = reranker.rerank_multivector([rs1, rs2], query) + assert len(result) == 20 + result_deduped = reranker.rerank_multivector( + [rs1, rs2, rs1], query, deduplicate=True + ) + assert len(result_deduped) < 20 + result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query) + assert len(result) == 20 and result == result_arrow + def _run_test_hybrid_reranker(reranker, tmp_path): table, schema = get_test_table(tmp_path) # The default reranker result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") + table.search( + "Our father who art in heaven", + query_type="hybrid", + vector_column_name="vector", + ) .rerank(normalize="score") .to_pydantic(schema) ) result2 = ( # noqa - table.search("Our father who art in heaven.", query_type="hybrid") + table.search( + "Our father who art in heaven.", + query_type="hybrid", + vector_column_name="vector", + ) .rerank(normalize="rank") .to_pydantic(schema) ) result3 = table.search( - "Our father who art in heaven..", query_type="hybrid" + "Our father who art in heaven..", + query_type="hybrid", + vector_column_name="vector", ).to_pydantic(schema) assert result1 == result3 # 2 & 3 should be the same as they use score as score @@ -164,7 +206,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search((query_vector, query), vector_column_name="vector") .limit(30) .rerank(normalize="score") .to_arrow()