diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index cab8486b..38a5f14f 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -94,6 +94,7 @@ class Query: def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... + def to_query_request(self) -> PyQueryRequest: ... class FTSQuery: def where(self, filter: str): ... @@ -108,6 +109,7 @@ class FTSQuery: def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... async def explain_plan(self) -> str: ... + def to_query_request(self) -> PyQueryRequest: ... class VectorQuery: async def execute(self) -> RecordBatchStream: ... @@ -123,6 +125,7 @@ class VectorQuery: def nprobes(self, nprobes: int): ... def bypass_vector_index(self): ... def nearest_to_text(self, query: dict) -> HybridQuery: ... + def to_query_request(self) -> PyQueryRequest: ... class HybridQuery: def where(self, filter: str): ... @@ -140,6 +143,33 @@ class HybridQuery: def to_fts_query(self) -> FTSQuery: ... def get_limit(self) -> int: ... def get_with_row_id(self) -> bool: ... + def to_query_request(self) -> PyQueryRequest: ... + +class PyFullTextSearchQuery: + columns: Optional[List[str]] + query: str + limit: Optional[int] + wand_factor: Optional[float] + +class PyQueryRequest: + limit: Optional[int] + offset: Optional[int] + filter: Optional[Union[str, bytes]] + full_text_search: Optional[PyFullTextSearchQuery] + select: Optional[Union[str, List[str]]] + fast_search: Optional[bool] + with_row_id: Optional[bool] + column: Optional[str] + query_vector: Optional[List[pa.Array]] + nprobes: Optional[int] + lower_bound: Optional[float] + upper_bound: Optional[float] + ef: Optional[int] + refine_factor: Optional[int] + distance_type: Optional[str] + bypass_vector_index: Optional[bool] + postfilter: Optional[bool] + norm: Optional[str] class CompactionStats: fragments_removed: int diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index fdbc8751..7c7447be 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -14,6 +14,7 @@ from typing import ( Tuple, Type, Union, + Any, ) import asyncio @@ -32,6 +33,8 @@ from .rerankers.rrf import RRFReranker from .rerankers.util import check_reranker_result from .util import flatten_columns +from typing_extensions import Annotated + if TYPE_CHECKING: import sys import PIL @@ -41,6 +44,7 @@ if TYPE_CHECKING: from ._lancedb import FTSQuery as LanceFTSQuery from ._lancedb import HybridQuery as LanceHybridQuery from ._lancedb import VectorQuery as LanceVectorQuery + from ._lancedb import PyQueryRequest from .common import VEC from .pydantic import LanceModel from .table import Table @@ -51,33 +55,116 @@ if TYPE_CHECKING: from typing_extensions import Self -class Query(pydantic.BaseModel): - """The LanceDB Query +# Pydantic validation function for vector queries +def ensure_vector_query( + val: Any, +) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]: + if isinstance(val, list): + if len(val) == 0: + return ValueError("Vector query must be a non-empty list") + sample = val[0] + else: + if isinstance(val, float): + raise ValueError( + "Vector query must be a list of floats or a list of lists of floats" + ) + sample = val + if isinstance(sample, pa.Array): + # val is array or list of array + return val + if isinstance(sample, list): + if len(sample) == 0: + return ValueError("Vector query must be a non-empty list") + if isinstance(sample[0], float): + # val is list of list of floats + return val + if isinstance(sample, float): + # val is a list of floats + return val + + +class FullTextSearchQuery(pydantic.BaseModel): + """A LanceDB Full Text Search Query Attributes ---------- - vector : List[float] - the vector to search for - filter : Optional[str] - sql filter to refine the query with, optional - prefilter : bool - if True then apply the filter before vector search - k : int - top k results to return - metric : str - the distance metric between a pair of vectors, + columns: List[str] + The columns to search - can support l2 (default), Cosine and Dot. - [metric definitions][search] - columns : Optional[List[str]] + If None, then the table should select the column automatically. + query: str + The query to search for + limit: Optional[int] = None + The limit on the number of results to return + wand_factor: Optional[float] = None + The wand factor to use for the search + """ + + columns: Optional[List[str]] = None + query: str + limit: Optional[int] = None + wand_factor: Optional[float] = None + + +class Query(pydantic.BaseModel): + """A LanceDB Query + + Queries are constructed by the `Table.search` and `Table.query` methods. This + class is a python representation of the query. Normally you will not need to + interact with this class directly. You can build up a query and execute it using + collection methods such as `to_batches()`, `to_arrow()`, `to_pandas()`, etc. + + However, you can use the `to_query()` method to get the underlying query object. + This can be useful for serializing a query or using it in a different context. + + Attributes + ---------- + filter : Optional[str] + sql filter to refine the query with + limit : Optional[int] + The limit on the number of results to return. If this is a vector or FTS query, + then this is required. If this is a plain SQL query, then this is optional. + offset: Optional[int] + The offset to start fetching results from + + This is ignored for vector / FTS search (will be None). + columns : Optional[Union[List[str], Dict[str, str]]] which columns to return in the results - nprobes : int - The number of probes used - optional + + This can be a list of column names or a dictionary. If it is a dictionary, + then the keys are the column names and the values are sql expressions to + use to calculate the result. + + If this is None then all columns are returned. This can be expensive. + with_row_id : Optional[bool] + if True then include the row id in the results + vector : Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]] + the vector to search for, if this a vector search or hybrid search. It will + be None for full text search and plain SQL filtering. + vector_column : Optional[str] + the name of the vector column to use for vector search + + If this is None then a default vector column will be used. + distance_type : Optional[str] + the distance type to use for vector search + + This can be l2 (default), cosine and dot. See [metric definitions][search] for + more details. + + If this is not a vector search this will be None. + postfilter : bool + if True then apply the filter after vector / FTS search. This is ignored for + plain SQL filtering. + nprobes : Optional[int] + The number of IVF partitions to search. If this is None then a default + number of partitions will be used. - A higher number makes search more accurate but also slower. - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. + + Will be None if this is not a vector search. refine_factor : Optional[int] Refine the results by reading extra elements and re-ranking them in memory. @@ -85,58 +172,130 @@ class Query(pydantic.BaseModel): - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. - offset: int - The offset to start fetching results from - fast_search: bool + + Will be None if this is not a vector search. + lower_bound : Optional[float] + The lower bound for distance search + + Only results with a distance greater than or equal to this value + will be returned. + + This will only be set on vector search. + upper_bound : Optional[float] + The upper bound for distance search + + Only results with a distance less than or equal to this value + will be returned. + + This will only be set on vector search. + ef : Optional[int] + The size of the nearest neighbor list maintained during HNSW search + + This will only be set on vector search. + full_text_query : Optional[Union[str, dict]] + The full text search query + + This can be a string or a dictionary. A dictionary will be used to search + multiple columns. The keys are the column names and the values are the + search queries. + + This will only be set on FTS or hybrid queries. + fast_search: Optional[bool] Skip a flat search of unindexed data. This will improve search performance but search results will not include unindexed data. - - *default False*. + The default is False """ + # The name of the vector column to use for vector search. vector_column: Optional[str] = None # vector to search for - vector: Union[List[float], List[List[float]]] + # + # Note: today this will be floats on the sync path and pa.Array on the async + # path though in the future we should unify this to pa.Array everywhere + vector: Annotated[ + Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]], + ensure_vector_query, + ] = None # sql filter to refine the query with filter: Optional[str] = None - # if True then apply the filter before vector search - prefilter: bool = False + # if True then apply the filter after vector search + postfilter: Optional[bool] = None # full text search query - full_text_query: Optional[Union[str, dict]] = None + full_text_query: Optional[FullTextSearchQuery] = None # top k results to return - k: Optional[int] = None + limit: Optional[int] = None - # # metrics - metric: str = "l2" + # distance type to use for vector search + distance_type: Optional[str] = None # which columns to return in the results columns: Optional[Union[List[str], Dict[str, str]]] = None - # optional query parameters for tuning the results, - # e.g. `{"nprobes": "10", "refine_factor": "10"}` - nprobes: int = 10 + # number of IVF partitions to search + nprobes: Optional[int] = None + # lower bound for distance search lower_bound: Optional[float] = None + + # upper bound for distance search upper_bound: Optional[float] = None - # Refine factor. + # multiplier for the number of results to inspect for reranking refine_factor: Optional[int] = None - with_row_id: bool = False + # if true, include the row id in the results + with_row_id: Optional[bool] = None - offset: int = 0 + # offset to start fetching results from + offset: Optional[int] = None - fast_search: bool = False + # if true, will only search the indexed data + fast_search: Optional[bool] = None + # size of the nearest neighbor list maintained during HNSW search ef: Optional[int] = None - # Default is true. Set to false to enforce a brute force search. - use_index: bool = True + # Bypass the vector index and use a brute force search + bypass_vector_index: Optional[bool] = None + + @classmethod + def from_inner(cls, req: PyQueryRequest) -> Self: + query = cls() + query.limit = req.limit + query.offset = req.offset + query.filter = req.filter + query.full_text_query = req.full_text_search + query.columns = req.select + query.with_row_id = req.with_row_id + query.vector_column = req.column + query.vector = req.query_vector + query.distance_type = req.distance_type + query.nprobes = req.nprobes + query.lower_bound = req.lower_bound + query.upper_bound = req.upper_bound + query.ef = req.ef + query.refine_factor = req.refine_factor + query.bypass_vector_index = req.bypass_vector_index + query.postfilter = req.postfilter + if req.full_text_search is not None: + query.full_text_query = FullTextSearchQuery( + columns=req.full_text_search.columns, + query=req.full_text_search.query, + limit=req.full_text_search.limit, + wand_factor=req.full_text_search.wand_factor, + ) + return query + + class Config: + # This tells pydantic to allow custom types (needed for the `vector` query since + # pa.Array wouln't be allowed otherwise) + arbitrary_types_allowed = True class LanceQueryBuilder(ABC): @@ -152,8 +311,8 @@ class LanceQueryBuilder(ABC): query_type: str, vector_column_name: str, ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], - fast_search: bool = False, + fts_columns: Optional[Union[str, List[str]]] = None, + fast_search: bool = None, ) -> Self: """ Create a query builder based on the given query and query type. @@ -257,15 +416,15 @@ class LanceQueryBuilder(ABC): def __init__(self, table: "Table"): self._table = table self._limit = None - self._offset = 0 + self._offset = None self._columns = None self._where = None - self._prefilter = True - self._with_row_id = False + self._postfilter = None + self._with_row_id = None self._vector = None self._text = None self._ef = None - self._use_index = True + self._bypass_vector_index = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -315,7 +474,7 @@ class LanceQueryBuilder(ABC): raise NotImplementedError @abstractmethod - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table: + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: """ Execute the query and return the results as a pyarrow [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) @@ -451,7 +610,7 @@ class LanceQueryBuilder(ABC): The LanceQueryBuilder object. """ self._where = where - self._prefilter = prefilter + self._postfilter = not prefilter return self def with_row_id(self, with_row_id: bool) -> Self: @@ -497,23 +656,7 @@ class LanceQueryBuilder(ABC): ------- plan : str """ # noqa: E501 - ds = self._table.to_lance() - return ds.scanner( - nearest={ - "column": self._vector_column, - "q": self._query, - "k": self._limit, - "metric": self._distance_type, - "nprobes": self._nprobes, - "refine_factor": self._refine_factor, - "use_index": self._use_index, - }, - prefilter=self._prefilter, - filter=self._str_query, - limit=self._limit, - with_row_id=self._with_row_id, - offset=self._offset, - ).explain_plan(verbose) + return self._table._explain_plan(self.to_query_object()) def vector(self, vector: Union[np.ndarray, list]) -> Self: """Set the vector to search for. @@ -561,6 +704,17 @@ class LanceQueryBuilder(ABC): """ raise NotImplementedError + @abstractmethod + def to_query_object(self) -> Query: + """Return a serializable representation of the query + + Returns + ------- + Query + The serializable representation of the query + """ + raise NotImplementedError + class LanceVectorQueryBuilder(LanceQueryBuilder): """ @@ -590,19 +744,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str, str_query: Optional[str] = None, - fast_search: bool = False, + fast_search: bool = None, ): super().__init__(table) - if self._limit is None: - self._limit = 10 self._query = query - self._distance_type = "l2" - self._nprobes = 20 + self._distance_type = None + self._nprobes = None self._lower_bound = None self._upper_bound = None self._refine_factor = None self._vector_column = vector_column - self._prefilter = False + self._postfilter = None self._reranker = None self._str_query = str_query self._fast_search = fast_search @@ -752,6 +904,34 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): """ return self.to_batches().read_all() + def to_query_object(self) -> Query: + """ + Build a Query object + + This can be used to serialize a query + """ + vector = self._query if isinstance(self._query, list) else self._query.tolist() + if isinstance(vector[0], np.ndarray): + vector = [v.tolist() for v in vector] + return Query( + vector=vector, + filter=self._where, + postfilter=self._postfilter, + limit=self._limit, + distance_type=self._distance_type, + columns=self._columns, + nprobes=self._nprobes, + lower_bound=self._lower_bound, + upper_bound=self._upper_bound, + refine_factor=self._refine_factor, + vector_column=self._vector_column, + with_row_id=self._with_row_id, + offset=self._offset, + fast_search=self._fast_search, + ef=self._ef, + bypass_vector_index=self._bypass_vector_index, + ) + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: """ Execute the query and return the result as a RecordBatchReader object. @@ -768,24 +948,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector = self._query if isinstance(self._query, list) else self._query.tolist() if isinstance(vector[0], np.ndarray): vector = [v.tolist() for v in vector] - query = Query( - vector=vector, - filter=self._where, - prefilter=self._prefilter, - k=self._limit, - metric=self._distance_type, - columns=self._columns, - nprobes=self._nprobes, - lower_bound=self._lower_bound, - upper_bound=self._upper_bound, - refine_factor=self._refine_factor, - vector_column=self._vector_column, - with_row_id=self._with_row_id, - offset=self._offset, - fast_search=self._fast_search, - ef=self._ef, - use_index=self._use_index, - ) + query = self.to_query_object() result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: rs_table = result_set.read_all() @@ -798,7 +961,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): return result_set - def where(self, where: str, prefilter: bool = True) -> LanceVectorQueryBuilder: + def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder: """Set the where clause. Parameters @@ -810,8 +973,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): prefilter: bool, default True If True, apply the filter before vector search, otherwise the filter is applied on the result of vector search. - This feature is **EXPERIMENTAL** and may be removed and modified - without warning in the future. Returns ------- @@ -819,7 +980,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): The LanceQueryBuilder object. """ self._where = where - self._prefilter = prefilter + if prefilter is not None: + self._postfilter = not prefilter return self def rerank( @@ -873,7 +1035,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceVectorQueryBuilder object. """ - self._use_index = False + self._bypass_vector_index = True return self @@ -885,11 +1047,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): table: "Table", query: str, ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], + fts_columns: Optional[Union[str, List[str]]] = None, ): super().__init__(table) - if self._limit is None: - self._limit = 10 self._query = query self._phrase_query = False self.ordering_field_name = ordering_field_name @@ -915,6 +1075,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): self._phrase_query = phrase_query return self + def to_query_object(self) -> Query: + return Query( + columns=self._columns, + filter=self._where, + limit=self._limit, + postfilter=self._postfilter, + with_row_id=self._with_row_id, + full_text_query=FullTextSearchQuery( + query=self._query, columns=self._fts_columns + ), + offset=self._offset, + ) + def to_arrow(self) -> pa.Table: path, fs, exist = self._table._get_fts_index_path() if exist: @@ -926,19 +1099,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): "Phrase query is not yet supported in Lance FTS. " "Use tantivy-based index instead for now." ) - query = Query( - columns=self._columns, - filter=self._where, - k=self._limit, - prefilter=self._prefilter, - with_row_id=self._with_row_id, - full_text_query={ - "query": query, - "columns": self._fts_columns, - }, - vector=[], - offset=self._offset, - ) + query = self.to_query_object() results = self._table._execute_query(query) results = results.read_all() if self._reranker is not None: @@ -983,8 +1144,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._phrase_query: query = query.replace('"', "'") query = f'"{query}"' + limit = self._limit if self._limit is not None else 10 row_ids, scores = search_index( - index, query, self._limit, ordering_field=self.ordering_field_name + index, query, limit, ordering_field=self.ordering_field_name ) if len(row_ids) == 0: empty_schema = pa.schema([pa.field("_score", pa.float32())]) @@ -1053,17 +1215,18 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: return self.to_batches().read_all() - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: - query = Query( + def to_query_object(self) -> Query: + return Query( columns=self._columns, filter=self._where, - k=self._limit, + limit=self._limit, with_row_id=self._with_row_id, - vector=[], - # not actually respected in remote query - offset=self._offset or 0, + offset=self._offset, ) - return self._table._execute_query(query) + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + query = self.to_query_object() + return self._table._execute_query(query, batch_size) def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: """Rerank the results using the specified reranker. @@ -1098,18 +1261,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): table: "Table", query: Optional[str] = None, vector_column: Optional[str] = None, - fts_columns: Union[str, List[str]] = [], + fts_columns: Optional[Union[str, List[str]]] = None, ): super().__init__(table) self._query = query self._vector_column = vector_column self._fts_columns = fts_columns - self._norm = "score" - self._reranker = RRFReranker() + self._norm = None + self._reranker = None self._nprobes = None self._refine_factor = None self._distance_type = None - self._phrase_query = False + self._phrase_query = None def _validate_query(self, query, vector=None, text=None): if query is not None and (vector is not None or text is not None): @@ -1131,7 +1294,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return vector_query, text_query - def phrase_query(self, phrase_query: bool = True) -> LanceHybridQueryBuilder: + def phrase_query(self, phrase_query: bool = None) -> LanceHybridQueryBuilder: """Set whether to use phrase query. Parameters @@ -1148,6 +1311,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._phrase_query = phrase_query return self + def to_query_object(self) -> Query: + raise NotImplementedError("to_query_object not yet supported on a hybrid query") + def to_arrow(self) -> pa.Table: vector_query, fts_query = self._validate_query( self._query, self._vector, self._text @@ -1169,8 +1335,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.select(self._columns) self._fts_query.select(self._columns) if self._where: - self._vector_query.where(self._where, self._prefilter) - self._fts_query.where(self._where, self._prefilter) + self._vector_query.where(self._where, self._postfilter) + self._fts_query.where(self._where, self._postfilter) if self._with_row_id: self._vector_query.with_row_id(True) self._fts_query.with_row_id(True) @@ -1184,9 +1350,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.refine_factor(self._refine_factor) if self._ef: self._vector_query.ef(self._ef) - if not self._use_index: + if self._bypass_vector_index: self._vector_query.bypass_vector_index() + if self._reranker is None: + self._reranker = RRFReranker() + with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) vector_future = executor.submit( @@ -1502,12 +1671,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceHybridQueryBuilder The LanceHybridQueryBuilder object. """ - self._use_index = False + self._bypass_vector_index = True return self class AsyncQueryBase(object): - def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]): + def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]): """ Construct an AsyncQueryBase @@ -1516,6 +1685,9 @@ class AsyncQueryBase(object): """ self._inner = inner + def to_query_object(self) -> Query: + return Query.from_inner(self._inner.to_query_request()) + def where(self, predicate: str) -> Self: """ Only return rows matching the given predicate @@ -1868,7 +2040,7 @@ class AsyncQuery(AsyncQueryBase): ) def nearest_to_text( - self, query: str, columns: Union[str, List[str]] = [] + self, query: str, columns: Union[str, List[str], None] = None ) -> AsyncFTSQuery: """ Find the documents that are most relevant to the given text query. @@ -1892,6 +2064,8 @@ class AsyncQuery(AsyncQueryBase): """ if isinstance(columns, str): columns = [columns] + if columns is None: + columns = [] return AsyncFTSQuery( self._inner.nearest_to_text({"query": query, "columns": columns}) ) @@ -2177,7 +2351,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): return self def nearest_to_text( - self, query: str, columns: Union[str, List[str]] = [] + self, query: str, columns: Union[str, List[str], None] = None ) -> AsyncHybridQuery: """ Find the documents that are most relevant to the given text query, @@ -2205,6 +2379,8 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): """ if isinstance(columns, str): columns = [columns] + if columns is None: + columns = [] return AsyncHybridQuery( self._inner.nearest_to_text({"query": query, "columns": columns}) ) diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 64fe5973..2a362996 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -282,7 +282,8 @@ class RemoteTable(Table): """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] - All query options are defined in [Query][lancedb.query.Query]. + All query options are defined in + [LanceVectorQueryBuilder][lancedb.query.LanceVectorQueryBuilder]. Examples -------- @@ -353,7 +354,16 @@ class RemoteTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - return LOOP.run(self._table._execute_query(query, batch_size=batch_size)) + async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size)) + + def iter_sync(): + try: + while True: + yield LOOP.run(async_iter.__anext__()) + except StopAsyncIteration: + return + + return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index ac7633ab..a1ca619b 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -101,7 +101,9 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader: schema = data.features.arrow_schema return pa.RecordBatchReader.from_batches(schema, data.data.to_batches()) elif isinstance(data, datasets.dataset_dict.DatasetDict): - schema = _schema_from_hf(data, schema) + schema = _schema_from_hf(data, None) + if "split" not in schema.names: + schema = schema.append(pa.field("split", pa.string())) return pa.RecordBatchReader.from_batches( schema, _to_batches_with_split(data) ) @@ -415,7 +417,7 @@ def sanitize_create_table( return data, schema -def _schema_from_hf(data, schema): +def _schema_from_hf(data, schema) -> pa.Schema: """ Extract pyarrow schema from HuggingFace DatasetDict and validate that they're all the same schema between @@ -927,7 +929,8 @@ class Table(ABC): of the given query vector. We currently support [vector search][search] and [full-text search][experimental-full-text-search]. - All query options are defined in [Query][lancedb.query.Query]. + All query options are defined in + [LanceQueryBuilder][lancedb.query.LanceQueryBuilder]. Examples -------- @@ -2278,7 +2281,19 @@ class LanceTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - return LOOP.run(self._table._execute_query(query, batch_size)) + async_iter = LOOP.run(self._table._execute_query(query, batch_size)) + + def iter_sync(): + try: + while True: + yield LOOP.run(async_iter.__anext__()) + except StopAsyncIteration: + return + + return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) + + def _explain_plan(self, query: Query) -> str: + return LOOP.run(self._table._explain_plan(query)) def _do_merge( self, @@ -3053,7 +3068,7 @@ class AsyncTable: query_type: Literal["auto"] = ..., ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, - ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ... + ) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: ... @overload async def search( @@ -3102,7 +3117,7 @@ class AsyncTable: query_type: QueryType = "auto", ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, - ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: + ) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] and [full-text search][experimental-full-text-search]. @@ -3264,12 +3279,12 @@ class AsyncTable: builder = builder.column(vector_column_name) return builder elif query_type == "fts": - return self.query().nearest_to_text(query, columns=fts_columns or []) + return self.query().nearest_to_text(query, columns=fts_columns) elif query_type == "hybrid": builder = self.query().nearest_to(vector_query) if vector_column_name: builder = builder.column(vector_column_name) - return builder.nearest_to_text(query, columns=fts_columns or []) + return builder.nearest_to_text(query, columns=fts_columns) else: raise ValueError(f"Unknown query type: '{query_type}'") @@ -3286,16 +3301,13 @@ class AsyncTable: """ return self.query().nearest_to(query_vector) - async def _execute_query( - self, query: Query, batch_size: Optional[int] = None - ) -> pa.RecordBatchReader: - # The sync remote table calls into this method, so we need to map the - # query to the async version of the query and run that here. This is only - # used for that code path right now. + def _sync_query_to_async( + self, query: Query + ) -> AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery | AsyncQuery: async_query = self.query() - if query.k is not None: - async_query = async_query.limit(query.k) - if query.offset > 0: + if query.limit is not None: + async_query = async_query.limit(query.limit) + if query.offset is not None: async_query = async_query.offset(query.offset) if query.columns: async_query = async_query.select(query.columns) @@ -3307,35 +3319,49 @@ class AsyncTable: async_query = async_query.with_row_id() if query.vector: - # we need the schema to get the vector column type - # to determine whether the vectors is batch queries or not - async_query = ( - async_query.nearest_to(query.vector) - .distance_type(query.metric) - .nprobes(query.nprobes) - .distance_range(query.lower_bound, query.upper_bound) + async_query = async_query.nearest_to(query.vector).distance_range( + query.lower_bound, query.upper_bound ) - if query.refine_factor: + if query.distance_type is not None: + async_query = async_query.distance_type(query.distance_type) + if query.nprobes is not None: + async_query = async_query.nprobes(query.nprobes) + if query.refine_factor is not None: async_query = async_query.refine_factor(query.refine_factor) if query.vector_column: async_query = async_query.column(query.vector_column) if query.ef: async_query = async_query.ef(query.ef) - if not query.use_index: + if query.bypass_vector_index: async_query = async_query.bypass_vector_index() - if not query.prefilter: + if query.postfilter: async_query = async_query.postfilter() - if isinstance(query.full_text_query, str): - async_query = async_query.nearest_to_text(query.full_text_query) - elif isinstance(query.full_text_query, dict): - fts_query = query.full_text_query["query"] - fts_columns = query.full_text_query.get("columns", []) or [] - async_query = async_query.nearest_to_text(fts_query, columns=fts_columns) + if query.full_text_query: + async_query = async_query.nearest_to_text( + query.full_text_query.query, query.full_text_query.columns + ) + if query.full_text_query.limit is not None: + async_query = async_query.limit(query.full_text_query.limit) - table = await async_query.to_arrow() - return table.to_reader() + return async_query + + async def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: + # The sync table calls into this method, so we need to map the + # query to the async version of the query and run that here. This is only + # used for that code path right now. + + async_query = self._sync_query_to_async(query) + + return await async_query.to_batches(max_batch_length=batch_size) + + async def _explain_plan(self, query: Query) -> str: + # This method is used by the sync table + async_query = self._sync_query_to_async(query) + return await async_query.explain_plan() async def _do_merge( self, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 54c6c69b..03b7dee9 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -26,10 +26,12 @@ from lancedb.query import ( AsyncVectorQuery, LanceVectorQueryBuilder, Query, + FullTextSearchQuery, ) from lancedb.rerankers.cross_encoder import CrossEncoderReranker from lancedb.table import AsyncTable, LanceTable from utils import exception_output +from importlib.util import find_spec @pytest.fixture(scope="module") @@ -392,12 +394,28 @@ def test_query_builder_batches(table): for item in rs: rs_list.append(item) assert isinstance(item, pa.RecordBatch) - assert len(rs_list) == 1 - assert len(rs_list[0]["id"]) == 2 + assert len(rs_list) == 2 + assert len(rs_list[0]["id"]) == 1 assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0]) assert rs_list[0].to_pandas()["id"][0] == 1 - assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0]) - assert rs_list[0].to_pandas()["id"][1] == 2 + assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0]) + assert rs_list[1].to_pandas()["id"][0] == 2 + + rs = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .limit(2) + .select(["id", "vector"]) + .to_batches(2) + ) + rs_list = [] + for item in rs: + rs_list.append(item) + assert isinstance(item, pa.RecordBatch) + assert len(rs_list) == 1 + assert len(rs_list[0]["id"]) == 2 + rs_list = rs_list[0].to_pandas() + assert rs_list["id"][0] == 1 + assert rs_list["id"][1] == 2 def test_dynamic_projection(table): @@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column(): Query( vector=query, filter="b < 10", - prefilter=True, - k=2, - metric="cosine", + limit=2, + distance_type="cosine", columns=["b"], - nprobes=20, - refine_factor=None, vector_column="foo_vector", ), None, @@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable): @pytest.mark.asyncio @pytest.mark.slow async def test_query_reranked_async(table_async: AsyncTable): + # CrossEncoderReranker requires torch + if find_spec("torch") is None: + pytest.skip("torch not installed") + # FTS with rerank await table_async.create_index("text", config=FTS(with_position=False)) await check_query( @@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection): assert "No embedding functions are registered for any columns" in exception_output( e ) + + +# Helper method used in the following tests. Looks at the simple python object `q` and +# checks that the properties match the expected values in kwargs. +def check_set_props(q, **kwargs): + for k in dict(q): + if not k.startswith("_"): + if k in kwargs: + assert kwargs[k] == getattr(q, k), ( + f"{k} should be {kwargs[k]} but is {getattr(q, k)}" + ) + else: + assert getattr(q, k) is None, f"{k} should be None" + + +def test_query_serialization_sync(table: lancedb.table.Table): + # Simple queries + q = table.search().where("id = 1").limit(500).offset(10).to_query_object() + check_set_props(q, limit=500, offset=10, filter="id = 1") + + q = table.search().select(["id", "vector"]).to_query_object() + check_set_props(q, columns=["id", "vector"]) + + q = table.search().with_row_id(True).to_query_object() + check_set_props(q, with_row_id=True) + + # Vector queries + q = table.search([5.0, 6.0]).limit(10).to_query_object() + check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0]) + + q = table.search([5.0, 6.0]).to_query_object() + check_set_props(q, vector_column="vector", vector=[5.0, 6.0]) + + q = ( + table.search([5.0, 6.0]) + .limit(10) + .where("id = 1", prefilter=False) + .to_query_object() + ) + check_set_props( + q, + limit=10, + vector_column="vector", + filter="id = 1", + postfilter=True, + vector=[5.0, 6.0], + ) + + q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object() + check_set_props( + q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5 + ) + + q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object() + check_set_props( + q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0 + ) + + q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object() + check_set_props( + q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0] + ) + + q = table.search([5.0, 6.0]).ef(7).to_query_object() + check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0]) + + q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object() + check_set_props( + q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0] + ) + + # FTS queries + q = table.search("foo").limit(10).to_query_object() + check_set_props( + q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo") + ) + + q = table.search("foo", query_type="fts").to_query_object() + check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo")) + + +@pytest.mark.asyncio +async def test_query_serialization_async(table_async: AsyncTable): + # Simple queries + q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object() + check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False) + + q = table_async.query().select(["id", "vector"]).to_query_object() + check_set_props(q, columns=["id", "vector"], with_row_id=False) + + q = table_async.query().with_row_id().to_query_object() + check_set_props(q, with_row_id=True) + + sample_vector = [pa.array([5.0, 6.0], type=pa.float32())] + + # Vector queries + q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object() + check_set_props( + q, + limit=10, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + ) + + q = (await table_async.search([5.0, 6.0])).to_query_object() + check_set_props( + q, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .limit(10) + .where("id = 1") + .postfilter() + .to_query_object() + ) + check_set_props( + q, + limit=10, + filter="id = 1", + postfilter=True, + vector=sample_vector, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .nprobes(10) + .refine_factor(5) + .to_query_object() + ) + check_set_props( + q, + vector=sample_vector, + nprobes=10, + refine_factor=5, + postfilter=False, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = ( + (await table_async.search([5.0, 6.0])) + .distance_range(0.0, 1.0) + .to_query_object() + ) + check_set_props( + q, + vector=sample_vector, + lower_bound=0.0, + upper_bound=1.0, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object() + check_set_props( + q, + distance_type="cosine", + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object() + check_set_props( + q, + ef=7, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + bypass_vector_index=False, + limit=10, + ) + + q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object() + check_set_props( + q, + bypass_vector_index=True, + vector=sample_vector, + postfilter=False, + nprobes=20, + with_row_id=False, + limit=10, + ) + + # FTS queries + q = (await table_async.search("foo")).limit(10).to_query_object() + check_set_props( + q, + limit=10, + full_text_query=FullTextSearchQuery(columns=[], query="foo"), + with_row_id=False, + ) + + q = (await table_async.search("foo", query_type="fts")).to_query_object() + check_set_props( + q, + full_text_query=FullTextSearchQuery(columns=[], query="foo"), + with_row_id=False, + ) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index b46d6880..642e2443 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -315,7 +315,7 @@ def test_query_sync_minimal(): assert body == { "distance_type": "l2", "k": 10, - "prefilter": False, + "prefilter": True, "refine_factor": None, "lower_bound": None, "upper_bound": None, @@ -340,7 +340,7 @@ def test_query_sync_empty_query(): "filter": "true", "vector": [], "columns": ["id"], - "prefilter": False, + "prefilter": True, "version": None, } @@ -478,7 +478,7 @@ def test_query_sync_hybrid(): assert body == { "distance_type": "l2", "k": 42, - "prefilter": False, + "prefilter": True, "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, diff --git a/python/src/query.rs b/python/src/query.rs index f3d1511c..7b941356 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -1,19 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use std::sync::Arc; + use arrow::array::make_array; +use arrow::array::Array; use arrow::array::ArrayData; use arrow::pyarrow::FromPyArrow; +use arrow::pyarrow::IntoPyArrow; use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::QueryExecutionOptions; +use lancedb::query::QueryFilter; use lancedb::query::{ ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; +use lancedb::table::AnyQuery; +use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::pymethods; use pyo3::types::PyDict; +use pyo3::types::PyList; use pyo3::Bound; +use pyo3::IntoPyObject; use pyo3::PyAny; use pyo3::PyRef; use pyo3::PyResult; @@ -24,6 +33,156 @@ use crate::arrow::RecordBatchStream; use crate::error::PythonErrorExt; use crate::util::parse_distance_type; +// Python representation of full text search parameters +#[derive(Clone)] +#[pyclass(get_all)] +pub struct PyFullTextSearchQuery { + pub columns: Vec, + pub query: String, + pub limit: Option, + pub wand_factor: Option, +} + +impl From for PyFullTextSearchQuery { + fn from(query: FullTextSearchQuery) -> Self { + PyFullTextSearchQuery { + columns: query.columns, + query: query.query, + limit: query.limit, + wand_factor: query.wand_factor, + } + } +} + +// Python representation of query vector(s) +#[derive(Clone)] +pub struct PyQueryVectors(Vec>); + +impl<'py> IntoPyObject<'py> for PyQueryVectors { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + let py_objs = self + .0 + .into_iter() + .map(|v| v.to_data().into_pyarrow(py)) + .collect::, _>>()?; + PyList::new(py, py_objs) + } +} + +// Python representation of a query +#[pyclass(get_all)] +pub struct PyQueryRequest { + pub limit: Option, + pub offset: Option, + pub filter: Option, + pub full_text_search: Option, + pub select: PySelect, + pub fast_search: Option, + pub with_row_id: Option, + pub column: Option, + pub query_vector: Option, + pub nprobes: Option, + pub lower_bound: Option, + pub upper_bound: Option, + pub ef: Option, + pub refine_factor: Option, + pub distance_type: Option, + pub bypass_vector_index: Option, + pub postfilter: Option, + pub norm: Option, +} + +impl From for PyQueryRequest { + fn from(query: AnyQuery) -> Self { + match query { + AnyQuery::Query(query_request) => PyQueryRequest { + limit: query_request.limit, + offset: query_request.offset, + filter: query_request.filter.map(PyQueryFilter), + full_text_search: query_request + .full_text_search + .map(PyFullTextSearchQuery::from), + select: PySelect(query_request.select), + fast_search: Some(query_request.fast_search), + with_row_id: Some(query_request.with_row_id), + column: None, + query_vector: None, + nprobes: None, + lower_bound: None, + upper_bound: None, + ef: None, + refine_factor: None, + distance_type: None, + bypass_vector_index: None, + postfilter: None, + norm: None, + }, + AnyQuery::VectorQuery(vector_query) => PyQueryRequest { + limit: vector_query.base.limit, + offset: vector_query.base.offset, + filter: vector_query.base.filter.map(PyQueryFilter), + full_text_search: None, + select: PySelect(vector_query.base.select), + fast_search: Some(vector_query.base.fast_search), + with_row_id: Some(vector_query.base.with_row_id), + column: vector_query.column, + query_vector: Some(PyQueryVectors(vector_query.query_vector)), + nprobes: Some(vector_query.nprobes), + lower_bound: vector_query.lower_bound, + upper_bound: vector_query.upper_bound, + ef: vector_query.ef, + refine_factor: vector_query.refine_factor, + distance_type: vector_query.distance_type.map(|d| d.to_string()), + bypass_vector_index: Some(!vector_query.use_index), + postfilter: Some(!vector_query.base.prefilter), + norm: vector_query.base.norm.map(|n| n.to_string()), + }, + } + } +} + +// Python representation of query selection +#[derive(Clone)] +pub struct PySelect(Select); + +impl<'py> IntoPyObject<'py> for PySelect { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + match self.0 { + Select::All => Ok(py.None().into_bound(py).into_any()), + Select::Columns(columns) => Ok(columns.into_pyobject(py)?.into_any()), + Select::Dynamic(columns) => Ok(columns.into_pyobject(py)?.into_any()), + } + } +} + +// Python representation of query filter +#[derive(Clone)] +pub struct PyQueryFilter(QueryFilter); + +impl<'py> IntoPyObject<'py> for PyQueryFilter { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + match self.0 { + QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err( + "Datafusion filter has no conversion to Python", + )), + QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()), + QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()), + } + } +} + #[pyclass] pub struct Query { inner: LanceDbQuery, @@ -121,6 +280,10 @@ impl Query { .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } + + pub fn to_query_request(&self) -> PyQueryRequest { + PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request())) + } } #[pyclass] @@ -205,6 +368,12 @@ impl FTSQuery { pub fn get_query(&self) -> String { self.fts_query.query.clone() } + + pub fn to_query_request(&self) -> PyQueryRequest { + let mut req = self.inner.clone().into_request(); + req.full_text_search = Some(self.fts_query.clone()); + PyQueryRequest::from(AnyQuery::Query(req)) + } } #[pyclass] @@ -319,6 +488,10 @@ impl VectorQuery { inner_fts: fts_query, }) } + + pub fn to_query_request(&self) -> PyQueryRequest { + PyQueryRequest::from(AnyQuery::VectorQuery(self.inner.clone().into_request())) + } } #[pyclass] @@ -421,4 +594,17 @@ impl HybridQuery { pub fn get_with_row_id(&mut self) -> bool { self.inner_fts.inner.current_request().with_row_id } + + pub fn to_query_request(&self) -> PyQueryRequest { + let mut req = self.inner_fts.to_query_request(); + let vec_req = self.inner_vec.to_query_request(); + req.query_vector = vec_req.query_vector; + req.column = vec_req.column; + req.distance_type = vec_req.distance_type; + req.ef = vec_req.ef; + req.refine_factor = vec_req.refine_factor; + req.lower_bound = vec_req.lower_bound; + req.upper_bound = vec_req.upper_bound; + req + } } diff --git a/rust/lancedb/src/rerankers.rs b/rust/lancedb/src/rerankers.rs index 338230c0..53f54abc 100644 --- a/rust/lancedb/src/rerankers.rs +++ b/rust/lancedb/src/rerankers.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::collections::BTreeSet; +use std::{collections::BTreeSet, str::FromStr}; use arrow::{ array::downcast_array, @@ -24,6 +24,29 @@ pub enum NormalizeMethod { Rank, } +impl FromStr for NormalizeMethod { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "score" => Ok(NormalizeMethod::Score), + "rank" => Ok(NormalizeMethod::Rank), + _ => Err(Error::InvalidInput { + message: format!("invalid normalize method: {}", s), + }), + } + } +} + +impl std::fmt::Display for NormalizeMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NormalizeMethod::Score => write!(f, "score"), + NormalizeMethod::Rank => write!(f, "rank"), + } + } +} + /// Interface for a reranker. A reranker is used to rerank the results from a /// vector and FTS search. This is useful for combining the results from both /// search methods.