fix: add appropriate QueryBuilder overloads to LanceTable.search (#1558)

- Add overloads to Table.search, to preserve the return information
of different types of QueryBuilder objects for LanceTable
- Fix fts_column type annotation by including making it `Optional`

resolves #1550

---------

Co-authored-by: sayandip-dutta <sayandip.dutta@nevaehtech.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Sayandip Dutta
2024-09-14 01:02:30 +05:30
committed by GitHub
parent 7ed86cadfb
commit 36d05ea641

View File

@@ -19,6 +19,7 @@ from typing import (
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import urlparse
@@ -36,7 +37,16 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
from .query import (
AsyncQuery,
AsyncVectorQuery,
LanceEmptyQueryBuilder,
LanceFtsQueryBuilder,
LanceHybridQueryBuilder,
LanceQueryBuilder,
LanceVectorQueryBuilder,
Query,
)
from .util import (
fs_from_uri,
get_uri_scheme,
@@ -57,6 +67,8 @@ if TYPE_CHECKING:
pd = safe_import_pandas()
pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data):
@@ -619,7 +631,7 @@ class Table(ABC):
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
@@ -1499,11 +1511,51 @@ class LanceTable(Table):
self.schema.metadata
)
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
query_type: Literal["vector"] = "vector",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceVectorQueryBuilder: ...
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["fts"] = "fts",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceFtsQueryBuilder: ...
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = "hybrid",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceHybridQueryBuilder: ...
@overload
def search(
self,
query: None = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceEmptyQueryBuilder: ...
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder: