feat: add to_query_object method (#2239)

This PR adds a `to_query_object` method to the various query builders
(except not hybrid queries yet). This makes it possible to inspect the
query that is built.

In addition this PR does some normalization between the sync and async
query paths. A few custom defaults were removed in favor of None (with
the default getting set once, in rust).

Also, the synchronous to_batches method will now actually stream results

Also, the remote API now defaults to prefiltering
This commit is contained in:
Weston Pace
2025-03-21 13:01:51 -07:00
committed by GitHub
parent b2a38ac366
commit 9403254442
8 changed files with 867 additions and 177 deletions

View File

@@ -94,6 +94,7 @@ class Query:
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class FTSQuery:
def where(self, filter: str): ...
@@ -108,6 +109,7 @@ class FTSQuery:
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
async def explain_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...
@@ -123,6 +125,7 @@ class VectorQuery:
def nprobes(self, nprobes: int): ...
def bypass_vector_index(self): ...
def nearest_to_text(self, query: dict) -> HybridQuery: ...
def to_query_request(self) -> PyQueryRequest: ...
class HybridQuery:
def where(self, filter: str): ...
@@ -140,6 +143,33 @@ class HybridQuery:
def to_fts_query(self) -> FTSQuery: ...
def get_limit(self) -> int: ...
def get_with_row_id(self) -> bool: ...
def to_query_request(self) -> PyQueryRequest: ...
class PyFullTextSearchQuery:
columns: Optional[List[str]]
query: str
limit: Optional[int]
wand_factor: Optional[float]
class PyQueryRequest:
limit: Optional[int]
offset: Optional[int]
filter: Optional[Union[str, bytes]]
full_text_search: Optional[PyFullTextSearchQuery]
select: Optional[Union[str, List[str]]]
fast_search: Optional[bool]
with_row_id: Optional[bool]
column: Optional[str]
query_vector: Optional[List[pa.Array]]
nprobes: Optional[int]
lower_bound: Optional[float]
upper_bound: Optional[float]
ef: Optional[int]
refine_factor: Optional[int]
distance_type: Optional[str]
bypass_vector_index: Optional[bool]
postfilter: Optional[bool]
norm: Optional[str]
class CompactionStats:
fragments_removed: int

View File

@@ -14,6 +14,7 @@ from typing import (
Tuple,
Type,
Union,
Any,
)
import asyncio
@@ -32,6 +33,8 @@ from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import flatten_columns
from typing_extensions import Annotated
if TYPE_CHECKING:
import sys
import PIL
@@ -41,6 +44,7 @@ if TYPE_CHECKING:
from ._lancedb import FTSQuery as LanceFTSQuery
from ._lancedb import HybridQuery as LanceHybridQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from ._lancedb import PyQueryRequest
from .common import VEC
from .pydantic import LanceModel
from .table import Table
@@ -51,33 +55,116 @@ if TYPE_CHECKING:
from typing_extensions import Self
class Query(pydantic.BaseModel):
"""The LanceDB Query
# Pydantic validation function for vector queries
def ensure_vector_query(
val: Any,
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
if isinstance(val, list):
if len(val) == 0:
return ValueError("Vector query must be a non-empty list")
sample = val[0]
else:
if isinstance(val, float):
raise ValueError(
"Vector query must be a list of floats or a list of lists of floats"
)
sample = val
if isinstance(sample, pa.Array):
# val is array or list of array
return val
if isinstance(sample, list):
if len(sample) == 0:
return ValueError("Vector query must be a non-empty list")
if isinstance(sample[0], float):
# val is list of list of floats
return val
if isinstance(sample, float):
# val is a list of floats
return val
class FullTextSearchQuery(pydantic.BaseModel):
"""A LanceDB Full Text Search Query
Attributes
----------
vector : List[float]
the vector to search for
filter : Optional[str]
sql filter to refine the query with, optional
prefilter : bool
if True then apply the filter before vector search
k : int
top k results to return
metric : str
the distance metric between a pair of vectors,
columns: List[str]
The columns to search
can support l2 (default), Cosine and Dot.
[metric definitions][search]
columns : Optional[List[str]]
If None, then the table should select the column automatically.
query: str
The query to search for
limit: Optional[int] = None
The limit on the number of results to return
wand_factor: Optional[float] = None
The wand factor to use for the search
"""
columns: Optional[List[str]] = None
query: str
limit: Optional[int] = None
wand_factor: Optional[float] = None
class Query(pydantic.BaseModel):
"""A LanceDB Query
Queries are constructed by the `Table.search` and `Table.query` methods. This
class is a python representation of the query. Normally you will not need to
interact with this class directly. You can build up a query and execute it using
collection methods such as `to_batches()`, `to_arrow()`, `to_pandas()`, etc.
However, you can use the `to_query()` method to get the underlying query object.
This can be useful for serializing a query or using it in a different context.
Attributes
----------
filter : Optional[str]
sql filter to refine the query with
limit : Optional[int]
The limit on the number of results to return. If this is a vector or FTS query,
then this is required. If this is a plain SQL query, then this is optional.
offset: Optional[int]
The offset to start fetching results from
This is ignored for vector / FTS search (will be None).
columns : Optional[Union[List[str], Dict[str, str]]]
which columns to return in the results
nprobes : int
The number of probes used - optional
This can be a list of column names or a dictionary. If it is a dictionary,
then the keys are the column names and the values are sql expressions to
use to calculate the result.
If this is None then all columns are returned. This can be expensive.
with_row_id : Optional[bool]
if True then include the row id in the results
vector : Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]]
the vector to search for, if this a vector search or hybrid search. It will
be None for full text search and plain SQL filtering.
vector_column : Optional[str]
the name of the vector column to use for vector search
If this is None then a default vector column will be used.
distance_type : Optional[str]
the distance type to use for vector search
This can be l2 (default), cosine and dot. See [metric definitions][search] for
more details.
If this is not a vector search this will be None.
postfilter : bool
if True then apply the filter after vector / FTS search. This is ignored for
plain SQL filtering.
nprobes : Optional[int]
The number of IVF partitions to search. If this is None then a default
number of partitions will be used.
- A higher number makes search more accurate but also slower.
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
Will be None if this is not a vector search.
refine_factor : Optional[int]
Refine the results by reading extra elements and re-ranking them in memory.
@@ -85,58 +172,130 @@ class Query(pydantic.BaseModel):
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
offset: int
The offset to start fetching results from
fast_search: bool
Will be None if this is not a vector search.
lower_bound : Optional[float]
The lower bound for distance search
Only results with a distance greater than or equal to this value
will be returned.
This will only be set on vector search.
upper_bound : Optional[float]
The upper bound for distance search
Only results with a distance less than or equal to this value
will be returned.
This will only be set on vector search.
ef : Optional[int]
The size of the nearest neighbor list maintained during HNSW search
This will only be set on vector search.
full_text_query : Optional[Union[str, dict]]
The full text search query
This can be a string or a dictionary. A dictionary will be used to search
multiple columns. The keys are the column names and the values are the
search queries.
This will only be set on FTS or hybrid queries.
fast_search: Optional[bool]
Skip a flat search of unindexed data. This will improve
search performance but search results will not include unindexed data.
- *default False*.
The default is False
"""
# The name of the vector column to use for vector search.
vector_column: Optional[str] = None
# vector to search for
vector: Union[List[float], List[List[float]]]
#
# Note: today this will be floats on the sync path and pa.Array on the async
# path though in the future we should unify this to pa.Array everywhere
vector: Annotated[
Optional[Union[List[float], List[List[float]], pa.Array, List[pa.Array]]],
ensure_vector_query,
] = None
# sql filter to refine the query with
filter: Optional[str] = None
# if True then apply the filter before vector search
prefilter: bool = False
# if True then apply the filter after vector search
postfilter: Optional[bool] = None
# full text search query
full_text_query: Optional[Union[str, dict]] = None
full_text_query: Optional[FullTextSearchQuery] = None
# top k results to return
k: Optional[int] = None
limit: Optional[int] = None
# # metrics
metric: str = "l2"
# distance type to use for vector search
distance_type: Optional[str] = None
# which columns to return in the results
columns: Optional[Union[List[str], Dict[str, str]]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
# number of IVF partitions to search
nprobes: Optional[int] = None
# lower bound for distance search
lower_bound: Optional[float] = None
# upper bound for distance search
upper_bound: Optional[float] = None
# Refine factor.
# multiplier for the number of results to inspect for reranking
refine_factor: Optional[int] = None
with_row_id: bool = False
# if true, include the row id in the results
with_row_id: Optional[bool] = None
offset: int = 0
# offset to start fetching results from
offset: Optional[int] = None
fast_search: bool = False
# if true, will only search the indexed data
fast_search: Optional[bool] = None
# size of the nearest neighbor list maintained during HNSW search
ef: Optional[int] = None
# Default is true. Set to false to enforce a brute force search.
use_index: bool = True
# Bypass the vector index and use a brute force search
bypass_vector_index: Optional[bool] = None
@classmethod
def from_inner(cls, req: PyQueryRequest) -> Self:
query = cls()
query.limit = req.limit
query.offset = req.offset
query.filter = req.filter
query.full_text_query = req.full_text_search
query.columns = req.select
query.with_row_id = req.with_row_id
query.vector_column = req.column
query.vector = req.query_vector
query.distance_type = req.distance_type
query.nprobes = req.nprobes
query.lower_bound = req.lower_bound
query.upper_bound = req.upper_bound
query.ef = req.ef
query.refine_factor = req.refine_factor
query.bypass_vector_index = req.bypass_vector_index
query.postfilter = req.postfilter
if req.full_text_search is not None:
query.full_text_query = FullTextSearchQuery(
columns=req.full_text_search.columns,
query=req.full_text_search.query,
limit=req.full_text_search.limit,
wand_factor=req.full_text_search.wand_factor,
)
return query
class Config:
# This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise)
arbitrary_types_allowed = True
class LanceQueryBuilder(ABC):
@@ -152,8 +311,8 @@ class LanceQueryBuilder(ABC):
query_type: str,
vector_column_name: str,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
fast_search: bool = False,
fts_columns: Optional[Union[str, List[str]]] = None,
fast_search: bool = None,
) -> Self:
"""
Create a query builder based on the given query and query type.
@@ -257,15 +416,15 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"):
self._table = table
self._limit = None
self._offset = 0
self._offset = None
self._columns = None
self._where = None
self._prefilter = True
self._with_row_id = False
self._postfilter = None
self._with_row_id = None
self._vector = None
self._text = None
self._ef = None
self._use_index = True
self._bypass_vector_index = None
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -315,7 +474,7 @@ class LanceQueryBuilder(ABC):
raise NotImplementedError
@abstractmethod
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table:
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
"""
Execute the query and return the results as a pyarrow
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
@@ -451,7 +610,7 @@ class LanceQueryBuilder(ABC):
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
self._postfilter = not prefilter
return self
def with_row_id(self, with_row_id: bool) -> Self:
@@ -497,23 +656,7 @@ class LanceQueryBuilder(ABC):
-------
plan : str
""" # noqa: E501
ds = self._table.to_lance()
return ds.scanner(
nearest={
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._distance_type,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
"use_index": self._use_index,
},
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose)
return self._table._explain_plan(self.to_query_object())
def vector(self, vector: Union[np.ndarray, list]) -> Self:
"""Set the vector to search for.
@@ -561,6 +704,17 @@ class LanceQueryBuilder(ABC):
"""
raise NotImplementedError
@abstractmethod
def to_query_object(self) -> Query:
"""Return a serializable representation of the query
Returns
-------
Query
The serializable representation of the query
"""
raise NotImplementedError
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
@@ -590,19 +744,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str,
str_query: Optional[str] = None,
fast_search: bool = False,
fast_search: bool = None,
):
super().__init__(table)
if self._limit is None:
self._limit = 10
self._query = query
self._distance_type = "l2"
self._nprobes = 20
self._distance_type = None
self._nprobes = None
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
self._postfilter = None
self._reranker = None
self._str_query = str_query
self._fast_search = fast_search
@@ -752,6 +904,34 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
return self.to_batches().read_all()
def to_query_object(self) -> Query:
"""
Build a Query object
This can be used to serialize a query
"""
vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
return Query(
vector=vector,
filter=self._where,
postfilter=self._postfilter,
limit=self._limit,
distance_type=self._distance_type,
columns=self._columns,
nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
bypass_vector_index=self._bypass_vector_index,
)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
"""
Execute the query and return the result as a RecordBatchReader object.
@@ -768,24 +948,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
query = Query(
vector=vector,
filter=self._where,
prefilter=self._prefilter,
k=self._limit,
metric=self._distance_type,
columns=self._columns,
nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
use_index=self._use_index,
)
query = self.to_query_object()
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
rs_table = result_set.read_all()
@@ -798,7 +961,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
return result_set
def where(self, where: str, prefilter: bool = True) -> LanceVectorQueryBuilder:
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
@@ -810,8 +973,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
prefilter: bool, default True
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future.
Returns
-------
@@ -819,7 +980,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
if prefilter is not None:
self._postfilter = not prefilter
return self
def rerank(
@@ -873,7 +1035,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder
The LanceVectorQueryBuilder object.
"""
self._use_index = False
self._bypass_vector_index = True
return self
@@ -885,11 +1047,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
table: "Table",
query: str,
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
fts_columns: Optional[Union[str, List[str]]] = None,
):
super().__init__(table)
if self._limit is None:
self._limit = 10
self._query = query
self._phrase_query = False
self.ordering_field_name = ordering_field_name
@@ -915,6 +1075,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self._phrase_query = phrase_query
return self
def to_query_object(self) -> Query:
return Query(
columns=self._columns,
filter=self._where,
limit=self._limit,
postfilter=self._postfilter,
with_row_id=self._with_row_id,
full_text_query=FullTextSearchQuery(
query=self._query, columns=self._fts_columns
),
offset=self._offset,
)
def to_arrow(self) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path()
if exist:
@@ -926,19 +1099,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Phrase query is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
query = Query(
columns=self._columns,
filter=self._where,
k=self._limit,
prefilter=self._prefilter,
with_row_id=self._with_row_id,
full_text_query={
"query": query,
"columns": self._fts_columns,
},
vector=[],
offset=self._offset,
)
query = self.to_query_object()
results = self._table._execute_query(query)
results = results.read_all()
if self._reranker is not None:
@@ -983,8 +1144,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
if self._phrase_query:
query = query.replace('"', "'")
query = f'"{query}"'
limit = self._limit if self._limit is not None else 10
row_ids, scores = search_index(
index, query, self._limit, ordering_field=self.ordering_field_name
index, query, limit, ordering_field=self.ordering_field_name
)
if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("_score", pa.float32())])
@@ -1053,17 +1215,18 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
return self.to_batches().read_all()
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
query = Query(
def to_query_object(self) -> Query:
return Query(
columns=self._columns,
filter=self._where,
k=self._limit,
limit=self._limit,
with_row_id=self._with_row_id,
vector=[],
# not actually respected in remote query
offset=self._offset or 0,
offset=self._offset,
)
return self._table._execute_query(query)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
query = self.to_query_object()
return self._table._execute_query(query, batch_size)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -1098,18 +1261,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
table: "Table",
query: Optional[str] = None,
vector_column: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
fts_columns: Optional[Union[str, List[str]]] = None,
):
super().__init__(table)
self._query = query
self._vector_column = vector_column
self._fts_columns = fts_columns
self._norm = "score"
self._reranker = RRFReranker()
self._norm = None
self._reranker = None
self._nprobes = None
self._refine_factor = None
self._distance_type = None
self._phrase_query = False
self._phrase_query = None
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
@@ -1131,7 +1294,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return vector_query, text_query
def phrase_query(self, phrase_query: bool = True) -> LanceHybridQueryBuilder:
def phrase_query(self, phrase_query: bool = None) -> LanceHybridQueryBuilder:
"""Set whether to use phrase query.
Parameters
@@ -1148,6 +1311,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._phrase_query = phrase_query
return self
def to_query_object(self) -> Query:
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
def to_arrow(self) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
@@ -1169,8 +1335,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, self._prefilter)
self._fts_query.where(self._where, self._prefilter)
self._vector_query.where(self._where, self._postfilter)
self._fts_query.where(self._where, self._postfilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)
@@ -1184,9 +1350,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
if not self._use_index:
if self._bypass_vector_index:
self._vector_query.bypass_vector_index()
if self._reranker is None:
self._reranker = RRFReranker()
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
vector_future = executor.submit(
@@ -1502,12 +1671,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._use_index = False
self._bypass_vector_index = True
return self
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
"""
Construct an AsyncQueryBase
@@ -1516,6 +1685,9 @@ class AsyncQueryBase(object):
"""
self._inner = inner
def to_query_object(self) -> Query:
return Query.from_inner(self._inner.to_query_request())
def where(self, predicate: str) -> Self:
"""
Only return rows matching the given predicate
@@ -1868,7 +2040,7 @@ class AsyncQuery(AsyncQueryBase):
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []
self, query: str, columns: Union[str, List[str], None] = None
) -> AsyncFTSQuery:
"""
Find the documents that are most relevant to the given text query.
@@ -1892,6 +2064,8 @@ class AsyncQuery(AsyncQueryBase):
"""
if isinstance(columns, str):
columns = [columns]
if columns is None:
columns = []
return AsyncFTSQuery(
self._inner.nearest_to_text({"query": query, "columns": columns})
)
@@ -2177,7 +2351,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
return self
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []
self, query: str, columns: Union[str, List[str], None] = None
) -> AsyncHybridQuery:
"""
Find the documents that are most relevant to the given text query,
@@ -2205,6 +2379,8 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
"""
if isinstance(columns, str):
columns = [columns]
if columns is None:
columns = []
return AsyncHybridQuery(
self._inner.nearest_to_text({"query": query, "columns": columns})
)

View File

@@ -282,7 +282,8 @@ class RemoteTable(Table):
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
All query options are defined in [Query][lancedb.query.Query].
All query options are defined in
[LanceVectorQueryBuilder][lancedb.query.LanceVectorQueryBuilder].
Examples
--------
@@ -353,7 +354,16 @@ class RemoteTable(Table):
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
return LOOP.run(self._table._execute_query(query, batch_size=batch_size))
async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
def iter_sync():
try:
while True:
yield LOOP.run(async_iter.__anext__())
except StopAsyncIteration:
return
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]

View File

@@ -101,7 +101,9 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
schema = data.features.arrow_schema
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
elif isinstance(data, datasets.dataset_dict.DatasetDict):
schema = _schema_from_hf(data, schema)
schema = _schema_from_hf(data, None)
if "split" not in schema.names:
schema = schema.append(pa.field("split", pa.string()))
return pa.RecordBatchReader.from_batches(
schema, _to_batches_with_split(data)
)
@@ -415,7 +417,7 @@ def sanitize_create_table(
return data, schema
def _schema_from_hf(data, schema):
def _schema_from_hf(data, schema) -> pa.Schema:
"""
Extract pyarrow schema from HuggingFace DatasetDict
and validate that they're all the same schema between
@@ -927,7 +929,8 @@ class Table(ABC):
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [Query][lancedb.query.Query].
All query options are defined in
[LanceQueryBuilder][lancedb.query.LanceQueryBuilder].
Examples
--------
@@ -2278,7 +2281,19 @@ class LanceTable(Table):
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
return LOOP.run(self._table._execute_query(query, batch_size))
async_iter = LOOP.run(self._table._execute_query(query, batch_size))
def iter_sync():
try:
while True:
yield LOOP.run(async_iter.__anext__())
except StopAsyncIteration:
return
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
def _explain_plan(self, query: Query) -> str:
return LOOP.run(self._table._explain_plan(query))
def _do_merge(
self,
@@ -3053,7 +3068,7 @@ class AsyncTable:
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]: ...
@overload
async def search(
@@ -3102,7 +3117,7 @@ class AsyncTable:
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
) -> Union[AsyncHybridQuery, AsyncFTSQuery, AsyncVectorQuery]:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
@@ -3264,12 +3279,12 @@ class AsyncTable:
builder = builder.column(vector_column_name)
return builder
elif query_type == "fts":
return self.query().nearest_to_text(query, columns=fts_columns or [])
return self.query().nearest_to_text(query, columns=fts_columns)
elif query_type == "hybrid":
builder = self.query().nearest_to(vector_query)
if vector_column_name:
builder = builder.column(vector_column_name)
return builder.nearest_to_text(query, columns=fts_columns or [])
return builder.nearest_to_text(query, columns=fts_columns)
else:
raise ValueError(f"Unknown query type: '{query_type}'")
@@ -3286,16 +3301,13 @@ class AsyncTable:
"""
return self.query().nearest_to(query_vector)
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
# The sync remote table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
# used for that code path right now.
def _sync_query_to_async(
self, query: Query
) -> AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery | AsyncQuery:
async_query = self.query()
if query.k is not None:
async_query = async_query.limit(query.k)
if query.offset > 0:
if query.limit is not None:
async_query = async_query.limit(query.limit)
if query.offset is not None:
async_query = async_query.offset(query.offset)
if query.columns:
async_query = async_query.select(query.columns)
@@ -3307,35 +3319,49 @@ class AsyncTable:
async_query = async_query.with_row_id()
if query.vector:
# we need the schema to get the vector column type
# to determine whether the vectors is batch queries or not
async_query = (
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
.distance_range(query.lower_bound, query.upper_bound)
async_query = async_query.nearest_to(query.vector).distance_range(
query.lower_bound, query.upper_bound
)
if query.refine_factor:
if query.distance_type is not None:
async_query = async_query.distance_type(query.distance_type)
if query.nprobes is not None:
async_query = async_query.nprobes(query.nprobes)
if query.refine_factor is not None:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if query.ef:
async_query = async_query.ef(query.ef)
if not query.use_index:
if query.bypass_vector_index:
async_query = async_query.bypass_vector_index()
if not query.prefilter:
if query.postfilter:
async_query = async_query.postfilter()
if isinstance(query.full_text_query, str):
async_query = async_query.nearest_to_text(query.full_text_query)
elif isinstance(query.full_text_query, dict):
fts_query = query.full_text_query["query"]
fts_columns = query.full_text_query.get("columns", []) or []
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
if query.full_text_query:
async_query = async_query.nearest_to_text(
query.full_text_query.query, query.full_text_query.columns
)
if query.full_text_query.limit is not None:
async_query = async_query.limit(query.full_text_query.limit)
table = await async_query.to_arrow()
return table.to_reader()
return async_query
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
# The sync table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
# used for that code path right now.
async_query = self._sync_query_to_async(query)
return await async_query.to_batches(max_batch_length=batch_size)
async def _explain_plan(self, query: Query) -> str:
# This method is used by the sync table
async_query = self._sync_query_to_async(query)
return await async_query.explain_plan()
async def _do_merge(
self,

View File

@@ -26,10 +26,12 @@ from lancedb.query import (
AsyncVectorQuery,
LanceVectorQueryBuilder,
Query,
FullTextSearchQuery,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
from importlib.util import find_spec
@pytest.fixture(scope="module")
@@ -392,12 +394,28 @@ def test_query_builder_batches(table):
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
assert len(rs_list) == 2
assert len(rs_list[0]["id"]) == 1
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
assert rs_list[0].to_pandas()["id"][0] == 1
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
assert rs_list[0].to_pandas()["id"][1] == 2
assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0])
assert rs_list[1].to_pandas()["id"][0] == 2
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(2)
.select(["id", "vector"])
.to_batches(2)
)
rs_list = []
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
rs_list = rs_list[0].to_pandas()
assert rs_list["id"][0] == 1
assert rs_list["id"][1] == 2
def test_dynamic_projection(table):
@@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column():
Query(
vector=query,
filter="b < 10",
prefilter=True,
k=2,
metric="cosine",
limit=2,
distance_type="cosine",
columns=["b"],
nprobes=20,
refine_factor=None,
vector_column="foo_vector",
),
None,
@@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable):
@pytest.mark.asyncio
@pytest.mark.slow
async def test_query_reranked_async(table_async: AsyncTable):
# CrossEncoderReranker requires torch
if find_spec("torch") is None:
pytest.skip("torch not installed")
# FTS with rerank
await table_async.create_index("text", config=FTS(with_position=False))
await check_query(
@@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection):
assert "No embedding functions are registered for any columns" in exception_output(
e
)
# Helper method used in the following tests. Looks at the simple python object `q` and
# checks that the properties match the expected values in kwargs.
def check_set_props(q, **kwargs):
for k in dict(q):
if not k.startswith("_"):
if k in kwargs:
assert kwargs[k] == getattr(q, k), (
f"{k} should be {kwargs[k]} but is {getattr(q, k)}"
)
else:
assert getattr(q, k) is None, f"{k} should be None"
def test_query_serialization_sync(table: lancedb.table.Table):
# Simple queries
q = table.search().where("id = 1").limit(500).offset(10).to_query_object()
check_set_props(q, limit=500, offset=10, filter="id = 1")
q = table.search().select(["id", "vector"]).to_query_object()
check_set_props(q, columns=["id", "vector"])
q = table.search().with_row_id(True).to_query_object()
check_set_props(q, with_row_id=True)
# Vector queries
q = table.search([5.0, 6.0]).limit(10).to_query_object()
check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0])
q = table.search([5.0, 6.0]).to_query_object()
check_set_props(q, vector_column="vector", vector=[5.0, 6.0])
q = (
table.search([5.0, 6.0])
.limit(10)
.where("id = 1", prefilter=False)
.to_query_object()
)
check_set_props(
q,
limit=10,
vector_column="vector",
filter="id = 1",
postfilter=True,
vector=[5.0, 6.0],
)
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
check_set_props(
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
)
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
check_set_props(
q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0
)
q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object()
check_set_props(
q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0]
)
q = table.search([5.0, 6.0]).ef(7).to_query_object()
check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0])
q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object()
check_set_props(
q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0]
)
# FTS queries
q = table.search("foo").limit(10).to_query_object()
check_set_props(
q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo")
)
q = table.search("foo", query_type="fts").to_query_object()
check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo"))
@pytest.mark.asyncio
async def test_query_serialization_async(table_async: AsyncTable):
# Simple queries
q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object()
check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False)
q = table_async.query().select(["id", "vector"]).to_query_object()
check_set_props(q, columns=["id", "vector"], with_row_id=False)
q = table_async.query().with_row_id().to_query_object()
check_set_props(q, with_row_id=True)
sample_vector = [pa.array([5.0, 6.0], type=pa.float32())]
# Vector queries
q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object()
check_set_props(
q,
limit=10,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
q = (await table_async.search([5.0, 6.0])).to_query_object()
check_set_props(
q,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (
(await table_async.search([5.0, 6.0]))
.limit(10)
.where("id = 1")
.postfilter()
.to_query_object()
)
check_set_props(
q,
limit=10,
filter="id = 1",
postfilter=True,
vector=sample_vector,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
q = (
(await table_async.search([5.0, 6.0]))
.nprobes(10)
.refine_factor(5)
.to_query_object()
)
check_set_props(
q,
vector=sample_vector,
nprobes=10,
refine_factor=5,
postfilter=False,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (
(await table_async.search([5.0, 6.0]))
.distance_range(0.0, 1.0)
.to_query_object()
)
check_set_props(
q,
vector=sample_vector,
lower_bound=0.0,
upper_bound=1.0,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object()
check_set_props(
q,
distance_type="cosine",
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object()
check_set_props(
q,
ef=7,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object()
check_set_props(
q,
bypass_vector_index=True,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
limit=10,
)
# FTS queries
q = (await table_async.search("foo")).limit(10).to_query_object()
check_set_props(
q,
limit=10,
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)
q = (await table_async.search("foo", query_type="fts")).to_query_object()
check_set_props(
q,
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)

View File

@@ -315,7 +315,7 @@ def test_query_sync_minimal():
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"prefilter": True,
"refine_factor": None,
"lower_bound": None,
"upper_bound": None,
@@ -340,7 +340,7 @@ def test_query_sync_empty_query():
"filter": "true",
"vector": [],
"columns": ["id"],
"prefilter": False,
"prefilter": True,
"version": None,
}
@@ -478,7 +478,7 @@ def test_query_sync_hybrid():
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": False,
"prefilter": True,
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,