fix: allow pass optional args in colbert reranker (#1649)

Fixes https://github.com/lancedb/lancedb/issues/1641
This commit is contained in:
Ayush Chaurasia
2024-09-14 23:48:09 +05:30
committed by GitHub
parent c02ee3c80c
commit 18484d0b6c
2 changed files with 10 additions and 1 deletions

View File

@@ -32,6 +32,9 @@ class AnswerdotaiRerankers(Reranker):
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
**kwargs
Additional keyword arguments to pass to the model. For example, 'device'.
See AnswerDotAI/rerankers for more information.
"""
def __init__(
@@ -40,13 +43,14 @@ class AnswerdotaiRerankers(Reranker):
model_name: str = "answerdotai/answerai-colbert-small-v1",
column: str = "text",
return_score="relevance",
**kwargs,
):
super().__init__(return_score)
self.column = column
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.reranker = rerankers.Reranker(model_name, model_type)
self.reranker = rerankers.Reranker(model_name, model_type, **kwargs)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()

View File

@@ -26,6 +26,9 @@ class ColbertReranker(AnswerdotaiRerankers):
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
**kwargs
Additional keyword arguments to pass to the model, for example, 'device'.
See AnswerDotAI/rerankers for more information.
"""
def __init__(
@@ -33,10 +36,12 @@ class ColbertReranker(AnswerdotaiRerankers):
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
**kwargs,
):
super().__init__(
model_type="colbert",
model_name=model_name,
column=column,
return_score=return_score,
**kwargs,
)