From 7d65dd97cf713618547ed289d8f1cd56221d0018 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 21 Aug 2024 12:26:52 +0530 Subject: [PATCH] chore(python): update Colbert architecture and minor improvements (#1547) - Update ColBertReranker architecture: The current implementation doesn't use the right arch. This PR uses the implementation in Rerankers library. Fixes https://github.com/lancedb/lancedb/issues/1546 Benchmark diff (hit rate): Hybrid - 91 vs 87 reranked vector - 85 vs 80 - Reranking in FTS is basically disabled in main after last week's FTS updates. I think there's no blocker in supporting that? - Allow overriding accelerators: Most transformer based Rerankers and Embedding automatically select device. This PR allows overriding those settings by passing `device`. Fixes: https://github.com/lancedb/lancedb/issues/1487 --------- Co-authored-by: BubbleCal --- .../python/lancedb/embeddings/instructor.py | 1 + .../python/lancedb/embeddings/transformers.py | 6 +- python/python/lancedb/query.py | 8 ++- python/python/lancedb/rerankers/colbert.py | 63 +++---------------- .../python/lancedb/rerankers/cross_encoder.py | 3 +- python/python/tests/test_rerankers.py | 25 +++++--- 6 files changed, 37 insertions(+), 69 deletions(-) diff --git a/python/python/lancedb/embeddings/instructor.py b/python/python/lancedb/embeddings/instructor.py index 98206bc5..a6022be6 100644 --- a/python/python/lancedb/embeddings/instructor.py +++ b/python/python/lancedb/embeddings/instructor.py @@ -127,6 +127,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, normalize_embeddings=self.normalize_embeddings, + device=self.device, ).tolist() return res diff --git a/python/python/lancedb/embeddings/transformers.py b/python/python/lancedb/embeddings/transformers.py index a20f27ff..dba5b161 100644 --- a/python/python/lancedb/embeddings/transformers.py +++ b/python/python/lancedb/embeddings/transformers.py @@ -44,6 +44,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction): """ name: str = "colbert-ir/colbertv2.0" + device: str = "cpu" _tokenizer: Any = PrivateAttr() _model: Any = PrivateAttr() @@ -53,6 +54,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction): transformers = attempt_import_or_raise("transformers") self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name) self._model = transformers.AutoModel.from_pretrained(self.name) + self._model.to(self.device) if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat @@ -75,9 +77,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction): for text in texts: encoding = self._tokenizer( text, return_tensors="pt", padding=True, truncation=True - ) + ).to(self.device) emb = self._model(**encoding).last_hidden_state.mean(dim=1).squeeze() - embedding.append(emb.detach().numpy()) + embedding.append(emb.tolist()) return embedding diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 874a606a..6c3c71bd 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -727,7 +727,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): vector=[], ) results = self._table._execute_query(query) - return results.read_all() + results = results.read_all() + if self._reranker is not None: + results = self._reranker.rerank_fts(self._query, results) + return results def tantivy_to_arrow(self) -> pa.Table: try: @@ -825,7 +828,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): LanceFtsQueryBuilder The LanceQueryBuilder object. """ - raise NotImplementedError("Reranking is not yet supported for FTS queries.") + self._reranker = reranker + return self class LanceEmptyQueryBuilder(LanceQueryBuilder): diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index 77ef58a1..5e8701b3 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -1,5 +1,3 @@ -from functools import cached_property - import pyarrow as pa from ..util import attempt_import_or_raise @@ -12,7 +10,7 @@ class ColbertReranker(Reranker): Parameters ---------- - model_name : str, default "colbert-ir/colbertv2.0" + model_name : str, default "colbert" (colbert-ir/colbert-v2.0) The name of the cross encoder model to use. column : str, default "text" The name of the column to use as input to the cross encoder model. @@ -22,41 +20,26 @@ class ColbertReranker(Reranker): def __init__( self, - model_name: str = "colbert-ir/colbertv2.0", + model_name: str = "colbert", column: str = "text", return_score="relevance", ): super().__init__(return_score) self.model_name = model_name self.column = column - self.torch = attempt_import_or_raise( - "torch" + rerankers = attempt_import_or_raise( + "rerankers" ) # import here for faster ops later + self.colbert = rerankers.Reranker(self.model_name, model_type="colbert") def _rerank(self, result_set: pa.Table, query: str): docs = result_set[self.column].to_pylist() + doc_ids = list(range(len(docs))) + result = self.colbert.rank(query, docs, doc_ids=doc_ids) - tokenizer, model = self._model + # get the scores of each document in the same order as the input + scores = [result.get_result_by_docid(i).score for i in doc_ids] - # Encode the query - query_encoding = tokenizer(query, return_tensors="pt") - query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1) - scores = [] - # Get score for each document - for document in docs: - document_encoding = tokenizer( - document, return_tensors="pt", truncation=True, max_length=512 - ) - document_embedding = model(**document_encoding).last_hidden_state - # Calculate MaxSim score - score = self.maxsim(query_embedding.unsqueeze(0), document_embedding) - scores.append(score.item()) - - # replace the self.column column with the docs - result_set = result_set.drop(self.column) - result_set = result_set.append_column( - self.column, pa.array(docs, type=pa.string()) - ) # add the scores result_set = result_set.append_column( "_relevance_score", pa.array(scores, type=pa.float32()) @@ -110,31 +93,3 @@ class ColbertReranker(Reranker): result_set = result_set.sort_by([("_relevance_score", "descending")]) return result_set - - @cached_property - def _model(self): - transformers = attempt_import_or_raise("transformers") - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) - model = transformers.AutoModel.from_pretrained(self.model_name) - - return tokenizer, model - - def maxsim(self, query_embedding, document_embedding): - # Expand dimensions for broadcasting - # Query: [batch, length, size] -> [batch, query, 1, size] - # Document: [batch, length, size] -> [batch, 1, length, size] - expanded_query = query_embedding.unsqueeze(2) - expanded_doc = document_embedding.unsqueeze(1) - - # Compute cosine similarity across the embedding dimension - sim_matrix = self.torch.nn.functional.cosine_similarity( - expanded_query, expanded_doc, dim=-1 - ) - - # Take the maximum similarity for each query token (across all document tokens) - # sim_matrix shape: [batch_size, query_length, doc_length] - max_sim_scores, _ = self.torch.max(sim_matrix, dim=2) - - # Average these maximum scores across all query tokens - avg_max_sim = self.torch.mean(max_sim_scores, dim=1) - return avg_max_sim diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index 88396fc3..05673673 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -42,7 +42,8 @@ class CrossEncoderReranker(Reranker): @cached_property def model(self): sbert = attempt_import_or_raise("sentence_transformers") - cross_encoder = sbert.CrossEncoder(self.model_name) + # Allows overriding the automatically selected device + cross_encoder = sbert.CrossEncoder(self.model_name, device=self.device) return cross_encoder diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 2c27b61d..442328d9 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -236,33 +236,37 @@ def test_rrf_reranker(tmp_path, use_tantivy): @pytest.mark.skipif( os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" ) -def test_cohere_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_cohere_reranker(tmp_path, use_tantivy): pytest.importorskip("cohere") reranker = CohereReranker() - table, schema = get_test_table(tmp_path) + table, schema = get_test_table(tmp_path, use_tantivy) _run_test_reranker(reranker, table, "single player experience", None, schema) -def test_cross_encoder_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_cross_encoder_reranker(tmp_path, use_tantivy): pytest.importorskip("sentence_transformers") reranker = CrossEncoderReranker() - table, schema = get_test_table(tmp_path) + table, schema = get_test_table(tmp_path, use_tantivy) _run_test_reranker(reranker, table, "single player experience", None, schema) -def test_colbert_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_colbert_reranker(tmp_path, use_tantivy): pytest.importorskip("transformers") reranker = ColbertReranker() - table, schema = get_test_table(tmp_path) + table, schema = get_test_table(tmp_path, use_tantivy) _run_test_reranker(reranker, table, "single player experience", None, schema) @pytest.mark.skipif( os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" ) -def test_openai_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_openai_reranker(tmp_path, use_tantivy): pytest.importorskip("openai") - table, schema = get_test_table(tmp_path) + table, schema = get_test_table(tmp_path, use_tantivy) reranker = OpenaiReranker() _run_test_reranker(reranker, table, "single player experience", None, schema) @@ -270,8 +274,9 @@ def test_openai_reranker(tmp_path): @pytest.mark.skipif( os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set" ) -def test_jina_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_jina_reranker(tmp_path, use_tantivy): pytest.importorskip("jina") - table, schema = get_test_table(tmp_path) + table, schema = get_test_table(tmp_path, use_tantivy) reranker = JinaReranker() _run_test_reranker(reranker, table, "single player experience", None, schema)