diff --git a/nodejs/__test__/remote.test.ts b/nodejs/__test__/remote.test.ts index eb62d9f5..96a45986 100644 --- a/nodejs/__test__/remote.test.ts +++ b/nodejs/__test__/remote.test.ts @@ -42,6 +42,28 @@ describe("remote connection", () => { }); }); + it("should accept overall timeout configuration", async () => { + await connect("db://test", { + apiKey: "fake", + clientConfig: { + timeoutConfig: { timeout: 30 }, + }, + }); + + // Test with all timeout parameters + await connect("db://test", { + apiKey: "fake", + clientConfig: { + timeoutConfig: { + timeout: 60, + connectTimeout: 10, + readTimeout: 20, + poolIdleTimeout: 300, + }, + }, + }); + }); + it("should pass down apiKey and userAgent", async () => { await withMockDatabase( (req, res) => { diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 2ec29897..0b4845e5 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -9,6 +9,12 @@ use napi_derive::*; #[napi(object)] #[derive(Debug)] pub struct TimeoutConfig { + /// The overall timeout for the entire request in seconds. This includes + /// connection, send, and read time. If the entire request doesn't complete + /// within this time, it will fail. Default is None (no overall timeout). + /// This can also be set via the environment variable `LANCE_CLIENT_TIMEOUT`, + /// as an integer number of seconds. + pub timeout: Option, /// The timeout for establishing a connection in seconds. Default is 120 /// seconds (2 minutes). This can also be set via the environment variable /// `LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds. @@ -75,6 +81,7 @@ pub struct ClientConfig { impl From for lancedb::remote::TimeoutConfig { fn from(config: TimeoutConfig) -> Self { Self { + timeout: config.timeout.map(std::time::Duration::from_secs_f64), connect_timeout: config .connect_timeout .map(std::time::Duration::from_secs_f64), diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 40502d64..896efbd7 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -17,6 +17,12 @@ class TimeoutConfig: Attributes ---------- + timeout: Optional[timedelta] + The overall timeout for the entire request. This includes connection, + send, and read time. If the entire request doesn't complete within + this time, it will fail. Default is None (no overall timeout). + This can also be set via the environment variable + `LANCE_CLIENT_TIMEOUT`, as an integer number of seconds. connect_timeout: Optional[timedelta] The timeout for establishing a connection. Default is 120 seconds (2 minutes). This can also be set via the environment variable @@ -31,6 +37,7 @@ class TimeoutConfig: `LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer number of seconds. """ + timeout: Optional[timedelta] = None connect_timeout: Optional[timedelta] = None read_timeout: Optional[timedelta] = None pool_idle_timeout: Optional[timedelta] = None @@ -50,6 +57,7 @@ class TimeoutConfig: ) def __post_init__(self): + self.timeout = self.__to_timedelta(self.timeout) self.connect_timeout = self.__to_timedelta(self.connect_timeout) self.read_timeout = self.__to_timedelta(self.read_timeout) self.pool_idle_timeout = self.__to_timedelta(self.pool_idle_timeout) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index ace9f263..5435e210 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -798,6 +798,21 @@ def test_create_client(): assert isinstance(db.client_config, ClientConfig) assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + # Test overall timeout parameter + db = lancedb.connect( + **mandatory_args, + client_config=ClientConfig(timeout_config={"timeout": 60}), + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.timeout == timedelta(seconds=60) + + db = lancedb.connect( + **mandatory_args, + client_config={"timeout_config": {"timeout": timedelta(seconds=60)}}, + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.timeout == timedelta(seconds=60) + db = lancedb.connect( **mandatory_args, client_config=ClientConfig(retry_config={"retries": 42}) ) diff --git a/python/src/connection.rs b/python/src/connection.rs index 2e2f64d3..d5fe00cd 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -241,6 +241,7 @@ pub struct PyClientRetryConfig { #[derive(FromPyObject)] pub struct PyClientTimeoutConfig { + timeout: Option, connect_timeout: Option, read_timeout: Option, pool_idle_timeout: Option, @@ -264,6 +265,7 @@ impl From for lancedb::remote::RetryConfig { impl From for lancedb::remote::TimeoutConfig { fn from(value: PyClientTimeoutConfig) -> Self { Self { + timeout: value.timeout, connect_timeout: value.connect_timeout, read_timeout: value.read_timeout, pool_idle_timeout: value.pool_idle_timeout, diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 50bc52a6..ecf58df4 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -41,6 +41,16 @@ impl Default for ClientConfig { /// How to handle timeouts for HTTP requests. #[derive(Clone, Default, Debug)] pub struct TimeoutConfig { + /// The overall timeout for the entire request. + /// + /// This includes connection, send, and read time. If the entire request + /// doesn't complete within this time, it will fail. + /// + /// You can also set the `LANCE_CLIENT_TIMEOUT` environment variable + /// to set this value. Use an integer value in seconds. + /// + /// By default, no overall timeout is set. + pub timeout: Option, /// The timeout for creating a connection to the server. /// /// You can also set the `LANCE_CLIENT_CONNECT_TIMEOUT` environment variable @@ -159,9 +169,9 @@ impl HttpSend for Sender { } impl RestfulLanceDbClient { - fn get_timeout(passed: Option, env_var: &str, default: Duration) -> Result { + fn get_timeout(passed: Option, env_var: &str) -> Result> { if let Some(passed) = passed { - Ok(passed) + Ok(Some(passed)) } else if let Ok(timeout) = std::env::var(env_var) { let timeout = timeout.parse::().map_err(|_| Error::InvalidInput { message: format!( @@ -169,9 +179,9 @@ impl RestfulLanceDbClient { env_var, timeout ), })?; - Ok(Duration::from_secs(timeout)) + Ok(Some(Duration::from_secs(timeout))) } else { - Ok(default) + Ok(None) } } @@ -203,28 +213,34 @@ impl RestfulLanceDbClient { }; // Get the timeouts + let timeout = + Self::get_timeout(client_config.timeout_config.timeout, "LANCE_CLIENT_TIMEOUT")?; let connect_timeout = Self::get_timeout( client_config.timeout_config.connect_timeout, "LANCE_CLIENT_CONNECT_TIMEOUT", - Duration::from_secs(120), - )?; + )? + .unwrap_or_else(|| Duration::from_secs(120)); let read_timeout = Self::get_timeout( client_config.timeout_config.read_timeout, "LANCE_CLIENT_READ_TIMEOUT", - Duration::from_secs(300), - )?; + )? + .unwrap_or_else(|| Duration::from_secs(300)); let pool_idle_timeout = Self::get_timeout( client_config.timeout_config.pool_idle_timeout, // Though it's confusing with the connect_timeout name, this is the // legacy name for this in the Python sync client. So we keep as-is. "LANCE_CLIENT_CONNECTION_TIMEOUT", - Duration::from_secs(300), - )?; + )? + .unwrap_or_else(|| Duration::from_secs(300)); - let client = reqwest::Client::builder() + let mut client_builder = reqwest::Client::builder() .connect_timeout(connect_timeout) .read_timeout(read_timeout) - .pool_idle_timeout(pool_idle_timeout) + .pool_idle_timeout(pool_idle_timeout); + if let Some(timeout) = timeout { + client_builder = client_builder.timeout(timeout); + } + let client = client_builder .default_headers(Self::default_headers( api_key, region, @@ -581,3 +597,51 @@ pub mod test_utils { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_timeout_config_default() { + let config = TimeoutConfig::default(); + assert!(config.timeout.is_none()); + assert!(config.connect_timeout.is_none()); + assert!(config.read_timeout.is_none()); + assert!(config.pool_idle_timeout.is_none()); + } + + #[test] + fn test_timeout_config_with_overall_timeout() { + let config = TimeoutConfig { + timeout: Some(Duration::from_secs(60)), + connect_timeout: Some(Duration::from_secs(10)), + read_timeout: Some(Duration::from_secs(30)), + pool_idle_timeout: Some(Duration::from_secs(300)), + }; + + assert_eq!(config.timeout, Some(Duration::from_secs(60))); + assert_eq!(config.connect_timeout, Some(Duration::from_secs(10))); + assert_eq!(config.read_timeout, Some(Duration::from_secs(30))); + assert_eq!(config.pool_idle_timeout, Some(Duration::from_secs(300))); + } + + #[test] + fn test_client_config_with_timeout() { + let timeout_config = TimeoutConfig { + timeout: Some(Duration::from_secs(120)), + ..Default::default() + }; + + let client_config = ClientConfig { + timeout_config, + ..Default::default() + }; + + assert_eq!( + client_config.timeout_config.timeout, + Some(Duration::from_secs(120)) + ); + } +}