diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 35201de9..c45346c1 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -125,12 +125,14 @@ class LanceQueryBuilder(ABC): @classmethod def create( - cls, - table: "Table", - query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], - query_type: str, - vector_column_name: str, - ordering_field_name: str = None, + cls, + table: "Table", + query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], + query_type: str, + vector_column_name: str, + ordering_field_name: Optional[str] = None, + fts_columns: Union[str, List[str]] = [], + fast_search: bool = False, ) -> LanceQueryBuilder: """ Create a query builder based on the given query and query type. @@ -147,13 +149,18 @@ class LanceQueryBuilder(ABC): If "auto", the query type is inferred based on the query. vector_column_name: str The name of the vector column to use for vector search. + fast_search: bool + Skip flat search of unindexed data. """ - if query is None: - return LanceEmptyQueryBuilder(table) - + # Check hybrid search first as it supports empty query pattern if query_type == "hybrid": # hybrid fts and vector query - return LanceHybridQueryBuilder(table, query, vector_column_name) + return LanceHybridQueryBuilder( + table, query, vector_column_name, fts_columns=fts_columns + ) + + if query is None: + return LanceEmptyQueryBuilder(table) # remember the string query for reranking purpose str_query = query if isinstance(query, str) else None @@ -165,12 +172,17 @@ class LanceQueryBuilder(ABC): ) if query_type == "hybrid": - return LanceHybridQueryBuilder(table, query, vector_column_name) + return LanceHybridQueryBuilder( + table, query, vector_column_name, fts_columns=fts_columns + ) if isinstance(query, str): # fts return LanceFtsQueryBuilder( - table, query, ordering_field_name=ordering_field_name + table, + query, + ordering_field_name=ordering_field_name, + fts_columns=fts_columns, ) if isinstance(query, list): @@ -180,7 +192,9 @@ class LanceQueryBuilder(ABC): else: raise TypeError(f"Unsupported query type: {type(query)}") - return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) + return LanceVectorQueryBuilder( + table, query, vector_column_name, str_query, fast_search + ) @classmethod def _resolve_query(cls, table, query, query_type, vector_column_name): @@ -196,8 +210,6 @@ class LanceQueryBuilder(ABC): elif query_type == "auto": if isinstance(query, (list, np.ndarray)): return query, "vector" - if isinstance(query, tuple): - return query, "hybrid" else: conf = table.embedding_functions.get(vector_column_name) if conf is not None: @@ -224,9 +236,14 @@ class LanceQueryBuilder(ABC): def __init__(self, table: "Table"): self._table = table self._limit = 10 + self._offset = 0 self._columns = None self._where = None + self._prefilter = False self._with_row_id = False + self._vector = None + self._text = None + self._ef = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -337,11 +354,13 @@ class LanceQueryBuilder(ABC): ---------- limit: int The maximum number of results to return. - By default the query is limited to the first 10. - Call this method and pass 0, a negative value, - or None to remove the limit. - *WARNING* if you have a large dataset, removing - the limit can potentially result in reading a + The default query limit is 10 results. + For ANN/KNN queries, you must specify a limit. + Entering 0, a negative number, or None will reset + the limit to the default value of 10. + *WARNING* if you have a large dataset, setting + the limit to a large number, e.g. the table size, + can potentially result in reading a large amount of data into memory and cause out of memory issues. @@ -351,11 +370,33 @@ class LanceQueryBuilder(ABC): The LanceQueryBuilder object. """ if limit is None or limit <= 0: - self._limit = None + if isinstance(self, LanceVectorQueryBuilder): + raise ValueError("Limit is required for ANN/KNN queries") + else: + self._limit = None else: self._limit = limit return self + def offset(self, offset: int) -> LanceQueryBuilder: + """Set the offset for the results. + + Parameters + ---------- + offset: int + The offset to start fetching results from. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + if offset is None or offset <= 0: + self._offset = 0 + else: + self._offset = offset + return self + def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder: """Set the columns to return. @@ -417,6 +458,80 @@ class LanceQueryBuilder(ABC): self._with_row_id = with_row_id return self + def explain_plan(self, verbose: Optional[bool] = False) -> str: + """Return the execution plan for this query. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", [{"vector": [99, 99]}]) + >>> query = [100, 100] + >>> plan = table.search(query).explain_plan(True) + >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance] + GlobalLimitExec: skip=0, fetch=10 + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] + KNNVectorDistance: metric=l2 + LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + + Parameters + ---------- + verbose : bool, default False + Use a verbose output format. + + Returns + ------- + plan : str + """ # noqa: E501 + ds = self._table.to_lance() + return ds.scanner( + nearest={ + "column": self._vector_column, + "q": self._query, + "k": self._limit, + "metric": self._metric, + "nprobes": self._nprobes, + "refine_factor": self._refine_factor, + }, + prefilter=self._prefilter, + filter=self._str_query, + limit=self._limit, + with_row_id=self._with_row_id, + offset=self._offset, + ).explain_plan(verbose) + + def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder: + """Set the vector to search for. + + Parameters + ---------- + vector: np.ndarray or list + The vector to search for. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + raise NotImplementedError + + def text(self, text: str) -> LanceQueryBuilder: + """Set the text to search for. + + Parameters + ---------- + text: str + The text to search for. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + raise NotImplementedError + class LanceVectorQueryBuilder(LanceQueryBuilder): """ @@ -440,11 +555,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): """ def __init__( - self, - table: "Table", - query: Union[np.ndarray, list, "PIL.Image.Image"], - vector_column: str, - str_query: Optional[str] = None, + self, + table: "Table", + query: Union[np.ndarray, list, "PIL.Image.Image"], + vector_column: str, + str_query: Optional[str] = None, + fast_search: bool = False, ): super().__init__(table) self._query = query @@ -455,13 +571,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._prefilter = False self._reranker = None self._str_query = str_query + self._fast_search = fast_search - def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: + def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. Parameters ---------- - metric: "L2" or "cosine" + metric: "L2" or "cosine" or "dot" The distance metric to use. By default "L2" is used. Returns @@ -469,7 +586,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceQueryBuilder object. """ - self._metric = metric + self._metric = metric.lower() return self def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: @@ -494,6 +611,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def ef(self, ef: int) -> LanceVectorQueryBuilder: + """Set the number of candidates to consider during search. + + Higher values will yield better recall (more likely to find vectors if + they exist) at the expense of latency. + + This only applies to the HNSW-related index. + The default value is 1.5 * limit. + + Parameters + ---------- + ef: int + The number of candidates to consider during search. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._ef = ef + return self + def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder: """Set the refine factor to use, increasing the number of vectors sampled. @@ -554,15 +693,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): refine_factor=self._refine_factor, vector_column=self._vector_column, with_row_id=self._with_row_id, + offset=self._offset, + fast_search=self._fast_search, + ef=self._ef, ) result_set = self._table._execute_query(query, batch_size) - if self._reranker is not None: - rs_table = result_set.read_all() - result_set = self._reranker.rerank_vector(self._str_query, rs_table) - # convert result_set back to RecordBatchReader - result_set = pa.RecordBatchReader.from_batches( - result_set.schema, result_set.to_batches() - ) return result_set @@ -591,7 +726,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): return self def rerank( - self, reranker: Reranker, query_string: Optional[str] = None + self, reranker: Reranker, query_string: Optional[str] = None ) -> LanceVectorQueryBuilder: """Rerank the results using the specified reranker. @@ -756,12 +891,34 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: - ds = self._table.to_lance() - return ds.to_table( + return self.to_batches().read_all() + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + query = Query( columns=self._columns, filter=self._where, - limit=self._limit, + k=self._limit or 10, + with_row_id=self._with_row_id, + vector=[], + # not actually respected in remote query + offset=self._offset or 0, ) + return self._table._execute_query(query) + + def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: + """Rerank the results using the specified reranker. + + Parameters + ---------- + reranker: Reranker + The reranker to use. + + Returns + ------- + LanceEmptyQueryBuilder + The LanceQueryBuilder object. + """ + raise NotImplementedError("Reranking is not yet supported.") class LanceHybridQueryBuilder(LanceQueryBuilder): diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index c6a0078b..d6fadea0 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -22,6 +22,7 @@ from lance import json_to_schema from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.merge import LanceMergeInsertBuilder +from lancedb.query import LanceQueryBuilder from ..query import LanceVectorQueryBuilder from ..table import Query, Table, _sanitize_data @@ -228,10 +229,21 @@ class RemoteTable(Table): content_type=ARROW_STREAM_CONTENT_TYPE, ) + def query( + self, + query: Union[VEC, str] = None, + query_type: str = "vector", + vector_column_name: Optional[str] = None, + fast_search: bool = False, + ) -> LanceVectorQueryBuilder: + return self.search(query, query_type, vector_column_name, fast_search) + def search( self, - query: Union[VEC, str], + query: Union[VEC, str] = None, + query_type: str = "vector", vector_column_name: Optional[str] = None, + fast_search: bool = False, ) -> LanceVectorQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -278,6 +290,11 @@ class RemoteTable(Table): - If the table has multiple vector columns then the *vector_column_name* needs to be specified. Otherwise, an error is raised. + fast_search: bool, optional + Skip a flat search of unindexed data. This may improve + search performance but search results will not include unindexed data. + + - *default False*. Returns ------- LanceQueryBuilder @@ -293,7 +310,14 @@ class RemoteTable(Table): """ if vector_column_name is None: vector_column_name = inf_vector_column_query(self.schema) - return LanceVectorQueryBuilder(self, query, vector_column_name) + + return LanceQueryBuilder.create( + self, + query, + query_type, + vector_column_name=vector_column_name, + fast_search=fast_search, + ) def _execute_query( self, query: Query, batch_size: Optional[int] = None diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index a775d5c7..92b82b85 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -39,3 +39,12 @@ def test_remote_db(): table = conn["test"] table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) table.search([1.0, 2.0]).to_pandas() + + +def test_empty_query_with_filter(): + conn = lancedb.connect("db://client-will-be-injected", api_key="fake") + setattr(conn, "_client", FakeLanceDBClient()) + + table = conn["test"] + table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) + print(table.query().select(["vector"]).where("foo == bar").to_arrow())