mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
feat(python): Reranker DX improvements (#904)
- Most users might not know how to use `QueryBuilder` object. Instead we should just pass the string query. - Add new rerankers: Colbert, openai
This commit is contained in:
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
self._query = query
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = self._normalize_scores(fts_results, "score")
|
||||
|
||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
||||
results = self._reranker.rerank_hybrid(
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
|
||||
if not self._with_row_id:
|
||||
results = results.drop(["_rowid"])
|
||||
return results
|
||||
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from .base import Reranker
|
||||
from .cohere import CohereReranker
|
||||
from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"CrossEncoderReranker",
|
||||
"CohereReranker",
|
||||
"LinearCombinationReranker",
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
def __init__(self, return_score: str = "relevance"):
|
||||
@@ -30,7 +26,7 @@ class Reranker(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
@@ -41,8 +37,8 @@ class Reranker(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.HybridQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
query : str
|
||||
The input query
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
@@ -50,36 +46,6 @@ class Reranker(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def rerank_vector(
|
||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.VectorQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
"""
|
||||
raise NotImplementedError("Vector Reranking is not implemented")
|
||||
|
||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
||||
"""
|
||||
Rerank function receives the individual results from the FTS search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.FTSQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
raise NotImplementedError("FTS Reranking is not implemented")
|
||||
|
||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -8,9 +7,6 @@ import pyarrow as pa
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CohereReranker(Reranker):
|
||||
"""
|
||||
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query_builder._query,
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
|
||||
107
python/lancedb/rerankers/colbert.py
Normal file
107
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = safe_import("torch") # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = safe_import("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
@@ -1,4 +1,3 @@
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -7,9 +6,6 @@ import pyarrow as pa
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CrossEncoderReranker(Reranker):
|
||||
"""
|
||||
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
|
||||
@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
|
||||
102
python/lancedb/rerankers/openai.py
Normal file
102
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class OpenaiReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the OpenAI API.
|
||||
WARNING: This is a prompt based reranker that uses chat model that is
|
||||
not a dedicated reranker API. This should be treated as experimental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-3.5-turbo-1106",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert relevance ranker. Given a list of\
|
||||
documents and a query, your job is to determine the relevance\
|
||||
each document is for answering the query. Your output is JSON,\
|
||||
which is a list of documents. Each document has two fields,\
|
||||
content and relevance_score. relevance_score is from 0.0 to\
|
||||
1.0 indicating the relevance of the text to the given query.\
|
||||
Make sure to include all documents in the response.",
|
||||
},
|
||||
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||
],
|
||||
)
|
||||
results = json.loads(response.choices[0].message.content)["documents"]
|
||||
docs, scores = list(
|
||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||
@@ -447,7 +447,7 @@ class Table(ABC):
|
||||
*default "vector"*
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", or "auto"
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
|
||||
@@ -7,7 +7,12 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.rerankers import (
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user