mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: add to_query_object method (#2239)
This PR adds a `to_query_object` method to the various query builders (except not hybrid queries yet). This makes it possible to inspect the query that is built. In addition this PR does some normalization between the sync and async query paths. A few custom defaults were removed in favor of None (with the default getting set once, in rust). Also, the synchronous to_batches method will now actually stream results Also, the remote API now defaults to prefiltering
This commit is contained in:
@@ -94,6 +94,7 @@ class Query:
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
@@ -108,6 +109,7 @@ class FTSQuery:
|
||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
|
||||
async def explain_plan(self) -> str: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class VectorQuery:
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
@@ -123,6 +125,7 @@ class VectorQuery:
|
||||
def nprobes(self, nprobes: int): ...
|
||||
def bypass_vector_index(self): ...
|
||||
def nearest_to_text(self, query: dict) -> HybridQuery: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class HybridQuery:
|
||||
def where(self, filter: str): ...
|
||||
@@ -140,6 +143,33 @@ class HybridQuery:
|
||||
def to_fts_query(self) -> FTSQuery: ...
|
||||
def get_limit(self) -> int: ...
|
||||
def get_with_row_id(self) -> bool: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class PyFullTextSearchQuery:
|
||||
columns: Optional[List[str]]
|
||||
query: str
|
||||
limit: Optional[int]
|
||||
wand_factor: Optional[float]
|
||||
|
||||
class PyQueryRequest:
|
||||
limit: Optional[int]
|
||||
offset: Optional[int]
|
||||
filter: Optional[Union[str, bytes]]
|
||||
full_text_search: Optional[PyFullTextSearchQuery]
|
||||
select: Optional[Union[str, List[str]]]
|
||||
fast_search: Optional[bool]
|
||||
with_row_id: Optional[bool]
|
||||
column: Optional[str]
|
||||
query_vector: Optional[List[pa.Array]]
|
||||
nprobes: Optional[int]
|
||||
lower_bound: Optional[float]
|
||||
upper_bound: Optional[float]
|
||||
ef: Optional[int]
|
||||
refine_factor: Optional[int]
|
||||
distance_type: Optional[str]
|
||||
bypass_vector_index: Optional[bool]
|
||||
postfilter: Optional[bool]
|
||||
norm: Optional[str]
|
||||
|
||||
class CompactionStats:
|
||||
fragments_removed: int
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
Any,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
@@ -32,6 +33,8 @@ from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
import PIL
|
||||
@@ -41,6 +44,7 @@ if TYPE_CHECKING:
|
||||
from ._lancedb import FTSQuery as LanceFTSQuery
|
||||
from ._lancedb import HybridQuery as LanceHybridQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from ._lancedb import PyQueryRequest
|
||||
from .common import VEC
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
@@ -51,33 +55,116 @@ if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""The LanceDB Query
|
||||
# Pydantic validation function for vector queries
|
||||
def ensure_vector_query(
|
||||
val: Any,
|
||||
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
|
||||
if isinstance(val, list):
|
||||
if len(val) == 0:
|
||||
return ValueError("Vector query must be a non-empty list")
|
||||
sample = val[0]
|
||||
else:
|
||||
if isinstance(val, float):
|
||||
raise ValueError(
|
||||
"Vector query must be a list of floats or a list of lists of floats"
|
||||
)
|
||||
sample = val
|
||||
if isinstance(sample, pa.Array):
|
||||
# val is array or list of array
|
||||
return val
|
||||
if isinstance(sample, list):
|
||||
if len(sample) == 0:
|
||||
return ValueError("Vector query must be a non-empty list")
|
||||
if isinstance(sample[0], float):
|
||||
# val is list of list of floats
|
||||
return val
|
||||
if isinstance(sample, float):
|
||||
# val is a list of floats
|
||||
return val
|
||||
|
||||
|
||||
class FullTextSearchQuery(pydantic.BaseModel):
|
||||
"""A LanceDB Full Text Search Query
|
||||
|
||||
Attributes
|
||||
----------
|
||||
vector : List[float]
|
||||
the vector to search for
|
||||
filter : Optional[str]
|
||||
sql filter to refine the query with, optional
|
||||
prefilter : bool
|
||||
if True then apply the filter before vector search
|
||||
k : int
|
||||
top k results to return
|
||||
metric : str
|
||||
the distance metric between a pair of vectors,
|
||||
columns: List[str]
|
||||
The columns to search
|
||||
|
||||
can support l2 (default), Cosine and Dot.
|
||||
[metric definitions][search]
|
||||
columns : Optional[List[str]]
|
||||
If None, then the table should select the column automatically.
|
||||
query: str
|
||||
The query to search for
|
||||
limit: Optional[int] = None
|
||||
The limit on the number of results to return
|
||||
wand_factor: Optional[float] = None
|
||||
The wand factor to use for the search
|
||||
"""
|
||||
|
||||
columns: Optional[List[str]] = None
|
||||
query: str
|
||||
limit: Optional[int] = None
|
||||
wand_factor: Optional[float] = None
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""A LanceDB Query
|
||||
|
||||
Queries are constructed by the `Table.search` and `Table.query` methods. This
|
||||
class is a python representation of the query. Normally you will not need to
|
||||
interact with this class directly. You can build up a query and execute it using
|
||||
collection methods such as `to_batches()`, `to_arrow()`, `to_pandas()`, etc.
|
||||
|
||||
However, you can use the `to_query()` method to get the underlying query object.
|
||||
This can be useful for serializing a query or using it in a different context.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filter : Optional[str]
|
||||
sql filter to refine the query with
|
||||
limit : Optional[int]
|
||||
The limit on the number of results to return. If this is a vector or FTS query,
|
||||
then this is required. If this is a plain SQL query, then this is optional.
|
||||
offset: Optional[int]
|
||||
The offset to start fetching results from
|
||||
|
||||
This is ignored for vector / FTS search (will be None).
|
||||
columns : Optional[Union[List[str], Dict[str, str]]]
|
||||
which columns to return in the results
|
||||
nprobes : int
|
||||
The number of probes used - optional
|
||||
|
||||
This can be a list of column names or a dictionary. If it is a dictionary,
|
||||
then the keys are the column names and the values are sql expressions to
|
||||
use to calculate the result.
|
||||
|
||||
If this is None then all columns are returned. This can be expensive.
|
||||
with_row_id : Optional[bool]
|
||||
if True then include the row id in the results
|
||||
vector : Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]]
|
||||
the vector to search for, if this a vector search or hybrid search. It will
|
||||
be None for full text search and plain SQL filtering.
|
||||
vector_column : Optional[str]
|
||||
the name of the vector column to use for vector search
|
||||
|
||||
If this is None then a default vector column will be used.
|
||||
distance_type : Optional[str]
|
||||
the distance type to use for vector search
|
||||
|
||||
This can be l2 (default), cosine and dot. See [metric definitions][search] for
|
||||
more details.
|
||||
|
||||
If this is not a vector search this will be None.
|
||||
postfilter : bool
|
||||
if True then apply the filter after vector / FTS search. This is ignored for
|
||||
plain SQL filtering.
|
||||
nprobes : Optional[int]
|
||||
The number of IVF partitions to search. If this is None then a default
|
||||
number of partitions will be used.
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
|
||||
Will be None if this is not a vector search.
|
||||
refine_factor : Optional[int]
|
||||
Refine the results by reading extra elements and re-ranking them in memory.
|
||||
|
||||
@@ -85,58 +172,130 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
offset: int
|
||||
The offset to start fetching results from
|
||||
fast_search: bool
|
||||
|
||||
Will be None if this is not a vector search.
|
||||
lower_bound : Optional[float]
|
||||
The lower bound for distance search
|
||||
|
||||
Only results with a distance greater than or equal to this value
|
||||
will be returned.
|
||||
|
||||
This will only be set on vector search.
|
||||
upper_bound : Optional[float]
|
||||
The upper bound for distance search
|
||||
|
||||
Only results with a distance less than or equal to this value
|
||||
will be returned.
|
||||
|
||||
This will only be set on vector search.
|
||||
ef : Optional[int]
|
||||
The size of the nearest neighbor list maintained during HNSW search
|
||||
|
||||
This will only be set on vector search.
|
||||
full_text_query : Optional[Union[str, dict]]
|
||||
The full text search query
|
||||
|
||||
This can be a string or a dictionary. A dictionary will be used to search
|
||||
multiple columns. The keys are the column names and the values are the
|
||||
search queries.
|
||||
|
||||
This will only be set on FTS or hybrid queries.
|
||||
fast_search: Optional[bool]
|
||||
Skip a flat search of unindexed data. This will improve
|
||||
search performance but search results will not include unindexed data.
|
||||
|
||||
- *default False*.
|
||||
The default is False
|
||||
"""
|
||||
|
||||
# The name of the vector column to use for vector search.
|
||||
vector_column: Optional[str] = None
|
||||
|
||||
# vector to search for
|
||||
vector: Union[List[float], List[List[float]]]
|
||||
#
|
||||
# Note: today this will be floats on the sync path and pa.Array on the async
|
||||
# path though in the future we should unify this to pa.Array everywhere
|
||||
vector: Annotated[
|
||||
Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]],
|
||||
ensure_vector_query,
|
||||
] = None
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# if True then apply the filter before vector search
|
||||
prefilter: bool = False
|
||||
# if True then apply the filter after vector search
|
||||
postfilter: Optional[bool] = None
|
||||
|
||||
# full text search query
|
||||
full_text_query: Optional[Union[str, dict]] = None
|
||||
full_text_query: Optional[FullTextSearchQuery] = None
|
||||
|
||||
# top k results to return
|
||||
k: Optional[int] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
# # metrics
|
||||
metric: str = "l2"
|
||||
# distance type to use for vector search
|
||||
distance_type: Optional[str] = None
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
# number of IVF partitions to search
|
||||
nprobes: Optional[int] = None
|
||||
|
||||
# lower bound for distance search
|
||||
lower_bound: Optional[float] = None
|
||||
|
||||
# upper bound for distance search
|
||||
upper_bound: Optional[float] = None
|
||||
|
||||
# Refine factor.
|
||||
# multiplier for the number of results to inspect for reranking
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
with_row_id: bool = False
|
||||
# if true, include the row id in the results
|
||||
with_row_id: Optional[bool] = None
|
||||
|
||||
offset: int = 0
|
||||
# offset to start fetching results from
|
||||
offset: Optional[int] = None
|
||||
|
||||
fast_search: bool = False
|
||||
# if true, will only search the indexed data
|
||||
fast_search: Optional[bool] = None
|
||||
|
||||
# size of the nearest neighbor list maintained during HNSW search
|
||||
ef: Optional[int] = None
|
||||
|
||||
# Default is true. Set to false to enforce a brute force search.
|
||||
use_index: bool = True
|
||||
# Bypass the vector index and use a brute force search
|
||||
bypass_vector_index: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def from_inner(cls, req: PyQueryRequest) -> Self:
|
||||
query = cls()
|
||||
query.limit = req.limit
|
||||
query.offset = req.offset
|
||||
query.filter = req.filter
|
||||
query.full_text_query = req.full_text_search
|
||||
query.columns = req.select
|
||||
query.with_row_id = req.with_row_id
|
||||
query.vector_column = req.column
|
||||
query.vector = req.query_vector
|
||||
query.distance_type = req.distance_type
|
||||
query.nprobes = req.nprobes
|
||||
query.lower_bound = req.lower_bound
|
||||
query.upper_bound = req.upper_bound
|
||||
query.ef = req.ef
|
||||
query.refine_factor = req.refine_factor
|
||||
query.bypass_vector_index = req.bypass_vector_index
|
||||
query.postfilter = req.postfilter
|
||||
if req.full_text_search is not None:
|
||||
query.full_text_query = FullTextSearchQuery(
|
||||
columns=req.full_text_search.columns,
|
||||
query=req.full_text_search.query,
|
||||
limit=req.full_text_search.limit,
|
||||
wand_factor=req.full_text_search.wand_factor,
|
||||
)
|
||||
return query
|
||||
|
||||
class Config:
|
||||
# This tells pydantic to allow custom types (needed for the `vector` query since
|
||||
# pa.Array wouln't be allowed otherwise)
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
@@ -152,8 +311,8 @@ class LanceQueryBuilder(ABC):
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
fast_search: bool = False,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
fast_search: bool = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Create a query builder based on the given query and query type.
|
||||
@@ -257,15 +416,15 @@ class LanceQueryBuilder(ABC):
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = None
|
||||
self._offset = 0
|
||||
self._offset = None
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._prefilter = True
|
||||
self._with_row_id = False
|
||||
self._postfilter = None
|
||||
self._with_row_id = None
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
self._use_index = True
|
||||
self._bypass_vector_index = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -315,7 +474,7 @@ class LanceQueryBuilder(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table:
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the results as a pyarrow
|
||||
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
|
||||
@@ -451,7 +610,7 @@ class LanceQueryBuilder(ABC):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
self._postfilter = not prefilter
|
||||
return self
|
||||
|
||||
def with_row_id(self, with_row_id: bool) -> Self:
|
||||
@@ -497,23 +656,7 @@ class LanceQueryBuilder(ABC):
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
ds = self._table.to_lance()
|
||||
return ds.scanner(
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._distance_type,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
"use_index": self._use_index,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
).explain_plan(verbose)
|
||||
return self._table._explain_plan(self.to_query_object())
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> Self:
|
||||
"""Set the vector to search for.
|
||||
@@ -561,6 +704,17 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_query_object(self) -> Query:
|
||||
"""Return a serializable representation of the query
|
||||
|
||||
Returns
|
||||
-------
|
||||
Query
|
||||
The serializable representation of the query
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -590,19 +744,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
fast_search: bool = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
if self._limit is None:
|
||||
self._limit = 10
|
||||
self._query = query
|
||||
self._distance_type = "l2"
|
||||
self._nprobes = 20
|
||||
self._distance_type = None
|
||||
self._nprobes = None
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
self._postfilter = None
|
||||
self._reranker = None
|
||||
self._str_query = str_query
|
||||
self._fast_search = fast_search
|
||||
@@ -752,6 +904,34 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
return self.to_batches().read_all()
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
Build a Query object
|
||||
|
||||
This can be used to serialize a query
|
||||
"""
|
||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||
if isinstance(vector[0], np.ndarray):
|
||||
vector = [v.tolist() for v in vector]
|
||||
return Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
postfilter=self._postfilter,
|
||||
limit=self._limit,
|
||||
distance_type=self._distance_type,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
bypass_vector_index=self._bypass_vector_index,
|
||||
)
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the result as a RecordBatchReader object.
|
||||
@@ -768,24 +948,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||
if isinstance(vector[0], np.ndarray):
|
||||
vector = [v.tolist() for v in vector]
|
||||
query = Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
prefilter=self._prefilter,
|
||||
k=self._limit,
|
||||
metric=self._distance_type,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
use_index=self._use_index,
|
||||
)
|
||||
query = self.to_query_object()
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
@@ -798,7 +961,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return result_set
|
||||
|
||||
def where(self, where: str, prefilter: bool = True) -> LanceVectorQueryBuilder:
|
||||
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
@@ -810,8 +973,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
prefilter: bool, default True
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -819,7 +980,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
if prefilter is not None:
|
||||
self._postfilter = not prefilter
|
||||
return self
|
||||
|
||||
def rerank(
|
||||
@@ -873,7 +1035,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceVectorQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
self._bypass_vector_index = True
|
||||
return self
|
||||
|
||||
|
||||
@@ -885,11 +1047,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
table: "Table",
|
||||
query: str,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
if self._limit is None:
|
||||
self._limit = 10
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
self.ordering_field_name = ordering_field_name
|
||||
@@ -915,6 +1075,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
self._phrase_query = phrase_query
|
||||
return self
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
return Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
postfilter=self._postfilter,
|
||||
with_row_id=self._with_row_id,
|
||||
full_text_query=FullTextSearchQuery(
|
||||
query=self._query, columns=self._fts_columns
|
||||
),
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
@@ -926,19 +1099,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"Phrase query is not yet supported in Lance FTS. "
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
k=self._limit,
|
||||
prefilter=self._prefilter,
|
||||
with_row_id=self._with_row_id,
|
||||
full_text_query={
|
||||
"query": query,
|
||||
"columns": self._fts_columns,
|
||||
},
|
||||
vector=[],
|
||||
offset=self._offset,
|
||||
)
|
||||
query = self.to_query_object()
|
||||
results = self._table._execute_query(query)
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
@@ -983,8 +1144,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
if self._phrase_query:
|
||||
query = query.replace('"', "'")
|
||||
query = f'"{query}"'
|
||||
limit = self._limit if self._limit is not None else 10
|
||||
row_ids, scores = search_index(
|
||||
index, query, self._limit, ordering_field=self.ordering_field_name
|
||||
index, query, limit, ordering_field=self.ordering_field_name
|
||||
)
|
||||
if len(row_ids) == 0:
|
||||
empty_schema = pa.schema([pa.field("_score", pa.float32())])
|
||||
@@ -1053,17 +1215,18 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.to_batches().read_all()
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
query = Query(
|
||||
def to_query_object(self) -> Query:
|
||||
return Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
k=self._limit,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
offset=self._offset or 0,
|
||||
offset=self._offset,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
query = self.to_query_object()
|
||||
return self._table._execute_query(query, batch_size)
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
@@ -1098,18 +1261,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
table: "Table",
|
||||
query: Optional[str] = None,
|
||||
vector_column: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._vector_column = vector_column
|
||||
self._fts_columns = fts_columns
|
||||
self._norm = "score"
|
||||
self._reranker = RRFReranker()
|
||||
self._norm = None
|
||||
self._reranker = None
|
||||
self._nprobes = None
|
||||
self._refine_factor = None
|
||||
self._distance_type = None
|
||||
self._phrase_query = False
|
||||
self._phrase_query = None
|
||||
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
@@ -1131,7 +1294,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return vector_query, text_query
|
||||
|
||||
def phrase_query(self, phrase_query: bool = True) -> LanceHybridQueryBuilder:
|
||||
def phrase_query(self, phrase_query: bool = None) -> LanceHybridQueryBuilder:
|
||||
"""Set whether to use phrase query.
|
||||
|
||||
Parameters
|
||||
@@ -1148,6 +1311,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._phrase_query = phrase_query
|
||||
return self
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
vector_query, fts_query = self._validate_query(
|
||||
self._query, self._vector, self._text
|
||||
@@ -1169,8 +1335,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.select(self._columns)
|
||||
self._fts_query.select(self._columns)
|
||||
if self._where:
|
||||
self._vector_query.where(self._where, self._prefilter)
|
||||
self._fts_query.where(self._where, self._prefilter)
|
||||
self._vector_query.where(self._where, self._postfilter)
|
||||
self._fts_query.where(self._where, self._postfilter)
|
||||
if self._with_row_id:
|
||||
self._vector_query.with_row_id(True)
|
||||
self._fts_query.with_row_id(True)
|
||||
@@ -1184,9 +1350,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
self._vector_query.ef(self._ef)
|
||||
if not self._use_index:
|
||||
if self._bypass_vector_index:
|
||||
self._vector_query.bypass_vector_index()
|
||||
|
||||
if self._reranker is None:
|
||||
self._reranker = RRFReranker()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
vector_future = executor.submit(
|
||||
@@ -1502,12 +1671,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
self._bypass_vector_index = True
|
||||
return self
|
||||
|
||||
|
||||
class AsyncQueryBase(object):
|
||||
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
|
||||
"""
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
@@ -1516,6 +1685,9 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
self._inner = inner
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
return Query.from_inner(self._inner.to_query_request())
|
||||
|
||||
def where(self, predicate: str) -> Self:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
@@ -1868,7 +2040,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
self, query: str, columns: Union[str, List[str], None] = None
|
||||
) -> AsyncFTSQuery:
|
||||
"""
|
||||
Find the documents that are most relevant to the given text query.
|
||||
@@ -1892,6 +2064,8 @@ class AsyncQuery(AsyncQueryBase):
|
||||
"""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
if columns is None:
|
||||
columns = []
|
||||
return AsyncFTSQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
@@ -2177,7 +2351,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
return self
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
self, query: str, columns: Union[str, List[str], None] = None
|
||||
) -> AsyncHybridQuery:
|
||||
"""
|
||||
Find the documents that are most relevant to the given text query,
|
||||
@@ -2205,6 +2379,8 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
"""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
if columns is None:
|
||||
columns = []
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
|
||||
@@ -282,7 +282,8 @@ class RemoteTable(Table):
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
All query options are defined in
|
||||
[LanceVectorQueryBuilder][lancedb.query.LanceVectorQueryBuilder].
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -353,7 +354,16 @@ class RemoteTable(Table):
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
return LOOP.run(self._table._execute_query(query, batch_size=batch_size))
|
||||
async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
while True:
|
||||
yield LOOP.run(async_iter.__anext__())
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
|
||||
@@ -101,7 +101,9 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
schema = data.features.arrow_schema
|
||||
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
|
||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
schema = _schema_from_hf(data, schema)
|
||||
schema = _schema_from_hf(data, None)
|
||||
if "split" not in schema.names:
|
||||
schema = schema.append(pa.field("split", pa.string()))
|
||||
return pa.RecordBatchReader.from_batches(
|
||||
schema, _to_batches_with_split(data)
|
||||
)
|
||||
@@ -415,7 +417,7 @@ def sanitize_create_table(
|
||||
return data, schema
|
||||
|
||||
|
||||
def _schema_from_hf(data, schema):
|
||||
def _schema_from_hf(data, schema) -> pa.Schema:
|
||||
"""
|
||||
Extract pyarrow schema from HuggingFace DatasetDict
|
||||
and validate that they're all the same schema between
|
||||
@@ -927,7 +929,8 @@ class Table(ABC):
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
All query options are defined in
|
||||
[LanceQueryBuilder][lancedb.query.LanceQueryBuilder].
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -2278,7 +2281,19 @@ class LanceTable(Table):
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
return LOOP.run(self._table._execute_query(query, batch_size))
|
||||
async_iter = LOOP.run(self._table._execute_query(query, batch_size))
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
while True:
|
||||
yield LOOP.run(async_iter.__anext__())
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
|
||||
|
||||
def _explain_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._explain_plan(query))
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
@@ -3053,7 +3068,7 @@ class AsyncTable:
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
|
||||
) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
@@ -3102,7 +3117,7 @@ class AsyncTable:
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
|
||||
) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
@@ -3264,12 +3279,12 @@ class AsyncTable:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder
|
||||
elif query_type == "fts":
|
||||
return self.query().nearest_to_text(query, columns=fts_columns or [])
|
||||
return self.query().nearest_to_text(query, columns=fts_columns)
|
||||
elif query_type == "hybrid":
|
||||
builder = self.query().nearest_to(vector_query)
|
||||
if vector_column_name:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder.nearest_to_text(query, columns=fts_columns or [])
|
||||
return builder.nearest_to_text(query, columns=fts_columns)
|
||||
else:
|
||||
raise ValueError(f"Unknown query type: '{query_type}'")
|
||||
|
||||
@@ -3286,16 +3301,13 @@ class AsyncTable:
|
||||
"""
|
||||
return self.query().nearest_to(query_vector)
|
||||
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
# The sync remote table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
# used for that code path right now.
|
||||
def _sync_query_to_async(
|
||||
self, query: Query
|
||||
) -> AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery | AsyncQuery:
|
||||
async_query = self.query()
|
||||
if query.k is not None:
|
||||
async_query = async_query.limit(query.k)
|
||||
if query.offset > 0:
|
||||
if query.limit is not None:
|
||||
async_query = async_query.limit(query.limit)
|
||||
if query.offset is not None:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
async_query = async_query.select(query.columns)
|
||||
@@ -3307,35 +3319,49 @@ class AsyncTable:
|
||||
async_query = async_query.with_row_id()
|
||||
|
||||
if query.vector:
|
||||
# we need the schema to get the vector column type
|
||||
# to determine whether the vectors is batch queries or not
|
||||
async_query = (
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
.distance_range(query.lower_bound, query.upper_bound)
|
||||
async_query = async_query.nearest_to(query.vector).distance_range(
|
||||
query.lower_bound, query.upper_bound
|
||||
)
|
||||
if query.refine_factor:
|
||||
if query.distance_type is not None:
|
||||
async_query = async_query.distance_type(query.distance_type)
|
||||
if query.nprobes is not None:
|
||||
async_query = async_query.nprobes(query.nprobes)
|
||||
if query.refine_factor is not None:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
async_query = async_query.column(query.vector_column)
|
||||
if query.ef:
|
||||
async_query = async_query.ef(query.ef)
|
||||
if not query.use_index:
|
||||
if query.bypass_vector_index:
|
||||
async_query = async_query.bypass_vector_index()
|
||||
|
||||
if not query.prefilter:
|
||||
if query.postfilter:
|
||||
async_query = async_query.postfilter()
|
||||
|
||||
if isinstance(query.full_text_query, str):
|
||||
async_query = async_query.nearest_to_text(query.full_text_query)
|
||||
elif isinstance(query.full_text_query, dict):
|
||||
fts_query = query.full_text_query["query"]
|
||||
fts_columns = query.full_text_query.get("columns", []) or []
|
||||
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
|
||||
if query.full_text_query:
|
||||
async_query = async_query.nearest_to_text(
|
||||
query.full_text_query.query, query.full_text_query.columns
|
||||
)
|
||||
if query.full_text_query.limit is not None:
|
||||
async_query = async_query.limit(query.full_text_query.limit)
|
||||
|
||||
table = await async_query.to_arrow()
|
||||
return table.to_reader()
|
||||
return async_query
|
||||
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
# The sync table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
# used for that code path right now.
|
||||
|
||||
async_query = self._sync_query_to_async(query)
|
||||
|
||||
return await async_query.to_batches(max_batch_length=batch_size)
|
||||
|
||||
async def _explain_plan(self, query: Query) -> str:
|
||||
# This method is used by the sync table
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.explain_plan()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
|
||||
@@ -26,10 +26,12 @@ from lancedb.query import (
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
FullTextSearchQuery,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -392,12 +394,28 @@ def test_query_builder_batches(table):
|
||||
for item in rs:
|
||||
rs_list.append(item)
|
||||
assert isinstance(item, pa.RecordBatch)
|
||||
assert len(rs_list) == 1
|
||||
assert len(rs_list[0]["id"]) == 2
|
||||
assert len(rs_list) == 2
|
||||
assert len(rs_list[0]["id"]) == 1
|
||||
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
|
||||
assert rs_list[0].to_pandas()["id"][0] == 1
|
||||
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
|
||||
assert rs_list[0].to_pandas()["id"][1] == 2
|
||||
assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0])
|
||||
assert rs_list[1].to_pandas()["id"][0] == 2
|
||||
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(2)
|
||||
.select(["id", "vector"])
|
||||
.to_batches(2)
|
||||
)
|
||||
rs_list = []
|
||||
for item in rs:
|
||||
rs_list.append(item)
|
||||
assert isinstance(item, pa.RecordBatch)
|
||||
assert len(rs_list) == 1
|
||||
assert len(rs_list[0]["id"]) == 2
|
||||
rs_list = rs_list[0].to_pandas()
|
||||
assert rs_list["id"][0] == 1
|
||||
assert rs_list["id"][1] == 2
|
||||
|
||||
|
||||
def test_dynamic_projection(table):
|
||||
@@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column():
|
||||
Query(
|
||||
vector=query,
|
||||
filter="b < 10",
|
||||
prefilter=True,
|
||||
k=2,
|
||||
metric="cosine",
|
||||
limit=2,
|
||||
distance_type="cosine",
|
||||
columns=["b"],
|
||||
nprobes=20,
|
||||
refine_factor=None,
|
||||
vector_column="foo_vector",
|
||||
),
|
||||
None,
|
||||
@@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_query_reranked_async(table_async: AsyncTable):
|
||||
# CrossEncoderReranker requires torch
|
||||
if find_spec("torch") is None:
|
||||
pytest.skip("torch not installed")
|
||||
|
||||
# FTS with rerank
|
||||
await table_async.create_index("text", config=FTS(with_position=False))
|
||||
await check_query(
|
||||
@@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
|
||||
# Helper method used in the following tests. Looks at the simple python object `q` and
|
||||
# checks that the properties match the expected values in kwargs.
|
||||
def check_set_props(q, **kwargs):
|
||||
for k in dict(q):
|
||||
if not k.startswith("_"):
|
||||
if k in kwargs:
|
||||
assert kwargs[k] == getattr(q, k), (
|
||||
f"{k} should be {kwargs[k]} but is {getattr(q, k)}"
|
||||
)
|
||||
else:
|
||||
assert getattr(q, k) is None, f"{k} should be None"
|
||||
|
||||
|
||||
def test_query_serialization_sync(table: lancedb.table.Table):
|
||||
# Simple queries
|
||||
q = table.search().where("id = 1").limit(500).offset(10).to_query_object()
|
||||
check_set_props(q, limit=500, offset=10, filter="id = 1")
|
||||
|
||||
q = table.search().select(["id", "vector"]).to_query_object()
|
||||
check_set_props(q, columns=["id", "vector"])
|
||||
|
||||
q = table.search().with_row_id(True).to_query_object()
|
||||
check_set_props(q, with_row_id=True)
|
||||
|
||||
# Vector queries
|
||||
q = table.search([5.0, 6.0]).limit(10).to_query_object()
|
||||
check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = table.search([5.0, 6.0]).to_query_object()
|
||||
check_set_props(q, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = (
|
||||
table.search([5.0, 6.0])
|
||||
.limit(10)
|
||||
.where("id = 1", prefilter=False)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
vector_column="vector",
|
||||
filter="id = 1",
|
||||
postfilter=True,
|
||||
vector=[5.0, 6.0],
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
|
||||
check_set_props(
|
||||
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
|
||||
check_set_props(
|
||||
q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object()
|
||||
check_set_props(
|
||||
q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0]
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).ef(7).to_query_object()
|
||||
check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object()
|
||||
check_set_props(
|
||||
q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0]
|
||||
)
|
||||
|
||||
# FTS queries
|
||||
q = table.search("foo").limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo")
|
||||
)
|
||||
|
||||
q = table.search("foo", query_type="fts").to_query_object()
|
||||
check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_serialization_async(table_async: AsyncTable):
|
||||
# Simple queries
|
||||
q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object()
|
||||
check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False)
|
||||
|
||||
q = table_async.query().select(["id", "vector"]).to_query_object()
|
||||
check_set_props(q, columns=["id", "vector"], with_row_id=False)
|
||||
|
||||
q = table_async.query().with_row_id().to_query_object()
|
||||
check_set_props(q, with_row_id=True)
|
||||
|
||||
sample_vector = [pa.array([5.0, 6.0], type=pa.float32())]
|
||||
|
||||
# Vector queries
|
||||
q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.limit(10)
|
||||
.where("id = 1")
|
||||
.postfilter()
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
filter="id = 1",
|
||||
postfilter=True,
|
||||
vector=sample_vector,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.nprobes(10)
|
||||
.refine_factor(5)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
nprobes=10,
|
||||
refine_factor=5,
|
||||
postfilter=False,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.distance_range(0.0, 1.0)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
lower_bound=0.0,
|
||||
upper_bound=1.0,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
distance_type="cosine",
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
ef=7,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
bypass_vector_index=True,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# FTS queries
|
||||
q = (await table_async.search("foo")).limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
q = (await table_async.search("foo", query_type="fts")).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
@@ -315,7 +315,7 @@ def test_query_sync_minimal():
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"prefilter": True,
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
@@ -340,7 +340,7 @@ def test_query_sync_empty_query():
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
"prefilter": False,
|
||||
"prefilter": True,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -478,7 +478,7 @@ def test_query_sync_hybrid():
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 42,
|
||||
"prefilter": False,
|
||||
"prefilter": True,
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
|
||||
Reference in New Issue
Block a user