feat(python): Reranker DX improvements (#904)

- Most users might not know how to use `QueryBuilder` object. Instead we
should just pass the string query.
- Add new rerankers: Colbert, openai
This commit is contained in:
Ayush Chaurasia
2024-02-06 13:59:31 +05:30
committed by Weston Pace
parent 39cc2fd62b
commit d07817a562
12 changed files with 400 additions and 68 deletions

View File

@@ -49,6 +49,9 @@ jobs:
test-node:
name: Test doc nodejs code
runs-on: "ubuntu-latest"
timeout-minutes: 45
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -66,6 +69,12 @@ jobs:
uses: swatinem/rust-cache@v2
- name: Install node dependencies
run: |
sudo swapoff -a
sudo fallocate -l 8G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
sudo swapon --show
cd node
npm ci
npm run build-release

View File

@@ -130,6 +130,60 @@ Arguments
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### ColBERT Reranker
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
```python
from lancedb.rerankers import ColbertReranker
reranker = ColbertReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
The name of the cross encoder model to use.
* `column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
* `return_score` : `str`, default `"relevance"`
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
!!! Note
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### OpenAI Reranker
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
!!! Note
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
!!! Tip
You might run out of token limit so set the search `limits` based on your token limit.
```python
from lancedb.rerankers import OpenaiReranker
reranker = OpenaiReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
The name of the cross encoder model to use.
`column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
`return_score` : `str`, default `"relevance"`
options are "relevance" or "all". Only "relevance" is supported for now.
`api_key` : `str`, default `None`
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
## Building Custom Rerankers
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
@@ -146,7 +200,7 @@ class MyReranker(Reranker):
self.param1 = param1
self.param2 = param2
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
@@ -168,7 +222,7 @@ import pyarrow as pa
class MyReranker(Reranker):
...
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)

View File

@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table)
self._validate_fts_index()
self._query = query
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
# rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score")
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
results = self._reranker.rerank_hybrid(
self._fts_query._query, vector_results, fts_results
)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
# apply limit after reranking
results = results.slice(length=self._limit)
if not self._with_row_id:
results = results.drop(["_rowid"])
return results
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
self._limit = limit
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:

View File

@@ -1,11 +1,15 @@
from .base import Reranker
from .cohere import CohereReranker
from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
__all__ = [
"Reranker",
"CrossEncoderReranker",
"CohereReranker",
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
]

View File

@@ -1,12 +1,8 @@
import typing
from abc import ABC, abstractmethod
import numpy as np
import pyarrow as pa
if typing.TYPE_CHECKING:
import lancedb
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
@@ -30,7 +26,7 @@ class Reranker(ABC):
@abstractmethod
def rerank_hybrid(
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
@@ -41,8 +37,8 @@ class Reranker(ABC):
Parameters
----------
query_builder : "lancedb.HybridQueryBuilder"
The query builder object that was used to generate the results
query : str
The input query
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
@@ -50,36 +46,6 @@ class Reranker(ABC):
"""
pass
def rerank_vector(
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
):
"""
Rerank function receives the individual results from the vector search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.VectorQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
"""
raise NotImplementedError("Vector Reranking is not implemented")
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
"""
Rerank function receives the individual results from the FTS search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.FTSQueryBuilder"
The query builder object that was used to generate the results
fts_results : pa.Table
The results from the FTS search
"""
raise NotImplementedError("FTS Reranking is not implemented")
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
"""
Merge the results from the vector and FTS search. This is a vanilla merging

View File

@@ -1,5 +1,4 @@
import os
import typing
from functools import cached_property
from typing import Union
@@ -8,9 +7,6 @@ import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CohereReranker(Reranker):
"""
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
results = self._client.rerank(
query=query_builder._query,
query=query,
documents=docs,
top_n=self.top_n,
model=self.model_name,

View File

@@ -0,0 +1,107 @@
from functools import cached_property
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
class ColbertReranker(Reranker):
"""
Reranks the results using the ColBERT model.
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = safe_import("torch") # import here for faster ops later
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
tokenizer, model = self._model
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _model(self):
transformers = safe_import("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim

View File

@@ -1,4 +1,3 @@
import typing
from functools import cached_property
from typing import Union
@@ -7,9 +6,6 @@ import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CrossEncoderReranker(Reranker):
"""
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist()
cross_inp = [[query_builder._query, passage] for passage in passages]
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())

View File

@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
query: str, # noqa: F821
vector_results: pa.Table,
fts_results: pa.Table,
):

View File

@@ -0,0 +1,102 @@
import json
import os
from functools import cached_property
from typing import Optional
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
class OpenaiReranker(Reranker):
"""
Reranks the results using the OpenAI API.
WARNING: This is a prompt based reranker that uses chat model that is
not a dedicated reranker API. This should be treated as experimental.
Parameters
----------
model_name : str, default "gpt-3.5-turbo-1106 "
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
api_key : str, default None
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
"""
def __init__(
self,
model_name: str = "gpt-3.5-turbo-1106",
column: str = "text",
return_score="relevance",
api_key: Optional[str] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.api_key = api_key
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
response = self._client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
temperature=0,
messages=[
{
"role": "system",
"content": "You are an expert relevance ranker. Given a list of\
documents and a query, your job is to determine the relevance\
each document is for answering the query. Your output is JSON,\
which is a list of documents. Each document has two fields,\
content and relevance_score. relevance_score is from 0.0 to\
1.0 indicating the relevance of the text to the given query.\
Make sure to include all documents in the response.",
},
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
],
)
results = json.loads(response.choices[0].message.content)["documents"]
docs, scores = list(
zip(*[(result["content"], result["relevance_score"]) for result in results])
) # tuples
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _client(self):
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
raise ValueError(
"OPENAI_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the CohereReranker."
)
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)

View File

@@ -446,7 +446,7 @@ class Table(ABC):
*default "vector"*
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", or "auto"
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;

View File

@@ -7,7 +7,12 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
from lancedb.rerankers import (
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
)
from lancedb.table import LanceTable
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
assert result1 == result3 # 2 & 3 should be the same as they use score as score
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query))
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CohereReranker())
.rerank(reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query))
.limit(30)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CrossEncoderReranker())
.rerank(reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=ColbertReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=ColbertReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=OpenaiReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)