mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-25 07:50:40 +00:00
feat: add timeout to query execution options (#2288)
Closes #2287 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added configurable timeout support for query executions. Users can now specify maximum wait times for queries, enhancing control over long-running operations across various integrations. - **Tests** - Expanded test coverage to validate timeout behavior in both synchronous and asynchronous query flows, ensuring timely error responses when query execution exceeds the specified limit. - Introduced a new test suite to verify query operations when a timeout is reached, checking for appropriate error handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
@@ -25,6 +25,7 @@ use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||
use crate::table::BaseTable;
|
||||
use crate::utils::TimeoutStream;
|
||||
use crate::DistanceType;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
|
||||
@@ -525,12 +526,15 @@ pub struct QueryExecutionOptions {
|
||||
///
|
||||
/// By default, this is 1024
|
||||
pub max_batch_length: u32,
|
||||
/// Max duration to wait for the query to execute before timing out.
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Default for QueryExecutionOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_batch_length: 1024,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1007,7 +1011,10 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||
pub async fn execute_hybrid(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
@@ -1016,7 +1023,10 @@ impl VectorQuery {
|
||||
let mut vector_query = self.clone().with_row_id();
|
||||
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_query.execute_with_options(options.clone()),
|
||||
vector_query.inner_execute_with_options(options)
|
||||
)?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_results.try_collect::<Vec<_>>(),
|
||||
@@ -1074,6 +1084,20 @@ impl VectorQuery {
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
}
|
||||
|
||||
async fn inner_execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let plan = self.create_plan(options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
inner
|
||||
};
|
||||
Ok(DatasetRecordBatchStream::new(inner).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
@@ -1087,16 +1111,13 @@ impl ExecutableQuery for VectorQuery {
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
if self.request.base.full_text_search.is_some() {
|
||||
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||
let hybrid_result = async move { self.execute_hybrid(options).await }
|
||||
.boxed()
|
||||
.await?;
|
||||
return Ok(hybrid_result);
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
DatasetRecordBatchStream::new(execute_plan(
|
||||
self.create_plan(options).await?,
|
||||
Default::default(),
|
||||
)?),
|
||||
))
|
||||
self.inner_execute_with_options(options).await
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
|
||||
@@ -13,7 +13,7 @@ use reqwest::{
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::db::RemoteOptions;
|
||||
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -299,7 +299,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-api-key",
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
@@ -307,7 +307,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
"Host",
|
||||
http::header::HOST,
|
||||
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
@@ -315,7 +315,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
if has_host_override {
|
||||
headers.insert(
|
||||
"x-lancedb-database",
|
||||
HeaderName::from_static("x-lancedb-database"),
|
||||
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
@@ -323,7 +323,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
if db_prefix.is_some() {
|
||||
headers.insert(
|
||||
"x-lancedb-database-prefix",
|
||||
HeaderName::from_static("x-lancedb-database-prefix"),
|
||||
HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput {
|
||||
message: format!(
|
||||
"non-ascii database prefix '{}' provided",
|
||||
@@ -335,7 +335,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
|
||||
if let Some(v) = options.0.get("account_name") {
|
||||
headers.insert(
|
||||
"x-azure-storage-account-name",
|
||||
HeaderName::from_static("x-azure-storage-account-name"),
|
||||
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii storage account name '{}' provided", db_name),
|
||||
})?,
|
||||
@@ -343,7 +343,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
if let Some(v) = options.0.get("azure_storage_account_name") {
|
||||
headers.insert(
|
||||
"x-azure-storage-account-name",
|
||||
HeaderName::from_static("x-azure-storage-account-name"),
|
||||
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii storage account name '{}' provided", db_name),
|
||||
})?,
|
||||
|
||||
@@ -20,7 +20,7 @@ use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
|
||||
use futures::TryStreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use http::{HeaderName, StatusCode};
|
||||
use lance::arrow::json::{JsonDataType, JsonSchema};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
|
||||
@@ -44,6 +44,8 @@ use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::db::ServerVersion;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
#[allow(dead_code)]
|
||||
@@ -332,9 +334,19 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
async fn execute_query(
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
options: &QueryExecutionOptions,
|
||||
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
if let Some(timeout) = options.timeout {
|
||||
// Client side timeout
|
||||
request = request.timeout(timeout);
|
||||
// Also send to server, so it can abort the query if it takes too long.
|
||||
// (If it doesn't fit into u64, it's not worth sending anyways.)
|
||||
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
|
||||
request = request.header(REQUEST_TIMEOUT_HEADER, timeout_ms);
|
||||
}
|
||||
}
|
||||
|
||||
let query_bodies = self.prepare_query_bodies(query).await?;
|
||||
let requests: Vec<reqwest::RequestBuilder> = query_bodies
|
||||
@@ -543,7 +555,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let streams = self.execute_query(query, options).await?;
|
||||
let streams = self.execute_query(query, &options).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
@@ -559,9 +571,9 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn query(
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let streams = self.execute_query(query, _options).await?;
|
||||
let streams = self.execute_query(query, &options).await?;
|
||||
|
||||
if streams.len() == 1 {
|
||||
Ok(DatasetRecordBatchStream::new(
|
||||
|
||||
@@ -68,7 +68,7 @@ use crate::query::{
|
||||
use crate::utils::{
|
||||
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
|
||||
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
|
||||
PatchReadParam, PatchWriteParam,
|
||||
PatchReadParam, PatchWriteParam, TimeoutStream,
|
||||
};
|
||||
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
@@ -1775,11 +1775,14 @@ impl NativeTable {
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = self.create_plan(query, options).await?;
|
||||
Ok(DatasetRecordBatchStream::new(execute_plan(
|
||||
plan,
|
||||
Default::default(),
|
||||
)?))
|
||||
let plan = self.create_plan(query, options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
inner
|
||||
};
|
||||
Ok(DatasetRecordBatchStream::new(inner))
|
||||
}
|
||||
|
||||
/// Check whether the table uses V2 manifest paths.
|
||||
|
||||
@@ -3,14 +3,20 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::{DataType, Schema};
|
||||
use arrow_array::RecordBatch;
|
||||
use arrow_schema::{DataType, Schema, SchemaRef};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use datafusion_execution::RecordBatchStream;
|
||||
use futures::{FutureExt, Stream};
|
||||
use lance::arrow::json::JsonDataType;
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::index::vector::utils::infer_vector_dim;
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lazy_static::lazy_static;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use datafusion_physical_plan::SendableRecordBatchStream;
|
||||
|
||||
lazy_static! {
|
||||
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
|
||||
@@ -178,11 +184,97 @@ pub fn string_to_datatype(s: &str) -> Option<DataType> {
|
||||
(&json_type).try_into().ok()
|
||||
}
|
||||
|
||||
enum TimeoutState {
|
||||
NotStarted {
|
||||
timeout: std::time::Duration,
|
||||
},
|
||||
Started {
|
||||
deadline: Pin<Box<tokio::time::Sleep>>,
|
||||
timeout: std::time::Duration,
|
||||
},
|
||||
Completed,
|
||||
}
|
||||
|
||||
/// A `Stream` wrapper that implements a timeout.
|
||||
///
|
||||
/// The timeout starts when the first `poll_next` is called. As soon as the timeout
|
||||
/// duration has passed, the stream will return an `Err` indicating a timeout error
|
||||
/// for the next poll.
|
||||
pub struct TimeoutStream {
|
||||
inner: SendableRecordBatchStream,
|
||||
state: TimeoutState,
|
||||
}
|
||||
|
||||
impl TimeoutStream {
|
||||
pub fn new(inner: SendableRecordBatchStream, timeout: std::time::Duration) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
state: TimeoutState::NotStarted { timeout },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_boxed(
|
||||
inner: SendableRecordBatchStream,
|
||||
timeout: std::time::Duration,
|
||||
) -> SendableRecordBatchStream {
|
||||
Box::pin(Self::new(inner, timeout))
|
||||
}
|
||||
|
||||
fn timeout_error(timeout: &std::time::Duration) -> DataFusionError {
|
||||
DataFusionError::Execution(format!("Query timeout after {} ms", timeout.as_millis()))
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for TimeoutStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for TimeoutStream {
|
||||
type Item = DataFusionResult<RecordBatch>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
match &mut self.state {
|
||||
TimeoutState::NotStarted { timeout } => {
|
||||
if timeout.is_zero() {
|
||||
return std::task::Poll::Ready(Some(Err(Self::timeout_error(timeout))));
|
||||
}
|
||||
let deadline = Box::pin(tokio::time::sleep(*timeout));
|
||||
self.state = TimeoutState::Started {
|
||||
deadline,
|
||||
timeout: *timeout,
|
||||
};
|
||||
self.poll_next(cx)
|
||||
}
|
||||
TimeoutState::Started { deadline, timeout } => match deadline.poll_unpin(cx) {
|
||||
std::task::Poll::Ready(_) => {
|
||||
let err = Self::timeout_error(timeout);
|
||||
self.state = TimeoutState::Completed;
|
||||
std::task::Poll::Ready(Some(Err(err)))
|
||||
}
|
||||
std::task::Poll::Pending => {
|
||||
let inner = Pin::new(&mut self.inner);
|
||||
inner.poll_next(cx)
|
||||
}
|
||||
},
|
||||
TimeoutState::Completed => std::task::Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow_array::Int32Array;
|
||||
use arrow_schema::Field;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use futures::{stream, StreamExt};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use arrow_schema::{DataType, Field};
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_guess_default_column() {
|
||||
@@ -249,4 +341,85 @@ mod tests {
|
||||
let expected = DataType::Int32;
|
||||
assert_eq!(string_to_datatype(string), Some(expected));
|
||||
}
|
||||
|
||||
fn sample_batch() -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"col1",
|
||||
DataType::Int32,
|
||||
false,
|
||||
)]));
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream() {
|
||||
let batch = sample_batch();
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
let timeout_duration = std::time::Duration::from_millis(10);
|
||||
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
|
||||
|
||||
// Poll the stream to get the first batch
|
||||
let first_result = timeout_stream.next().await;
|
||||
assert!(first_result.is_some());
|
||||
assert!(first_result.unwrap().is_ok());
|
||||
|
||||
// Sleep for the timeout duration
|
||||
sleep(timeout_duration).await;
|
||||
|
||||
// Poll the stream again and ensure it returns a timeout error
|
||||
let second_result = timeout_stream.next().await.unwrap();
|
||||
assert!(second_result.is_err());
|
||||
assert!(second_result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Query timeout"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_zero_duration() {
|
||||
let batch = sample_batch();
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
|
||||
// Setup similar to test_timeout_stream
|
||||
let timeout_duration = std::time::Duration::from_secs(0);
|
||||
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
|
||||
|
||||
// First poll should immediately return a timeout error
|
||||
let result = timeout_stream.next().await.unwrap();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Query timeout"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_completes_normally() {
|
||||
let batch = sample_batch();
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
|
||||
// Setup a stream with 2 batches
|
||||
// Use a longer timeout that won't trigger
|
||||
let timeout_duration = std::time::Duration::from_secs(1);
|
||||
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
|
||||
|
||||
// Both polls should return data normally
|
||||
assert!(timeout_stream.next().await.unwrap().is_ok());
|
||||
assert!(timeout_stream.next().await.unwrap().is_ok());
|
||||
// Stream should be empty now
|
||||
assert!(timeout_stream.next().await.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user