From 42fad84ec802a21f8bd82ce5bc071ae21b377b7f Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 19 Mar 2024 22:20:31 +0530 Subject: [PATCH] feat(python): Support reranking for vector and fts (#1103) solves https://github.com/lancedb/lancedb/issues/1086 Usage Reranking with FTS: ``` retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite") pylist = [{"text": "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274."}, {"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan."}, {"text": "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas."}, {"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. "}, {"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."}, {"text": "North Dakota is a state in the United States. 672,591 people lived in North Dakota in the year 2010. The capital and seat of government is Bismarck."}, ] retriever.add(pylist) retriever.create_fts_index("text", replace=True) query = "What is the capital of the United States?" reranker = CohereReranker(return_score="all") print(retriever.search(query, query_type="fts").limit(10).to_pandas()) print(retriever.search(query, query_type="fts").rerank(reranker=reranker).limit(10).to_pandas()) ``` Result ``` text vector score 0 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 0.729602 1 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 0.678046 2 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 0.671521 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 0.667898 4 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 0.653422 5 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 0.639346 text vector score _relevance_score 0 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 0.653422 0.979977 1 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 0.671521 0.299105 2 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 0.729602 0.284874 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 0.667898 0.089614 4 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 0.639346 0.063832 5 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 0.678046 0.041462 ``` ## Vector Search usage: ``` query = "What is the capital of the United States?" reranker = CohereReranker(return_score="all") print(retriever.search(query).limit(10).to_pandas()) print(retriever.search(query).rerank(reranker=reranker, query=query).limit(10).to_pandas()) # <-- Note: passing extra string query here ``` Results ``` text vector _distance 0 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 39.728973 1 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 41.384884 2 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 55.220200 3 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 58.345654 4 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 60.060867 5 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 64.260544 text vector _distance _relevance_score 0 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 41.384884 0.979977 1 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 60.060867 0.299105 2 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 39.728973 0.284874 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 55.220200 0.089614 4 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 64.260544 0.063832 5 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 58.345654 0.041462 ``` --- python/python/lancedb/query.py | 70 ++++++++- python/python/lancedb/rerankers/base.py | 56 +++++++ python/python/lancedb/rerankers/cohere.py | 46 ++++-- python/python/lancedb/rerankers/colbert.py | 53 +++++-- .../python/lancedb/rerankers/cross_encoder.py | 42 ++++- python/python/lancedb/rerankers/openai.py | 45 ++++-- python/python/tests/test_rerankers.py | 146 +++++++++++++++--- 7 files changed, 399 insertions(+), 59 deletions(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 2addcb24..8a7b231c 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -144,6 +144,9 @@ class LanceQueryBuilder(ABC): # hybrid fts and vector query return LanceHybridQueryBuilder(table, query, vector_column_name) + # remember the string query for reranking purpose + str_query = query if isinstance(query, str) else None + # convert "auto" query_type to "vector", "fts" # or "hybrid" and convert the query to vector if needed query, query_type = cls._resolve_query( @@ -164,7 +167,7 @@ class LanceQueryBuilder(ABC): else: raise TypeError(f"Unsupported query type: {type(query)}") - return LanceVectorQueryBuilder(table, query, vector_column_name) + return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) @classmethod def _resolve_query(cls, table, query, query_type, vector_column_name): @@ -428,6 +431,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): table: "Table", query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str, + str_query: Optional[str] = None, ): super().__init__(table) self._query = query @@ -436,6 +440,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._refine_factor = None self._vector_column = vector_column self._prefilter = False + self._reranker = None + self._str_query = str_query def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. @@ -521,7 +527,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector_column=self._vector_column, with_row_id=self._with_row_id, ) - return self._table._execute_query(query) + result_set = self._table._execute_query(query) + if self._reranker is not None: + result_set = self._reranker.rerank_vector(self._str_query, result_set) + + return result_set def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder: """Set the where clause. @@ -547,6 +557,42 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._prefilter = prefilter return self + def rerank( + self, reranker: Reranker, query_string: Optional[str] = None + ) -> LanceVectorQueryBuilder: + """Rerank the results using the specified reranker. + + Parameters + ---------- + reranker: Reranker + The reranker to use. + + query_string: Optional[str] + The query to use for reranking. This needs to be specified explicitly here + as the query used for vector search may already be vectorized and the + reranker requires a string query. + This is only required if the query used for vector search is not a string. + Note: This doesn't yet support the case where the query is multimodal or a + list of vectors. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._reranker = reranker + if self._str_query is None and query_string is None: + raise ValueError( + """ + The query used for vector search is not a string. + In this case, the reranker query needs to be specified explicitly. + """ + ) + if query_string is not None and not isinstance(query_string, str): + raise ValueError("Reranking currently only supports string queries") + self._str_query = query_string if query_string is not None else self._str_query + return self + class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" @@ -555,6 +601,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): super().__init__(table) self._query = query self._phrase_query = False + self._reranker = None def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: """Set whether to use phrase query. @@ -641,8 +688,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._with_row_id: output_tbl = output_tbl.append_column("_rowid", row_ids) + + if self._reranker is not None: + output_tbl = self._reranker.rerank_fts(self._query, output_tbl) return output_tbl + def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder: + """Rerank the results using the specified reranker. + + Parameters + ---------- + reranker: Reranker + The reranker to use. + + Returns + ------- + LanceFtsQueryBuilder + The LanceQueryBuilder object. + """ + self._reranker = reranker + return self + class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 96479dbd..d3881741 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -24,8 +24,59 @@ class Reranker(ABC): raise ValueError("score must be either 'relevance' or 'all'") self.score = return_score + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + """ + Rerank function receives the result from the vector search. + This isn't mandatory to implement + + Parameters + ---------- + query : str + The input query + vector_results : pa.Table + The results from the vector search + + Returns + ------- + pa.Table + The reranked results + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement rerank_vector" + ) + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + """ + Rerank function receives the result from the FTS search. + This isn't mandatory to implement + + Parameters + ---------- + query : str + The input query + fts_results : pa.Table + The results from the FTS search + + Returns + ------- + pa.Table + The reranked results + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement rerank_fts" + ) + @abstractmethod def rerank_hybrid( + self, query: str, vector_results: pa.Table, fts_results: pa.Table, @@ -43,6 +94,11 @@ class Reranker(ABC): The results from the vector search fts_results : pa.Table The results from the FTS search + + Returns + ------- + pa.Table + The reranked results """ pass diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 611da9f8..a1ccb060 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -49,14 +49,8 @@ class CohereReranker(Reranker): ) return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() results = self._client.rerank( query=query, documents=docs, @@ -66,12 +60,22 @@ class CohereReranker(Reranker): indices, scores = list( zip(*[(result.index, result.relevance_score) for result in results]) ) # tuples - combined_results = combined_results.take(list(indices)) + result_set = result_set.take(list(indices)) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -79,3 +83,25 @@ class CohereReranker(Reranker): "return_score='all' not implemented for cohere reranker" ) return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["score"]) + + return result_set diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index e3a0aa77..87a8e690 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -33,14 +33,8 @@ class ColbertReranker(Reranker): "torch" ) # import here for faster ops later - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() tokenizer, model = self._model @@ -59,14 +53,25 @@ class ColbertReranker(Reranker): scores.append(score.item()) # replace the self.column column with the docs - combined_results = combined_results.drop(self.column) - combined_results = combined_results.append_column( + result_set = result_set.drop(self.column) + result_set = result_set.append_column( self.column, pa.array(docs, type=pa.string()) ) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -80,6 +85,32 @@ class ColbertReranker(Reranker): return combined_results + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + result_set = result_set.sort_by([("_relevance_score", "descending")]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["score"]) + + result_set = result_set.sort_by([("_relevance_score", "descending")]) + + return result_set + @cached_property def _model(self): transformers = attempt_import_or_raise("transformers") diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index ea2ea099..5a066a13 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -46,6 +46,16 @@ class CrossEncoderReranker(Reranker): return cross_encoder + def _rerank(self, result_set: pa.Table, query: str): + passages = result_set[self.column].to_pylist() + cross_inp = [[query, passage] for passage in passages] + cross_scores = self.model.predict(cross_inp) + result_set = result_set.append_column( + "_relevance_score", pa.array(cross_scores, type=pa.float32()) + ) + + return result_set + def rerank_hybrid( self, query: str, @@ -53,13 +63,7 @@ class CrossEncoderReranker(Reranker): fts_results: pa.Table, ): combined_results = self.merge_results(vector_results, fts_results) - passages = combined_results[self.column].to_pylist() - cross_inp = [[query, passage] for passage in passages] - cross_scores = self.model.predict(cross_inp) - combined_results = combined_results.append_column( - "_relevance_score", pa.array(cross_scores, type=pa.float32()) - ) - + combined_results = self._rerank(combined_results, query) # sort the results by _score if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) @@ -72,3 +76,27 @@ class CrossEncoderReranker(Reranker): ) return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + return vector_results + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + return fts_results diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index ca21c9b7..04d9f0d2 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -39,14 +39,8 @@ class OpenaiReranker(Reranker): self.column = column self.api_key = api_key - def rerank_hybrid( - self, - query: str, - vector_results: pa.Table, - fts_results: pa.Table, - ): - combined_results = self.merge_results(vector_results, fts_results) - docs = combined_results[self.column].to_pylist() + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() response = self._client.chat.completions.create( model=self.model_name, response_format={"type": "json_object"}, @@ -70,14 +64,25 @@ class OpenaiReranker(Reranker): zip(*[(result["content"], result["relevance_score"]) for result in results]) ) # tuples # replace the self.column column with the docs - combined_results = combined_results.drop(self.column) - combined_results = combined_results.append_column( + result_set = result_set.drop(self.column) + result_set = result_set.append_column( self.column, pa.array(docs, type=pa.string()) ) # add the scores - combined_results = combined_results.append_column( + result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = combined_results.drop_columns(["score", "_distance"]) elif self.score == "all": @@ -91,6 +96,24 @@ class OpenaiReranker(Reranker): return combined_results + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) + if self.score == "relevance": + vector_results = vector_results.drop_columns(["_distance"]) + + vector_results = vector_results.sort_by([("_relevance_score", "descending")]) + + return vector_results + + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) + if self.score == "relevance": + fts_results = fts_results.drop_columns(["score"]) + + fts_results = fts_results.sort_by([("_relevance_score", "descending")]) + + return fts_results + @cached_property def _client(self): openai = attempt_import_or_raise( diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index a912c64b..6465c51a 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -124,8 +124,9 @@ def test_linear_combination(tmp_path): ) def test_cohere_reranker(tmp_path): pytest.importorskip("cohere") + reranker = CohereReranker() table, schema = get_test_table(tmp_path) - # The default reranker + # Hybrid search setting result1 = ( table.search("Our father who art in heaven", query_type="hybrid") .rerank(normalize="score", reranker=CohereReranker()) @@ -133,7 +134,7 @@ def test_cohere_reranker(tmp_path): ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=CohereReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 @@ -143,64 +144,120 @@ def test_cohere_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=CohereReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + query = "Our father who art in heaven" + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err def test_cross_encoder_reranker(tmp_path): pytest.importorskip("sentence_transformers") + reranker = CrossEncoderReranker() table, schema = get_test_table(tmp_path) result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=CrossEncoderReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=CrossEncoderReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 - # test explicit hybrid query query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( table.search((query_vector, query), query_type="hybrid") .limit(30) - .rerank(reranker=CrossEncoderReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err def test_colbert_reranker(tmp_path): pytest.importorskip("transformers") + reranker = ColbertReranker() table, schema = get_test_table(tmp_path) result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=ColbertReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=ColbertReranker()) + .rerank(reranker=reranker) .to_pydantic(schema) ) assert result1 == result2 @@ -211,17 +268,43 @@ def test_colbert_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=ColbertReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err @pytest.mark.skipif( @@ -230,9 +313,10 @@ def test_colbert_reranker(tmp_path): def test_openai_reranker(tmp_path): pytest.importorskip("openai") table, schema = get_test_table(tmp_path) + reranker = OpenaiReranker() result1 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=OpenaiReranker()) + .rerank(normalize="score", reranker=reranker) .to_pydantic(schema) ) result2 = ( @@ -248,14 +332,40 @@ def test_openai_reranker(tmp_path): result = ( table.search((query_vector, query)) .limit(30) - .rerank(reranker=OpenaiReranker()) + .rerank(reranker=reranker) .to_arrow() ) assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + err = ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err