mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 10:22:56 +00:00
feat: add timeout to query execution options (#2288)
Closes #2287 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added configurable timeout support for query executions. Users can now specify maximum wait times for queries, enhancing control over long-running operations across various integrations. - **Tests** - Expanded test coverage to validate timeout behavior in both synchronous and asynchronous query flows, ensuring timely error responses when query execution exceeds the specified limit. - Introduced a new test suite to verify query operations when a timeout is reached, checking for appropriate error handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -94,7 +95,9 @@ class Query:
|
||||
def postfilter(self): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
async def explain_plan(self, verbose: Optional[bool]) -> str: ...
|
||||
async def analyze_plan(self) -> str: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
@@ -110,7 +113,9 @@ class FTSQuery:
|
||||
def get_query(self) -> str: ...
|
||||
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class VectorQuery:
|
||||
|
||||
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
||||
import abc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
@@ -650,7 +651,12 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
|
||||
def to_pandas(
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
*,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
@@ -664,12 +670,15 @@ class LanceQueryBuilder(ABC):
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
tbl = flatten_columns(self.to_arrow(), flatten)
|
||||
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
|
||||
return tbl.to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
@@ -677,34 +686,65 @@ class LanceQueryBuilder(ABC):
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
def to_batches(
|
||||
self,
|
||||
/,
|
||||
batch_size: Optional[int] = None,
|
||||
*,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the results as a pyarrow
|
||||
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
The maximum number of selected records in a RecordBatch object.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_list(self) -> List[dict]:
|
||||
def to_list(self, *, timeout: Optional[timedelta] = None) -> List[dict]:
|
||||
"""
|
||||
Execute the query and return the results as a list of dictionaries.
|
||||
|
||||
Each list entry is a dictionary with the selected column names as keys,
|
||||
or all table columns if `select` is not called. The vector and the "_distance"
|
||||
fields are returned whether or not they're explicitly selected.
|
||||
"""
|
||||
return self.to_arrow().to_pylist()
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
return self.to_arrow(timeout=timeout).to_pylist()
|
||||
|
||||
def to_pydantic(
|
||||
self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None
|
||||
) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -712,19 +752,25 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
for row in self.to_arrow(timeout=timeout).to_pylist()
|
||||
]
|
||||
|
||||
def to_polars(self) -> "pl.DataFrame":
|
||||
def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a Polars DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
return pl.from_arrow(self.to_arrow())
|
||||
return pl.from_arrow(self.to_arrow(timeout=timeout))
|
||||
|
||||
def limit(self, limit: Union[int, None]) -> Self:
|
||||
"""Set the maximum number of results to return.
|
||||
@@ -1139,7 +1185,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
@@ -1147,8 +1193,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
"""
|
||||
return self.to_batches().read_all()
|
||||
return self.to_batches(timeout=timeout).read_all()
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
@@ -1178,7 +1230,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
bypass_vector_index=self._bypass_vector_index,
|
||||
)
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
def to_batches(
|
||||
self,
|
||||
/,
|
||||
batch_size: Optional[int] = None,
|
||||
*,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the result as a RecordBatchReader object.
|
||||
|
||||
@@ -1186,6 +1244,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
----------
|
||||
batch_size: int
|
||||
The maximum number of selected records in a RecordBatch object.
|
||||
timeout: timedelta, default None
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -1195,7 +1256,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
if isinstance(vector[0], np.ndarray):
|
||||
vector = [v.tolist() for v in vector]
|
||||
query = self.to_query_object()
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
result_set = self._table._execute_query(
|
||||
query, batch_size=batch_size, timeout=timeout
|
||||
)
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
@@ -1334,7 +1397,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
return self.tantivy_to_arrow()
|
||||
@@ -1346,14 +1409,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"Use tantivy-based index instead for now."
|
||||
)
|
||||
query = self.to_query_object()
|
||||
results = self._table._execute_query(query)
|
||||
results = self._table._execute_query(query, timeout=timeout)
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
check_reranker_result(results)
|
||||
return results
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None):
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
):
|
||||
raise NotImplementedError("to_batches on an FTS query")
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
@@ -1458,8 +1523,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.to_batches().read_all()
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
return self.to_batches(timeout=timeout).read_all()
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
return Query(
|
||||
@@ -1470,9 +1535,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
query = self.to_query_object()
|
||||
return self._table._execute_query(query, batch_size)
|
||||
return self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
@@ -1560,7 +1627,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def to_query_object(self) -> Query:
|
||||
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
vector_query, fts_query = self._validate_query(
|
||||
self._query, self._vector, self._text
|
||||
)
|
||||
@@ -1603,9 +1670,11 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._reranker = RRFReranker()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
fts_future = executor.submit(
|
||||
self._fts_query.with_row_id(True).to_arrow, timeout=timeout
|
||||
)
|
||||
vector_future = executor.submit(
|
||||
self._vector_query.with_row_id(True).to_arrow
|
||||
self._vector_query.with_row_id(True).to_arrow, timeout=timeout
|
||||
)
|
||||
fts_results = fts_future.result()
|
||||
vector_results = vector_future.result()
|
||||
@@ -1692,7 +1761,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return results
|
||||
|
||||
def to_batches(self):
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
):
|
||||
raise NotImplementedError("to_batches not yet supported on a hybrid query")
|
||||
|
||||
@staticmethod
|
||||
@@ -2056,7 +2127,10 @@ class AsyncQueryBase(object):
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
self,
|
||||
*,
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the results as an Apache Arrow RecordBatchReader.
|
||||
@@ -2069,34 +2143,56 @@ class AsyncQueryBase(object):
|
||||
If not specified, a default batch length is used.
|
||||
It is possible for batches to be smaller than the provided length if the
|
||||
underlying data is stored in smaller chunks.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
|
||||
return AsyncRecordBatchReader(
|
||||
await self._inner.execute(max_batch_length, timeout)
|
||||
)
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
batch_iter = await self.to_batches()
|
||||
batch_iter = await self.to_batches(timeout=timeout)
|
||||
return pa.Table.from_batches(
|
||||
await batch_iter.read_all(), schema=batch_iter.schema
|
||||
)
|
||||
|
||||
async def to_list(self) -> List[dict]:
|
||||
async def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]:
|
||||
"""
|
||||
Execute the query and return the results as a list of dictionaries.
|
||||
|
||||
Each list entry is a dictionary with the selected column names as keys,
|
||||
or all table columns if `select` is not called. The vector and the "_distance"
|
||||
fields are returned whether or not they're explicitly selected.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return (await self.to_arrow()).to_pylist()
|
||||
return (await self.to_arrow(timeout=timeout)).to_pylist()
|
||||
|
||||
async def to_pandas(
|
||||
self, flatten: Optional[Union[int, bool]] = None
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
@@ -2125,10 +2221,19 @@ class AsyncQueryBase(object):
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
|
||||
return (
|
||||
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
).to_pandas()
|
||||
|
||||
async def to_polars(self) -> "pl.DataFrame":
|
||||
async def to_polars(
|
||||
self,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a Polars DataFrame.
|
||||
|
||||
@@ -2137,6 +2242,13 @@ class AsyncQueryBase(object):
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
|
||||
polars separately.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -2152,7 +2264,7 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
return pl.from_arrow(await self.to_arrow())
|
||||
return pl.from_arrow(await self.to_arrow(timeout=timeout))
|
||||
|
||||
async def explain_plan(self, verbose: Optional[bool] = False):
|
||||
"""Return the execution plan for this query.
|
||||
@@ -2423,9 +2535,12 @@ class AsyncFTSQuery(AsyncQueryBase):
|
||||
)
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
self,
|
||||
*,
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
reader = await super().to_batches()
|
||||
reader = await super().to_batches(timeout=timeout)
|
||||
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
|
||||
if self._reranker:
|
||||
results = self._reranker.rerank_fts(self.get_query(), results)
|
||||
@@ -2649,9 +2764,12 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
self,
|
||||
*,
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
reader = await super().to_batches()
|
||||
reader = await super().to_batches(timeout=timeout)
|
||||
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
|
||||
if self._reranker:
|
||||
results = self._reranker.rerank_vector(self._query_string, results)
|
||||
@@ -2707,7 +2825,10 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
self,
|
||||
*,
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
|
||||
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
|
||||
@@ -2719,8 +2840,8 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
vec_query.with_row_id()
|
||||
|
||||
fts_results, vector_results = await asyncio.gather(
|
||||
fts_query.to_arrow(),
|
||||
vec_query.to_arrow(),
|
||||
fts_query.to_arrow(timeout=timeout),
|
||||
vec_query.to_arrow(timeout=timeout),
|
||||
)
|
||||
|
||||
result = LanceHybridQueryBuilder._combine_hybrid_results(
|
||||
|
||||
@@ -355,9 +355,15 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
self,
|
||||
query: Query,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
|
||||
async_iter = LOOP.run(
|
||||
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
|
||||
)
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
|
||||
@@ -1007,7 +1007,11 @@ class Table(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
self,
|
||||
query: Query,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader: ...
|
||||
|
||||
@abstractmethod
|
||||
@@ -2312,9 +2316,15 @@ class LanceTable(Table):
|
||||
LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
self,
|
||||
query: Query,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
async_iter = LOOP.run(self._table._execute_query(query, batch_size))
|
||||
async_iter = LOOP.run(
|
||||
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
|
||||
)
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
@@ -3390,7 +3400,11 @@ class AsyncTable:
|
||||
return async_query
|
||||
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
self,
|
||||
query: Query,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
# The sync table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
@@ -3398,7 +3412,9 @@ class AsyncTable:
|
||||
|
||||
async_query = self._sync_query_to_async(query)
|
||||
|
||||
return await async_query.to_batches(max_batch_length=batch_size)
|
||||
return await async_query.to_batches(
|
||||
max_batch_length=batch_size, timeout=timeout
|
||||
)
|
||||
|
||||
async def _explain_plan(self, query: Query, verbose: Optional[bool]) -> str:
|
||||
# This method is used by the sync table
|
||||
|
||||
@@ -511,7 +511,8 @@ def test_query_builder_with_different_vector_column():
|
||||
columns=["b"],
|
||||
vector_column="foo_vector",
|
||||
),
|
||||
None,
|
||||
batch_size=None,
|
||||
timeout=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1076,3 +1077,67 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
|
||||
def test_query_timeout(tmp_path):
|
||||
# Use local directory instead of memory:// to add a bit of latency to
|
||||
# operations so a timeout of zero will trigger exceptions.
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["a", "b"],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(4).cast(pa.float32()), 2
|
||||
),
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
table.search().where("text = 'a'").to_list(timeout=timedelta(0))
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
table.search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
table.search("a", query_type="fts").to_pandas(timeout=timedelta(0))
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
table.search(query_type="hybrid").vector([0.0, 0.0]).text("a").to_arrow(
|
||||
timeout=timedelta(0)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["a", "b"],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(4).cast(pa.float32()), 2
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await db.create_table("test", data)
|
||||
await table.create_index("text", config=FTS())
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
await table.query().where("text != 'a'").to_list(timeout=timedelta(0))
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
await table.vector_search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
await (await table.search("a", query_type="fts")).to_pandas(
|
||||
timeout=timedelta(0)
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
await (
|
||||
table.query()
|
||||
.nearest_to_text("a")
|
||||
.nearest_to([0.0, 0.0])
|
||||
.to_list(timeout=timedelta(0))
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::Array;
|
||||
@@ -294,10 +295,11 @@ impl Query {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None))]
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
timeout: Option<Duration>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -305,6 +307,9 @@ impl Query {
|
||||
if let Some(max_batch_length) = max_batch_length {
|
||||
opts.max_batch_length = max_batch_length;
|
||||
}
|
||||
if let Some(timeout) = timeout {
|
||||
opts.timeout = Some(timeout);
|
||||
}
|
||||
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
@@ -376,10 +381,11 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None))]
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
timeout: Option<Duration>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_
|
||||
.inner
|
||||
@@ -391,6 +397,9 @@ impl FTSQuery {
|
||||
if let Some(max_batch_length) = max_batch_length {
|
||||
opts.max_batch_length = max_batch_length;
|
||||
}
|
||||
if let Some(timeout) = timeout {
|
||||
opts.timeout = Some(timeout);
|
||||
}
|
||||
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
@@ -513,10 +522,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None))]
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
timeout: Option<Duration>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -524,6 +534,9 @@ impl VectorQuery {
|
||||
if let Some(max_batch_length) = max_batch_length {
|
||||
opts.max_batch_length = max_batch_length;
|
||||
}
|
||||
if let Some(timeout) = timeout {
|
||||
opts.timeout = Some(timeout);
|
||||
}
|
||||
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user