diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 2addcb24..8a7b231c 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -144,6 +144,9 @@ class LanceQueryBuilder(ABC): # hybrid fts and vector query return LanceHybridQueryBuilder(table, query, vector_column_name) + # remember the string query for reranking purpose + str_query = query if isinstance(query, str) else None + # convert "auto" query_type to "vector", "fts" # or "hybrid" and convert the query to vector if needed query, query_type = cls._resolve_query( @@ -164,7 +167,7 @@ class LanceQueryBuilder(ABC): else: raise TypeError(f"Unsupported query type: {type(query)}") - return LanceVectorQueryBuilder(table, query, vector_column_name) + return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) @classmethod def _resolve_query(cls, table, query, query_type, vector_column_name): @@ -428,6 +431,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): table: "Table", query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str, + str_query: Optional[str] = None, ): super().__init__(table) self._query = query @@ -436,6 +440,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._refine_factor = None self._vector_column = vector_column self._prefilter = False + self._reranker = None + self._str_query = str_query def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. @@ -521,7 +527,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector_column=self._vector_column, with_row_id=self._with_row_id, ) - return self._table._execute_query(query) + result_set = self._table._execute_query(query) + if self._reranker is not None: + result_set = self._reranker.rerank_vector(self._str_query, result_set) + + return result_set def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder: """Set the where clause. @@ -547,6 +557,42 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._prefilter = prefilter return self + def rerank( + self, reranker: Reranker, query_string: Optional[str] = None + ) -> LanceVectorQueryBuilder: + """Rerank the results using the specified reranker. + + Parameters + ---------- + reranker: Reranker + The reranker to use. + + query_string: Optional[str] + The query to use for reranking. This needs to be specified explicitly here + as the query used for vector search may already be vectorized and the + reranker requires a string query. + This is only required if the query used for vector search is not a string. + Note: This doesn't yet support the case where the query is multimodal or a + list of vectors. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._reranker = reranker + if self._str_query is None and query_string is None: + raise ValueError( + """ + The query used for vector search is not a string. + In this case, the reranker query needs to be specified explicitly. + """ + ) + if query_string is not None and not isinstance(query_string, str): + raise ValueError("Reranking currently only supports string queries") + self._str_query = query_string if query_string is not None else self._str_query + return self + class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" @@ -555,6 +601,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): super().__init__(table) self._query = query self._phrase_query = False + self._reranker = None def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: """Set whether to use phrase query. @@ -641,8 +688,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._with_row_id: output_tbl = output_tbl.append_column("_rowid", row_ids) + + if self._reranker is not None: + output_tbl = self._reranker.rerank_fts(self._query, output_tbl) return output_tbl + def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder: + """Rerank the results using the specified reranker. + + Parameters + ---------- + reranker: Reranker + The reranker to use. + + Returns + ------- + LanceFtsQueryBuilder + The LanceQueryBuilder object. + """ + self._reranker = reranker + return self + class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 96479dbd..d3881741 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -24,8 +24,59 @@ class Reranker(ABC): raise ValueError("score must be either 'relevance' or 'all'") self.score = return_score + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + """ + Rerank function receives the result from the vector search. + This isn't mandatory to implement + + Parameters + ---------- + query : str + The input query + vector_results : pa.Table + The results from the vector search + + Returns + ------- + pa.Table + The reranked results + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement rerank_vector" + ) + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + """ + Rerank function receives the result from the FTS search. + This isn't mandatory to implement + + Parameters + ---------- + query : str + The input query + fts_results : pa.Table + The results from the FTS search + + Returns + ------- + pa.Table + The reranked results + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement rerank_fts" + ) + @abstractmethod def rerank_hybrid( + self, query: str, vector_results: pa.Table, fts_results: pa.Table, @@ -43,6 +94,11 @@ class Reranker(ABC): The results from the vector search fts_results : pa.Table The results from the FTS search + + Returns + ------- + pa.Table + The reranked results """ pass diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 611da9f8..a1ccb060 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -49,14 +49,8 @@ class CohereReranker(Reranker): ) return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() results = self._client.rerank( query=query, documents=docs, @@ -66,12 +60,22 @@ class CohereReranker(Reranker): indices, scores = list( zip(*[(result.index, result.relevance_score) for result in results]) ) # tuples - combined_results = combined_results.take(list(indices)) + result_set = result_set.take(list(indices)) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -79,3 +83,25 @@ class CohereReranker(Reranker): "return_score='all' not implemented for cohere reranker" ) return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["score"]) + + return result_set diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index e3a0aa77..87a8e690 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -33,14 +33,8 @@ class ColbertReranker(Reranker): "torch" ) # import here for faster ops later - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() tokenizer, model = self._model @@ -59,14 +53,25 @@ class ColbertReranker(Reranker): scores.append(score.item()) # replace the self.column column with the docs - combined_results = combined_results.drop(self.column) - combined_results = combined_results.append_column( + result_set = result_set.drop(self.column) + result_set = result_set.append_column( self.column, pa.array(docs, type=pa.string()) ) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -80,6 +85,32 @@ class ColbertReranker(Reranker): return combined_results + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + result_set = result_set.sort_by([("_relevance_score", "descending")]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["score"]) + + result_set = result_set.sort_by([("_relevance_score", "descending")]) + + return result_set + @cached_property def _model(self): transformers = attempt_import_or_raise("transformers") diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index ea2ea099..5a066a13 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -46,6 +46,16 @@ class CrossEncoderReranker(Reranker): return cross_encoder + def _rerank(self, result_set: pa.Table, query: str): + passages = result_set[self.column].to_pylist() + cross_inp = [[query, passage] for passage in passages] + cross_scores = self.model.predict(cross_inp) + result_set = result_set.append_column( + "_relevance_score", pa.array(cross_scores, type=pa.float32()) + ) + + return result_set + def rerank_hybrid( self, query: str, @@ -53,13 +63,7 @@ class CrossEncoderReranker(Reranker): fts_results: pa.Table, ): combined_results = self.merge_results(vector_results, fts_results) - passages = combined_results[self.column].to_pylist() - cross_inp = [[query, passage] for passage in passages] - cross_scores = self.model.predict(cross_inp) - combined_results = combined_results.append_column( - "_relevance_score", pa.array(cross_scores, type=pa.float32()) - ) - + combined_results = self._rerank(combined_results, query) # sort the results by _score if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) @@ -72,3 +76,27 @@ class CrossEncoderReranker(Reranker): ) return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + return vector_results + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + return fts_results diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index ca21c9b7..04d9f0d2 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -39,14 +39,8 @@ class OpenaiReranker(Reranker): self.column = column self.api_key = api_key - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() response = self._client.chat.completions.create( model=self.model_name, response_format={"type": "json_object"}, @@ -70,14 +64,25 @@ class OpenaiReranker(Reranker): zip(*[(result["content"], result["relevance_score"]) for result in results]) ) # tuples # replace the self.column column with the docs - combined_results = combined_results.drop(self.column) - combined_results = combined_results.append_column( + result_set = result_set.drop(self.column) + result_set = result_set.append_column( self.column, pa.array(docs, type=pa.string()) ) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -91,6 +96,24 @@ class OpenaiReranker(Reranker): return combined_results + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + + return vector_results + + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + + return fts_results + @cached_property def _client(self): openai = attempt_import_or_raise( diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index a912c64b..6465c51a 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -124,8 +124,9 @@ def test_linear_combination(tmp_path): ) def test_cohere_reranker(tmp_path): pytest.importorskip("cohere") + reranker = CohereReranker() table, schema = get_test_table(tmp_path) - # The default reranker + # Hybrid search setting result1 = ( table.search("Our father who art in heaven", query_type="hybrid") .rerank(normalize="score", reranker=CohereReranker()) @@ -133,7 +134,7 @@ def test_cohere_reranker(tmp_path): ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=CohereReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 @@ -143,64 +144,120 @@ def test_cohere_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=CohereReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + query = "Our father who art in heaven" + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err def test_cross_encoder_reranker(tmp_path): pytest.importorskip("sentence_transformers") + reranker = CrossEncoderReranker() table, schema = get_test_table(tmp_path) result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=CrossEncoderReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=CrossEncoderReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 - # test explicit hybrid query query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( table.search((query_vector, query), query_type="hybrid") .limit(30) - .rerank(reranker=CrossEncoderReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err def test_colbert_reranker(tmp_path): pytest.importorskip("transformers") + reranker = ColbertReranker() table, schema = get_test_table(tmp_path) result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=ColbertReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=ColbertReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 @@ -211,17 +268,43 @@ def test_colbert_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=ColbertReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err @pytest.mark.skipif( @@ -230,9 +313,10 @@ def test_colbert_reranker(tmp_path): def test_openai_reranker(tmp_path): pytest.importorskip("openai") table, schema = get_test_table(tmp_path) + reranker = OpenaiReranker() result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=OpenaiReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( @@ -248,14 +332,40 @@ def test_openai_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=OpenaiReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err