From 36d05ea6413660614e0a1be2db476fbe811e14d9 Mon Sep 17 00:00:00 2001 From: Sayandip Dutta Date: Sat, 14 Sep 2024 01:02:30 +0530 Subject: [PATCH] fix: add appropriate QueryBuilder overloads to LanceTable.search (#1558) - Add overloads to Table.search, to preserve the return information of different types of QueryBuilder objects for LanceTable - Fix fts_column type annotation by including making it `Optional` resolves #1550 --------- Co-authored-by: sayandip-dutta Co-authored-by: Will Jones --- python/python/lancedb/table.py | 58 ++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 470874bf..46e3860d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -19,6 +19,7 @@ from typing import ( Optional, Tuple, Union, + overload, ) from urllib.parse import urlparse @@ -36,7 +37,16 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict -from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query +from .query import ( + AsyncQuery, + AsyncVectorQuery, + LanceEmptyQueryBuilder, + LanceFtsQueryBuilder, + LanceHybridQueryBuilder, + LanceQueryBuilder, + LanceVectorQueryBuilder, + Query, +) from .util import ( fs_from_uri, get_uri_scheme, @@ -57,6 +67,8 @@ if TYPE_CHECKING: pd = safe_import_pandas() pl = safe_import_polars() +QueryType = Literal["vector", "fts", "hybrid", "auto"] + def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: if _check_for_hugging_face(data): @@ -619,7 +631,7 @@ class Table(ABC): self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, - query_type: str = "auto", + query_type: QueryType = "auto", ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, ) -> LanceQueryBuilder: @@ -1499,11 +1511,51 @@ class LanceTable(Table): self.schema.metadata ) + @overload def search( self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, - query_type: str = "auto", + query_type: Literal["vector"] = "vector", + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> LanceVectorQueryBuilder: ... + + @overload + def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["fts"] = "fts", + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> LanceFtsQueryBuilder: ... + + @overload + def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["hybrid"] = "hybrid", + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> LanceHybridQueryBuilder: ... + + @overload + def search( + self, + query: None = None, + vector_column_name: Optional[str] = None, + query_type: QueryType = "auto", + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> LanceEmptyQueryBuilder: ... + + def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: QueryType = "auto", ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, ) -> LanceQueryBuilder: