From 94032544424408fde85bc2be423e69e00359d65e Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 21 Mar 2025 13:01:51 -0700 Subject: [PATCH] feat: add to_query_object method (#2239) This PR adds a `to_query_object` method to the various query builders (except not hybrid queries yet). This makes it possible to inspect the query that is built. In addition this PR does some normalization between the sync and async query paths. A few custom defaults were removed in favor of None (with the default getting set once, in rust). Also, the synchronous to_batches method will now actually stream results Also, the remote API now defaults to prefiltering --- python/python/lancedb/_lancedb.pyi | 30 ++ python/python/lancedb/query.py | 430 ++++++++++++++++++-------- python/python/lancedb/remote/table.py | 14 +- python/python/lancedb/table.py | 96 +++--- python/python/tests/test_query.py | 257 ++++++++++++++- python/python/tests/test_remote_db.py | 6 +- python/src/query.rs | 186 +++++++++++ rust/lancedb/src/rerankers.rs | 25 +- 8 files changed, 867 insertions(+), 177 deletions(-) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index cab8486b..38a5f14f 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -94,6 +94,7 @@ class Query: def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... + def to_query_request(self) -> PyQueryRequest: ... class FTSQuery: def where(self, filter: str): ... @@ -108,6 +109,7 @@ class FTSQuery: def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... async def explain_plan(self) -> str: ... + def to_query_request(self) -> PyQueryRequest: ... class VectorQuery: async def execute(self) -> RecordBatchStream: ... @@ -123,6 +125,7 @@ class VectorQuery: def nprobes(self, nprobes: int): ... def bypass_vector_index(self): ... def nearest_to_text(self, query: dict) -> HybridQuery: ... + def to_query_request(self) -> PyQueryRequest: ... class HybridQuery: def where(self, filter: str): ... @@ -140,6 +143,33 @@ class HybridQuery: def to_fts_query(self) -> FTSQuery: ... def get_limit(self) -> int: ... def get_with_row_id(self) -> bool: ... + def to_query_request(self) -> PyQueryRequest: ... + +class PyFullTextSearchQuery: + columns: Optional[List[str]] + query: str + limit: Optional[int] + wand_factor: Optional[float] + +class PyQueryRequest: + limit: Optional[int] + offset: Optional[int] + filter: Optional[Union[str, bytes]] + full_text_search: Optional[PyFullTextSearchQuery] + select: Optional[Union[str, List[str]]] + fast_search: Optional[bool] + with_row_id: Optional[bool] + column: Optional[str] + query_vector: Optional[List[pa.Array]] + nprobes: Optional[int] + lower_bound: Optional[float] + upper_bound: Optional[float] + ef: Optional[int] + refine_factor: Optional[int] + distance_type: Optional[str] + bypass_vector_index: Optional[bool] + postfilter: Optional[bool] + norm: Optional[str] class CompactionStats: fragments_removed: int diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index fdbc8751..7c7447be 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -14,6 +14,7 @@ from typing import ( Tuple, Type, Union, + Any, ) import asyncio @@ -32,6 +33,8 @@ from .rerankers.rrf import RRFReranker from .rerankers.util import check_reranker_result from .util import flatten_columns +from typing_extensions import Annotated + if TYPE_CHECKING: import sys import PIL @@ -41,6 +44,7 @@ if TYPE_CHECKING: from ._lancedb import FTSQuery as LanceFTSQuery from ._lancedb import HybridQuery as LanceHybridQuery from ._lancedb import VectorQuery as LanceVectorQuery + from ._lancedb import PyQueryRequest from .common import VEC from .pydantic import LanceModel from .table import Table @@ -51,33 +55,116 @@ if TYPE_CHECKING: from typing_extensions import Self -class Query(pydantic.BaseModel): - """The LanceDB Query +# Pydantic validation function for vector queries +def ensure_vector_query( + val: Any, +) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]: + if isinstance(val, list): + if len(val) == 0: + return ValueError("Vector query must be a non-empty list") + sample = val[0] + else: + if isinstance(val, float): + raise ValueError( + "Vector query must be a list of floats or a list of lists of floats" + ) + sample = val + if isinstance(sample, pa.Array): + # val is array or list of array + return val + if isinstance(sample, list): + if len(sample) == 0: + return ValueError("Vector query must be a non-empty list") + if isinstance(sample[0], float): + # val is list of list of floats + return val + if isinstance(sample, float): + # val is a list of floats + return val + + +class FullTextSearchQuery(pydantic.BaseModel): + """A LanceDB Full Text Search Query Attributes ---------- - vector : List[float] - the vector to search for - filter : Optional[str] - sql filter to refine the query with, optional - prefilter : bool - if True then apply the filter before vector search - k : int - top k results to return - metric : str - the distance metric between a pair of vectors, + columns: List[str] + The columns to search - can support l2 (default), Cosine and Dot. - [metric definitions][search] - columns : Optional[List[str]] + If None, then the table should select the column automatically. + query: str + The query to search for + limit: Optional[int] = None + The limit on the number of results to return + wand_factor: Optional[float] = None + The wand factor to use for the search + """ + + columns: Optional[List[str]] = None + query: str + limit: Optional[int] = None + wand_factor: Optional[float] = None + + +class Query(pydantic.BaseModel): + """A LanceDB Query + + Queries are constructed by the `Table.search` and `Table.query` methods. This + class is a python representation of the query. Normally you will not need to + interact with this class directly. You can build up a query and execute it using + collection methods such as `to_batches()`, `to_arrow()`, `to_pandas()`, etc. + + However, you can use the `to_query()` method to get the underlying query object. + This can be useful for serializing a query or using it in a different context. + + Attributes + ---------- + filter : Optional[str] + sql filter to refine the query with + limit : Optional[int] + The limit on the number of results to return. If this is a vector or FTS query, + then this is required. If this is a plain SQL query, then this is optional. + offset: Optional[int] + The offset to start fetching results from + + This is ignored for vector / FTS search (will be None). + columns : Optional[Union[List[str], Dict[str, str]]] which columns to return in the results - nprobes : int - The number of probes used - optional + + This can be a list of column names or a dictionary. If it is a dictionary, + then the keys are the column names and the values are sql expressions to + use to calculate the result. + + If this is None then all columns are returned. This can be expensive. + with_row_id : Optional[bool] + if True then include the row id in the results + vector : Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]] + the vector to search for, if this a vector search or hybrid search. It will + be None for full text search and plain SQL filtering. + vector_column : Optional[str] + the name of the vector column to use for vector search + + If this is None then a default vector column will be used. + distance_type : Optional[str] + the distance type to use for vector search + + This can be l2 (default), cosine and dot. See [metric definitions][search] for + more details. + + If this is not a vector search this will be None. + postfilter : bool + if True then apply the filter after vector / FTS search. This is ignored for + plain SQL filtering. + nprobes : Optional[int] + The number of IVF partitions to search. If this is None then a default + number of partitions will be used. - A higher number makes search more accurate but also slower. - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. + + Will be None if this is not a vector search. refine_factor : Optional[int] Refine the results by reading extra elements and re-ranking them in memory. @@ -85,58 +172,130 @@ class Query(pydantic.BaseModel): - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. - offset: int - The offset to start fetching results from - fast_search: bool + + Will be None if this is not a vector search. + lower_bound : Optional[float] + The lower bound for distance search + + Only results with a distance greater than or equal to this value + will be returned. + + This will only be set on vector search. + upper_bound : Optional[float] + The upper bound for distance search + + Only results with a distance less than or equal to this value + will be returned. + + This will only be set on vector search. + ef : Optional[int] + The size of the nearest neighbor list maintained during HNSW search + + This will only be set on vector search. + full_text_query : Optional[Union[str, dict]] + The full text search query + + This can be a string or a dictionary. A dictionary will be used to search + multiple columns. The keys are the column names and the values are the + search queries. + + This will only be set on FTS or hybrid queries. + fast_search: Optional[bool] Skip a flat search of unindexed data. This will improve search performance but search results will not include unindexed data. - - *default False*. + The default is False """ + # The name of the vector column to use for vector search. vector_column: Optional[str] = None # vector to search for - vector: Union[List[float], List[List[float]]] + # + # Note: today this will be floats on the sync path and pa.Array on the async + # path though in the future we should unify this to pa.Array everywhere + vector: Annotated[ + Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]], + ensure_vector_query, + ] = None # sql filter to refine the query with filter: Optional[str] = None - # if True then apply the filter before vector search - prefilter: bool = False + # if True then apply the filter after vector search + postfilter: Optional[bool] = None # full text search query - full_text_query: Optional[Union[str, dict]] = None + full_text_query: Optional[FullTextSearchQuery] = None # top k results to return - k: Optional[int] = None + limit: Optional[int] = None - # # metrics - metric: str = "l2" + # distance type to use for vector search + distance_type: Optional[str] = None # which columns to return in the results columns: Optional[Union[List[str], Dict[str, str]]] = None - # optional query parameters for tuning the results, - # e.g. `{"nprobes": "10", "refine_factor": "10"}` - nprobes: int = 10 + # number of IVF partitions to search + nprobes: Optional[int] = None + # lower bound for distance search lower_bound: Optional[float] = None + + # upper bound for distance search upper_bound: Optional[float] = None - # Refine factor. + # multiplier for the number of results to inspect for reranking refine_factor: Optional[int] = None - with_row_id: bool = False + # if true, include the row id in the results + with_row_id: Optional[bool] = None - offset: int = 0 + # offset to start fetching results from + offset: Optional[int] = None - fast_search: bool = False + # if true, will only search the indexed data + fast_search: Optional[bool] = None + # size of the nearest neighbor list maintained during HNSW search ef: Optional[int] = None - # Default is true. Set to false to enforce a brute force search. - use_index: bool = True + # Bypass the vector index and use a brute force search + bypass_vector_index: Optional[bool] = None + + @classmethod + def from_inner(cls, req: PyQueryRequest) -> Self: + query = cls() + query.limit = req.limit + query.offset = req.offset + query.filter = req.filter + query.full_text_query = req.full_text_search + query.columns = req.select + query.with_row_id = req.with_row_id + query.vector_column = req.column + query.vector = req.query_vector + query.distance_type = req.distance_type + query.nprobes = req.nprobes + query.lower_bound = req.lower_bound + query.upper_bound = req.upper_bound + query.ef = req.ef + query.refine_factor = req.refine_factor + query.bypass_vector_index = req.bypass_vector_index + query.postfilter = req.postfilter + if req.full_text_search is not None: + query.full_text_query = FullTextSearchQuery( + columns=req.full_text_search.columns, + query=req.full_text_search.query, + limit=req.full_text_search.limit, + wand_factor=req.full_text_search.wand_factor, + ) + return query + + class Config: + # This tells pydantic to allow custom types (needed for the `vector` query since + # pa.Array wouln't be allowed otherwise) + arbitrary_types_allowed = True class LanceQueryBuilder(ABC): @@ -152,8 +311,8 @@ class LanceQueryBuilder(ABC): query_type: str, vector_column_name: str, ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], - fast_search: bool = False, + fts_columns: Optional[Union[str, List[str]]] = None, + fast_search: bool = None, ) -> Self: """ Create a query builder based on the given query and query type. @@ -257,15 +416,15 @@ class LanceQueryBuilder(ABC): def __init__(self, table: "Table"): self._table = table self._limit = None - self._offset = 0 + self._offset = None self._columns = None self._where = None - self._prefilter = True - self._with_row_id = False + self._postfilter = None + self._with_row_id = None self._vector = None self._text = None self._ef = None - self._use_index = True + self._bypass_vector_index = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -315,7 +474,7 @@ class LanceQueryBuilder(ABC): raise NotImplementedError @abstractmethod - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table: + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: """ Execute the query and return the results as a pyarrow [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) @@ -451,7 +610,7 @@ class LanceQueryBuilder(ABC): The LanceQueryBuilder object. """ self._where = where - self._prefilter = prefilter + self._postfilter = not prefilter return self def with_row_id(self, with_row_id: bool) -> Self: @@ -497,23 +656,7 @@ class LanceQueryBuilder(ABC): ------- plan : str """ # noqa: E501 - ds = self._table.to_lance() - return ds.scanner( - nearest={ - "column": self._vector_column, - "q": self._query, - "k": self._limit, - "metric": self._distance_type, - "nprobes": self._nprobes, - "refine_factor": self._refine_factor, - "use_index": self._use_index, - }, - prefilter=self._prefilter, - filter=self._str_query, - limit=self._limit, - with_row_id=self._with_row_id, - offset=self._offset, - ).explain_plan(verbose) + return self._table._explain_plan(self.to_query_object()) def vector(self, vector: Union[np.ndarray, list]) -> Self: """Set the vector to search for. @@ -561,6 +704,17 @@ class LanceQueryBuilder(ABC): """ raise NotImplementedError + @abstractmethod + def to_query_object(self) -> Query: + """Return a serializable representation of the query + + Returns + ------- + Query + The serializable representation of the query + """ + raise NotImplementedError + class LanceVectorQueryBuilder(LanceQueryBuilder): """ @@ -590,19 +744,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str, str_query: Optional[str] = None, - fast_search: bool = False, + fast_search: bool = None, ): super().__init__(table) - if self._limit is None: - self._limit = 10 self._query = query - self._distance_type = "l2" - self._nprobes = 20 + self._distance_type = None + self._nprobes = None self._lower_bound = None self._upper_bound = None self._refine_factor = None self._vector_column = vector_column - self._prefilter = False + self._postfilter = None self._reranker = None self._str_query = str_query self._fast_search = fast_search @@ -752,6 +904,34 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): """ return self.to_batches().read_all() + def to_query_object(self) -> Query: + """ + Build a Query object + + This can be used to serialize a query + """ + vector = self._query if isinstance(self._query, list) else self._query.tolist() + if isinstance(vector[0], np.ndarray): + vector = [v.tolist() for v in vector] + return Query( + vector=vector, + filter=self._where, + postfilter=self._postfilter, + limit=self._limit, + distance_type=self._distance_type, + columns=self._columns, + nprobes=self._nprobes, + lower_bound=self._lower_bound, + upper_bound=self._upper_bound, + refine_factor=self._refine_factor, + vector_column=self._vector_column, + with_row_id=self._with_row_id, + offset=self._offset, + fast_search=self._fast_search, + ef=self._ef, + bypass_vector_index=self._bypass_vector_index, + ) + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: """ Execute the query and return the result as a RecordBatchReader object. @@ -768,24 +948,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector = self._query if isinstance(self._query, list) else self._query.tolist() if isinstance(vector[0], np.ndarray): vector = [v.tolist() for v in vector] - query = Query( - vector=vector, - filter=self._where, - prefilter=self._prefilter, - k=self._limit, - metric=self._distance_type, - columns=self._columns, - nprobes=self._nprobes, - lower_bound=self._lower_bound, - upper_bound=self._upper_bound, - refine_factor=self._refine_factor, - vector_column=self._vector_column, - with_row_id=self._with_row_id, - offset=self._offset, - fast_search=self._fast_search, - ef=self._ef, - use_index=self._use_index, - ) + query = self.to_query_object() result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: rs_table = result_set.read_all() @@ -798,7 +961,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): return result_set - def where(self, where: str, prefilter: bool = True) -> LanceVectorQueryBuilder: + def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder: """Set the where clause. Parameters @@ -810,8 +973,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): prefilter: bool, default True If True, apply the filter before vector search, otherwise the filter is applied on the result of vector search. - This feature is **EXPERIMENTAL** and may be removed and modified - without warning in the future. Returns ------- @@ -819,7 +980,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): The LanceQueryBuilder object. """ self._where = where - self._prefilter = prefilter + if prefilter is not None: + self._postfilter = not prefilter return self def rerank( @@ -873,7 +1035,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceVectorQueryBuilder object. """ - self._use_index = False + self._bypass_vector_index = True return self @@ -885,11 +1047,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): table: "Table", query: str, ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], + fts_columns: Optional[Union[str, List[str]]] = None, ): super().__init__(table) - if self._limit is None: - self._limit = 10 self._query = query self._phrase_query = False self.ordering_field_name = ordering_field_name @@ -915,6 +1075,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): self._phrase_query = phrase_query return self + def to_query_object(self) -> Query: + return Query( + columns=self._columns, + filter=self._where, + limit=self._limit, + postfilter=self._postfilter, + with_row_id=self._with_row_id, + full_text_query=FullTextSearchQuery( + query=self._query, columns=self._fts_columns + ), + offset=self._offset, + ) + def to_arrow(self) -> pa.Table: path, fs, exist = self._table._get_fts_index_path() if exist: @@ -926,19 +1099,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): "Phrase query is not yet supported in Lance FTS. " "Use tantivy-based index instead for now." ) - query = Query( - columns=self._columns, - filter=self._where, - k=self._limit, - prefilter=self._prefilter, - with_row_id=self._with_row_id, - full_text_query={ - "query": query, - "columns": self._fts_columns, - }, - vector=[], - offset=self._offset, - ) + query = self.to_query_object() results = self._table._execute_query(query) results = results.read_all() if self._reranker is not None: @@ -983,8 +1144,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._phrase_query: query = query.replace('"', "'") query = f'"{query}"' + limit = self._limit if self._limit is not None else 10 row_ids, scores = search_index( - index, query, self._limit, ordering_field=self.ordering_field_name + index, query, limit, ordering_field=self.ordering_field_name ) if len(row_ids) == 0: empty_schema = pa.schema([pa.field("_score", pa.float32())]) @@ -1053,17 +1215,18 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: return self.to_batches().read_all() - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: - query = Query( + def to_query_object(self) -> Query: + return Query( columns=self._columns, filter=self._where, - k=self._limit, + limit=self._limit, with_row_id=self._with_row_id, - vector=[], - # not actually respected in remote query - offset=self._offset or 0, + offset=self._offset, ) - return self._table._execute_query(query) + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + query = self.to_query_object() + return self._table._execute_query(query, batch_size) def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: """Rerank the results using the specified reranker. @@ -1098,18 +1261,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): table: "Table", query: Optional[str] = None, vector_column: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], + fts_columns: Optional[Union[str, List[str]]] = None, ): super().__init__(table) self._query = query self._vector_column = vector_column self._fts_columns = fts_columns - self._norm = "score" - self._reranker = RRFReranker() + self._norm = None + self._reranker = None self._nprobes = None self._refine_factor = None self._distance_type = None - self._phrase_query = False + self._phrase_query = None def _validate_query(self, query, vector=None, text=None): if query is not None and (vector is not None or text is not None): @@ -1131,7 +1294,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return vector_query, text_query - def phrase_query(self, phrase_query: bool = True) -> LanceHybridQueryBuilder: + def phrase_query(self, phrase_query: bool = None) -> LanceHybridQueryBuilder: """Set whether to use phrase query. Parameters @@ -1148,6 +1311,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._phrase_query = phrase_query return self + def to_query_object(self) -> Query: + raise NotImplementedError("to_query_object not yet supported on a hybrid query") + def to_arrow(self) -> pa.Table: vector_query, fts_query = self._validate_query( self._query, self._vector, self._text @@ -1169,8 +1335,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.select(self._columns) self._fts_query.select(self._columns) if self._where: - self._vector_query.where(self._where, self._prefilter) - self._fts_query.where(self._where, self._prefilter) + self._vector_query.where(self._where, self._postfilter) + self._fts_query.where(self._where, self._postfilter) if self._with_row_id: self._vector_query.with_row_id(True) self._fts_query.with_row_id(True) @@ -1184,9 +1350,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.refine_factor(self._refine_factor) if self._ef: self._vector_query.ef(self._ef) - if not self._use_index: + if self._bypass_vector_index: self._vector_query.bypass_vector_index() + if self._reranker is None: + self._reranker = RRFReranker() + with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) vector_future = executor.submit( @@ -1502,12 +1671,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceHybridQueryBuilder The LanceHybridQueryBuilder object. """ - self._use_index = False + self._bypass_vector_index = True return self class AsyncQueryBase(object): - def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]): + def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]): """ Construct an AsyncQueryBase @@ -1516,6 +1685,9 @@ class AsyncQueryBase(object): """ self._inner = inner + def to_query_object(self) -> Query: + return Query.from_inner(self._inner.to_query_request()) + def where(self, predicate: str) -> Self: """ Only return rows matching the given predicate @@ -1868,7 +2040,7 @@ class AsyncQuery(AsyncQueryBase): ) def nearest_to_text( - self, query: str, columns: Union[str, List[str]] = [] + self, query: str, columns: Union[str, List[str], None] = None ) -> AsyncFTSQuery: """ Find the documents that are most relevant to the given text query. @@ -1892,6 +2064,8 @@ class AsyncQuery(AsyncQueryBase): """ if isinstance(columns, str): columns = [columns] + if columns is None: + columns = [] return AsyncFTSQuery( self._inner.nearest_to_text({"query": query, "columns": columns}) ) @@ -2177,7 +2351,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): return self def nearest_to_text( - self, query: str, columns: Union[str, List[str]] = [] + self, query: str, columns: Union[str, List[str], None] = None ) -> AsyncHybridQuery: """ Find the documents that are most relevant to the given text query, @@ -2205,6 +2379,8 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): """ if isinstance(columns, str): columns = [columns] + if columns is None: + columns = [] return AsyncHybridQuery( self._inner.nearest_to_text({"query": query, "columns": columns}) ) diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 64fe5973..2a362996 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -282,7 +282,8 @@ class RemoteTable(Table): """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] - All query options are defined in [Query][lancedb.query.Query]. + All query options are defined in + [LanceVectorQueryBuilder][lancedb.query.LanceVectorQueryBuilder]. Examples -------- @@ -353,7 +354,16 @@ class RemoteTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - return LOOP.run(self._table._execute_query(query, batch_size=batch_size)) + async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size)) + + def iter_sync(): + try: + while True: + yield LOOP.run(async_iter.__anext__()) + except StopAsyncIteration: + return + + return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index ac7633ab..a1ca619b 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -101,7 +101,9 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader: schema = data.features.arrow_schema return pa.RecordBatchReader.from_batches(schema, data.data.to_batches()) elif isinstance(data, datasets.dataset_dict.DatasetDict): - schema = _schema_from_hf(data, schema) + schema = _schema_from_hf(data, None) + if "split" not in schema.names: + schema = schema.append(pa.field("split", pa.string())) return pa.RecordBatchReader.from_batches( schema, _to_batches_with_split(data) ) @@ -415,7 +417,7 @@ def sanitize_create_table( return data, schema -def _schema_from_hf(data, schema): +def _schema_from_hf(data, schema) -> pa.Schema: """ Extract pyarrow schema from HuggingFace DatasetDict and validate that they're all the same schema between @@ -927,7 +929,8 @@ class Table(ABC): of the given query vector. We currently support [vector search][search] and [full-text search][experimental-full-text-search]. - All query options are defined in [Query][lancedb.query.Query]. + All query options are defined in + [LanceQueryBuilder][lancedb.query.LanceQueryBuilder]. Examples -------- @@ -2278,7 +2281,19 @@ class LanceTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - return LOOP.run(self._table._execute_query(query, batch_size)) + async_iter = LOOP.run(self._table._execute_query(query, batch_size)) + + def iter_sync(): + try: + while True: + yield LOOP.run(async_iter.__anext__()) + except StopAsyncIteration: + return + + return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) + + def _explain_plan(self, query: Query) -> str: + return LOOP.run(self._table._explain_plan(query)) def _do_merge( self, @@ -3053,7 +3068,7 @@ class AsyncTable: query_type: Literal["auto"] = ..., ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, - ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ... + ) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: ... @overload async def search( @@ -3102,7 +3117,7 @@ class AsyncTable: query_type: QueryType = "auto", ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, - ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: + ) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] and [full-text search][experimental-full-text-search]. @@ -3264,12 +3279,12 @@ class AsyncTable: builder = builder.column(vector_column_name) return builder elif query_type == "fts": - return self.query().nearest_to_text(query, columns=fts_columns or []) + return self.query().nearest_to_text(query, columns=fts_columns) elif query_type == "hybrid": builder = self.query().nearest_to(vector_query) if vector_column_name: builder = builder.column(vector_column_name) - return builder.nearest_to_text(query, columns=fts_columns or []) + return builder.nearest_to_text(query, columns=fts_columns) else: raise ValueError(f"Unknown query type: '{query_type}'") @@ -3286,16 +3301,13 @@ class AsyncTable: """ return self.query().nearest_to(query_vector) - async def _execute_query( - self, query: Query, batch_size: Optional[int] = None - ) -> pa.RecordBatchReader: - # The sync remote table calls into this method, so we need to map the - # query to the async version of the query and run that here. This is only - # used for that code path right now. + def _sync_query_to_async( + self, query: Query + ) -> AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery | AsyncQuery: async_query = self.query() - if query.k is not None: - async_query = async_query.limit(query.k) - if query.offset > 0: + if query.limit is not None: + async_query = async_query.limit(query.limit) + if query.offset is not None: async_query = async_query.offset(query.offset) if query.columns: async_query = async_query.select(query.columns) @@ -3307,35 +3319,49 @@ class AsyncTable: async_query = async_query.with_row_id() if query.vector: - # we need the schema to get the vector column type - # to determine whether the vectors is batch queries or not - async_query = ( - async_query.nearest_to(query.vector) - .distance_type(query.metric) - .nprobes(query.nprobes) - .distance_range(query.lower_bound, query.upper_bound) + async_query = async_query.nearest_to(query.vector).distance_range( + query.lower_bound, query.upper_bound ) - if query.refine_factor: + if query.distance_type is not None: + async_query = async_query.distance_type(query.distance_type) + if query.nprobes is not None: + async_query = async_query.nprobes(query.nprobes) + if query.refine_factor is not None: async_query = async_query.refine_factor(query.refine_factor) if query.vector_column: async_query = async_query.column(query.vector_column) if query.ef: async_query = async_query.ef(query.ef) - if not query.use_index: + if query.bypass_vector_index: async_query = async_query.bypass_vector_index() - if not query.prefilter: + if query.postfilter: async_query = async_query.postfilter() - if isinstance(query.full_text_query, str): - async_query = async_query.nearest_to_text(query.full_text_query) - elif isinstance(query.full_text_query, dict): - fts_query = query.full_text_query["query"] - fts_columns = query.full_text_query.get("columns", []) or [] - async_query = async_query.nearest_to_text(fts_query, columns=fts_columns) + if query.full_text_query: + async_query = async_query.nearest_to_text( + query.full_text_query.query, query.full_text_query.columns + ) + if query.full_text_query.limit is not None: + async_query = async_query.limit(query.full_text_query.limit) - table = await async_query.to_arrow() - return table.to_reader() + return async_query + + async def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: + # The sync table calls into this method, so we need to map the + # query to the async version of the query and run that here. This is only + # used for that code path right now. + + async_query = self._sync_query_to_async(query) + + return await async_query.to_batches(max_batch_length=batch_size) + + async def _explain_plan(self, query: Query) -> str: + # This method is used by the sync table + async_query = self._sync_query_to_async(query) + return await async_query.explain_plan() async def _do_merge( self, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 54c6c69b..03b7dee9 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -26,10 +26,12 @@ from lancedb.query import ( AsyncVectorQuery, LanceVectorQueryBuilder, Query, + FullTextSearchQuery, ) from lancedb.rerankers.cross_encoder import CrossEncoderReranker from lancedb.table import AsyncTable, LanceTable from utils import exception_output +from importlib.util import find_spec @pytest.fixture(scope="module") @@ -392,12 +394,28 @@ def test_query_builder_batches(table): for item in rs: rs_list.append(item) assert isinstance(item, pa.RecordBatch) - assert len(rs_list) == 1 - assert len(rs_list[0]["id"]) == 2 + assert len(rs_list) == 2 + assert len(rs_list[0]["id"]) == 1 assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0]) assert rs_list[0].to_pandas()["id"][0] == 1 - assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0]) - assert rs_list[0].to_pandas()["id"][1] == 2 + assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0]) + assert rs_list[1].to_pandas()["id"][0] == 2 + + rs = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .limit(2) + .select(["id", "vector"]) + .to_batches(2) + ) + rs_list = [] + for item in rs: + rs_list.append(item) + assert isinstance(item, pa.RecordBatch) + assert len(rs_list) == 1 + assert len(rs_list[0]["id"]) == 2 + rs_list = rs_list[0].to_pandas() + assert rs_list["id"][0] == 1 + assert rs_list["id"][1] == 2 def test_dynamic_projection(table): @@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column(): Query( vector=query, filter="b < 10", - prefilter=True, - k=2, - metric="cosine", + limit=2, + distance_type="cosine", columns=["b"], - nprobes=20, - refine_factor=None, vector_column="foo_vector", ), None, @@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable): @pytest.mark.asyncio @pytest.mark.slow async def test_query_reranked_async(table_async: AsyncTable): + # CrossEncoderReranker requires torch + if find_spec("torch") is None: + pytest.skip("torch not installed") + # FTS with rerank await table_async.create_index("text", config=FTS(with_position=False)) await check_query( @@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection): assert "No embedding functions are registered for any columns" in exception_output( e ) + + +# Helper method used in the following tests. Looks at the simple python object `q` and +# checks that the properties match the expected values in kwargs. +def check_set_props(q, **kwargs): + for k in dict(q): + if not k.startswith("_"): + if k in kwargs: + assert kwargs[k] == getattr(q, k), ( + f"{k} should be {kwargs[k]} but is {getattr(q, k)}" + ) + else: + assert getattr(q, k) is None, f"{k} should be None" + + +def test_query_serialization_sync(table: lancedb.table.Table): + # Simple queries + q = table.search().where("id = 1").limit(500).offset(10).to_query_object() + check_set_props(q, limit=500, offset=10, filter="id = 1") + + q = table.search().select(["id", "vector"]).to_query_object() + check_set_props(q, columns=["id", "vector"]) + + q = table.search().with_row_id(True).to_query_object() + check_set_props(q, with_row_id=True) + + # Vector queries + q = table.search([5.0, 6.0]).limit(10).to_query_object() + check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0]) + + q = table.search([5.0, 6.0]).to_query_object() + check_set_props(q, vector_column="vector", vector=[5.0, 6.0]) + + q = ( + table.search([5.0, 6.0]) + .limit(10) + .where("id = 1", prefilter=False) + .to_query_object() + ) + check_set_props( + q, + limit=10, + vector_column="vector", + filter="id = 1", + postfilter=True, + vector=[5.0, 6.0], + ) + + q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object() + check_set_props( + q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5 + ) + + q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object() + check_set_props( + q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0 + ) + + q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object() + check_set_props( + q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0] + ) + + q = table.search([5.0, 6.0]).ef(7).to_query_object() + check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0]) + + q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object() + check_set_props( + q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0] + ) + + # FTS queries + q = table.search("foo").limit(10).to_query_object() + check_set_props( + q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo") + ) + + q = table.search("foo", query_type="fts").to_query_object() + check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo")) + + +@pytest.mark.asyncio +async def test_query_serialization_async(table_async: AsyncTable): + # Simple queries + q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object() + check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False) + + q = table_async.query().select(["id", "vector"]).to_query_object() + check_set_props(q, columns=["id", "vector"], with_row_id=False) + + q = table_async.query().with_row_id().to_query_object() + check_set_props(q, with_row_id=True) + + sample_vector = [pa.array([5.0, 6.0], type=pa.float32())] + + # Vector queries + q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object() + check_set_props( + q, + limit=10, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + ) + + q = (await table_async.search([5.0, 6.0])).to_query_object() + check_set_props( + q, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .limit(10) + .where("id = 1") + .postfilter() + .to_query_object() + ) + check_set_props( + q, + limit=10, + filter="id = 1", + postfilter=True, + vector=sample_vector, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .nprobes(10) + .refine_factor(5) + .to_query_object() + ) + check_set_props( + q, + vector=sample_vector, + nprobes=10, + refine_factor=5, + postfilter=False, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .distance_range(0.0, 1.0) + .to_query_object() + ) + check_set_props( + q, + vector=sample_vector, + lower_bound=0.0, + upper_bound=1.0, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object() + check_set_props( + q, + distance_type="cosine", + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object() + check_set_props( + q, + ef=7, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object() + check_set_props( + q, + bypass_vector_index=True, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + limit=10, + ) + + # FTS queries + q = (await table_async.search("foo")).limit(10).to_query_object() + check_set_props( + q, + limit=10, + full_text_query=FullTextSearchQuery(columns=[], query="foo"), + with_row_id=False, + ) + + q = (await table_async.search("foo", query_type="fts")).to_query_object() + check_set_props( + q, + full_text_query=FullTextSearchQuery(columns=[], query="foo"), + with_row_id=False, + ) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index b46d6880..642e2443 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -315,7 +315,7 @@ def test_query_sync_minimal(): assert body == { "distance_type": "l2", "k": 10, - "prefilter": False, + "prefilter": True, "refine_factor": None, "lower_bound": None, "upper_bound": None, @@ -340,7 +340,7 @@ def test_query_sync_empty_query(): "filter": "true", "vector": [], "columns": ["id"], - "prefilter": False, + "prefilter": True, "version": None, } @@ -478,7 +478,7 @@ def test_query_sync_hybrid(): assert body == { "distance_type": "l2", "k": 42, - "prefilter": False, + "prefilter": True, "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, diff --git a/python/src/query.rs b/python/src/query.rs index f3d1511c..7b941356 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -1,19 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use std::sync::Arc; + use arrow::array::make_array; +use arrow::array::Array; use arrow::array::ArrayData; use arrow::pyarrow::FromPyArrow; +use arrow::pyarrow::IntoPyArrow; use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::QueryExecutionOptions; +use lancedb::query::QueryFilter; use lancedb::query::{ ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; +use lancedb::table::AnyQuery; +use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::pymethods; use pyo3::types::PyDict; +use pyo3::types::PyList; use pyo3::Bound; +use pyo3::IntoPyObject; use pyo3::PyAny; use pyo3::PyRef; use pyo3::PyResult; @@ -24,6 +33,156 @@ use crate::arrow::RecordBatchStream; use crate::error::PythonErrorExt; use crate::util::parse_distance_type; +// Python representation of full text search parameters +#[derive(Clone)] +#[pyclass(get_all)] +pub struct PyFullTextSearchQuery { + pub columns: Vec, + pub query: String, + pub limit: Option, + pub wand_factor: Option, +} + +impl From for PyFullTextSearchQuery { + fn from(query: FullTextSearchQuery) -> Self { + PyFullTextSearchQuery { + columns: query.columns, + query: query.query, + limit: query.limit, + wand_factor: query.wand_factor, + } + } +} + +// Python representation of query vector(s) +#[derive(Clone)] +pub struct PyQueryVectors(Vec>); + +impl<'py> IntoPyObject<'py> for PyQueryVectors { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + let py_objs = self + .0 + .into_iter() + .map(|v| v.to_data().into_pyarrow(py)) + .collect::, _>>()?; + PyList::new(py, py_objs) + } +} + +// Python representation of a query +#[pyclass(get_all)] +pub struct PyQueryRequest { + pub limit: Option, + pub offset: Option, + pub filter: Option, + pub full_text_search: Option, + pub select: PySelect, + pub fast_search: Option, + pub with_row_id: Option, + pub column: Option, + pub query_vector: Option, + pub nprobes: Option, + pub lower_bound: Option, + pub upper_bound: Option, + pub ef: Option, + pub refine_factor: Option, + pub distance_type: Option, + pub bypass_vector_index: Option, + pub postfilter: Option, + pub norm: Option, +} + +impl From for PyQueryRequest { + fn from(query: AnyQuery) -> Self { + match query { + AnyQuery::Query(query_request) => PyQueryRequest { + limit: query_request.limit, + offset: query_request.offset, + filter: query_request.filter.map(PyQueryFilter), + full_text_search: query_request + .full_text_search + .map(PyFullTextSearchQuery::from), + select: PySelect(query_request.select), + fast_search: Some(query_request.fast_search), + with_row_id: Some(query_request.with_row_id), + column: None, + query_vector: None, + nprobes: None, + lower_bound: None, + upper_bound: None, + ef: None, + refine_factor: None, + distance_type: None, + bypass_vector_index: None, + postfilter: None, + norm: None, + }, + AnyQuery::VectorQuery(vector_query) => PyQueryRequest { + limit: vector_query.base.limit, + offset: vector_query.base.offset, + filter: vector_query.base.filter.map(PyQueryFilter), + full_text_search: None, + select: PySelect(vector_query.base.select), + fast_search: Some(vector_query.base.fast_search), + with_row_id: Some(vector_query.base.with_row_id), + column: vector_query.column, + query_vector: Some(PyQueryVectors(vector_query.query_vector)), + nprobes: Some(vector_query.nprobes), + lower_bound: vector_query.lower_bound, + upper_bound: vector_query.upper_bound, + ef: vector_query.ef, + refine_factor: vector_query.refine_factor, + distance_type: vector_query.distance_type.map(|d| d.to_string()), + bypass_vector_index: Some(!vector_query.use_index), + postfilter: Some(!vector_query.base.prefilter), + norm: vector_query.base.norm.map(|n| n.to_string()), + }, + } + } +} + +// Python representation of query selection +#[derive(Clone)] +pub struct PySelect(Select); + +impl<'py> IntoPyObject<'py> for PySelect { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + match self.0 { + Select::All => Ok(py.None().into_bound(py).into_any()), + Select::Columns(columns) => Ok(columns.into_pyobject(py)?.into_any()), + Select::Dynamic(columns) => Ok(columns.into_pyobject(py)?.into_any()), + } + } +} + +// Python representation of query filter +#[derive(Clone)] +pub struct PyQueryFilter(QueryFilter); + +impl<'py> IntoPyObject<'py> for PyQueryFilter { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + match self.0 { + QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err( + "Datafusion filter has no conversion to Python", + )), + QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()), + QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()), + } + } +} + #[pyclass] pub struct Query { inner: LanceDbQuery, @@ -121,6 +280,10 @@ impl Query { .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } + + pub fn to_query_request(&self) -> PyQueryRequest { + PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request())) + } } #[pyclass] @@ -205,6 +368,12 @@ impl FTSQuery { pub fn get_query(&self) -> String { self.fts_query.query.clone() } + + pub fn to_query_request(&self) -> PyQueryRequest { + let mut req = self.inner.clone().into_request(); + req.full_text_search = Some(self.fts_query.clone()); + PyQueryRequest::from(AnyQuery::Query(req)) + } } #[pyclass] @@ -319,6 +488,10 @@ impl VectorQuery { inner_fts: fts_query, }) } + + pub fn to_query_request(&self) -> PyQueryRequest { + PyQueryRequest::from(AnyQuery::VectorQuery(self.inner.clone().into_request())) + } } #[pyclass] @@ -421,4 +594,17 @@ impl HybridQuery { pub fn get_with_row_id(&mut self) -> bool { self.inner_fts.inner.current_request().with_row_id } + + pub fn to_query_request(&self) -> PyQueryRequest { + let mut req = self.inner_fts.to_query_request(); + let vec_req = self.inner_vec.to_query_request(); + req.query_vector = vec_req.query_vector; + req.column = vec_req.column; + req.distance_type = vec_req.distance_type; + req.ef = vec_req.ef; + req.refine_factor = vec_req.refine_factor; + req.lower_bound = vec_req.lower_bound; + req.upper_bound = vec_req.upper_bound; + req + } } diff --git a/rust/lancedb/src/rerankers.rs b/rust/lancedb/src/rerankers.rs index 338230c0..53f54abc 100644 --- a/rust/lancedb/src/rerankers.rs +++ b/rust/lancedb/src/rerankers.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::collections::BTreeSet; +use std::{collections::BTreeSet, str::FromStr}; use arrow::{ array::downcast_array, @@ -24,6 +24,29 @@ pub enum NormalizeMethod { Rank, } +impl FromStr for NormalizeMethod { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "score" => Ok(NormalizeMethod::Score), + "rank" => Ok(NormalizeMethod::Rank), + _ => Err(Error::InvalidInput { + message: format!("invalid normalize method: {}", s), + }), + } + } +} + +impl std::fmt::Display for NormalizeMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NormalizeMethod::Score => write!(f, "score"), + NormalizeMethod::Rank => write!(f, "rank"), + } + } +} + /// Interface for a reranker. A reranker is used to rerank the results from a /// vector and FTS search. This is useful for combining the results from both /// search methods.