mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
fix: add appropriate QueryBuilder overloads to LanceTable.search (#1558)
- Add overloads to Table.search, to preserve the return information of different types of QueryBuilder objects for LanceTable - Fix fts_column type annotation by including making it `Optional` resolves #1550 --------- Co-authored-by: sayandip-dutta <sayandip.dutta@nevaehtech.com> Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -36,7 +37,16 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
||||
from .query import (
|
||||
AsyncQuery,
|
||||
AsyncVectorQuery,
|
||||
LanceEmptyQueryBuilder,
|
||||
LanceFtsQueryBuilder,
|
||||
LanceHybridQueryBuilder,
|
||||
LanceQueryBuilder,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_scheme,
|
||||
@@ -57,6 +67,8 @@ if TYPE_CHECKING:
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
if _check_for_hugging_face(data):
|
||||
@@ -619,7 +631,7 @@ class Table(ABC):
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -1499,11 +1511,51 @@ class LanceTable(Table):
|
||||
self.schema.metadata
|
||||
)
|
||||
|
||||
@overload
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
query_type: Literal["vector"] = "vector",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceVectorQueryBuilder: ...
|
||||
|
||||
@overload
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["fts"] = "fts",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceFtsQueryBuilder: ...
|
||||
|
||||
@overload
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["hybrid"] = "hybrid",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceHybridQueryBuilder: ...
|
||||
|
||||
@overload
|
||||
def search(
|
||||
self,
|
||||
query: None = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceEmptyQueryBuilder: ...
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
|
||||
Reference in New Issue
Block a user