mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat(python): Reranker DX improvements (#904)
- Most users might not know how to use `QueryBuilder` object. Instead we should just pass the string query. - Add new rerankers: Colbert, openai
This commit is contained in:
committed by
Weston Pace
parent
39cc2fd62b
commit
d07817a562
9
.github/workflows/docs_test.yml
vendored
9
.github/workflows/docs_test.yml
vendored
@@ -49,6 +49,9 @@ jobs:
|
|||||||
test-node:
|
test-node:
|
||||||
name: Test doc nodejs code
|
name: Test doc nodejs code
|
||||||
runs-on: "ubuntu-latest"
|
runs-on: "ubuntu-latest"
|
||||||
|
timeout-minutes: 45
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -66,6 +69,12 @@ jobs:
|
|||||||
uses: swatinem/rust-cache@v2
|
uses: swatinem/rust-cache@v2
|
||||||
- name: Install node dependencies
|
- name: Install node dependencies
|
||||||
run: |
|
run: |
|
||||||
|
sudo swapoff -a
|
||||||
|
sudo fallocate -l 8G /swapfile
|
||||||
|
sudo chmod 600 /swapfile
|
||||||
|
sudo mkswap /swapfile
|
||||||
|
sudo swapon /swapfile
|
||||||
|
sudo swapon --show
|
||||||
cd node
|
cd node
|
||||||
npm ci
|
npm ci
|
||||||
npm run build-release
|
npm run build-release
|
||||||
|
|||||||
@@ -130,6 +130,60 @@ Arguments
|
|||||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
|
||||||
|
### ColBERT Reranker
|
||||||
|
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import ColbertReranker
|
||||||
|
|
||||||
|
reranker = ColbertReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
* `column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `return_score` : `str`, default `"relevance"`
|
||||||
|
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
### OpenAI Reranker
|
||||||
|
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||||
|
|
||||||
|
!!! Tip
|
||||||
|
You might run out of token limit so set the search `limits` based on your token limit.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import OpenaiReranker
|
||||||
|
|
||||||
|
reranker = OpenaiReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
`column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
`return_score` : `str`, default `"relevance"`
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
`api_key` : `str`, default `None`
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
|
||||||
|
|
||||||
## Building Custom Rerankers
|
## Building Custom Rerankers
|
||||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||||
|
|
||||||
@@ -146,7 +200,7 @@ class MyReranker(Reranker):
|
|||||||
self.param1 = param1
|
self.param1 = param1
|
||||||
self.param2 = param2
|
self.param2 = param2
|
||||||
|
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
# Use the built-in merging function
|
# Use the built-in merging function
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
@@ -168,7 +222,7 @@ import pyarrow as pa
|
|||||||
class MyReranker(Reranker):
|
class MyReranker(Reranker):
|
||||||
...
|
...
|
||||||
|
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||||
# Use the built-in merging function
|
# Use the built-in merging function
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
|
|||||||
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._validate_fts_index()
|
self._validate_fts_index()
|
||||||
self._query = query
|
|
||||||
vector_query, fts_query = self._validate_query(query)
|
vector_query, fts_query = self._validate_query(query)
|
||||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||||
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
# rerankers might need to preserve this score to support `return_score="all"`
|
# rerankers might need to preserve this score to support `return_score="all"`
|
||||||
fts_results = self._normalize_scores(fts_results, "score")
|
fts_results = self._normalize_scores(fts_results, "score")
|
||||||
|
|
||||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
results = self._reranker.rerank_hybrid(
|
||||||
|
self._fts_query._query, vector_results, fts_results
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(results, pa.Table): # Enforce type
|
if not isinstance(results, pa.Table): # Enforce type
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# apply limit after reranking
|
||||||
|
results = results.slice(length=self._limit)
|
||||||
|
|
||||||
if not self._with_row_id:
|
if not self._with_row_id:
|
||||||
results = results.drop(["_rowid"])
|
results = results.drop(["_rowid"])
|
||||||
return results
|
return results
|
||||||
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
"""
|
"""
|
||||||
self._vector_query.limit(limit)
|
self._vector_query.limit(limit)
|
||||||
self._fts_query.limit(limit)
|
self._fts_query.limit(limit)
|
||||||
|
self._limit = limit
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
from .cohere import CohereReranker
|
from .cohere import CohereReranker
|
||||||
|
from .colbert import ColbertReranker
|
||||||
from .cross_encoder import CrossEncoderReranker
|
from .cross_encoder import CrossEncoderReranker
|
||||||
from .linear_combination import LinearCombinationReranker
|
from .linear_combination import LinearCombinationReranker
|
||||||
|
from .openai import OpenaiReranker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reranker",
|
"Reranker",
|
||||||
"CrossEncoderReranker",
|
"CrossEncoderReranker",
|
||||||
"CohereReranker",
|
"CohereReranker",
|
||||||
"LinearCombinationReranker",
|
"LinearCombinationReranker",
|
||||||
|
"OpenaiReranker",
|
||||||
|
"ColbertReranker",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
import typing
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class Reranker(ABC):
|
class Reranker(ABC):
|
||||||
def __init__(self, return_score: str = "relevance"):
|
def __init__(self, return_score: str = "relevance"):
|
||||||
@@ -30,7 +26,7 @@ class Reranker(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
@@ -41,8 +37,8 @@ class Reranker(ABC):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query_builder : "lancedb.HybridQueryBuilder"
|
query : str
|
||||||
The query builder object that was used to generate the results
|
The input query
|
||||||
vector_results : pa.Table
|
vector_results : pa.Table
|
||||||
The results from the vector search
|
The results from the vector search
|
||||||
fts_results : pa.Table
|
fts_results : pa.Table
|
||||||
@@ -50,36 +46,6 @@ class Reranker(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def rerank_vector(
|
|
||||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rerank function receives the individual results from the vector search.
|
|
||||||
This isn't mandatory to implement
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query_builder : "lancedb.VectorQueryBuilder"
|
|
||||||
The query builder object that was used to generate the results
|
|
||||||
vector_results : pa.Table
|
|
||||||
The results from the vector search
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Vector Reranking is not implemented")
|
|
||||||
|
|
||||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
|
||||||
"""
|
|
||||||
Rerank function receives the individual results from the FTS search.
|
|
||||||
This isn't mandatory to implement
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query_builder : "lancedb.FTSQueryBuilder"
|
|
||||||
The query builder object that was used to generate the results
|
|
||||||
fts_results : pa.Table
|
|
||||||
The results from the FTS search
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("FTS Reranking is not implemented")
|
|
||||||
|
|
||||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
"""
|
"""
|
||||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -8,9 +7,6 @@ import pyarrow as pa
|
|||||||
from ..util import safe_import
|
from ..util import safe_import
|
||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class CohereReranker(Reranker):
|
class CohereReranker(Reranker):
|
||||||
"""
|
"""
|
||||||
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
docs = combined_results[self.column].to_pylist()
|
docs = combined_results[self.column].to_pylist()
|
||||||
results = self._client.rerank(
|
results = self._client.rerank(
|
||||||
query=query_builder._query,
|
query=query,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
top_n=self.top_n,
|
top_n=self.top_n,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
|
|||||||
107
python/lancedb/rerankers/colbert.py
Normal file
107
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class ColbertReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the ColBERT model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "colbert-ir/colbertv2.0"
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "colbert-ir/colbertv2.0",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.torch = safe_import("torch") # import here for faster ops later
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
|
||||||
|
tokenizer, model = self._model
|
||||||
|
|
||||||
|
# Encode the query
|
||||||
|
query_encoding = tokenizer(query, return_tensors="pt")
|
||||||
|
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||||
|
scores = []
|
||||||
|
# Get score for each document
|
||||||
|
for document in docs:
|
||||||
|
document_encoding = tokenizer(
|
||||||
|
document, return_tensors="pt", truncation=True, max_length=512
|
||||||
|
)
|
||||||
|
document_embedding = model(**document_encoding).last_hidden_state
|
||||||
|
# Calculate MaxSim score
|
||||||
|
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||||
|
scores.append(score.item())
|
||||||
|
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _model(self):
|
||||||
|
transformers = safe_import("transformers")
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
def maxsim(self, query_embedding, document_embedding):
|
||||||
|
# Expand dimensions for broadcasting
|
||||||
|
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||||
|
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||||
|
expanded_query = query_embedding.unsqueeze(2)
|
||||||
|
expanded_doc = document_embedding.unsqueeze(1)
|
||||||
|
|
||||||
|
# Compute cosine similarity across the embedding dimension
|
||||||
|
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||||
|
expanded_query, expanded_doc, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Take the maximum similarity for each query token (across all document tokens)
|
||||||
|
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||||
|
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||||
|
|
||||||
|
# Average these maximum scores across all query tokens
|
||||||
|
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||||
|
return avg_max_sim
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
import typing
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -7,9 +6,6 @@ import pyarrow as pa
|
|||||||
from ..util import safe_import
|
from ..util import safe_import
|
||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderReranker(Reranker):
|
class CrossEncoderReranker(Reranker):
|
||||||
"""
|
"""
|
||||||
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
passages = combined_results[self.column].to_pylist()
|
passages = combined_results[self.column].to_pylist()
|
||||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
cross_scores = self.model.predict(cross_inp)
|
cross_scores = self.model.predict(cross_inp)
|
||||||
combined_results = combined_results.append_column(
|
combined_results = combined_results.append_column(
|
||||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
query: str, # noqa: F821
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
|
|||||||
102
python/lancedb/rerankers/openai.py
Normal file
102
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the OpenAI API.
|
||||||
|
WARNING: This is a prompt based reranker that uses chat model that is
|
||||||
|
not a dedicated reranker API. This should be treated as experimental.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
api_key : str, default None
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-3.5-turbo-1106",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
response = self._client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an expert relevance ranker. Given a list of\
|
||||||
|
documents and a query, your job is to determine the relevance\
|
||||||
|
each document is for answering the query. Your output is JSON,\
|
||||||
|
which is a list of documents. Each document has two fields,\
|
||||||
|
content and relevance_score. relevance_score is from 0.0 to\
|
||||||
|
1.0 indicating the relevance of the text to the given query.\
|
||||||
|
Make sure to include all documents in the response.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
results = json.loads(response.choices[0].message.content)["documents"]
|
||||||
|
docs, scores = list(
|
||||||
|
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||||
|
) # tuples
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self):
|
||||||
|
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||||
|
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||||
|
pass it as `api_key` argument to the CohereReranker."
|
||||||
|
)
|
||||||
|
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||||
@@ -446,7 +446,7 @@ class Table(ABC):
|
|||||||
*default "vector"*
|
*default "vector"*
|
||||||
query_type: str
|
query_type: str
|
||||||
*default "auto"*.
|
*default "auto"*.
|
||||||
Acceptable types are: "vector", "fts", or "auto"
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
|
|
||||||
- If "auto" then the query type is inferred from the query;
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,12 @@ import lancedb
|
|||||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
from lancedb.rerankers import (
|
||||||
|
CohereReranker,
|
||||||
|
ColbertReranker,
|
||||||
|
CrossEncoderReranker,
|
||||||
|
OpenaiReranker,
|
||||||
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
|||||||
return table, MyTable
|
return table, MyTable
|
||||||
|
|
||||||
|
|
||||||
## These tests are pretty loose, we should also check for correctness
|
|
||||||
def test_linear_combination(tmp_path):
|
def test_linear_combination(tmp_path):
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
# The default reranker
|
# The default reranker
|
||||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
|||||||
|
|
||||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query))
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(normalize="score")
|
.rerank(normalize="score")
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="rank", reranker=CohereReranker())
|
.rerank(reranker=CohereReranker())
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query))
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(reranker=CohereReranker())
|
.rerank(reranker=CohereReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query), query_type="hybrid")
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(reranker=CrossEncoderReranker())
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_colbert_reranker(tmp_path):
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_openai_reranker(tmp_path):
|
||||||
|
pytest.importorskip("openai")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user