diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 751135209..926b51527 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -2118,19 +2118,17 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): """ # noqa: E501 self._create_query_builders() - results = ["Vector Search Plan:"] - results.append( - self._table._explain_plan( - self._vector_query.to_query_object(), verbose=verbose - ) + reranker_label = str(self._reranker) if self._reranker else "No reranker" + vector_plan = self._table._explain_plan( + self._vector_query.to_query_object(), verbose=verbose ) - results.append("FTS Search Plan:") - results.append( - self._table._explain_plan( - self._fts_query.to_query_object(), verbose=verbose - ) + fts_plan = self._table._explain_plan( + self._fts_query.to_query_object(), verbose=verbose ) - return "\n".join(results) + # Indent sub-plans under the reranker + indented_vector = "\n".join(" " + line for line in vector_plan.splitlines()) + indented_fts = "\n".join(" " + line for line in fts_plan.splitlines()) + return f"{reranker_label}\n {indented_vector}\n {indented_fts}" def analyze_plan(self): """Execute the query and display with runtime metrics. @@ -3164,23 +3162,20 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase): ... plan = await table.query().nearest_to([1.0, 2.0]).nearest_to_text("hello").explain_plan(True) ... print(plan) >>> asyncio.run(doctest_example()) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE - Vector Search Plan: - ProjectionExec: expr=[vector@0 as vector, text@3 as text, _distance@2 as _distance] - Take: columns="vector, _rowid, _distance, (text)" - CoalesceBatchesExec: target_batch_size=1024 - GlobalLimitExec: skip=0, fetch=10 - FilterExec: _distance@2 IS NOT NULL - SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST], preserve_partitioning=[false] - KNNVectorDistance: metric=l2 - LanceRead: uri=..., projection=[vector], ... - - FTS Search Plan: - ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score] - Take: columns="_rowid, _score, (vector), (text)" - CoalesceBatchesExec: target_batch_size=1024 - GlobalLimitExec: skip=0, fetch=10 - MatchQuery: column=text, query=hello - + RRFReranker(K=60) + ProjectionExec: expr=[vector@0 as vector, text@3 as text, _distance@2 as _distance] + Take: columns="vector, _rowid, _distance, (text)" + CoalesceBatchesExec: target_batch_size=1024 + GlobalLimitExec: skip=0, fetch=10 + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST], preserve_partitioning=[false] + KNNVectorDistance: metric=l2 + LanceRead: uri=..., projection=[vector], ... + ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score] + Take: columns="_rowid, _score, (vector), (text)" + CoalesceBatchesExec: target_batch_size=1024 + GlobalLimitExec: skip=0, fetch=10 + MatchQuery: column=text, query=hello Parameters ---------- @@ -3192,12 +3187,12 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase): plan : str """ # noqa: E501 - results = ["Vector Search Plan:"] - results.append(await self._inner.to_vector_query().explain_plan(verbose)) - results.append("FTS Search Plan:") - results.append(await self._inner.to_fts_query().explain_plan(verbose)) - - return "\n".join(results) + vector_plan = await self._inner.to_vector_query().explain_plan(verbose) + fts_plan = await self._inner.to_fts_query().explain_plan(verbose) + # Indent sub-plans under the reranker + indented_vector = "\n".join(" " + line for line in vector_plan.splitlines()) + indented_fts = "\n".join(" " + line for line in fts_plan.splitlines()) + return f"{self._reranker}\n {indented_vector}\n {indented_fts}" async def analyze_plan(self): """ diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py index d615ed317..0b44569e3 100644 --- a/python/python/lancedb/rerankers/answerdotai.py +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -42,10 +42,18 @@ class AnswerdotaiRerankers(Reranker): rerankers = attempt_import_or_raise( "rerankers" ) # import here for faster ops later + self.model_name = model_name + self.model_type = model_type self.reranker = rerankers.Reranker( model_name=model_name, model_type=model_type, **kwargs ) + def __str__(self): + return ( + f"AnswerdotaiRerankers(model_type={self.model_type}, " + f"model_name={self.model_name})" + ) + def _rerank(self, result_set: pa.Table, query: str): result_set = self._handle_empty_results(result_set) if len(result_set) == 0: diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 0c546a772..7bc7ff105 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -40,6 +40,9 @@ class Reranker(ABC): if ARROW_VERSION.major <= 13: self._concat_tables_args = {"promote": True} + def __str__(self): + return self.__class__.__name__ + def rerank_vector( self, query: str, diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 107960671..efd752c22 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -44,6 +44,9 @@ class CohereReranker(Reranker): self.top_n = top_n self.api_key = api_key + def __str__(self): + return f"CohereReranker(model_name={self.model_name})" + @cached_property def _client(self): cohere = attempt_import_or_raise("cohere") diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index fd24eed11..30ad9d382 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -50,6 +50,9 @@ class CrossEncoderReranker(Reranker): if self.device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" + def __str__(self): + return f"CrossEncoderReranker(model_name={self.model_name})" + @cached_property def model(self): sbert = attempt_import_or_raise("sentence_transformers") diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 84ab21f4f..349a8eca5 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -45,6 +45,9 @@ class JinaReranker(Reranker): self.top_n = top_n self.api_key = api_key + def __str__(self): + return f"JinaReranker(model_name={self.model_name})" + @cached_property def _client(self): import requests diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index a22c7b8d9..9f1d645c9 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -38,6 +38,9 @@ class LinearCombinationReranker(Reranker): self.weight = weight self.fill = fill + def __str__(self): + return f"LinearCombinationReranker(weight={self.weight}, fill={self.fill})" + def rerank_hybrid( self, query: str, # noqa: F821 diff --git a/python/python/lancedb/rerankers/mrr.py b/python/python/lancedb/rerankers/mrr.py index c4860ff23..e2d5d1a97 100644 --- a/python/python/lancedb/rerankers/mrr.py +++ b/python/python/lancedb/rerankers/mrr.py @@ -54,6 +54,12 @@ class MRRReranker(Reranker): self.weight_vector = weight_vector self.weight_fts = weight_fts + def __str__(self): + return ( + f"MRRReranker(weight_vector={self.weight_vector}, " + f"weight_fts={self.weight_fts})" + ) + def rerank_hybrid( self, query: str, # noqa: F821 diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 7b181e806..4cbbb0822 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -43,6 +43,9 @@ class OpenaiReranker(Reranker): self.column = column self.api_key = api_key + def __str__(self): + return f"OpenaiReranker(model_name={self.model_name})" + def _rerank(self, result_set: pa.Table, query: str): result_set = self._handle_empty_results(result_set) if len(result_set) == 0: diff --git a/python/python/lancedb/rerankers/rrf.py b/python/python/lancedb/rerankers/rrf.py index cbb38caee..f512d7a8c 100644 --- a/python/python/lancedb/rerankers/rrf.py +++ b/python/python/lancedb/rerankers/rrf.py @@ -36,6 +36,9 @@ class RRFReranker(Reranker): super().__init__(return_score) self.K = K + def __str__(self): + return f"RRFReranker(K={self.K})" + def rerank_hybrid( self, query: str, # noqa: F821 diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py index 47fb45c04..0f41ba927 100644 --- a/python/python/lancedb/rerankers/voyageai.py +++ b/python/python/lancedb/rerankers/voyageai.py @@ -52,6 +52,9 @@ class VoyageAIReranker(Reranker): self.api_key = api_key self.truncation = truncation + def __str__(self): + return f"VoyageAIReranker(model_name={self.model_name})" + @cached_property def _client(self): voyageai = attempt_import_or_raise("voyageai") diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index 3957568a5..bb6f2befc 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -163,9 +163,7 @@ async def test_explain_plan(table: AsyncTable): table.query().nearest_to_text("dog").nearest_to([0.1, 0.1]).explain_plan(True) ) - assert "Vector Search Plan" in plan assert "KNNVectorDistance" in plan - assert "FTS Search Plan" in plan assert "LanceRead" in plan