diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index 87a8e690..60187202 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -1,4 +1,5 @@ from functools import cached_property +from typing import Union import pyarrow as pa @@ -18,13 +19,16 @@ class ColbertReranker(Reranker): The name of the column to use as input to the cross encoder model. return_score : str, default "relevance" options are "relevance" or "all". Only "relevance" is supported for now. + device : str, default "None" + The device to use for the model. If "auto", will use "cuda" if available, else "cpu". """ def __init__( self, model_name: str = "colbert-ir/colbertv2.0", column: str = "text", - return_score="relevance", + return_score: str="relevance", + device: Union[str, None] = None, ): super().__init__(return_score) self.model_name = model_name @@ -32,6 +36,10 @@ class ColbertReranker(Reranker): self.torch = attempt_import_or_raise( "torch" ) # import here for faster ops later + self.device = device + if device is None: + self.device = "cuda" if self.torch.cuda.is_available() else "cpu" + def _rerank(self, result_set: pa.Table, query: str): docs = result_set[self.column].to_pylist() @@ -46,7 +54,7 @@ class ColbertReranker(Reranker): for document in docs: document_encoding = tokenizer( document, return_tensors="pt", truncation=True, max_length=512 - ) + ).to(self.device) document_embedding = model(**document_encoding).last_hidden_state # Calculate MaxSim score score = self.maxsim(query_embedding.unsqueeze(0), document_embedding) @@ -116,7 +124,7 @@ class ColbertReranker(Reranker): transformers = attempt_import_or_raise("transformers") tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) model = transformers.AutoModel.from_pretrained(self.model_name) - + model.to(self.device) return tokenizer, model def maxsim(self, query_embedding, document_embedding): diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index c88b354a..d417fae1 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -42,7 +42,7 @@ class CrossEncoderReranker(Reranker): @cached_property def model(self): sbert = attempt_import_or_raise("sentence_transformers") - cross_encoder = sbert.CrossEncoder(self.model_name) + cross_encoder = sbert.CrossEncoder(self.model_name).to(self.device) return cross_encoder