mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
This changes `lancedb` from a "pure python" setuptools project to a maturin project and adds a rust lancedb dependency. The async python client is extremely minimal (only `connect` and `Connection.table_names` are supported). The purpose of this PR is to get the infrastructure in place for building out the rest of the async client. Although this is not technically a breaking change (no APIs are changing) it is still a considerable change in the way the wheels are built because they now include the native shared library.
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
from functools import cached_property
|
|
|
|
import pyarrow as pa
|
|
|
|
from ..util import attempt_import_or_raise
|
|
from .base import Reranker
|
|
|
|
|
|
class ColbertReranker(Reranker):
|
|
"""
|
|
Reranks the results using the ColBERT model.
|
|
|
|
Parameters
|
|
----------
|
|
model_name : str, default "colbert-ir/colbertv2.0"
|
|
The name of the cross encoder model to use.
|
|
column : str, default "text"
|
|
The name of the column to use as input to the cross encoder model.
|
|
return_score : str, default "relevance"
|
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "colbert-ir/colbertv2.0",
|
|
column: str = "text",
|
|
return_score="relevance",
|
|
):
|
|
super().__init__(return_score)
|
|
self.model_name = model_name
|
|
self.column = column
|
|
self.torch = attempt_import_or_raise(
|
|
"torch"
|
|
) # import here for faster ops later
|
|
|
|
def rerank_hybrid(
|
|
self,
|
|
query: str,
|
|
vector_results: pa.Table,
|
|
fts_results: pa.Table,
|
|
):
|
|
combined_results = self.merge_results(vector_results, fts_results)
|
|
docs = combined_results[self.column].to_pylist()
|
|
|
|
tokenizer, model = self._model
|
|
|
|
# Encode the query
|
|
query_encoding = tokenizer(query, return_tensors="pt")
|
|
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
|
scores = []
|
|
# Get score for each document
|
|
for document in docs:
|
|
document_encoding = tokenizer(
|
|
document, return_tensors="pt", truncation=True, max_length=512
|
|
)
|
|
document_embedding = model(**document_encoding).last_hidden_state
|
|
# Calculate MaxSim score
|
|
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
|
scores.append(score.item())
|
|
|
|
# replace the self.column column with the docs
|
|
combined_results = combined_results.drop(self.column)
|
|
combined_results = combined_results.append_column(
|
|
self.column, pa.array(docs, type=pa.string())
|
|
)
|
|
# add the scores
|
|
combined_results = combined_results.append_column(
|
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
|
)
|
|
if self.score == "relevance":
|
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
|
elif self.score == "all":
|
|
raise NotImplementedError(
|
|
"OpenAI Reranker does not support score='all' yet"
|
|
)
|
|
|
|
combined_results = combined_results.sort_by(
|
|
[("_relevance_score", "descending")]
|
|
)
|
|
|
|
return combined_results
|
|
|
|
@cached_property
|
|
def _model(self):
|
|
transformers = attempt_import_or_raise("transformers")
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
|
model = transformers.AutoModel.from_pretrained(self.model_name)
|
|
|
|
return tokenizer, model
|
|
|
|
def maxsim(self, query_embedding, document_embedding):
|
|
# Expand dimensions for broadcasting
|
|
# Query: [batch, length, size] -> [batch, query, 1, size]
|
|
# Document: [batch, length, size] -> [batch, 1, length, size]
|
|
expanded_query = query_embedding.unsqueeze(2)
|
|
expanded_doc = document_embedding.unsqueeze(1)
|
|
|
|
# Compute cosine similarity across the embedding dimension
|
|
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
|
expanded_query, expanded_doc, dim=-1
|
|
)
|
|
|
|
# Take the maximum similarity for each query token (across all document tokens)
|
|
# sim_matrix shape: [batch_size, query_length, doc_length]
|
|
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
|
|
|
# Average these maximum scores across all query tokens
|
|
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
|
return avg_max_sim
|