diff --git a/docs/src/reranking/linear_combination.md b/docs/src/reranking/linear_combination.md index 4a27907c..de9bfec4 100644 --- a/docs/src/reranking/linear_combination.md +++ b/docs/src/reranking/linear_combination.md @@ -1,6 +1,9 @@ # Linear Combination Reranker -This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search. +!!! note + This is depricated. It is recommended to use the `RRFReranker` instead, if you want to use a score based reranker. + +It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search. !!! note Supported Query Types: Hybrid diff --git a/docs/src/reranking/rrf.md b/docs/src/reranking/rrf.md index 972c2443..dce7fc2a 100644 --- a/docs/src/reranking/rrf.md +++ b/docs/src/reranking/rrf.md @@ -1,6 +1,6 @@ # Reciprocal Rank Fusion Reranker -Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf). +This is the default re-ranker used by LanceDB hybrid search. Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf). !!! note diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 13b0460c..48b34860 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -36,6 +36,7 @@ from . import __version__ from .arrow import AsyncRecordBatchReader from .rerankers.base import Reranker from .rerankers.rrf import RRFReranker +from .rerankers.util import check_reranker_result from .util import safe_import_pandas if TYPE_CHECKING: @@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): if self._reranker is not None: rs_table = result_set.read_all() result_set = self._reranker.rerank_vector(self._str_query, rs_table) + check_reranker_result(result_set) # convert result_set back to RecordBatchReader result_set = pa.RecordBatchReader.from_batches( result_set.schema, result_set.to_batches() @@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): results = results.read_all() if self._reranker is not None: results = self._reranker.rerank_fts(self._query, results) + check_reranker_result(results) return results def tantivy_to_arrow(self) -> pa.Table: @@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def __init__( self, table: "Table", - query: str = None, - vector_column: str = None, + query: Optional[str] = None, + vector_column: Optional[str] = None, fts_columns: Union[str, List[str]] = [], ): super().__init__(table) @@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._fts_query._query, vector_results, fts_results ) - if not isinstance(results, pa.Table): # Enforce type - raise TypeError( - f"rerank_hybrid must return a pyarrow.Table, got {type(results)}" - ) + check_reranker_result(results) # apply limit after reranking results = results.slice(length=self._limit) @@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def rerank( self, - normalize="score", reranker: Reranker = RRFReranker(), + normalize: str = "score", ) -> LanceHybridQueryBuilder: """ Rerank the hybrid search results using the specified reranker. The reranker @@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): Parameters ---------- + reranker: Reranker, default RRFReranker() + The reranker to use. Must be an instance of Reranker class. normalize: str, default "score" The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly. - reranker: Reranker, default RRFReranker() - The reranker to use. Must be an instance of Reranker class. Returns ------- LanceHybridQueryBuilder diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 65ed43e7..7f90d40a 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -105,7 +105,7 @@ class Reranker(ABC): query: str, vector_results: pa.Table, fts_results: pa.Table, - ): + ) -> pa.Table: """ Rerank function receives the individual results from the vector and FTS search results. You can choose to use any of the results to generate the final results, diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 6ab18427..8bcfb5e3 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from numpy import NaN import pyarrow as pa from .base import Reranker @@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker): def merge_results( self, vector_results: pa.Table, fts_results: pa.Table, fill: float ): - # If both are empty then just return an empty table - if len(vector_results) == 0 and len(fts_results) == 0: - return vector_results - # If one is empty then return the other + # If one is empty then return the other and add _relevance_score + # column equal the existing vector or fts score if len(vector_results) == 0: - return fts_results + results = fts_results.append_column( + "_relevance_score", + pa.array(fts_results["_score"], type=pa.float32()), + ) + if self.score == "relevance": + results = self._keep_relevance_score(results) + elif self.score == "all": + results = results.append_column( + "_distance", + pa.array([NaN] * len(fts_results), type=pa.float32()), + ) + return results + if len(fts_results) == 0: - return vector_results + # invert the distance to relevance score + results = vector_results.append_column( + "_relevance_score", + pa.array( + [ + self._invert_score(distance) + for distance in vector_results["_distance"].to_pylist() + ], + type=pa.float32(), + ), + ) + if self.score == "relevance": + results = self._keep_relevance_score(results) + elif self.score == "all": + results = results.append_column( + "_score", + pa.array([NaN] * len(vector_results), type=pa.float32()), + ) + return results # sort both input tables on _rowid combined_list = [] diff --git a/python/python/lancedb/rerankers/util.py b/python/python/lancedb/rerankers/util.py new file mode 100644 index 00000000..0edcfc20 --- /dev/null +++ b/python/python/lancedb/rerankers/util.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import pyarrow as pa + + +def check_reranker_result(result): + if not isinstance(result, pa.Table): # Enforce type + raise TypeError( + f"rerank_hybrid must return a pyarrow.Table, got {type(result)}" + ) + + # Enforce that `_relevance_score` column is present in the result of every + # rerank_hybrid method + if "_relevance_score" not in result.column_names: + raise ValueError( + "rerank_hybrid must return a pyarrow.Table with a column" + "named `_relevance_score`" + ) diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 23132066..f2f7c6cc 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): ) assert len(result) == 30 - err = ( + ascending_relevance_err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err # Vector search setting result = ( @@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): .to_arrow() ) assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err result_explicit = ( table.search(query_vector, vector_column_name="vector") .rerank(reranker=reranker, query_string=query) @@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): .to_arrow() ) assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err + + # empty FTS results + query = "abcxyz" * 100 + result = ( + table.search(query_type="hybrid", vector_column_name="vector") + .vector(query_vector) + .text(query) + .limit(30) + .rerank(reranker=reranker) + .to_arrow() + ) + + # should return _relevance_score column + assert "_relevance_score" in result.column_names + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err # Multi-vector search setting rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True) @@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): result_deduped = reranker.rerank_multivector( [rs1, rs2, rs1], query, deduplicate=True ) - assert len(result_deduped) < 20 + assert len(result_deduped) <= 20 result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query) assert len(result) == 20 and result == result_arrow @@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): .vector(query_vector) .text(query) .limit(30) - .rerank(normalize="score") + .rerank(reranker, normalize="score") .to_arrow() ) assert len(result) == 30 @@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): table.search(query, query_type="hybrid", vector_column_name="vector").text( query ).to_arrow() - - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err + + # Test with empty FTS results + query = "abcxyz" * 100 + result = ( + table.search(query_type="hybrid", vector_column_name="vector") + .vector(query_vector) + .text(query) + .limit(30) + .rerank(reranker=reranker) + .to_arrow() + ) + # should return _relevance_score column + assert "_relevance_score" in result.column_names + assert np.all( + np.diff(result.column("_relevance_score").to_numpy()) <= 0 + ), ascending_relevance_err @pytest.mark.parametrize("use_tantivy", [True, False])