diff --git a/Cargo.lock b/Cargo.lock index 32c11833..c43950d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4110,7 +4110,7 @@ dependencies = [ [[package]] name = "lancedb" -version = "0.19.0-beta.3" +version = "0.19.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4197,7 +4197,7 @@ dependencies = [ [[package]] name = "lancedb-node" -version = "0.19.0-beta.3" +version = "0.19.0-beta.4" dependencies = [ "arrow-array", "arrow-ipc", @@ -4222,7 +4222,7 @@ dependencies = [ [[package]] name = "lancedb-nodejs" -version = "0.19.0-beta.3" +version = "0.19.0-beta.4" dependencies = [ "arrow-array", "arrow-ipc", @@ -4240,7 +4240,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.22.0-beta.3" +version = "0.22.0-beta.4" dependencies = [ "arrow", "env_logger", diff --git a/docs/src/js/interfaces/QueryExecutionOptions.md b/docs/src/js/interfaces/QueryExecutionOptions.md index dd3495bd..46b06ce7 100644 --- a/docs/src/js/interfaces/QueryExecutionOptions.md +++ b/docs/src/js/interfaces/QueryExecutionOptions.md @@ -20,3 +20,13 @@ The maximum number of rows to return in a single batch Batches may have fewer rows if the underlying data is stored in smaller chunks. + +*** + +### timeoutMs? + +```ts +optional timeoutMs: number; +``` + +Timeout for query execution in milliseconds diff --git a/node/package-lock.json b/node/package-lock.json index 55976938..eaf93546 100644 --- a/node/package-lock.json +++ b/node/package-lock.json @@ -326,6 +326,66 @@ "@jridgewell/sourcemap-codec": "^1.4.10" } }, + "node_modules/@lancedb/vectordb-darwin-arm64": { + "version": "0.19.0-beta.4", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.4.tgz", + "integrity": "sha512-uS5AuT3Q4swrtM9JAhF8mM8Nt+kvewmB3DQWGiuYbhmMismSu8WlOHQAs9Yyh8N7NBdWENSTjroSExqjHPdFhQ==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@lancedb/vectordb-darwin-x64": { + "version": "0.19.0-beta.4", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.4.tgz", + "integrity": "sha512-kjn3iTqZSx57ek9PN2AdPvJMx14tFkXc8sUFd3MLhY7FdWafx7Wvl0SLz2LubotJVFd6LMxvsPPNJEM5bEgMOw==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@lancedb/vectordb-linux-arm64-gnu": { + "version": "0.19.0-beta.4", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.4.tgz", + "integrity": "sha512-iZlR7ffKC+XA1mGuuwXJojgFcUvXkgMt6pKR6lP3hsxXh8UOTWDljN7jkI8jKHcJez3rrqoqt1VjH3xD69fwtA==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@lancedb/vectordb-linux-x64-gnu": { + "version": "0.19.0-beta.4", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.4.tgz", + "integrity": "sha512-uxLeerlT5FuWzuvHlTDLdLCakyUJ+qJitReoCKT6tKhfcjIkbr+NEoLZEHifJC4dRFPtbddVgiYN6VHlnPPD/w==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@lancedb/vectordb-win32-x64-msvc": { + "version": "0.19.0-beta.4", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.4.tgz", + "integrity": "sha512-QSugxudXooLCF7trudaAo9PfOzX7SFBIiHOoL4N6nwjC61u/JAsoiytw1Xjs/+0pOG5cT2WUMufBzBPgJyOxbw==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/@neon-rs/cli": { "version": "0.0.160", "resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz", diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 944755fb..b306126e 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -867,6 +867,44 @@ describe("When creating an index", () => { }); }); +describe("When querying a table", () => { + let tmpDir: tmp.DirResult; + beforeEach(() => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + }); + afterEach(() => tmpDir.removeCallback()); + + it("should throw an error when timeout is reached", async () => { + const db = await connect(tmpDir.name); + const data = makeArrowTable([ + { text: "a", vector: [0.1, 0.2] }, + { text: "b", vector: [0.3, 0.4] }, + ]); + const table = await db.createTable("test", data); + await table.createIndex("text", { config: Index.fts() }); + + await expect( + table.query().where("text != 'a'").toArray({ timeoutMs: 0 }), + ).rejects.toThrow("Query timeout"); + + await expect( + table.query().nearestTo([0.0, 0.0]).toArrow({ timeoutMs: 0 }), + ).rejects.toThrow("Query timeout"); + + await expect( + table.search("a", "fts").toArray({ timeoutMs: 0 }), + ).rejects.toThrow("Query timeout"); + + await expect( + table + .query() + .nearestToText("a") + .nearestTo([0.0, 0.0]) + .toArrow({ timeoutMs: 0 }), + ).rejects.toThrow("Query timeout"); + }); +}); + describe("Read consistency interval", () => { let tmpDir: tmp.DirResult; beforeEach(() => { diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 0db81982..a8143a2e 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -63,7 +63,7 @@ class RecordBatchIterable< // biome-ignore lint/suspicious/noExplicitAny: skip [Symbol.asyncIterator](): AsyncIterator, any, undefined> { return new RecordBatchIterator( - this.inner.execute(this.options?.maxBatchLength), + this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs), ); } } @@ -79,6 +79,11 @@ export interface QueryExecutionOptions { * in smaller chunks. */ maxBatchLength?: number; + + /** + * Timeout for query execution in milliseconds + */ + timeoutMs?: number; } /** @@ -283,9 +288,11 @@ export class QueryBase options?: Partial, ): Promise { if (this.inner instanceof Promise) { - return this.inner.then((inner) => inner.execute(options?.maxBatchLength)); + return this.inner.then((inner) => + inner.execute(options?.maxBatchLength, options?.timeoutMs), + ); } else { - return this.inner.execute(options?.maxBatchLength); + return this.inner.execute(options?.maxBatchLength, options?.timeoutMs); } } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index d945d48e..59dc72f6 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -131,11 +131,15 @@ impl Query { pub async fn execute( &self, max_batch_length: Option, + timeout_ms: Option, ) -> napi::Result { let mut execution_opts = QueryExecutionOptions::default(); if let Some(max_batch_length) = max_batch_length { execution_opts.max_batch_length = max_batch_length; } + if let Some(timeout_ms) = timeout_ms { + execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64)) + } let inner_stream = self .inner .execute_with_options(execution_opts) @@ -330,11 +334,15 @@ impl VectorQuery { pub async fn execute( &self, max_batch_length: Option, + timeout_ms: Option, ) -> napi::Result { let mut execution_opts = QueryExecutionOptions::default(); if let Some(max_batch_length) = max_batch_length { execution_opts.max_batch_length = max_batch_length; } + if let Some(timeout_ms) = timeout_ms { + execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64)) + } let inner_stream = self .inner .execute_with_options(execution_opts) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 29266ee0..3ac0d67e 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Dict, List, Optional, Tuple, Any, Union, Literal import pyarrow as pa @@ -94,7 +95,9 @@ class Query: def postfilter(self): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... - async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... + async def execute( + self, max_batch_length: Optional[int], timeout: Optional[timedelta] + ) -> RecordBatchStream: ... async def explain_plan(self, verbose: Optional[bool]) -> str: ... async def analyze_plan(self) -> str: ... def to_query_request(self) -> PyQueryRequest: ... @@ -110,7 +113,9 @@ class FTSQuery: def get_query(self) -> str: ... def add_query_vector(self, query_vec: pa.Array) -> None: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... - async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... + async def execute( + self, max_batch_length: Optional[int], timeout: Optional[timedelta] + ) -> RecordBatchStream: ... def to_query_request(self) -> PyQueryRequest: ... class VectorQuery: diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index cb1aded8..06d4b1b6 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod import abc from concurrent.futures import ThreadPoolExecutor from enum import Enum +from datetime import timedelta from typing import ( TYPE_CHECKING, Dict, @@ -650,7 +651,12 @@ class LanceQueryBuilder(ABC): """ return self.to_pandas() - def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame": + def to_pandas( + self, + flatten: Optional[Union[int, bool]] = None, + *, + timeout: Optional[timedelta] = None, + ) -> "pd.DataFrame": """ Execute the query and return the results as a pandas DataFrame. In addition to the selected columns, LanceDB also returns a vector @@ -664,12 +670,15 @@ class LanceQueryBuilder(ABC): If flatten is an integer, flatten the nested columns up to the specified depth. If unspecified, do not flatten the nested columns. + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. """ - tbl = flatten_columns(self.to_arrow(), flatten) + tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten) return tbl.to_pandas() @abstractmethod - def to_arrow(self) -> pa.Table: + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: """ Execute the query and return the results as an [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table). @@ -677,34 +686,65 @@ class LanceQueryBuilder(ABC): In addition to the selected columns, LanceDB also returns a vector and also the "_distance" column which is the distance between the query vector and the returned vectors. + + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. """ raise NotImplementedError @abstractmethod - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + def to_batches( + self, + /, + batch_size: Optional[int] = None, + *, + timeout: Optional[timedelta] = None, + ) -> pa.RecordBatchReader: """ Execute the query and return the results as a pyarrow [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) + + Parameters + ---------- + batch_size: int + The maximum number of selected records in a RecordBatch object. + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. """ raise NotImplementedError - def to_list(self) -> List[dict]: + def to_list(self, *, timeout: Optional[timedelta] = None) -> List[dict]: """ Execute the query and return the results as a list of dictionaries. Each list entry is a dictionary with the selected column names as keys, or all table columns if `select` is not called. The vector and the "_distance" fields are returned whether or not they're explicitly selected. - """ - return self.to_arrow().to_pylist() - def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]: + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. + """ + return self.to_arrow(timeout=timeout).to_pylist() + + def to_pydantic( + self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None + ) -> List[LanceModel]: """Return the table as a list of pydantic models. Parameters ---------- model: Type[LanceModel] The pydantic model to use. + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. Returns ------- @@ -712,19 +752,25 @@ class LanceQueryBuilder(ABC): """ return [ model(**{k: v for k, v in row.items() if k in model.field_names()}) - for row in self.to_arrow().to_pylist() + for row in self.to_arrow(timeout=timeout).to_pylist() ] - def to_polars(self) -> "pl.DataFrame": + def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame": """ Execute the query and return the results as a Polars DataFrame. In addition to the selected columns, LanceDB also returns a vector and also the "_distance" column which is the distance between the query vector and the returned vector. + + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. """ import polars as pl - return pl.from_arrow(self.to_arrow()) + return pl.from_arrow(self.to_arrow(timeout=timeout)) def limit(self, limit: Union[int, None]) -> Self: """Set the maximum number of results to return. @@ -1139,7 +1185,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._refine_factor = refine_factor return self - def to_arrow(self) -> pa.Table: + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: """ Execute the query and return the results as an [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table). @@ -1147,8 +1193,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): In addition to the selected columns, LanceDB also returns a vector and also the "_distance" column which is the distance between the query vector and the returned vectors. + + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If None, wait indefinitely. """ - return self.to_batches().read_all() + return self.to_batches(timeout=timeout).read_all() def to_query_object(self) -> Query: """ @@ -1178,7 +1230,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): bypass_vector_index=self._bypass_vector_index, ) - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + def to_batches( + self, + /, + batch_size: Optional[int] = None, + *, + timeout: Optional[timedelta] = None, + ) -> pa.RecordBatchReader: """ Execute the query and return the result as a RecordBatchReader object. @@ -1186,6 +1244,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): ---------- batch_size: int The maximum number of selected records in a RecordBatch object. + timeout: timedelta, default None + The maximum time to wait for the query to complete. + If None, wait indefinitely. Returns ------- @@ -1195,7 +1256,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): if isinstance(vector[0], np.ndarray): vector = [v.tolist() for v in vector] query = self.to_query_object() - result_set = self._table._execute_query(query, batch_size) + result_set = self._table._execute_query( + query, batch_size=batch_size, timeout=timeout + ) if self._reranker is not None: rs_table = result_set.read_all() result_set = self._reranker.rerank_vector(self._str_query, rs_table) @@ -1334,7 +1397,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): offset=self._offset, ) - def to_arrow(self) -> pa.Table: + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: path, fs, exist = self._table._get_fts_index_path() if exist: return self.tantivy_to_arrow() @@ -1346,14 +1409,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): "Use tantivy-based index instead for now." ) query = self.to_query_object() - results = self._table._execute_query(query) + results = self._table._execute_query(query, timeout=timeout) results = results.read_all() if self._reranker is not None: results = self._reranker.rerank_fts(self._query, results) check_reranker_result(results) return results - def to_batches(self, /, batch_size: Optional[int] = None): + def to_batches( + self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None + ): raise NotImplementedError("to_batches on an FTS query") def tantivy_to_arrow(self) -> pa.Table: @@ -1458,8 +1523,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder): - def to_arrow(self) -> pa.Table: - return self.to_batches().read_all() + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: + return self.to_batches(timeout=timeout).read_all() def to_query_object(self) -> Query: return Query( @@ -1470,9 +1535,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): offset=self._offset, ) - def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + def to_batches( + self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None + ) -> pa.RecordBatchReader: query = self.to_query_object() - return self._table._execute_query(query, batch_size) + return self._table._execute_query(query, batch_size=batch_size, timeout=timeout) def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: """Rerank the results using the specified reranker. @@ -1560,7 +1627,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def to_query_object(self) -> Query: raise NotImplementedError("to_query_object not yet supported on a hybrid query") - def to_arrow(self) -> pa.Table: + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: vector_query, fts_query = self._validate_query( self._query, self._vector, self._text ) @@ -1603,9 +1670,11 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._reranker = RRFReranker() with ThreadPoolExecutor() as executor: - fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) + fts_future = executor.submit( + self._fts_query.with_row_id(True).to_arrow, timeout=timeout + ) vector_future = executor.submit( - self._vector_query.with_row_id(True).to_arrow + self._vector_query.with_row_id(True).to_arrow, timeout=timeout ) fts_results = fts_future.result() vector_results = vector_future.result() @@ -1692,7 +1761,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return results - def to_batches(self): + def to_batches( + self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None + ): raise NotImplementedError("to_batches not yet supported on a hybrid query") @staticmethod @@ -2056,7 +2127,10 @@ class AsyncQueryBase(object): return self async def to_batches( - self, *, max_batch_length: Optional[int] = None + self, + *, + max_batch_length: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> AsyncRecordBatchReader: """ Execute the query and return the results as an Apache Arrow RecordBatchReader. @@ -2069,34 +2143,56 @@ class AsyncQueryBase(object): If not specified, a default batch length is used. It is possible for batches to be smaller than the provided length if the underlying data is stored in smaller chunks. + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If not specified, no timeout is applied. If the query does not + complete within the specified time, an error will be raised. """ - return AsyncRecordBatchReader(await self._inner.execute(max_batch_length)) + return AsyncRecordBatchReader( + await self._inner.execute(max_batch_length, timeout) + ) - async def to_arrow(self) -> pa.Table: + async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table: """ Execute the query and collect the results into an Apache Arrow Table. This method will collect all results into memory before returning. If you expect a large number of results, you may want to use [to_batches][lancedb.query.AsyncQueryBase.to_batches] + + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If not specified, no timeout is applied. If the query does not + complete within the specified time, an error will be raised. """ - batch_iter = await self.to_batches() + batch_iter = await self.to_batches(timeout=timeout) return pa.Table.from_batches( await batch_iter.read_all(), schema=batch_iter.schema ) - async def to_list(self) -> List[dict]: + async def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]: """ Execute the query and return the results as a list of dictionaries. Each list entry is a dictionary with the selected column names as keys, or all table columns if `select` is not called. The vector and the "_distance" fields are returned whether or not they're explicitly selected. + + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If not specified, no timeout is applied. If the query does not + complete within the specified time, an error will be raised. """ - return (await self.to_arrow()).to_pylist() + return (await self.to_arrow(timeout=timeout)).to_pylist() async def to_pandas( - self, flatten: Optional[Union[int, bool]] = None + self, + flatten: Optional[Union[int, bool]] = None, + timeout: Optional[timedelta] = None, ) -> "pd.DataFrame": """ Execute the query and collect the results into a pandas DataFrame. @@ -2125,10 +2221,19 @@ class AsyncQueryBase(object): If flatten is an integer, flatten the nested columns up to the specified depth. If unspecified, do not flatten the nested columns. + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If not specified, no timeout is applied. If the query does not + complete within the specified time, an error will be raised. """ - return (flatten_columns(await self.to_arrow(), flatten)).to_pandas() + return ( + flatten_columns(await self.to_arrow(timeout=timeout), flatten) + ).to_pandas() - async def to_polars(self) -> "pl.DataFrame": + async def to_polars( + self, + timeout: Optional[timedelta] = None, + ) -> "pl.DataFrame": """ Execute the query and collect the results into a Polars DataFrame. @@ -2137,6 +2242,13 @@ class AsyncQueryBase(object): [to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to polars separately. + Parameters + ---------- + timeout: Optional[timedelta] + The maximum time to wait for the query to complete. + If not specified, no timeout is applied. If the query does not + complete within the specified time, an error will be raised. + Examples -------- @@ -2152,7 +2264,7 @@ class AsyncQueryBase(object): """ import polars as pl - return pl.from_arrow(await self.to_arrow()) + return pl.from_arrow(await self.to_arrow(timeout=timeout)) async def explain_plan(self, verbose: Optional[bool] = False): """Return the execution plan for this query. @@ -2423,9 +2535,12 @@ class AsyncFTSQuery(AsyncQueryBase): ) async def to_batches( - self, *, max_batch_length: Optional[int] = None + self, + *, + max_batch_length: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> AsyncRecordBatchReader: - reader = await super().to_batches() + reader = await super().to_batches(timeout=timeout) results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: results = self._reranker.rerank_fts(self.get_query(), results) @@ -2649,9 +2764,12 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()})) async def to_batches( - self, *, max_batch_length: Optional[int] = None + self, + *, + max_batch_length: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> AsyncRecordBatchReader: - reader = await super().to_batches() + reader = await super().to_batches(timeout=timeout) results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: results = self._reranker.rerank_vector(self._query_string, results) @@ -2707,7 +2825,10 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase): return self async def to_batches( - self, *, max_batch_length: Optional[int] = None + self, + *, + max_batch_length: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> AsyncRecordBatchReader: fts_query = AsyncFTSQuery(self._inner.to_fts_query()) vec_query = AsyncVectorQuery(self._inner.to_vector_query()) @@ -2719,8 +2840,8 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase): vec_query.with_row_id() fts_results, vector_results = await asyncio.gather( - fts_query.to_arrow(), - vec_query.to_arrow(), + fts_query.to_arrow(timeout=timeout), + vec_query.to_arrow(timeout=timeout), ) result = LanceHybridQueryBuilder._combine_hybrid_results( diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 50976720..59bcb5bb 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -355,9 +355,15 @@ class RemoteTable(Table): ) def _execute_query( - self, query: Query, batch_size: Optional[int] = None + self, + query: Query, + *, + batch_size: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> pa.RecordBatchReader: - async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size)) + async_iter = LOOP.run( + self._table._execute_query(query, batch_size=batch_size, timeout=timeout) + ) def iter_sync(): try: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index d64610a2..7e28237d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1007,7 +1007,11 @@ class Table(ABC): @abstractmethod def _execute_query( - self, query: Query, batch_size: Optional[int] = None + self, + query: Query, + *, + batch_size: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> pa.RecordBatchReader: ... @abstractmethod @@ -2312,9 +2316,15 @@ class LanceTable(Table): LOOP.run(self._table.update(values, where=where, updates_sql=values_sql)) def _execute_query( - self, query: Query, batch_size: Optional[int] = None + self, + query: Query, + *, + batch_size: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> pa.RecordBatchReader: - async_iter = LOOP.run(self._table._execute_query(query, batch_size)) + async_iter = LOOP.run( + self._table._execute_query(query, batch_size=batch_size, timeout=timeout) + ) def iter_sync(): try: @@ -3390,7 +3400,11 @@ class AsyncTable: return async_query async def _execute_query( - self, query: Query, batch_size: Optional[int] = None + self, + query: Query, + *, + batch_size: Optional[int] = None, + timeout: Optional[timedelta] = None, ) -> pa.RecordBatchReader: # The sync table calls into this method, so we need to map the # query to the async version of the query and run that here. This is only @@ -3398,7 +3412,9 @@ class AsyncTable: async_query = self._sync_query_to_async(query) - return await async_query.to_batches(max_batch_length=batch_size) + return await async_query.to_batches( + max_batch_length=batch_size, timeout=timeout + ) async def _explain_plan(self, query: Query, verbose: Optional[bool]) -> str: # This method is used by the sync table diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 95ebd2f2..7e3aaf81 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -511,7 +511,8 @@ def test_query_builder_with_different_vector_column(): columns=["b"], vector_column="foo_vector", ), - None, + batch_size=None, + timeout=None, ) @@ -1076,3 +1077,67 @@ async def test_query_serialization_async(table_async: AsyncTable): full_text_query=FullTextSearchQuery(columns=[], query="foo"), with_row_id=False, ) + + +def test_query_timeout(tmp_path): + # Use local directory instead of memory:// to add a bit of latency to + # operations so a timeout of zero will trigger exceptions. + db = lancedb.connect(tmp_path) + data = pa.table( + { + "text": ["a", "b"], + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(4).cast(pa.float32()), 2 + ), + } + ) + table = db.create_table("test", data) + table.create_fts_index("text", use_tantivy=False) + + with pytest.raises(Exception, match="Query timeout"): + table.search().where("text = 'a'").to_list(timeout=timedelta(0)) + + with pytest.raises(Exception, match="Query timeout"): + table.search([0.0, 0.0]).to_arrow(timeout=timedelta(0)) + + with pytest.raises(Exception, match="Query timeout"): + table.search("a", query_type="fts").to_pandas(timeout=timedelta(0)) + + with pytest.raises(Exception, match="Query timeout"): + table.search(query_type="hybrid").vector([0.0, 0.0]).text("a").to_arrow( + timeout=timedelta(0) + ) + + +@pytest.mark.asyncio +async def test_query_timeout_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + data = pa.table( + { + "text": ["a", "b"], + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(4).cast(pa.float32()), 2 + ), + } + ) + table = await db.create_table("test", data) + await table.create_index("text", config=FTS()) + + with pytest.raises(Exception, match="Query timeout"): + await table.query().where("text != 'a'").to_list(timeout=timedelta(0)) + + with pytest.raises(Exception, match="Query timeout"): + await table.vector_search([0.0, 0.0]).to_arrow(timeout=timedelta(0)) + + with pytest.raises(Exception, match="Query timeout"): + await (await table.search("a", query_type="fts")).to_pandas( + timeout=timedelta(0) + ) + + with pytest.raises(Exception, match="Query timeout"): + await ( + table.query() + .nearest_to_text("a") + .nearest_to([0.0, 0.0]) + .to_list(timeout=timedelta(0)) + ) diff --git a/python/src/query.rs b/python/src/query.rs index 9411016d..9dd7477c 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::Arc; +use std::time::Duration; use arrow::array::make_array; use arrow::array::Array; @@ -294,10 +295,11 @@ impl Query { }) } - #[pyo3(signature = (max_batch_length=None))] + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, max_batch_length: Option, + timeout: Option, ) -> PyResult> { let inner = self_.inner.clone(); future_into_py(self_.py(), async move { @@ -305,6 +307,9 @@ impl Query { if let Some(max_batch_length) = max_batch_length { opts.max_batch_length = max_batch_length; } + if let Some(timeout) = timeout { + opts.timeout = Some(timeout); + } let inner_stream = inner.execute_with_options(opts).await.infer_error()?; Ok(RecordBatchStream::new(inner_stream)) }) @@ -376,10 +381,11 @@ impl FTSQuery { self.inner = self.inner.clone().postfilter(); } - #[pyo3(signature = (max_batch_length=None))] + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, max_batch_length: Option, + timeout: Option, ) -> PyResult> { let inner = self_ .inner @@ -391,6 +397,9 @@ impl FTSQuery { if let Some(max_batch_length) = max_batch_length { opts.max_batch_length = max_batch_length; } + if let Some(timeout) = timeout { + opts.timeout = Some(timeout); + } let inner_stream = inner.execute_with_options(opts).await.infer_error()?; Ok(RecordBatchStream::new(inner_stream)) }) @@ -513,10 +522,11 @@ impl VectorQuery { self.inner = self.inner.clone().bypass_vector_index() } - #[pyo3(signature = (max_batch_length=None))] + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, max_batch_length: Option, + timeout: Option, ) -> PyResult> { let inner = self_.inner.clone(); future_into_py(self_.py(), async move { @@ -524,6 +534,9 @@ impl VectorQuery { if let Some(max_batch_length) = max_batch_length { opts.max_batch_length = max_batch_length; } + if let Some(timeout) = timeout { + opts.timeout = Some(timeout); + } let inner_stream = inner.execute_with_options(opts).await.infer_error()?; Ok(RecordBatchStream::new(inner_stream)) }) diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 11b413a2..056c3aa6 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::future::Future; use std::sync::Arc; +use std::{future::Future, time::Duration}; use arrow::compute::concat_batches; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; @@ -25,6 +25,7 @@ use crate::error::{Error, Result}; use crate::rerankers::rrf::RRFReranker; use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker}; use crate::table::BaseTable; +use crate::utils::TimeoutStream; use crate::DistanceType; use crate::{arrow::SendableRecordBatchStream, table::AnyQuery}; @@ -525,12 +526,15 @@ pub struct QueryExecutionOptions { /// /// By default, this is 1024 pub max_batch_length: u32, + /// Max duration to wait for the query to execute before timing out. + pub timeout: Option, } impl Default for QueryExecutionOptions { fn default() -> Self { Self { max_batch_length: 1024, + timeout: None, } } } @@ -1007,7 +1011,10 @@ impl VectorQuery { self } - pub async fn execute_hybrid(&self) -> Result { + pub async fn execute_hybrid( + &self, + options: QueryExecutionOptions, + ) -> Result { // clone query and specify we want to include row IDs, which can be needed for reranking let mut fts_query = Query::new(self.parent.clone()); fts_query.request = self.request.base.clone(); @@ -1016,7 +1023,10 @@ impl VectorQuery { let mut vector_query = self.clone().with_row_id(); vector_query.request.base.full_text_search = None; - let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?; + let (fts_results, vec_results) = try_join!( + fts_query.execute_with_options(options.clone()), + vector_query.inner_execute_with_options(options) + )?; let (fts_results, vec_results) = try_join!( fts_results.try_collect::>(), @@ -1074,6 +1084,20 @@ impl VectorQuery { RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])), )) } + + async fn inner_execute_with_options( + &self, + options: QueryExecutionOptions, + ) -> Result { + let plan = self.create_plan(options.clone()).await?; + let inner = execute_plan(plan, Default::default())?; + let inner = if let Some(timeout) = options.timeout { + TimeoutStream::new_boxed(inner, timeout) + } else { + inner + }; + Ok(DatasetRecordBatchStream::new(inner).into()) + } } impl ExecutableQuery for VectorQuery { @@ -1087,16 +1111,13 @@ impl ExecutableQuery for VectorQuery { options: QueryExecutionOptions, ) -> Result { if self.request.base.full_text_search.is_some() { - let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?; + let hybrid_result = async move { self.execute_hybrid(options).await } + .boxed() + .await?; return Ok(hybrid_result); } - Ok(SendableRecordBatchStream::from( - DatasetRecordBatchStream::new(execute_plan( - self.create_plan(options).await?, - Default::default(), - )?), - )) + self.inner_execute_with_options(options).await } async fn explain_plan(&self, verbose: bool) -> Result { diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 89e06f7a..9d30bea7 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -13,7 +13,7 @@ use reqwest::{ use crate::error::{Error, Result}; use crate::remote::db::RemoteOptions; -const REQUEST_ID_HEADER: &str = "x-request-id"; +const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); /// Configuration for the LanceDB Cloud HTTP client. #[derive(Clone, Debug)] @@ -299,7 +299,7 @@ impl RestfulLanceDbClient { ) -> Result { let mut headers = HeaderMap::new(); headers.insert( - "x-api-key", + HeaderName::from_static("x-api-key"), HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput { message: "non-ascii api key provided".to_string(), })?, @@ -307,7 +307,7 @@ impl RestfulLanceDbClient { if region == "local" { let host = format!("{}.local.api.lancedb.com", db_name); headers.insert( - "Host", + http::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput { message: format!("non-ascii database name '{}' provided", db_name), })?, @@ -315,7 +315,7 @@ impl RestfulLanceDbClient { } if has_host_override { headers.insert( - "x-lancedb-database", + HeaderName::from_static("x-lancedb-database"), HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput { message: format!("non-ascii database name '{}' provided", db_name), })?, @@ -323,7 +323,7 @@ impl RestfulLanceDbClient { } if db_prefix.is_some() { headers.insert( - "x-lancedb-database-prefix", + HeaderName::from_static("x-lancedb-database-prefix"), HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput { message: format!( "non-ascii database prefix '{}' provided", @@ -335,7 +335,7 @@ impl RestfulLanceDbClient { if let Some(v) = options.0.get("account_name") { headers.insert( - "x-azure-storage-account-name", + HeaderName::from_static("x-azure-storage-account-name"), HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { message: format!("non-ascii storage account name '{}' provided", db_name), })?, @@ -343,7 +343,7 @@ impl RestfulLanceDbClient { } if let Some(v) = options.0.get("azure_storage_account_name") { headers.insert( - "x-azure-storage-account-name", + HeaderName::from_static("x-azure-storage-account-name"), HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { message: format!("non-ascii storage account name '{}' provided", db_name), })?, diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 5226e92d..d6817d89 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -20,7 +20,7 @@ use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use futures::TryStreamExt; use http::header::CONTENT_TYPE; -use http::StatusCode; +use http::{HeaderName, StatusCode}; use lance::arrow::json::{JsonDataType, JsonSchema}; use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform, Version}; @@ -44,6 +44,8 @@ use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; use super::ARROW_STREAM_CONTENT_TYPE; +const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); + #[derive(Debug)] pub struct RemoteTable { #[allow(dead_code)] @@ -332,9 +334,19 @@ impl RemoteTable { async fn execute_query( &self, query: &AnyQuery, - _options: QueryExecutionOptions, + options: &QueryExecutionOptions, ) -> Result>>> { - let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); + let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name)); + + if let Some(timeout) = options.timeout { + // Client side timeout + request = request.timeout(timeout); + // Also send to server, so it can abort the query if it takes too long. + // (If it doesn't fit into u64, it's not worth sending anyways.) + if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) { + request = request.header(REQUEST_TIMEOUT_HEADER, timeout_ms); + } + } let query_bodies = self.prepare_query_bodies(query).await?; let requests: Vec = query_bodies @@ -543,7 +555,7 @@ impl BaseTable for RemoteTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result> { - let streams = self.execute_query(query, options).await?; + let streams = self.execute_query(query, &options).await?; if streams.len() == 1 { let stream = streams.into_iter().next().unwrap(); Ok(Arc::new(OneShotExec::new(stream))) @@ -559,9 +571,9 @@ impl BaseTable for RemoteTable { async fn query( &self, query: &AnyQuery, - _options: QueryExecutionOptions, + options: QueryExecutionOptions, ) -> Result { - let streams = self.execute_query(query, _options).await?; + let streams = self.execute_query(query, &options).await?; if streams.len() == 1 { Ok(DatasetRecordBatchStream::new( diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index f0688927..f90b23dd 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -68,7 +68,7 @@ use crate::query::{ use crate::utils::{ default_vector_column, supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, - PatchReadParam, PatchWriteParam, + PatchReadParam, PatchWriteParam, TimeoutStream, }; use self::dataset::DatasetConsistencyWrapper; @@ -1775,11 +1775,14 @@ impl NativeTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - let plan = self.create_plan(query, options).await?; - Ok(DatasetRecordBatchStream::new(execute_plan( - plan, - Default::default(), - )?)) + let plan = self.create_plan(query, options.clone()).await?; + let inner = execute_plan(plan, Default::default())?; + let inner = if let Some(timeout) = options.timeout { + TimeoutStream::new_boxed(inner, timeout) + } else { + inner + }; + Ok(DatasetRecordBatchStream::new(inner)) } /// Check whether the table uses V2 manifest paths. diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index bcb544d5..3f44bb0f 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -3,14 +3,20 @@ use std::sync::Arc; -use arrow_schema::{DataType, Schema}; +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema, SchemaRef}; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_execution::RecordBatchStream; +use futures::{FutureExt, Stream}; use lance::arrow::json::JsonDataType; use lance::dataset::{ReadParams, WriteParams}; use lance::index::vector::utils::infer_vector_dim; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lazy_static::lazy_static; +use std::pin::Pin; use crate::error::{Error, Result}; +use datafusion_physical_plan::SendableRecordBatchStream; lazy_static! { static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap(); @@ -178,11 +184,97 @@ pub fn string_to_datatype(s: &str) -> Option { (&json_type).try_into().ok() } +enum TimeoutState { + NotStarted { + timeout: std::time::Duration, + }, + Started { + deadline: Pin>, + timeout: std::time::Duration, + }, + Completed, +} + +/// A `Stream` wrapper that implements a timeout. +/// +/// The timeout starts when the first `poll_next` is called. As soon as the timeout +/// duration has passed, the stream will return an `Err` indicating a timeout error +/// for the next poll. +pub struct TimeoutStream { + inner: SendableRecordBatchStream, + state: TimeoutState, +} + +impl TimeoutStream { + pub fn new(inner: SendableRecordBatchStream, timeout: std::time::Duration) -> Self { + Self { + inner, + state: TimeoutState::NotStarted { timeout }, + } + } + + pub fn new_boxed( + inner: SendableRecordBatchStream, + timeout: std::time::Duration, + ) -> SendableRecordBatchStream { + Box::pin(Self::new(inner, timeout)) + } + + fn timeout_error(timeout: &std::time::Duration) -> DataFusionError { + DataFusionError::Execution(format!("Query timeout after {} ms", timeout.as_millis())) + } +} + +impl RecordBatchStream for TimeoutStream { + fn schema(&self) -> SchemaRef { + self.inner.schema() + } +} + +impl Stream for TimeoutStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match &mut self.state { + TimeoutState::NotStarted { timeout } => { + if timeout.is_zero() { + return std::task::Poll::Ready(Some(Err(Self::timeout_error(timeout)))); + } + let deadline = Box::pin(tokio::time::sleep(*timeout)); + self.state = TimeoutState::Started { + deadline, + timeout: *timeout, + }; + self.poll_next(cx) + } + TimeoutState::Started { deadline, timeout } => match deadline.poll_unpin(cx) { + std::task::Poll::Ready(_) => { + let err = Self::timeout_error(timeout); + self.state = TimeoutState::Completed; + std::task::Poll::Ready(Some(Err(err))) + } + std::task::Poll::Pending => { + let inner = Pin::new(&mut self.inner); + inner.poll_next(cx) + } + }, + TimeoutState::Completed => std::task::Poll::Ready(None), + } + } +} + #[cfg(test)] mod tests { - use super::*; + use arrow_array::Int32Array; + use arrow_schema::Field; + use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + use futures::{stream, StreamExt}; + use tokio::time::sleep; - use arrow_schema::{DataType, Field}; + use super::*; #[test] fn test_guess_default_column() { @@ -249,4 +341,85 @@ mod tests { let expected = DataType::Int32; assert_eq!(string_to_datatype(string), Some(expected)); } + + fn sample_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap() + } + + #[tokio::test] + async fn test_timeout_stream() { + let batch = sample_batch(); + let schema = batch.schema(); + let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); + + let sendable_stream: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream)); + let timeout_duration = std::time::Duration::from_millis(10); + let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration); + + // Poll the stream to get the first batch + let first_result = timeout_stream.next().await; + assert!(first_result.is_some()); + assert!(first_result.unwrap().is_ok()); + + // Sleep for the timeout duration + sleep(timeout_duration).await; + + // Poll the stream again and ensure it returns a timeout error + let second_result = timeout_stream.next().await.unwrap(); + assert!(second_result.is_err()); + assert!(second_result + .unwrap_err() + .to_string() + .contains("Query timeout")); + } + + #[tokio::test] + async fn test_timeout_stream_zero_duration() { + let batch = sample_batch(); + let schema = batch.schema(); + let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); + + let sendable_stream: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream)); + + // Setup similar to test_timeout_stream + let timeout_duration = std::time::Duration::from_secs(0); + let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration); + + // First poll should immediately return a timeout error + let result = timeout_stream.next().await.unwrap(); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Query timeout")); + } + + #[tokio::test] + async fn test_timeout_stream_completes_normally() { + let batch = sample_batch(); + let schema = batch.schema(); + let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); + + let sendable_stream: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream)); + + // Setup a stream with 2 batches + // Use a longer timeout that won't trigger + let timeout_duration = std::time::Duration::from_secs(1); + let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration); + + // Both polls should return data normally + assert!(timeout_stream.next().await.unwrap().is_ok()); + assert!(timeout_stream.next().await.unwrap().is_ok()); + // Stream should be empty now + assert!(timeout_stream.next().await.is_none()); + } }