chore(python): update Colbert architecture and minor improvements (#1547)

- Update ColBertReranker architecture: The current implementation
doesn't use the right arch. This PR uses the implementation in Rerankers
library. Fixes https://github.com/lancedb/lancedb/issues/1546
Benchmark diff (hit rate):
Hybrid - 91 vs 87
reranked vector - 85 vs 80

- Reranking in FTS is basically disabled in main after last week's FTS
updates. I think there's no blocker in supporting that?
- Allow overriding accelerators: Most transformer based Rerankers and
Embedding automatically select device. This PR allows overriding those
settings by passing `device`. Fixes:
https://github.com/lancedb/lancedb/issues/1487

---------

Co-authored-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
Ayush Chaurasia
2024-08-21 12:26:52 +05:30
committed by GitHub
parent 85bb7e54e4
commit 7d65dd97cf
6 changed files with 37 additions and 69 deletions

View File

@@ -127,6 +127,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
normalize_embeddings=self.normalize_embeddings,
device=self.device,
).tolist()
return res

View File

@@ -44,6 +44,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
"""
name: str = "colbert-ir/colbertv2.0"
device: str = "cpu"
_tokenizer: Any = PrivateAttr()
_model: Any = PrivateAttr()
@@ -53,6 +54,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
transformers = attempt_import_or_raise("transformers")
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name)
self._model.to(self.device)
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
@@ -75,9 +77,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
for text in texts:
encoding = self._tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
).to(self.device)
emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze()
embedding.append(emb.detach().numpy())
embedding.append(emb.tolist())
return embedding

View File

@@ -727,7 +727,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
vector=[],
)
results = self._table._execute_query(query)
return results.read_all()
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
return results
def tantivy_to_arrow(self) -> pa.Table:
try:
@@ -825,7 +828,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
LanceFtsQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError("Reranking is not yet supported for FTS queries.")
self._reranker = reranker
return self
class LanceEmptyQueryBuilder(LanceQueryBuilder):

View File

@@ -1,5 +1,3 @@
from functools import cached_property
import pyarrow as pa
from ..util import attempt_import_or_raise
@@ -12,7 +10,7 @@ class ColbertReranker(Reranker):
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
model_name : str, default "colbert" (colbert-ir/colbert-v2.0)
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
@@ -22,41 +20,26 @@ class ColbertReranker(Reranker):
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
model_name: str = "colbert",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = attempt_import_or_raise(
"torch"
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.colbert = rerankers.Reranker(self.model_name, model_type="colbert")
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
doc_ids = list(range(len(docs)))
result = self.colbert.rank(query, docs, doc_ids=doc_ids)
tokenizer, model = self._model
# get the scores of each document in the same order as the input
scores = [result.get_result_by_docid(i).score for i in doc_ids]
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
result_set = result_set.drop(self.column)
result_set = result_set.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
@@ -110,31 +93,3 @@ class ColbertReranker(Reranker):
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
@cached_property
def _model(self):
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim

View File

@@ -42,7 +42,8 @@ class CrossEncoderReranker(Reranker):
@cached_property
def model(self):
sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
# Allows overriding the automatically selected device
cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device)
return cross_encoder

View File

@@ -236,33 +236,37 @@ def test_rrf_reranker(tmp_path, use_tantivy):
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cohere_reranker(tmp_path, use_tantivy):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_cross_encoder_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cross_encoder_reranker(tmp_path, use_tantivy):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_colbert_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_colbert_reranker(tmp_path, use_tantivy):
pytest.importorskip("transformers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_openai_reranker(tmp_path, use_tantivy):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = OpenaiReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@@ -270,8 +274,9 @@ def test_openai_reranker(tmp_path):
@pytest.mark.skipif(
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
)
def test_jina_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_jina_reranker(tmp_path, use_tantivy):
pytest.importorskip("jina")
table, schema = get_test_table(tmp_path)
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)