mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
Compare commits
1 Commits
python-v0.
...
ayush/rera
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8f43ae0d3 |
@@ -1,4 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -18,13 +19,16 @@ class ColbertReranker(Reranker):
|
|||||||
The name of the column to use as input to the cross encoder model.
|
The name of the column to use as input to the cross encoder model.
|
||||||
return_score : str, default "relevance"
|
return_score : str, default "relevance"
|
||||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
device : str, default "None"
|
||||||
|
The device to use for the model. If "auto", will use "cuda" if available, else "cpu".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "colbert-ir/colbertv2.0",
|
model_name: str = "colbert-ir/colbertv2.0",
|
||||||
column: str = "text",
|
column: str = "text",
|
||||||
return_score="relevance",
|
return_score: str="relevance",
|
||||||
|
device: Union[str, None] = None,
|
||||||
):
|
):
|
||||||
super().__init__(return_score)
|
super().__init__(return_score)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@@ -32,6 +36,10 @@ class ColbertReranker(Reranker):
|
|||||||
self.torch = attempt_import_or_raise(
|
self.torch = attempt_import_or_raise(
|
||||||
"torch"
|
"torch"
|
||||||
) # import here for faster ops later
|
) # import here for faster ops later
|
||||||
|
self.device = device
|
||||||
|
if device is None:
|
||||||
|
self.device = "cuda" if self.torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def _rerank(self, result_set: pa.Table, query: str):
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
docs = result_set[self.column].to_pylist()
|
docs = result_set[self.column].to_pylist()
|
||||||
@@ -46,7 +54,7 @@ class ColbertReranker(Reranker):
|
|||||||
for document in docs:
|
for document in docs:
|
||||||
document_encoding = tokenizer(
|
document_encoding = tokenizer(
|
||||||
document, return_tensors="pt", truncation=True, max_length=512
|
document, return_tensors="pt", truncation=True, max_length=512
|
||||||
)
|
).to(self.device)
|
||||||
document_embedding = model(**document_encoding).last_hidden_state
|
document_embedding = model(**document_encoding).last_hidden_state
|
||||||
# Calculate MaxSim score
|
# Calculate MaxSim score
|
||||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||||
@@ -116,7 +124,7 @@ class ColbertReranker(Reranker):
|
|||||||
transformers = attempt_import_or_raise("transformers")
|
transformers = attempt_import_or_raise("transformers")
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||||
|
model.to(self.device)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
def maxsim(self, query_embedding, document_embedding):
|
def maxsim(self, query_embedding, document_embedding):
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class CrossEncoderReranker(Reranker):
|
|||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
sbert = attempt_import_or_raise("sentence_transformers")
|
sbert = attempt_import_or_raise("sentence_transformers")
|
||||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
cross_encoder = sbert.CrossEncoder(self.model_name).to(self.device)
|
||||||
|
|
||||||
return cross_encoder
|
return cross_encoder
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user