feat(python): allow explicit hybrid search query pattern in SaaS (feat parity) (#1698)

-  fixes https://github.com/lancedb/lancedb/issues/1697.
- unifies vector column inference logic for remote and local table to
prevent future disparities.
- Updates docstring in RemoteTable to specify empty queries are not
supported
This commit is contained in:
Ayush Chaurasia
2024-09-25 21:04:00 +05:30
committed by GitHub
parent f00b21c98c
commit 2f2721e242
3 changed files with 36 additions and 18 deletions

View File

@@ -26,7 +26,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from ..util import inf_vector_column_query, value_to_sql from ..util import value_to_sql, infer_vector_column_name
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection from .db import RemoteDBConnection
@@ -266,7 +266,7 @@ class RemoteTable(Table):
def search( def search(
self, self,
query: Union[VEC, str], query: Union[VEC, str] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type="auto", query_type="auto",
fts_columns: Optional[Union[str, List[str]]] = None, fts_columns: Optional[Union[str, List[str]]] = None,
@@ -305,8 +305,6 @@ class RemoteTable(Table):
- *default None*. - *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional vector_column_name: str, optional
The name of the vector column to search. The name of the vector column to search.
@@ -329,11 +327,15 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query - and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if vector_column_name is None and query is not None and query_type != "fts": # empty query builder is not supported in saas, raise error
try: if query is None and query_type != "hybrid":
vector_column_name = inf_vector_column_query(self.schema) raise ValueError("Empty query is not supported")
except Exception as e: vector_column_name = infer_vector_column_name(
raise e schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
self, self,

View File

@@ -50,7 +50,7 @@ from .query import (
from .util import ( from .util import (
fs_from_uri, fs_from_uri,
get_uri_scheme, get_uri_scheme,
inf_vector_column_query, infer_vector_column_name,
join_uri, join_uri,
safe_import_pandas, safe_import_pandas,
safe_import_polars, safe_import_polars,
@@ -1630,13 +1630,12 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if ( vector_column_name = infer_vector_column_name(
vector_column_name is None and query is not None and query_type != "fts" schema=self.schema,
) or (vector_column_name is None and query_type == "hybrid"): query_type=query_type,
try: query=query,
vector_column_name = inf_vector_column_query(self.schema) vector_column_name=vector_column_name,
except Exception as e: )
raise e
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
self, self,

View File

@@ -9,7 +9,7 @@ import pathlib
import warnings import warnings
from datetime import date, datetime from datetime import date, datetime
from functools import singledispatch from functools import singledispatch
from typing import Tuple, Union from typing import Tuple, Union, Optional, Any
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@@ -212,6 +212,23 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
return vector_col_name return vector_col_name
def infer_vector_column_name(
schema: pa.Schema,
query_type: str,
query: Optional[Any], # inferred later in query builder
vector_column_name: Optional[str],
):
if (vector_column_name is None and query is not None and query_type != "fts") or (
vector_column_name is None and query_type == "hybrid"
):
try:
vector_column_name = inf_vector_column_query(schema)
except Exception as e:
raise e
return vector_column_name
@singledispatch @singledispatch
def value_to_sql(value): def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type") raise NotImplementedError("SQL conversion is not implemented for this type")