Files
lancedb/python/python/lancedb/rerankers/colbert.py
Ayush Chaurasia 0a387a5429 feat(python): Support reranking for vector and fts (#1103)
solves https://github.com/lancedb/lancedb/issues/1086

Usage Reranking with FTS:
```
retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite")
pylist = [{"text": "Carson City is the capital city of the American state of Nevada. At the  2010 United States Census, Carson City had a population of 55,274."},
          {"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan."},
        {"text": "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas."},
        {"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. "},
        {"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."},
        {"text": "North Dakota is a state in the United States. 672,591 people lived in North Dakota in the year 2010. The capital and seat of government is Bismarck."},
        ]
retriever.add(pylist)
retriever.create_fts_index("text", replace=True)

query = "What is the capital of the United States?"
reranker = CohereReranker(return_score="all")
print(retriever.search(query, query_type="fts").limit(10).to_pandas())
print(retriever.search(query, query_type="fts").rerank(reranker=reranker).limit(10).to_pandas())
```
Result
```
                                                text                                             vector     score
0  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  0.729602
1  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  0.678046
2  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  0.671521
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  0.667898
4  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  0.653422
5  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  0.639346
                                                text                                             vector     score  _relevance_score
0  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  0.653422          0.979977
1  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  0.671521          0.299105
2  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  0.729602          0.284874
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  0.667898          0.089614
4  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  0.639346          0.063832
5  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  0.678046          0.041462
```

## Vector Search usage:
```
query = "What is the capital of the United States?"
reranker = CohereReranker(return_score="all")
print(retriever.search(query).limit(10).to_pandas())
print(retriever.search(query).rerank(reranker=reranker, query=query).limit(10).to_pandas()) # <-- Note: passing extra string query here
```

Results
```
                                                text                                             vector  _distance
0  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  39.728973
1  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  41.384884
2  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  55.220200
3  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  58.345654
4  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  60.060867
5  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  64.260544
                                                text                                             vector  _distance  _relevance_score
0  Washington, D.C. (also known as simply Washing...  [-0.0090408325, 0.42578125, 0.3798828, -0.3574...  41.384884          0.979977
1  The Commonwealth of the Northern Mariana Islan...  [0.3684082, 0.30493164, 0.004600525, -0.049407...  60.060867          0.299105
2  Capital punishment (the death penalty) has exi...  [0.099975586, 0.047943115, -0.16723633, -0.183...  39.728973          0.284874
3  Carson City is the capital city of the America...  [0.13989258, 0.14990234, 0.14172363, 0.0546569...  55.220200          0.089614
4  North Dakota is a state in the United States. ...  [0.55859375, -0.2109375, 0.14526367, 0.1634521...  64.260544          0.063832
5  Charlotte Amalie is the capital and largest ci...  [-0.021255493, 0.03363037, -0.027450562, -0.17...  58.345654          0.041462
```
2024-03-19 22:20:31 +05:30

141 lines
4.7 KiB
Python

from functools import cached_property
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class ColbertReranker(Reranker):
"""
Reranks the results using the ColBERT model.
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
tokenizer, model = self._model
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
result_set = result_set.drop(self.column)
result_set = result_set.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
@cached_property
def _model(self):
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim