From 18484d0b6c9c2c8276f8a9d2678c320a8f227062 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 14 Sep 2024 23:48:09 +0530 Subject: [PATCH] fix: allow pass optional args in colbert reranker (#1649) Fixes https://github.com/lancedb/lancedb/issues/1641 --- python/python/lancedb/rerankers/answerdotai.py | 6 +++++- python/python/lancedb/rerankers/colbert.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py index 3c2fcb2d..3940fe4b 100644 --- a/python/python/lancedb/rerankers/answerdotai.py +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -32,6 +32,9 @@ class AnswerdotaiRerankers(Reranker): The name of the column to use as input to the cross encoder model. return_score : str, default "relevance" options are "relevance" or "all". Only "relevance" is supported for now. + **kwargs + Additional keyword arguments to pass to the model. For example, 'device'. + See AnswerDotAI/rerankers for more information. """ def __init__( @@ -40,13 +43,14 @@ class AnswerdotaiRerankers(Reranker): model_name: str = "answerdotai/answerai-colbert-small-v1", column: str = "text", return_score="relevance", + **kwargs, ): super().__init__(return_score) self.column = column rerankers = attempt_import_or_raise( "rerankers" ) # import here for faster ops later - self.reranker = rerankers.Reranker(model_name, model_type) + self.reranker = rerankers.Reranker(model_name, model_type, **kwargs) def _rerank(self, result_set: pa.Table, query: str): docs = result_set[self.column].to_pylist() diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index cffdd0ba..b40c0b4b 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -26,6 +26,9 @@ class ColbertReranker(AnswerdotaiRerankers): The name of the column to use as input to the cross encoder model. return_score : str, default "relevance" options are "relevance" or "all". Only "relevance" is supported for now. + **kwargs + Additional keyword arguments to pass to the model, for example, 'device'. + See AnswerDotAI/rerankers for more information. """ def __init__( @@ -33,10 +36,12 @@ class ColbertReranker(AnswerdotaiRerankers): model_name: str = "colbert-ir/colbertv2.0", column: str = "text", return_score="relevance", + **kwargs, ): super().__init__( model_type="colbert", model_name=model_name, column=column, return_score=return_score, + **kwargs, )