mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
Compare commits
1 Commits
python-v0.
...
ayush/rera
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8f43ae0d3 |
@@ -1,4 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -18,13 +19,16 @@ class ColbertReranker(Reranker):
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
device : str, default "None"
|
||||
The device to use for the model. If "auto", will use "cuda" if available, else "cpu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
return_score: str="relevance",
|
||||
device: Union[str, None] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
@@ -32,6 +36,10 @@ class ColbertReranker(Reranker):
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
self.device = device
|
||||
if device is None:
|
||||
self.device = "cuda" if self.torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
@@ -46,7 +54,7 @@ class ColbertReranker(Reranker):
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
).to(self.device)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
@@ -116,7 +124,7 @@ class ColbertReranker(Reranker):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
model.to(self.device)
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
|
||||
@@ -42,7 +42,7 @@ class CrossEncoderReranker(Reranker):
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name).to(self.device)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user