mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
feat(python): Support reranking for vector and fts (#1103)
solves https://github.com/lancedb/lancedb/issues/1086 Usage Reranking with FTS: ``` retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite") pylist = [{"text": "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274."}, {"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan."}, {"text": "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas."}, {"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. "}, {"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."}, {"text": "North Dakota is a state in the United States. 672,591 people lived in North Dakota in the year 2010. The capital and seat of government is Bismarck."}, ] retriever.add(pylist) retriever.create_fts_index("text", replace=True) query = "What is the capital of the United States?" reranker = CohereReranker(return_score="all") print(retriever.search(query, query_type="fts").limit(10).to_pandas()) print(retriever.search(query, query_type="fts").rerank(reranker=reranker).limit(10).to_pandas()) ``` Result ``` text vector score 0 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 0.729602 1 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 0.678046 2 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 0.671521 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 0.667898 4 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 0.653422 5 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 0.639346 text vector score _relevance_score 0 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 0.653422 0.979977 1 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 0.671521 0.299105 2 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 0.729602 0.284874 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 0.667898 0.089614 4 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 0.639346 0.063832 5 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 0.678046 0.041462 ``` ## Vector Search usage: ``` query = "What is the capital of the United States?" reranker = CohereReranker(return_score="all") print(retriever.search(query).limit(10).to_pandas()) print(retriever.search(query).rerank(reranker=reranker, query=query).limit(10).to_pandas()) # <-- Note: passing extra string query here ``` Results ``` text vector _distance 0 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 39.728973 1 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 41.384884 2 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 55.220200 3 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 58.345654 4 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 60.060867 5 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 64.260544 text vector _distance _relevance_score 0 Washington, D.C. (also known as simply Washing... [-0.0090408325, 0.42578125, 0.3798828, -0.3574... 41.384884 0.979977 1 The Commonwealth of the Northern Mariana Islan... [0.3684082, 0.30493164, 0.004600525, -0.049407... 60.060867 0.299105 2 Capital punishment (the death penalty) has exi... [0.099975586, 0.047943115, -0.16723633, -0.183... 39.728973 0.284874 3 Carson City is the capital city of the America... [0.13989258, 0.14990234, 0.14172363, 0.0546569... 55.220200 0.089614 4 North Dakota is a state in the United States. ... [0.55859375, -0.2109375, 0.14526367, 0.1634521... 64.260544 0.063832 5 Charlotte Amalie is the capital and largest ci... [-0.021255493, 0.03363037, -0.027450562, -0.17... 58.345654 0.041462 ```
This commit is contained in:
committed by
Weston Pace
parent
b36c750cc7
commit
42fad84ec8
@@ -144,6 +144,9 @@ class LanceQueryBuilder(ABC):
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
|
||||
# convert "auto" query_type to "vector", "fts"
|
||||
# or "hybrid" and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
@@ -164,7 +167,7 @@ class LanceQueryBuilder(ABC):
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
@@ -428,6 +431,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
@@ -436,6 +440,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
self._reranker = None
|
||||
self._str_query = str_query
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
@@ -521,7 +527,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
result_set = self._table._execute_query(query)
|
||||
if self._reranker is not None:
|
||||
result_set = self._reranker.rerank_vector(self._str_query, result_set)
|
||||
|
||||
return result_set
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
@@ -547,6 +557,42 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
def rerank(
|
||||
self, reranker: Reranker, query_string: Optional[str] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
query_string: Optional[str]
|
||||
The query to use for reranking. This needs to be specified explicitly here
|
||||
as the query used for vector search may already be vectorized and the
|
||||
reranker requires a string query.
|
||||
This is only required if the query used for vector search is not a string.
|
||||
Note: This doesn't yet support the case where the query is multimodal or a
|
||||
list of vectors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._reranker = reranker
|
||||
if self._str_query is None and query_string is None:
|
||||
raise ValueError(
|
||||
"""
|
||||
The query used for vector search is not a string.
|
||||
In this case, the reranker query needs to be specified explicitly.
|
||||
"""
|
||||
)
|
||||
if query_string is not None and not isinstance(query_string, str):
|
||||
raise ValueError("Reranking currently only supports string queries")
|
||||
self._str_query = query_string if query_string is not None else self._str_query
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
@@ -555,6 +601,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
self._reranker = None
|
||||
|
||||
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
||||
"""Set whether to use phrase query.
|
||||
@@ -641,8 +688,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
if self._with_row_id:
|
||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||
|
||||
if self._reranker is not None:
|
||||
output_tbl = self._reranker.rerank_fts(self._query, output_tbl)
|
||||
return output_tbl
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceFtsQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._reranker = reranker
|
||||
return self
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
|
||||
@@ -24,8 +24,59 @@ class Reranker(ABC):
|
||||
raise ValueError("score must be either 'relevance' or 'all'")
|
||||
self.score = return_score
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
"""
|
||||
Rerank function receives the result from the vector search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The input query
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The reranked results
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement rerank_vector"
|
||||
)
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
"""
|
||||
Rerank function receives the result from the FTS search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The input query
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The reranked results
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement rerank_fts"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
@@ -43,6 +94,11 @@ class Reranker(ABC):
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The reranked results
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -49,14 +49,8 @@ class CohereReranker(Reranker):
|
||||
)
|
||||
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
@@ -66,12 +60,22 @@ class CohereReranker(Reranker):
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
) # tuples
|
||||
combined_results = combined_results.take(list(indices))
|
||||
result_set = result_set.take(list(indices))
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
@@ -79,3 +83,25 @@ class CohereReranker(Reranker):
|
||||
"return_score='all' not implemented for cohere reranker"
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["score"])
|
||||
|
||||
return result_set
|
||||
|
||||
@@ -33,14 +33,8 @@ class ColbertReranker(Reranker):
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
@@ -59,14 +53,25 @@ class ColbertReranker(Reranker):
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
result_set = result_set.drop(self.column)
|
||||
result_set = result_set.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
@@ -80,6 +85,32 @@ class ColbertReranker(Reranker):
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["score"])
|
||||
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
|
||||
@@ -46,6 +46,16 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
return cross_encoder
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
passages = result_set[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
@@ -53,13 +63,7 @@ class CrossEncoderReranker(Reranker):
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
@@ -72,3 +76,27 @@ class CrossEncoderReranker(Reranker):
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
return fts_results
|
||||
|
||||
@@ -39,14 +39,8 @@ class OpenaiReranker(Reranker):
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
response_format={"type": "json_object"},
|
||||
@@ -70,14 +64,25 @@ class OpenaiReranker(Reranker):
|
||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
result_set = result_set.drop(self.column)
|
||||
result_set = result_set.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
@@ -91,6 +96,24 @@ class OpenaiReranker(Reranker):
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return fts_results
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = attempt_import_or_raise(
|
||||
|
||||
@@ -124,8 +124,9 @@ def test_linear_combination(tmp_path):
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
# Hybrid search setting
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
@@ -133,7 +134,7 @@ def test_cohere_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=CohereReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
@@ -143,64 +144,120 @@ def test_cohere_reranker(tmp_path):
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
query = "Our father who art in heaven"
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
@@ -211,17 +268,43 @@ def test_colbert_reranker(tmp_path):
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -230,9 +313,10 @@ def test_colbert_reranker(tmp_path):
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
reranker = OpenaiReranker()
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
@@ -248,14 +332,40 @@ def test_openai_reranker(tmp_path):
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
Reference in New Issue
Block a user