mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
fix: allow pass optional args in colbert reranker (#1649)
Fixes https://github.com/lancedb/lancedb/issues/1641
This commit is contained in:
@@ -32,6 +32,9 @@ class AnswerdotaiRerankers(Reranker):
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
**kwargs
|
||||
Additional keyword arguments to pass to the model. For example, 'device'.
|
||||
See AnswerDotAI/rerankers for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,13 +43,14 @@ class AnswerdotaiRerankers(Reranker):
|
||||
model_name: str = "answerdotai/answerai-colbert-small-v1",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.column = column
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.reranker = rerankers.Reranker(model_name, model_type)
|
||||
self.reranker = rerankers.Reranker(model_name, model_type, **kwargs)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
|
||||
@@ -26,6 +26,9 @@ class ColbertReranker(AnswerdotaiRerankers):
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
**kwargs
|
||||
Additional keyword arguments to pass to the model, for example, 'device'.
|
||||
See AnswerDotAI/rerankers for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -33,10 +36,12 @@ class ColbertReranker(AnswerdotaiRerankers):
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model_type="colbert",
|
||||
model_name=model_name,
|
||||
column=column,
|
||||
return_score=return_score,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user