feat(python): Support reranking for vector and fts (#1103)

solves https://github.com/lancedb/lancedb/issues/1086

Usage Reranking with FTS:
```
retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite")
pylist = [{"text": "Carson City is the capital city of the American state of Nevada. At the  2010 United States Census, Carson City had a population of 55,274."},
          {"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan."},
        {"text": "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas."},
        {"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. "},
        {"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."},
        {"text": "North Dakota is a state in the United States. 672,591 people lived in North Dakota in the year 2010. The capital and seat of government is Bismarck."},
        ]
retriever.add(pylist)
retriever.create_fts_index("text", replace=True)

query = "What is the capital of the United States?"
reranker = CohereReranker(return_score="all")
print(retriever.search(query, query_type="fts").limit(10).to_pandas())
print(retriever.search(query, query_type="fts").rerank(reranker=reranker).limit(10).to_pandas())
```
Result
```
                                                text                                             vector     score
0  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  0.729602
1  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  0.678046
2  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  0.671521
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  0.667898
4  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  0.653422
5  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  0.639346
                                                text                                             vector     score  _relevance_score
0  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  0.653422          0.979977
1  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  0.671521          0.299105
2  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  0.729602          0.284874
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  0.667898          0.089614
4  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  0.639346          0.063832
5  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  0.678046          0.041462
```

## Vector Search usage:
```
query = "What is the capital of the United States?"
reranker = CohereReranker(return_score="all")
print(retriever.search(query).limit(10).to_pandas())
print(retriever.search(query).rerank(reranker=reranker, query=query).limit(10).to_pandas()) # <-- Note: passing extra string query here
```

Results
```
                                                text                                             vector  _distance
0  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  39.728973
1  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  41.384884
2  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  55.220200
3  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  58.345654
4  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  60.060867
5  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  64.260544
                                                text                                             vector  _distance  _relevance_score
0  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  41.384884          0.979977
1  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  60.060867          0.299105
2  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  39.728973          0.284874
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  55.220200          0.089614
4  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  64.260544          0.063832
5  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  58.345654          0.041462
```
This commit is contained in:
Ayush Chaurasia
2024-03-19 22:20:31 +05:30
committed by Weston Pace
parent b36c750cc7
commit 42fad84ec8
7 changed files with 399 additions and 59 deletions

View File

@@ -144,6 +144,9 @@ class LanceQueryBuilder(ABC):
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
# remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
query, query_type = cls._resolve_query(
@@ -164,7 +167,7 @@ class LanceQueryBuilder(ABC):
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
@@ -428,6 +431,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str,
str_query: Optional[str] = None,
):
super().__init__(table)
self._query = query
@@ -436,6 +440,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
self._reranker = None
self._str_query = str_query
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
@@ -521,7 +527,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
vector_column=self._vector_column,
with_row_id=self._with_row_id,
)
return self._table._execute_query(query)
result_set = self._table._execute_query(query)
if self._reranker is not None:
result_set = self._reranker.rerank_vector(self._str_query, result_set)
return result_set
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
"""Set the where clause.
@@ -547,6 +557,42 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._prefilter = prefilter
return self
def rerank(
self, reranker: Reranker, query_string: Optional[str] = None
) -> LanceVectorQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
query_string: Optional[str]
The query to use for reranking. This needs to be specified explicitly here
as the query used for vector search may already be vectorized and the
reranker requires a string query.
This is only required if the query used for vector search is not a string.
Note: This doesn't yet support the case where the query is multimodal or a
list of vectors.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._reranker = reranker
if self._str_query is None and query_string is None:
raise ValueError(
"""
The query used for vector search is not a string.
In this case, the reranker query needs to be specified explicitly.
"""
)
if query_string is not None and not isinstance(query_string, str):
raise ValueError("Reranking currently only supports string queries")
self._str_query = query_string if query_string is not None else self._str_query
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
@@ -555,6 +601,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
super().__init__(table)
self._query = query
self._phrase_query = False
self._reranker = None
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
"""Set whether to use phrase query.
@@ -641,8 +688,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
if self._with_row_id:
output_tbl = output_tbl.append_column("_rowid", row_ids)
if self._reranker is not None:
output_tbl = self._reranker.rerank_fts(self._query, output_tbl)
return output_tbl
def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
LanceFtsQueryBuilder
The LanceQueryBuilder object.
"""
self._reranker = reranker
return self
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:

View File

@@ -24,8 +24,59 @@ class Reranker(ABC):
raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
"""
Rerank function receives the result from the vector search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
vector_results : pa.Table
The results from the vector search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_vector"
)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
"""
Rerank function receives the result from the FTS search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
fts_results : pa.Table
The results from the FTS search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_fts"
)
@abstractmethod
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
@@ -43,6 +94,11 @@ class Reranker(ABC):
The results from the vector search
fts_results : pa.Table
The results from the FTS search
Returns
-------
pa.Table
The reranked results
"""
pass

View File

@@ -49,14 +49,8 @@ class CohereReranker(Reranker):
)
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
results = self._client.rerank(
query=query,
documents=docs,
@@ -66,12 +60,22 @@ class CohereReranker(Reranker):
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])
) # tuples
combined_results = combined_results.take(list(indices))
result_set = result_set.take(list(indices))
# add the scores
combined_results = combined_results.append_column(
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
@@ -79,3 +83,25 @@ class CohereReranker(Reranker):
"return_score='all' not implemented for cohere reranker"
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
return result_set

View File

@@ -33,14 +33,8 @@ class ColbertReranker(Reranker):
"torch"
) # import here for faster ops later
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
tokenizer, model = self._model
@@ -59,14 +53,25 @@ class ColbertReranker(Reranker):
scores.append(score.item())
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
result_set = result_set.drop(self.column)
result_set = result_set.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
@@ -80,6 +85,32 @@ class ColbertReranker(Reranker):
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
@cached_property
def _model(self):
transformers = attempt_import_or_raise("transformers")

View File

@@ -46,6 +46,16 @@ class CrossEncoderReranker(Reranker):
return cross_encoder
def _rerank(self, result_set: pa.Table, query: str):
passages = result_set[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
result_set = result_set.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
@@ -53,13 +63,7 @@ class CrossEncoderReranker(Reranker):
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
combined_results = self._rerank(combined_results, query)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
@@ -72,3 +76,27 @@ class CrossEncoderReranker(Reranker):
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -39,14 +39,8 @@ class OpenaiReranker(Reranker):
self.column = column
self.api_key = api_key
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
response = self._client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
@@ -70,14 +64,25 @@ class OpenaiReranker(Reranker):
zip(*[(result["content"], result["relevance_score"]) for result in results])
) # tuples
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
result_set = result_set.drop(self.column)
result_set = result_set.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
@@ -91,6 +96,24 @@ class OpenaiReranker(Reranker):
return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results
@cached_property
def _client(self):
openai = attempt_import_or_raise(

View File

@@ -124,8 +124,9 @@ def test_linear_combination(tmp_path):
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
# The default reranker
# Hybrid search setting
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
@@ -133,7 +134,7 @@ def test_cohere_reranker(tmp_path):
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CohereReranker())
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
@@ -143,64 +144,120 @@ def test_cohere_reranker(tmp_path):
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=CohereReranker())
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
query = "Our father who art in heaven"
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CrossEncoderReranker())
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CrossEncoderReranker())
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=ColbertReranker())
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=ColbertReranker())
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
@@ -211,17 +268,43 @@ def test_colbert_reranker(tmp_path):
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=ColbertReranker())
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
@pytest.mark.skipif(
@@ -230,9 +313,10 @@ def test_colbert_reranker(tmp_path):
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
reranker = OpenaiReranker()
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=OpenaiReranker())
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
@@ -248,14 +332,40 @@ def test_openai_reranker(tmp_path):
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=OpenaiReranker())
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err