mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
chore(python): update Colbert architecture and minor improvements (#1547)
- Update ColBertReranker architecture: The current implementation doesn't use the right arch. This PR uses the implementation in Rerankers library. Fixes https://github.com/lancedb/lancedb/issues/1546 Benchmark diff (hit rate): Hybrid - 91 vs 87 reranked vector - 85 vs 80 - Reranking in FTS is basically disabled in main after last week's FTS updates. I think there's no blocker in supporting that? - Allow overriding accelerators: Most transformer based Rerankers and Embedding automatically select device. This PR allows overriding those settings by passing `device`. Fixes: https://github.com/lancedb/lancedb/issues/1487 --------- Co-authored-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -127,6 +127,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
device=self.device,
|
||||
).tolist()
|
||||
return res
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "colbert-ir/colbertv2.0"
|
||||
device: str = "cpu"
|
||||
_tokenizer: Any = PrivateAttr()
|
||||
_model: Any = PrivateAttr()
|
||||
|
||||
@@ -53,6 +54,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
self._model.to(self.device)
|
||||
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
@@ -75,9 +77,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
for text in texts:
|
||||
encoding = self._tokenizer(
|
||||
text, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
).to(self.device)
|
||||
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
|
||||
embedding.append(emb.detach().numpy())
|
||||
embedding.append(emb.tolist())
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
@@ -727,7 +727,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
vector=[],
|
||||
)
|
||||
results = self._table._execute_query(query)
|
||||
return results.read_all()
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
return results
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
@@ -825,7 +828,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
LanceFtsQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not yet supported for FTS queries.")
|
||||
self._reranker = reranker
|
||||
return self
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
@@ -12,7 +10,7 @@ class ColbertReranker(Reranker):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
model_name : str, default "colbert" (colbert-ir/colbert-v2.0)
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
@@ -22,41 +20,26 @@ class ColbertReranker(Reranker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
model_name: str = "colbert",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.colbert = rerankers.Reranker(self.model_name, model_type="colbert")
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
doc_ids = list(range(len(docs)))
|
||||
result = self.colbert.rank(query, docs, doc_ids=doc_ids)
|
||||
|
||||
tokenizer, model = self._model
|
||||
# get the scores of each document in the same order as the input
|
||||
scores = [result.get_result_by_docid(i).score for i in doc_ids]
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
result_set = result_set.drop(self.column)
|
||||
result_set = result_set.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
@@ -110,31 +93,3 @@ class ColbertReranker(Reranker):
|
||||
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return result_set
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
|
||||
@@ -42,7 +42,8 @@ class CrossEncoderReranker(Reranker):
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
# Allows overriding the automatically selected device
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
|
||||
@@ -236,33 +236,37 @@ def test_rrf_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cohere_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_colbert_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("transformers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_openai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = OpenaiReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -270,8 +274,9 @@ def test_openai_reranker(tmp_path):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
|
||||
)
|
||||
def test_jina_reranker(tmp_path):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_jina_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("jina")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = JinaReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
Reference in New Issue
Block a user