Compare commits

...

1 Commits

Author SHA1 Message Date
ayush chaurasia
d8f43ae0d3 update 2024-07-29 17:54:24 +05:30
2 changed files with 12 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
from functools import cached_property from functools import cached_property
from typing import Union
import pyarrow as pa import pyarrow as pa
@@ -18,13 +19,16 @@ class ColbertReranker(Reranker):
The name of the column to use as input to the cross encoder model. The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance" return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now. options are "relevance" or "all". Only "relevance" is supported for now.
device : str, default "None"
The device to use for the model. If "auto", will use "cuda" if available, else "cpu".
""" """
def __init__( def __init__(
self, self,
model_name: str = "colbert-ir/colbertv2.0", model_name: str = "colbert-ir/colbertv2.0",
column: str = "text", column: str = "text",
return_score="relevance", return_score: str="relevance",
device: Union[str, None] = None,
): ):
super().__init__(return_score) super().__init__(return_score)
self.model_name = model_name self.model_name = model_name
@@ -32,6 +36,10 @@ class ColbertReranker(Reranker):
self.torch = attempt_import_or_raise( self.torch = attempt_import_or_raise(
"torch" "torch"
) # import here for faster ops later ) # import here for faster ops later
self.device = device
if device is None:
self.device = "cuda" if self.torch.cuda.is_available() else "cpu"
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
@@ -46,7 +54,7 @@ class ColbertReranker(Reranker):
for document in docs: for document in docs:
document_encoding = tokenizer( document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512 document, return_tensors="pt", truncation=True, max_length=512
) ).to(self.device)
document_embedding = model(**document_encoding).last_hidden_state document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score # Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding) score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
@@ -116,7 +124,7 @@ class ColbertReranker(Reranker):
transformers = attempt_import_or_raise("transformers") transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name) model = transformers.AutoModel.from_pretrained(self.model_name)
model.to(self.device)
return tokenizer, model return tokenizer, model
def maxsim(self, query_embedding, document_embedding): def maxsim(self, query_embedding, document_embedding):

View File

@@ -42,7 +42,7 @@ class CrossEncoderReranker(Reranker):
@cached_property @cached_property
def model(self): def model(self):
sbert = attempt_import_or_raise("sentence_transformers") sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name) cross_encoder = sbert.CrossEncoder(self.model_name).to(self.device)
return cross_encoder return cross_encoder