feat: add timeout to query execution options (#2288)

Closes #2287


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added configurable timeout support for query executions. Users can now
specify maximum wait times for queries, enhancing control over
long-running operations across various integrations.
- **Tests**
- Expanded test coverage to validate timeout behavior in both
synchronous and asynchronous query flows, ensuring timely error
responses when query execution exceeds the specified limit.
- Introduced a new test suite to verify query operations when a timeout
is reached, checking for appropriate error handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Will Jones
2025-04-04 12:34:41 -07:00
committed by GitHub
parent a38f784081
commit 1cd76b8498
17 changed files with 653 additions and 95 deletions

8
Cargo.lock generated
View File

@@ -4110,7 +4110,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow",
"arrow-array",
@@ -4197,7 +4197,7 @@ dependencies = [
[[package]]
name = "lancedb-node"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4222,7 +4222,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4240,7 +4240,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.22.0-beta.3"
version = "0.22.0-beta.4"
dependencies = [
"arrow",
"env_logger",

View File

@@ -20,3 +20,13 @@ The maximum number of rows to return in a single batch
Batches may have fewer rows if the underlying data is stored
in smaller chunks.
***
### timeoutMs?
```ts
optional timeoutMs: number;
```
Timeout for query execution in milliseconds

60
node/package-lock.json generated
View File

@@ -326,6 +326,66 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.4.tgz",
"integrity": "sha512-uS5AuT3Q4swrtM9JAhF8mM8Nt+kvewmB3DQWGiuYbhmMismSu8WlOHQAs9Yyh8N7NBdWENSTjroSExqjHPdFhQ==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.4.tgz",
"integrity": "sha512-kjn3iTqZSx57ek9PN2AdPvJMx14tFkXc8sUFd3MLhY7FdWafx7Wvl0SLz2LubotJVFd6LMxvsPPNJEM5bEgMOw==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.4.tgz",
"integrity": "sha512-iZlR7ffKC+XA1mGuuwXJojgFcUvXkgMt6pKR6lP3hsxXh8UOTWDljN7jkI8jKHcJez3rrqoqt1VjH3xD69fwtA==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.4.tgz",
"integrity": "sha512-uxLeerlT5FuWzuvHlTDLdLCakyUJ+qJitReoCKT6tKhfcjIkbr+NEoLZEHifJC4dRFPtbddVgiYN6VHlnPPD/w==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.4.tgz",
"integrity": "sha512-QSugxudXooLCF7trudaAo9PfOzX7SFBIiHOoL4N6nwjC61u/JAsoiytw1Xjs/+0pOG5cT2WUMufBzBPgJyOxbw==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": {
"version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -867,6 +867,44 @@ describe("When creating an index", () => {
});
});
describe("When querying a table", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should throw an error when timeout is reached", async () => {
const db = await connect(tmpDir.name);
const data = makeArrowTable([
{ text: "a", vector: [0.1, 0.2] },
{ text: "b", vector: [0.3, 0.4] },
]);
const table = await db.createTable("test", data);
await table.createIndex("text", { config: Index.fts() });
await expect(
table.query().where("text != 'a'").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.query().nearestTo([0.0, 0.0]).toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.search("a", "fts").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table
.query()
.nearestToText("a")
.nearestTo([0.0, 0.0])
.toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
});
});
describe("Read consistency interval", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {

View File

@@ -63,7 +63,7 @@ class RecordBatchIterable<
// biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength),
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
);
}
}
@@ -79,6 +79,11 @@ export interface QueryExecutionOptions {
* in smaller chunks.
*/
maxBatchLength?: number;
/**
* Timeout for query execution in milliseconds
*/
timeoutMs?: number;
}
/**
@@ -283,9 +288,11 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
if (this.inner instanceof Promise) {
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
return this.inner.then((inner) =>
inner.execute(options?.maxBatchLength, options?.timeoutMs),
);
} else {
return this.inner.execute(options?.maxBatchLength);
return this.inner.execute(options?.maxBatchLength, options?.timeoutMs);
}
}

View File

@@ -131,11 +131,15 @@ impl Query {
pub async fn execute(
&self,
max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
@@ -330,11 +334,15 @@ impl VectorQuery {
pub async fn execute(
&self,
max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)

View File

@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa
@@ -94,7 +95,9 @@ class Query:
def postfilter(self): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
async def explain_plan(self, verbose: Optional[bool]) -> str: ...
async def analyze_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ...
@@ -110,7 +113,9 @@ class FTSQuery:
def get_query(self) -> str: ...
def add_query_vector(self, query_vec: pa.Array) -> None: ...
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery:

View File

@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
import abc
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Dict,
@@ -650,7 +651,12 @@ class LanceQueryBuilder(ABC):
"""
return self.to_pandas()
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
def to_pandas(
self,
flatten: Optional[Union[int, bool]] = None,
*,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
@@ -664,12 +670,15 @@ class LanceQueryBuilder(ABC):
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
tbl = flatten_columns(self.to_arrow(), flatten)
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
return tbl.to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -677,34 +686,65 @@ class LanceQueryBuilder(ABC):
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
raise NotImplementedError
@abstractmethod
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
"""
Execute the query and return the results as a pyarrow
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
Parameters
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
raise NotImplementedError
def to_list(self) -> List[dict]:
def to_list(self, *, timeout: Optional[timedelta] = None) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
"""
return self.to_arrow().to_pylist()
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
return self.to_arrow(timeout=timeout).to_pylist()
def to_pydantic(
self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None
) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns
-------
@@ -712,19 +752,25 @@ class LanceQueryBuilder(ABC):
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
for row in self.to_arrow(timeout=timeout).to_pylist()
]
def to_polars(self) -> "pl.DataFrame":
def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame":
"""
Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
import polars as pl
return pl.from_arrow(self.to_arrow())
return pl.from_arrow(self.to_arrow(timeout=timeout))
def limit(self, limit: Union[int, None]) -> Self:
"""Set the maximum number of results to return.
@@ -1139,7 +1185,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -1147,8 +1193,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
return self.to_batches().read_all()
return self.to_batches(timeout=timeout).read_all()
def to_query_object(self) -> Query:
"""
@@ -1178,7 +1230,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
bypass_vector_index=self._bypass_vector_index,
)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
"""
Execute the query and return the result as a RecordBatchReader object.
@@ -1186,6 +1244,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
timeout: timedelta, default None
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns
-------
@@ -1195,7 +1256,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
query = self.to_query_object()
result_set = self._table._execute_query(query, batch_size)
result_set = self._table._execute_query(
query, batch_size=batch_size, timeout=timeout
)
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
@@ -1334,7 +1397,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path()
if exist:
return self.tantivy_to_arrow()
@@ -1346,14 +1409,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Use tantivy-based index instead for now."
)
query = self.to_query_object()
results = self._table._execute_query(query)
results = self._table._execute_query(query, timeout=timeout)
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
check_reranker_result(results)
return results
def to_batches(self, /, batch_size: Optional[int] = None):
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches on an FTS query")
def tantivy_to_arrow(self) -> pa.Table:
@@ -1458,8 +1523,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
return self.to_batches().read_all()
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
return self.to_batches(timeout=timeout).read_all()
def to_query_object(self) -> Query:
return Query(
@@ -1470,9 +1535,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
) -> pa.RecordBatchReader:
query = self.to_query_object()
return self._table._execute_query(query, batch_size)
return self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -1560,7 +1627,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def to_query_object(self) -> Query:
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
@@ -1603,9 +1670,11 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._reranker = RRFReranker()
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
fts_future = executor.submit(
self._fts_query.with_row_id(True).to_arrow, timeout=timeout
)
vector_future = executor.submit(
self._vector_query.with_row_id(True).to_arrow
self._vector_query.with_row_id(True).to_arrow, timeout=timeout
)
fts_results = fts_future.result()
vector_results = vector_future.result()
@@ -1692,7 +1761,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return results
def to_batches(self):
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches not yet supported on a hybrid query")
@staticmethod
@@ -2056,7 +2127,10 @@ class AsyncQueryBase(object):
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
@@ -2069,34 +2143,56 @@ class AsyncQueryBase(object):
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
return AsyncRecordBatchReader(
await self._inner.execute(max_batch_length, timeout)
)
async def to_arrow(self) -> pa.Table:
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use
[to_batches][lancedb.query.AsyncQueryBase.to_batches]
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
batch_iter = await self.to_batches()
batch_iter = await self.to_batches(timeout=timeout)
return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema
)
async def to_list(self) -> List[dict]:
async def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return (await self.to_arrow()).to_pylist()
return (await self.to_arrow(timeout=timeout)).to_pylist()
async def to_pandas(
self, flatten: Optional[Union[int, bool]] = None
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
@@ -2125,10 +2221,19 @@ class AsyncQueryBase(object):
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
return (
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
).to_pandas()
async def to_polars(self) -> "pl.DataFrame":
async def to_polars(
self,
timeout: Optional[timedelta] = None,
) -> "pl.DataFrame":
"""
Execute the query and collect the results into a Polars DataFrame.
@@ -2137,6 +2242,13 @@ class AsyncQueryBase(object):
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
polars separately.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
Examples
--------
@@ -2152,7 +2264,7 @@ class AsyncQueryBase(object):
"""
import polars as pl
return pl.from_arrow(await self.to_arrow())
return pl.from_arrow(await self.to_arrow(timeout=timeout))
async def explain_plan(self, verbose: Optional[bool] = False):
"""Return the execution plan for this query.
@@ -2423,9 +2535,12 @@ class AsyncFTSQuery(AsyncQueryBase):
)
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
reader = await super().to_batches()
reader = await super().to_batches(timeout=timeout)
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker:
results = self._reranker.rerank_fts(self.get_query(), results)
@@ -2649,9 +2764,12 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
reader = await super().to_batches()
reader = await super().to_batches(timeout=timeout)
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker:
results = self._reranker.rerank_vector(self._query_string, results)
@@ -2707,7 +2825,10 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
@@ -2719,8 +2840,8 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
vec_query.with_row_id()
fts_results, vector_results = await asyncio.gather(
fts_query.to_arrow(),
vec_query.to_arrow(),
fts_query.to_arrow(timeout=timeout),
vec_query.to_arrow(timeout=timeout),
)
result = LanceHybridQueryBuilder._combine_hybrid_results(

View File

@@ -355,9 +355,15 @@ class RemoteTable(Table):
)
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
async_iter = LOOP.run(
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync():
try:

View File

@@ -1007,7 +1007,11 @@ class Table(ABC):
@abstractmethod
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ...
@abstractmethod
@@ -2312,9 +2316,15 @@ class LanceTable(Table):
LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
async_iter = LOOP.run(self._table._execute_query(query, batch_size))
async_iter = LOOP.run(
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync():
try:
@@ -3390,7 +3400,11 @@ class AsyncTable:
return async_query
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
# The sync table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
@@ -3398,7 +3412,9 @@ class AsyncTable:
async_query = self._sync_query_to_async(query)
return await async_query.to_batches(max_batch_length=batch_size)
return await async_query.to_batches(
max_batch_length=batch_size, timeout=timeout
)
async def _explain_plan(self, query: Query, verbose: Optional[bool]) -> str:
# This method is used by the sync table

View File

@@ -511,7 +511,8 @@ def test_query_builder_with_different_vector_column():
columns=["b"],
vector_column="foo_vector",
),
None,
batch_size=None,
timeout=None,
)
@@ -1076,3 +1077,67 @@ async def test_query_serialization_async(table_async: AsyncTable):
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)
def test_query_timeout(tmp_path):
# Use local directory instead of memory:// to add a bit of latency to
# operations so a timeout of zero will trigger exceptions.
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", use_tantivy=False)
with pytest.raises(Exception, match="Query timeout"):
table.search().where("text = 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search("a", query_type="fts").to_pandas(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search(query_type="hybrid").vector([0.0, 0.0]).text("a").to_arrow(
timeout=timedelta(0)
)
@pytest.mark.asyncio
async def test_query_timeout_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = await db.create_table("test", data)
await table.create_index("text", config=FTS())
with pytest.raises(Exception, match="Query timeout"):
await table.query().where("text != 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await table.vector_search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await (await table.search("a", query_type="fts")).to_pandas(
timeout=timedelta(0)
)
with pytest.raises(Exception, match="Query timeout"):
await (
table.query()
.nearest_to_text("a")
.nearest_to([0.0, 0.0])
.to_list(timeout=timedelta(0))
)

View File

@@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use std::time::Duration;
use arrow::array::make_array;
use arrow::array::Array;
@@ -294,10 +295,11 @@ impl Query {
})
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
@@ -305,6 +307,9 @@ impl Query {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
@@ -376,10 +381,11 @@ impl FTSQuery {
self.inner = self.inner.clone().postfilter();
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_
.inner
@@ -391,6 +397,9 @@ impl FTSQuery {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
@@ -513,10 +522,11 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
@@ -524,6 +534,9 @@ impl VectorQuery {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})

View File

@@ -1,8 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::future::Future;
use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
@@ -25,6 +25,7 @@ use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::DistanceType;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
@@ -525,12 +526,15 @@ pub struct QueryExecutionOptions {
///
/// By default, this is 1024
pub max_batch_length: u32,
/// Max duration to wait for the query to execute before timing out.
pub timeout: Option<Duration>,
}
impl Default for QueryExecutionOptions {
fn default() -> Self {
Self {
max_batch_length: 1024,
timeout: None,
}
}
}
@@ -1007,7 +1011,10 @@ impl VectorQuery {
self
}
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
pub async fn execute_hybrid(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
@@ -1016,7 +1023,10 @@ impl VectorQuery {
let mut vector_query = self.clone().with_row_id();
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
)?;
let (fts_results, vec_results) = try_join!(
fts_results.try_collect::<Vec<_>>(),
@@ -1074,6 +1084,20 @@ impl VectorQuery {
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
}
async fn inner_execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner).into())
}
}
impl ExecutableQuery for VectorQuery {
@@ -1087,16 +1111,13 @@ impl ExecutableQuery for VectorQuery {
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
if self.request.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
let hybrid_result = async move { self.execute_hybrid(options).await }
.boxed()
.await?;
return Ok(hybrid_result);
}
Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
))
self.inner_execute_with_options(options).await
}
async fn explain_plan(&self, verbose: bool) -> Result<String> {

View File

@@ -13,7 +13,7 @@ use reqwest::{
use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions;
const REQUEST_ID_HEADER: &str = "x-request-id";
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
/// Configuration for the LanceDB Cloud HTTP client.
#[derive(Clone, Debug)]
@@ -299,7 +299,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
"x-api-key",
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
@@ -307,7 +307,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
"Host",
http::header::HOST,
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
@@ -315,7 +315,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if has_host_override {
headers.insert(
"x-lancedb-database",
HeaderName::from_static("x-lancedb-database"),
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
@@ -323,7 +323,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if db_prefix.is_some() {
headers.insert(
"x-lancedb-database-prefix",
HeaderName::from_static("x-lancedb-database-prefix"),
HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput {
message: format!(
"non-ascii database prefix '{}' provided",
@@ -335,7 +335,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if let Some(v) = options.0.get("account_name") {
headers.insert(
"x-azure-storage-account-name",
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,
@@ -343,7 +343,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if let Some(v) = options.0.get("azure_storage_account_name") {
headers.insert(
"x-azure-storage-account-name",
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,

View File

@@ -20,7 +20,7 @@ use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
use futures::TryStreamExt;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use http::{HeaderName, StatusCode};
use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
@@ -44,6 +44,8 @@ use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)]
@@ -332,9 +334,19 @@ impl<S: HttpSend> RemoteTable<S> {
async fn execute_query(
&self,
query: &AnyQuery,
_options: QueryExecutionOptions,
options: &QueryExecutionOptions,
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
if let Some(timeout) = options.timeout {
// Client side timeout
request = request.timeout(timeout);
// Also send to server, so it can abort the query if it takes too long.
// (If it doesn't fit into u64, it's not worth sending anyways.)
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
request = request.header(REQUEST_TIMEOUT_HEADER, timeout_ms);
}
}
let query_bodies = self.prepare_query_bodies(query).await?;
let requests: Vec<reqwest::RequestBuilder> = query_bodies
@@ -543,7 +555,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let streams = self.execute_query(query, options).await?;
let streams = self.execute_query(query, &options).await?;
if streams.len() == 1 {
let stream = streams.into_iter().next().unwrap();
Ok(Arc::new(OneShotExec::new(stream)))
@@ -559,9 +571,9 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
async fn query(
&self,
query: &AnyQuery,
_options: QueryExecutionOptions,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let streams = self.execute_query(query, _options).await?;
let streams = self.execute_query(query, &options).await?;
if streams.len() == 1 {
Ok(DatasetRecordBatchStream::new(

View File

@@ -68,7 +68,7 @@ use crate::query::{
use crate::utils::{
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
PatchReadParam, PatchWriteParam,
PatchReadParam, PatchWriteParam, TimeoutStream,
};
use self::dataset::DatasetConsistencyWrapper;
@@ -1775,11 +1775,14 @@ impl NativeTable {
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let plan = self.create_plan(query, options).await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
let plan = self.create_plan(query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner))
}
/// Check whether the table uses V2 manifest paths.

View File

@@ -3,14 +3,20 @@
use std::sync::Arc;
use arrow_schema::{DataType, Schema};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::RecordBatchStream;
use futures::{FutureExt, Stream};
use lance::arrow::json::JsonDataType;
use lance::dataset::{ReadParams, WriteParams};
use lance::index::vector::utils::infer_vector_dim;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
use std::pin::Pin;
use crate::error::{Error, Result};
use datafusion_physical_plan::SendableRecordBatchStream;
lazy_static! {
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
@@ -178,11 +184,97 @@ pub fn string_to_datatype(s: &str) -> Option<DataType> {
(&json_type).try_into().ok()
}
enum TimeoutState {
NotStarted {
timeout: std::time::Duration,
},
Started {
deadline: Pin<Box<tokio::time::Sleep>>,
timeout: std::time::Duration,
},
Completed,
}
/// A `Stream` wrapper that implements a timeout.
///
/// The timeout starts when the first `poll_next` is called. As soon as the timeout
/// duration has passed, the stream will return an `Err` indicating a timeout error
/// for the next poll.
pub struct TimeoutStream {
inner: SendableRecordBatchStream,
state: TimeoutState,
}
impl TimeoutStream {
pub fn new(inner: SendableRecordBatchStream, timeout: std::time::Duration) -> Self {
Self {
inner,
state: TimeoutState::NotStarted { timeout },
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
timeout: std::time::Duration,
) -> SendableRecordBatchStream {
Box::pin(Self::new(inner, timeout))
}
fn timeout_error(timeout: &std::time::Duration) -> DataFusionError {
DataFusionError::Execution(format!("Query timeout after {} ms", timeout.as_millis()))
}
}
impl RecordBatchStream for TimeoutStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for TimeoutStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match &mut self.state {
TimeoutState::NotStarted { timeout } => {
if timeout.is_zero() {
return std::task::Poll::Ready(Some(Err(Self::timeout_error(timeout))));
}
let deadline = Box::pin(tokio::time::sleep(*timeout));
self.state = TimeoutState::Started {
deadline,
timeout: *timeout,
};
self.poll_next(cx)
}
TimeoutState::Started { deadline, timeout } => match deadline.poll_unpin(cx) {
std::task::Poll::Ready(_) => {
let err = Self::timeout_error(timeout);
self.state = TimeoutState::Completed;
std::task::Poll::Ready(Some(Err(err)))
}
std::task::Poll::Pending => {
let inner = Pin::new(&mut self.inner);
inner.poll_next(cx)
}
},
TimeoutState::Completed => std::task::Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Int32Array;
use arrow_schema::Field;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{stream, StreamExt};
use tokio::time::sleep;
use arrow_schema::{DataType, Field};
use super::*;
#[test]
fn test_guess_default_column() {
@@ -249,4 +341,85 @@ mod tests {
let expected = DataType::Int32;
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
false,
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
let timeout_duration = std::time::Duration::from_millis(10);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// Poll the stream to get the first batch
let first_result = timeout_stream.next().await;
assert!(first_result.is_some());
assert!(first_result.unwrap().is_ok());
// Sleep for the timeout duration
sleep(timeout_duration).await;
// Poll the stream again and ensure it returns a timeout error
let second_result = timeout_stream.next().await.unwrap();
assert!(second_result.is_err());
assert!(second_result
.unwrap_err()
.to_string()
.contains("Query timeout"));
}
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
// Setup similar to test_timeout_stream
let timeout_duration = std::time::Duration::from_secs(0);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// First poll should immediately return a timeout error
let result = timeout_stream.next().await.unwrap();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Query timeout"));
}
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
// Setup a stream with 2 batches
// Use a longer timeout that won't trigger
let timeout_duration = std::time::Duration::from_secs(1);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// Both polls should return data normally
assert!(timeout_stream.next().await.unwrap().is_ok());
assert!(timeout_stream.next().await.unwrap().is_ok());
// Stream should be empty now
assert!(timeout_stream.next().await.is_none());
}
}