mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
feat: show reranker info in hybrid search explain plan (#3006)
Closes #3000 The hybrid search `explain_plan` now shows the reranker as the top-level node with the vector and FTS sub-plans indented underneath, instead of just listing them separately with no reranker context. **Before:** ``` Vector Search Plan: ProjectionExec: ... FTS Search Plan: ProjectionExec: ... ``` **After:** ``` RRFReranker(K=60) Vector Search Plan: ProjectionExec: ... FTS Search Plan: ProjectionExec: ... ``` Other rerankers display similarly ; e.g. `LinearCombinationReranker(weight=0.7, fill=1.0)`, `MRRReranker(weight_vector=0.5, weight_fts=0.5)`, `CohereReranker(model_name=name)`. --------- Signed-off-by: dask-58 <googldhruv@gmail.com> Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -2118,19 +2118,17 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
""" # noqa: E501
|
||||
self._create_query_builders()
|
||||
|
||||
results = ["Vector Search Plan:"]
|
||||
results.append(
|
||||
self._table._explain_plan(
|
||||
self._vector_query.to_query_object(), verbose=verbose
|
||||
)
|
||||
reranker_label = str(self._reranker) if self._reranker else "No reranker"
|
||||
vector_plan = self._table._explain_plan(
|
||||
self._vector_query.to_query_object(), verbose=verbose
|
||||
)
|
||||
results.append("FTS Search Plan:")
|
||||
results.append(
|
||||
self._table._explain_plan(
|
||||
self._fts_query.to_query_object(), verbose=verbose
|
||||
)
|
||||
fts_plan = self._table._explain_plan(
|
||||
self._fts_query.to_query_object(), verbose=verbose
|
||||
)
|
||||
return "\n".join(results)
|
||||
# Indent sub-plans under the reranker
|
||||
indented_vector = "\n".join(" " + line for line in vector_plan.splitlines())
|
||||
indented_fts = "\n".join(" " + line for line in fts_plan.splitlines())
|
||||
return f"{reranker_label}\n {indented_vector}\n {indented_fts}"
|
||||
|
||||
def analyze_plan(self):
|
||||
"""Execute the query and display with runtime metrics.
|
||||
@@ -3164,23 +3162,20 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
... plan = await table.query().nearest_to([1.0, 2.0]).nearest_to_text("hello").explain_plan(True)
|
||||
... print(plan)
|
||||
>>> asyncio.run(doctest_example()) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
Vector Search Plan:
|
||||
ProjectionExec: expr=[vector@0 as vector, text@3 as text, _distance@2 as _distance]
|
||||
Take: columns="vector, _rowid, _distance, (text)"
|
||||
CoalesceBatchesExec: target_batch_size=1024
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceRead: uri=..., projection=[vector], ...
|
||||
<BLANKLINE>
|
||||
FTS Search Plan:
|
||||
ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score]
|
||||
Take: columns="_rowid, _score, (vector), (text)"
|
||||
CoalesceBatchesExec: target_batch_size=1024
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
MatchQuery: column=text, query=hello
|
||||
<BLANKLINE>
|
||||
RRFReranker(K=60)
|
||||
ProjectionExec: expr=[vector@0 as vector, text@3 as text, _distance@2 as _distance]
|
||||
Take: columns="vector, _rowid, _distance, (text)"
|
||||
CoalesceBatchesExec: target_batch_size=1024
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceRead: uri=..., projection=[vector], ...
|
||||
ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score]
|
||||
Take: columns="_rowid, _score, (vector), (text)"
|
||||
CoalesceBatchesExec: target_batch_size=1024
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
MatchQuery: column=text, query=hello
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -3192,12 +3187,12 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
|
||||
results = ["Vector Search Plan:"]
|
||||
results.append(await self._inner.to_vector_query().explain_plan(verbose))
|
||||
results.append("FTS Search Plan:")
|
||||
results.append(await self._inner.to_fts_query().explain_plan(verbose))
|
||||
|
||||
return "\n".join(results)
|
||||
vector_plan = await self._inner.to_vector_query().explain_plan(verbose)
|
||||
fts_plan = await self._inner.to_fts_query().explain_plan(verbose)
|
||||
# Indent sub-plans under the reranker
|
||||
indented_vector = "\n".join(" " + line for line in vector_plan.splitlines())
|
||||
indented_fts = "\n".join(" " + line for line in fts_plan.splitlines())
|
||||
return f"{self._reranker}\n {indented_vector}\n {indented_fts}"
|
||||
|
||||
async def analyze_plan(self):
|
||||
"""
|
||||
|
||||
@@ -42,10 +42,18 @@ class AnswerdotaiRerankers(Reranker):
|
||||
rerankers = attempt_import_or_raise(
|
||||
"rerankers"
|
||||
) # import here for faster ops later
|
||||
self.model_name = model_name
|
||||
self.model_type = model_type
|
||||
self.reranker = rerankers.Reranker(
|
||||
model_name=model_name, model_type=model_type, **kwargs
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"AnswerdotaiRerankers(model_type={self.model_type}, "
|
||||
f"model_name={self.model_name})"
|
||||
)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
|
||||
@@ -40,6 +40,9 @@ class Reranker(ABC):
|
||||
if ARROW_VERSION.major <= 13:
|
||||
self._concat_tables_args = {"promote": True}
|
||||
|
||||
def __str__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@@ -44,6 +44,9 @@ class CohereReranker(Reranker):
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
|
||||
def __str__(self):
|
||||
return f"CohereReranker(model_name={self.model_name})"
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
|
||||
@@ -50,6 +50,9 @@ class CrossEncoderReranker(Reranker):
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def __str__(self):
|
||||
return f"CrossEncoderReranker(model_name={self.model_name})"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
|
||||
@@ -45,6 +45,9 @@ class JinaReranker(Reranker):
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
|
||||
def __str__(self):
|
||||
return f"JinaReranker(model_name={self.model_name})"
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
import requests
|
||||
|
||||
@@ -38,6 +38,9 @@ class LinearCombinationReranker(Reranker):
|
||||
self.weight = weight
|
||||
self.fill = fill
|
||||
|
||||
def __str__(self):
|
||||
return f"LinearCombinationReranker(weight={self.weight}, fill={self.fill})"
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
|
||||
@@ -54,6 +54,12 @@ class MRRReranker(Reranker):
|
||||
self.weight_vector = weight_vector
|
||||
self.weight_fts = weight_fts
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"MRRReranker(weight_vector={self.weight_vector}, "
|
||||
f"weight_fts={self.weight_fts})"
|
||||
)
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
|
||||
@@ -43,6 +43,9 @@ class OpenaiReranker(Reranker):
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def __str__(self):
|
||||
return f"OpenaiReranker(model_name={self.model_name})"
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
|
||||
@@ -36,6 +36,9 @@ class RRFReranker(Reranker):
|
||||
super().__init__(return_score)
|
||||
self.K = K
|
||||
|
||||
def __str__(self):
|
||||
return f"RRFReranker(K={self.K})"
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
|
||||
@@ -52,6 +52,9 @@ class VoyageAIReranker(Reranker):
|
||||
self.api_key = api_key
|
||||
self.truncation = truncation
|
||||
|
||||
def __str__(self):
|
||||
return f"VoyageAIReranker(model_name={self.model_name})"
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
voyageai = attempt_import_or_raise("voyageai")
|
||||
|
||||
@@ -163,9 +163,7 @@ async def test_explain_plan(table: AsyncTable):
|
||||
table.query().nearest_to_text("dog").nearest_to([0.1, 0.1]).explain_plan(True)
|
||||
)
|
||||
|
||||
assert "Vector Search Plan" in plan
|
||||
assert "KNNVectorDistance" in plan
|
||||
assert "FTS Search Plan" in plan
|
||||
assert "LanceRead" in plan
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user