Compare commits

...

1 Commits

Author SHA1 Message Date
ayush chaurasia
d8f43ae0d3 update 2024-07-29 17:54:24 +05:30
2 changed files with 12 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
from functools import cached_property
from typing import Union
import pyarrow as pa
@@ -18,13 +19,16 @@ class ColbertReranker(Reranker):
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
device : str, default "None"
The device to use for the model. If "auto", will use "cuda" if available, else "cpu".
"""
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
return_score: str="relevance",
device: Union[str, None] = None,
):
super().__init__(return_score)
self.model_name = model_name
@@ -32,6 +36,10 @@ class ColbertReranker(Reranker):
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
self.device = device
if device is None:
self.device = "cuda" if self.torch.cuda.is_available() else "cpu"
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
@@ -46,7 +54,7 @@ class ColbertReranker(Reranker):
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
).to(self.device)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
@@ -116,7 +124,7 @@ class ColbertReranker(Reranker):
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
model.to(self.device)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):

View File

@@ -42,7 +42,7 @@ class CrossEncoderReranker(Reranker):
@cached_property
def model(self):
sbert = attempt_import_or_raise("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
cross_encoder = sbert.CrossEncoder(self.model_name).to(self.device)
return cross_encoder