diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index c4d0ff0b..34e2fe7c 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -26,7 +26,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry from ..query import LanceVectorQueryBuilder, LanceQueryBuilder from ..table import Query, Table, _sanitize_data -from ..util import inf_vector_column_query, value_to_sql +from ..util import value_to_sql, infer_vector_column_name from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE from .db import RemoteDBConnection @@ -266,7 +266,7 @@ class RemoteTable(Table): def search( self, - query: Union[VEC, str], + query: Union[VEC, str] = None, vector_column_name: Optional[str] = None, query_type="auto", fts_columns: Optional[Union[str, List[str]]] = None, @@ -305,8 +305,6 @@ class RemoteTable(Table): - *default None*. Acceptable types are: list, np.ndarray, PIL.Image.Image - - If None then the select/where/limit clauses are applied to filter - the table vector_column_name: str, optional The name of the vector column to search. @@ -329,11 +327,15 @@ class RemoteTable(Table): - and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if vector_column_name is None and query is not None and query_type != "fts": - try: - vector_column_name = inf_vector_column_query(self.schema) - except Exception as e: - raise e + # empty query builder is not supported in saas, raise error + if query is None and query_type != "hybrid": + raise ValueError("Empty query is not supported") + vector_column_name = infer_vector_column_name( + schema=self.schema, + query_type=query_type, + query=query, + vector_column_name=vector_column_name, + ) return LanceQueryBuilder.create( self, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index d0bd1f38..11d923f5 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -50,7 +50,7 @@ from .query import ( from .util import ( fs_from_uri, get_uri_scheme, - inf_vector_column_query, + infer_vector_column_name, join_uri, safe_import_pandas, safe_import_polars, @@ -1630,13 +1630,12 @@ class LanceTable(Table): and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if ( - vector_column_name is None and query is not None and query_type != "fts" - ) or (vector_column_name is None and query_type == "hybrid"): - try: - vector_column_name = inf_vector_column_query(self.schema) - except Exception as e: - raise e + vector_column_name = infer_vector_column_name( + schema=self.schema, + query_type=query_type, + query=query, + vector_column_name=vector_column_name, + ) return LanceQueryBuilder.create( self, diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 6392e40e..3e12efd3 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -9,7 +9,7 @@ import pathlib import warnings from datetime import date, datetime from functools import singledispatch -from typing import Tuple, Union +from typing import Tuple, Union, Optional, Any from urllib.parse import urlparse import numpy as np @@ -212,6 +212,23 @@ def inf_vector_column_query(schema: pa.Schema) -> str: return vector_col_name +def infer_vector_column_name( + schema: pa.Schema, + query_type: str, + query: Optional[Any], # inferred later in query builder + vector_column_name: Optional[str], +): + if (vector_column_name is None and query is not None and query_type != "fts") or ( + vector_column_name is None and query_type == "hybrid" + ): + try: + vector_column_name = inf_vector_column_query(schema) + except Exception as e: + raise e + + return vector_column_name + + @singledispatch def value_to_sql(value): raise NotImplementedError("SQL conversion is not implemented for this type")