diff --git a/Cargo.lock b/Cargo.lock index e0f52f5d..2f9d326c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -128,9 +128,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "arbitrary" @@ -390,9 +390,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a194f9d963d8099596278594b3107448656ba73831c9d8c783e613ce86da64" +checksum = "b37fc50485c4f3f736a4fb14199f6d5f5ba008d7f28fe710306c92780f004c07" dependencies = [ "flate2", "futures-core", @@ -564,9 +564,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.28.0" +version = "0.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9f7720b74ed28ca77f90769a71fd8c637a0137f6fae4ae947e1050229cff57f" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" dependencies = [ "bindgen", "cc", @@ -882,7 +882,7 @@ dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "h2 0.4.8", + "h2 0.4.9", "http 0.2.12", "http 1.3.1", "http-body 0.4.6", @@ -1185,9 +1185,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", @@ -2377,7 +2377,16 @@ version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys", + "dirs-sys 0.4.1", +] + +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys 0.5.0", ] [[package]] @@ -2388,10 +2397,22 @@ checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" dependencies = [ "libc", "option-ext", - "redox_users", + "redox_users 0.4.6", "windows-sys 0.48.0", ] +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users 0.5.0", + "windows-sys 0.59.0", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -2558,9 +2579,9 @@ dependencies = [ [[package]] name = "ethnum" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" +checksum = "0939f82868b77ef93ce3c3c3daf2b3c526b456741da5a1a4559e590965b6026b" [[package]] name = "event-listener" @@ -3049,9 +3070,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" dependencies = [ "atomic-waker", "bytes", @@ -3138,7 +3159,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1" dependencies = [ - "dirs", + "dirs 5.0.1", "futures", "http 1.3.1", "indicatif", @@ -3286,7 +3307,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.8", + "h2 0.4.9", "http 1.3.1", "http-body 1.0.1", "httparse", @@ -3645,9 +3666,9 @@ checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" [[package]] name = "jiff" -version = "0.2.6" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f33145a5cbea837164362c7bd596106eb7c5198f97d1ba6f6ebb3223952e488" +checksum = "5a064218214dc6a10fbae5ec5fa888d80c45d611aba169222fc272072bf7aef6" dependencies = [ "jiff-static", "log", @@ -3658,9 +3679,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.6" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43ce13c40ec6956157a3635d97a1ee2df323b263f09ea14165131289cb0f5c19" +checksum = "199b7932d97e325aff3a7030e141eafe7f2c6268e1d1b24859b753a627f45254" dependencies = [ "proc-macro2", "quote", @@ -3965,7 +3986,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "deepsize", - "dirs", + "dirs 5.0.1", "fst", "futures", "half", @@ -4115,7 +4136,7 @@ dependencies = [ [[package]] name = "lancedb" -version = "0.19.0-beta.8" +version = "0.19.0-beta.9" dependencies = [ "arrow", "arrow-array", @@ -4202,7 +4223,7 @@ dependencies = [ [[package]] name = "lancedb-node" -version = "0.19.0-beta.8" +version = "0.19.0-beta.9" dependencies = [ "arrow-array", "arrow-ipc", @@ -4227,7 +4248,7 @@ dependencies = [ [[package]] name = "lancedb-nodejs" -version = "0.19.0-beta.8" +version = "0.19.0-beta.9" dependencies = [ "arrow-array", "arrow-ipc", @@ -4245,7 +4266,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.22.0-beta.8" +version = "0.22.0-beta.9" dependencies = [ "arrow", "env_logger", @@ -4342,9 +4363,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libloading" @@ -4368,9 +4389,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.11" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72" [[package]] name = "libredox" @@ -5637,9 +5658,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -5837,13 +5858,13 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.10" +version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" +checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b" dependencies = [ "bytes", "getrandom 0.3.2", - "rand 0.9.0", + "rand 0.9.1", "ring", "rustc-hash 2.1.1", "rustls 0.23.26", @@ -5903,13 +5924,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.24", ] [[package]] @@ -6084,6 +6104,17 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "redox_users" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" +dependencies = [ + "getrandom 0.2.15", + "libredox", + "thiserror 2.0.12", +] + [[package]] name = "regex" version = "1.11.1" @@ -6152,7 +6183,7 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2 0.4.8", + "h2 0.4.9", "http 1.3.1", "http-body 1.0.1", "http-body-util", @@ -6701,11 +6732,11 @@ dependencies = [ [[package]] name = "shellexpand" -version = "3.1.0" +version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b" +checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb" dependencies = [ - "dirs", + "dirs 6.0.0", ] [[package]] @@ -6716,9 +6747,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] diff --git a/rust/lancedb/src/index/waiter.rs b/rust/lancedb/src/index/waiter.rs index 8fde37b7..a20bfa2b 100644 --- a/rust/lancedb/src/index/waiter.rs +++ b/rust/lancedb/src/index/waiter.rs @@ -20,7 +20,7 @@ pub async fn wait_for_index( ) -> Result<()> { if timeout > MAX_WAIT { return Err(Error::InvalidInput { - message: format!("timeout must be less than {:?}", MAX_WAIT).to_string(), + message: format!("timeout must be less than {:?}", MAX_WAIT), }); } let start = Instant::now(); @@ -84,7 +84,6 @@ pub async fn wait_for_index( message: format!( "timed out waiting for indices: {:?} after {:?}", remaining, timeout - ) - .to_string(), + ), }) } diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index a154a662..b8bc96b1 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -8,6 +8,7 @@ pub(crate) mod client; pub(crate) mod db; +mod retry; pub(crate) mod table; pub(crate) mod util; diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 9d30bea7..50bc52a6 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -1,17 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::{collections::HashMap, future::Future, str::FromStr, time::Duration}; - use http::HeaderName; use log::debug; use reqwest::{ header::{HeaderMap, HeaderValue}, - Request, RequestBuilder, Response, + Body, Request, RequestBuilder, Response, }; +use std::{collections::HashMap, future::Future, str::FromStr, time::Duration}; use crate::error::{Error, Result}; use crate::remote::db::RemoteOptions; +use crate::remote::retry::{ResolvedRetryConfig, RetryCounter}; const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); @@ -118,41 +118,14 @@ pub struct RetryConfig { /// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable /// to set this value. Use a comma-separated list of integer values. /// - /// The default is 429, 500, 502, 503. + /// Note that write operations will never be retried on 5xx errors as this may + /// result in duplicated writes. + /// + /// The default is 409, 429, 500, 502, 503, 504. pub statuses: Option>, // TODO: should we allow customizing methods? } -#[derive(Debug, Clone)] -struct ResolvedRetryConfig { - retries: u8, - connect_retries: u8, - read_retries: u8, - backoff_factor: f32, - backoff_jitter: f32, - statuses: Vec, -} - -impl TryFrom for ResolvedRetryConfig { - type Error = Error; - - fn try_from(retry_config: RetryConfig) -> Result { - Ok(Self { - retries: retry_config.retries.unwrap_or(3), - connect_retries: retry_config.connect_retries.unwrap_or(3), - read_retries: retry_config.read_retries.unwrap_or(3), - backoff_factor: retry_config.backoff_factor.unwrap_or(0.25), - backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25), - statuses: retry_config - .statuses - .unwrap_or_else(|| vec![429, 500, 502, 503]) - .into_iter() - .map(|status| reqwest::StatusCode::from_u16(status).unwrap()) - .collect(), - }) - } -} - // We use the `HttpSend` trait to abstract over the `reqwest::Client` so that // we can mock responses in tests. Based on the patterns from this blog post: // https://write.as/balrogboogie/testing-reqwest-based-clients @@ -160,8 +133,8 @@ impl TryFrom for ResolvedRetryConfig { pub struct RestfulLanceDbClient { client: reqwest::Client, host: String, - retry_config: ResolvedRetryConfig, - sender: S, + pub(crate) retry_config: ResolvedRetryConfig, + pub(crate) sender: S, } pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static { @@ -375,74 +348,69 @@ impl RestfulLanceDbClient { self.client.post(full_uri) } - pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> { + pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> { let (client, request) = req.build_split(); let mut request = request.unwrap(); + let request_id = self.extract_request_id(&mut request); + self.log_request(&request, &request_id); - // Set a request id. - // TODO: allow the user to supply this, through middleware? - let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) { - request_id.to_str().unwrap().to_string() - } else { - let request_id = uuid::Uuid::new_v4().to_string(); - let header = HeaderValue::from_str(&request_id).unwrap(); - request.headers_mut().insert(REQUEST_ID_HEADER, header); - request_id - }; - - if log::log_enabled!(log::Level::Debug) { - let content_type = request - .headers() - .get("content-type") - .map(|v| v.to_str().unwrap()); - if content_type == Some("application/json") { - let body = request.body().as_ref().unwrap().as_bytes().unwrap(); - let body = String::from_utf8_lossy(body); - debug!( - "Sending request_id={}: {:?} with body {}", - request_id, request, body - ); - } else { - debug!("Sending request_id={}: {:?}", request_id, request); - } - } - - if with_retry { - self.send_with_retry_impl(client, request, request_id).await - } else { - let response = self - .sender - .send(&client, request) - .await - .err_to_http(request_id.clone())?; - debug!( - "Received response for request_id={}: {:?}", - request_id, &response - ); - Ok((request_id, response)) - } + let response = self + .sender + .send(&client, request) + .await + .err_to_http(request_id.clone())?; + debug!( + "Received response for request_id={}: {:?}", + request_id, &response + ); + Ok((request_id, response)) } - async fn send_with_retry_impl( + /// Send the request using retries configured in the RetryConfig. + /// If retry_5xx is false, 5xx requests will not be retried regardless of the statuses configured + /// in the RetryConfig. + /// Since this requires arrow serialization, this is implemented here instead of in RestfulLanceDbClient + pub async fn send_with_retry( &self, - client: reqwest::Client, - req: Request, - request_id: String, + req_builder: RequestBuilder, + mut make_body: Option Result + Send + 'static>>, + retry_5xx: bool, ) -> Result<(String, Response)> { - let mut retry_counter = RetryCounter::new(&self.retry_config, request_id); + let retry_config = &self.retry_config; + let non_5xx_statuses = retry_config + .statuses + .iter() + .filter(|s| !s.is_server_error()) + .cloned() + .collect::>(); + + // clone and build the request to extract the request id + let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime { + message: "Attempted to retry a request that cannot be cloned".to_string(), + })?; + let (_, r) = tmp_req.build_split(); + let mut r = r.unwrap(); + let request_id = self.extract_request_id(&mut r); + let mut retry_counter = RetryCounter::new(retry_config, request_id.clone()); loop { - // This only works if the request body is not a stream. If it is - // a stream, we can't use the retry path. We would need to implement - // an outer retry. - let request = req.try_clone().ok_or_else(|| Error::Runtime { + let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime { message: "Attempted to retry a request that cannot be cloned".to_string(), })?; - let response = self - .sender - .send(&client, request) - .await - .map(|r| (r.status(), r)); + + // set the streaming body on the request builder after clone + if let Some(body_gen) = make_body.as_mut() { + let body = body_gen()?; + req_builder = req_builder.body(body); + } + + let (c, request) = req_builder.build_split(); + let mut request = request.unwrap(); + self.set_request_id(&mut request, &request_id.clone()); + self.log_request(&request, &request_id); + + let response = self.sender.send(&c, request).await.map(|r| (r.status(), r)); + match response { Ok((status, response)) if status.is_success() => { debug!( @@ -451,7 +419,10 @@ impl RestfulLanceDbClient { ); return Ok((retry_counter.request_id, response)); } - Ok((status, response)) if self.retry_config.statuses.contains(&status) => { + Ok((status, response)) + if (retry_5xx && retry_config.statuses.contains(&status)) + || non_5xx_statuses.contains(&status) => + { let source = self .check_response(&retry_counter.request_id, response) .await @@ -480,6 +451,47 @@ impl RestfulLanceDbClient { } } + fn log_request(&self, request: &Request, request_id: &String) { + if log::log_enabled!(log::Level::Debug) { + let content_type = request + .headers() + .get("content-type") + .map(|v| v.to_str().unwrap()); + if content_type == Some("application/json") { + let body = request.body().as_ref().unwrap().as_bytes().unwrap(); + let body = String::from_utf8_lossy(body); + debug!( + "Sending request_id={}: {:?} with body {}", + request_id, request, body + ); + } else { + debug!("Sending request_id={}: {:?}", request_id, request); + } + } + } + + /// Extract the request ID from the request headers. + /// If the request ID header is not set, this will generate a new one and set + /// it on the request headers + pub fn extract_request_id(&self, request: &mut Request) -> String { + // Set a request id. + // TODO: allow the user to supply this, through middleware? + let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) { + request_id.to_str().unwrap().to_string() + } else { + let request_id = uuid::Uuid::new_v4().to_string(); + self.set_request_id(request, &request_id); + request_id + }; + request_id + } + + /// Set the request ID header + pub fn set_request_id(&self, request: &mut Request, request_id: &str) { + let header = HeaderValue::from_str(request_id).unwrap(); + request.headers_mut().insert(REQUEST_ID_HEADER, header); + } + pub async fn check_response(&self, request_id: &str, response: Response) -> Result { // Try to get the response text, but if that fails, just return the status code let status = response.status(); @@ -501,91 +513,6 @@ impl RestfulLanceDbClient { } } -struct RetryCounter<'a> { - request_failures: u8, - connect_failures: u8, - read_failures: u8, - config: &'a ResolvedRetryConfig, - request_id: String, -} - -impl<'a> RetryCounter<'a> { - fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self { - Self { - request_failures: 0, - connect_failures: 0, - read_failures: 0, - config, - request_id, - } - } - - fn check_out_of_retries( - &self, - source: Box, - status_code: Option, - ) -> Result<()> { - if self.request_failures >= self.config.retries - || self.connect_failures >= self.config.connect_retries - || self.read_failures >= self.config.read_retries - { - Err(Error::Retry { - request_id: self.request_id.clone(), - request_failures: self.request_failures, - max_request_failures: self.config.retries, - connect_failures: self.connect_failures, - max_connect_failures: self.config.connect_retries, - read_failures: self.read_failures, - max_read_failures: self.config.read_retries, - source, - status_code, - }) - } else { - Ok(()) - } - } - - fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> { - self.request_failures += 1; - let status_code = if let crate::Error::Http { status_code, .. } = &source { - *status_code - } else { - None - }; - self.check_out_of_retries(Box::new(source), status_code) - } - - fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> { - self.connect_failures += 1; - let status_code = source.status(); - self.check_out_of_retries(Box::new(source), status_code) - } - - fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> { - self.read_failures += 1; - let status_code = source.status(); - self.check_out_of_retries(Box::new(source), status_code) - } - - fn next_sleep_time(&self) -> Duration { - let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32)); - let jitter = rand::random::() * self.config.backoff_jitter; - let sleep_time = Duration::from_secs_f32(backoff + jitter); - debug!( - "Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}", - self.request_id, - self.connect_failures, - self.config.connect_retries, - self.request_failures, - self.config.retries, - self.read_failures, - self.config.read_retries, - sleep_time - ); - sleep_time - } -} - pub trait RequestResultExt { type Output; fn err_to_http(self, request_id: String) -> Result; diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 21703efb..e87ad668 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -255,7 +255,7 @@ impl Database for RemoteDatabase { if let Some(start_after) = request.start_after { req = req.query(&[("page_token", start_after)]); } - let (request_id, rsp) = self.client.send(req, true).await?; + let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?; let rsp = self.client.check_response(&request_id, rsp).await?; let version = parse_server_version(&request_id, &rsp)?; let tables = rsp @@ -302,7 +302,7 @@ impl Database for RemoteDatabase { .body(data_buffer) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); - let (request_id, rsp) = self.client.send(req, false).await?; + let (request_id, rsp) = self.client.send(req).await?; if rsp.status() == StatusCode::BAD_REQUEST { let body = rsp.text().await.err_to_http(request_id.clone())?; @@ -362,7 +362,7 @@ impl Database for RemoteDatabase { let req = self .client .post(&format!("/v1/table/{}/describe/", request.name)); - let (request_id, rsp) = self.client.send(req, true).await?; + let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?; if rsp.status() == StatusCode::NOT_FOUND { return Err(crate::Error::TableNotFound { name: request.name }); } @@ -383,7 +383,7 @@ impl Database for RemoteDatabase { .client .post(&format!("/v1/table/{}/rename/", current_name)); let req = req.json(&serde_json::json!({ "new_table_name": new_name })); - let (request_id, resp) = self.client.send(req, false).await?; + let (request_id, resp) = self.client.send(req).await?; self.client.check_response(&request_id, resp).await?; let table = self.table_cache.remove(current_name).await; if let Some(table) = table { @@ -394,7 +394,7 @@ impl Database for RemoteDatabase { async fn drop_table(&self, name: &str) -> Result<()> { let req = self.client.post(&format!("/v1/table/{}/drop/", name)); - let (request_id, resp) = self.client.send(req, true).await?; + let (request_id, resp) = self.client.send(req).await?; self.client.check_response(&request_id, resp).await?; self.table_cache.remove(name).await; Ok(()) diff --git a/rust/lancedb/src/remote/retry.rs b/rust/lancedb/src/remote/retry.rs new file mode 100644 index 00000000..a20d9166 --- /dev/null +++ b/rust/lancedb/src/remote/retry.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use crate::remote::RetryConfig; +use crate::Error; +use log::debug; +use std::time::Duration; + +pub struct RetryCounter<'a> { + pub request_failures: u8, + pub connect_failures: u8, + pub read_failures: u8, + pub config: &'a ResolvedRetryConfig, + pub request_id: String, +} + +impl<'a> RetryCounter<'a> { + pub(crate) fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self { + Self { + request_failures: 0, + connect_failures: 0, + read_failures: 0, + config, + request_id, + } + } + + fn check_out_of_retries( + &self, + source: Box, + status_code: Option, + ) -> crate::Result<()> { + if self.request_failures >= self.config.retries + || self.connect_failures >= self.config.connect_retries + || self.read_failures >= self.config.read_retries + { + Err(Error::Retry { + request_id: self.request_id.clone(), + request_failures: self.request_failures, + max_request_failures: self.config.retries, + connect_failures: self.connect_failures, + max_connect_failures: self.config.connect_retries, + read_failures: self.read_failures, + max_read_failures: self.config.read_retries, + source, + status_code, + }) + } else { + Ok(()) + } + } + + pub fn increment_request_failures(&mut self, source: crate::Error) -> crate::Result<()> { + self.request_failures += 1; + let status_code = if let crate::Error::Http { status_code, .. } = &source { + *status_code + } else { + None + }; + self.check_out_of_retries(Box::new(source), status_code) + } + + pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> { + self.connect_failures += 1; + let status_code = source.status(); + self.check_out_of_retries(Box::new(source), status_code) + } + + pub fn increment_read_failures(&mut self, source: reqwest::Error) -> crate::Result<()> { + self.read_failures += 1; + let status_code = source.status(); + self.check_out_of_retries(Box::new(source), status_code) + } + + pub fn next_sleep_time(&self) -> Duration { + let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32)); + let jitter = rand::random::() * self.config.backoff_jitter; + let sleep_time = Duration::from_secs_f32(backoff + jitter); + debug!( + "Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}", + self.request_id, + self.connect_failures, + self.config.connect_retries, + self.request_failures, + self.config.retries, + self.read_failures, + self.config.read_retries, + sleep_time + ); + sleep_time + } +} + +#[derive(Debug, Clone)] +pub struct ResolvedRetryConfig { + pub retries: u8, + pub connect_retries: u8, + pub read_retries: u8, + pub backoff_factor: f32, + pub backoff_jitter: f32, + pub statuses: Vec, +} + +impl TryFrom for ResolvedRetryConfig { + type Error = Error; + + fn try_from(retry_config: RetryConfig) -> crate::Result { + Ok(Self { + retries: retry_config.retries.unwrap_or(3), + connect_retries: retry_config.connect_retries.unwrap_or(3), + read_retries: retry_config.read_retries.unwrap_or(3), + backoff_factor: retry_config.backoff_factor.unwrap_or(0.25), + backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25), + statuses: retry_config + .statuses + .unwrap_or_else(|| vec![409, 429, 500, 502, 503, 504]) + .into_iter() + .map(|status| reqwest::StatusCode::from_u16(status).unwrap()) + .collect(), + }) + } +} diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 4a01d296..46d2422b 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -7,7 +7,7 @@ use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; use crate::table::{AddDataMode, AnyQuery, Filter}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error, Table}; -use arrow_array::RecordBatchReader; +use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; @@ -21,6 +21,7 @@ use lance::arrow::json::{JsonDataType, JsonSchema}; use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform, Version}; use lance_datafusion::exec::{execute_plan, OneShotExec}; +use reqwest::{RequestBuilder, Response}; use serde::{Deserialize, Serialize}; use std::io::Cursor; use std::pin::Pin; @@ -83,7 +84,7 @@ impl RemoteTable { let body = serde_json::json!({ "version": version }); request = request.json(&body); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -127,6 +128,61 @@ impl RemoteTable { Ok(reqwest::Body::wrap_stream(body_stream)) } + /// Buffer the reader into memory + async fn buffer_reader( + reader: &mut R, + ) -> Result<(SchemaRef, Vec)> { + let schema = reader.schema(); + let mut batches = Vec::new(); + for batch in reader { + batches.push(batch?); + } + Ok((schema, batches)) + } + + /// Create a new RecordBatchReader from buffered data + fn make_reader(schema: SchemaRef, batches: Vec) -> impl RecordBatchReader { + let iter = batches.into_iter().map(Ok); + RecordBatchIterator::new(iter, schema) + } + + async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> { + let res = if with_retry { + self.client.send_with_retry(req, None, true).await? + } else { + self.client.send(req).await? + }; + Ok(res) + } + + /// Send the request with streaming body. + /// This will use retries if with_retry is set and the number of configured retries is > 0. + /// If retries are enabled, the stream will be buffered into memory. + async fn send_streaming( + &self, + req: RequestBuilder, + mut data: Box, + with_retry: bool, + ) -> Result<(String, Response)> { + if !with_retry || self.client.retry_config.retries == 0 { + let body = Self::reader_as_body(data)?; + return self.client.send(req.body(body)).await; + } + + // to support retries, buffer into memory and clone the batches on each retry + let (schema, batches) = Self::buffer_reader(&mut *data).await?; + let make_body = Box::new(move || { + let reader = Self::make_reader(schema.clone(), batches.clone()); + Self::reader_as_body(Box::new(reader)) + }); + let res = self + .client + .send_with_retry(req, Some(make_body), false) + .await?; + + Ok(res) + } + async fn check_table_response( &self, request_id: &str, @@ -353,7 +409,7 @@ impl RemoteTable { .collect(); let futures = requests.into_iter().map(|req| async move { - let (request_id, response) = self.client.send(req, true).await?; + let (request_id, response) = self.send(req, true).await?; self.read_arrow_stream(&request_id, response).await }); let streams = futures::future::try_join_all(futures); @@ -471,7 +527,7 @@ impl BaseTable for RemoteTable { let body = serde_json::json!({ "version": version }); request = request.json(&body); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; self.checkout_latest().await?; Ok(()) @@ -481,7 +537,7 @@ impl BaseTable for RemoteTable { let request = self .client .post(&format!("/v1/table/{}/version/list/", self.name)); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; #[derive(Deserialize)] @@ -527,7 +583,7 @@ impl BaseTable for RemoteTable { request = request.json(&body); } - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -545,12 +601,10 @@ impl BaseTable for RemoteTable { data: Box, ) -> Result<()> { self.check_mutable().await?; - let body = Self::reader_as_body(data)?; let mut request = self .client .post(&format!("/v1/table/{}/insert/", self.name)) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) - .body(body); + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); match add.mode { AddDataMode::Append => {} @@ -559,8 +613,7 @@ impl BaseTable for RemoteTable { } } - let (request_id, response) = self.client.send(request, false).await?; - + let (request_id, response) = self.send_streaming(request, data, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) @@ -628,7 +681,7 @@ impl BaseTable for RemoteTable { .collect::>(); let futures = requests.into_iter().map(|req| async move { - let (request_id, response) = self.client.send(req, true).await?; + let (request_id, response) = self.send(req, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -670,7 +723,7 @@ impl BaseTable for RemoteTable { .collect(); let futures = requests.into_iter().map(|req| async move { - let (request_id, response) = self.client.send(req, true).await?; + let (request_id, response) = self.send(req, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -712,7 +765,7 @@ impl BaseTable for RemoteTable { "predicate": update.filter, })); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; @@ -726,7 +779,7 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/delete/", self.name)) .json(&body); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) } @@ -812,7 +865,7 @@ impl BaseTable for RemoteTable { let request = request.json(&body); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; @@ -836,21 +889,21 @@ impl BaseTable for RemoteTable { new_data: Box, ) -> Result<()> { self.check_mutable().await?; + let query = MergeInsertRequest::try_from(params)?; - let body = Self::reader_as_body(new_data)?; let request = self .client .post(&format!("/v1/table/{}/merge_insert/", self.name)) .query(&query) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) - .body(body); + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send_streaming(request, new_data, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) } + async fn optimize(&self, _action: OptimizeAction) -> Result { self.check_mutable().await?; Err(Error::NotSupported { @@ -879,7 +932,7 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/add_columns/", self.name)) .json(&body); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; // todo: self.check_table_response(&request_id, response).await?; Ok(()) } @@ -918,7 +971,7 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/alter_columns/", self.name)) .json(&body); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) } @@ -930,7 +983,7 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/drop_columns/", self.name)) .json(&body); - let (request_id, response) = self.client.send(request, false).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) } @@ -944,7 +997,7 @@ impl BaseTable for RemoteTable { let body = serde_json::json!({ "version": version }); request = request.json(&body); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; #[derive(Deserialize)] @@ -1001,7 +1054,7 @@ impl BaseTable for RemoteTable { let body = serde_json::json!({ "version": version }); request = request.json(&body); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; if response.status() == StatusCode::NOT_FOUND { return Ok(None); @@ -1011,7 +1064,6 @@ impl BaseTable for RemoteTable { let body = response.text().await.err_to_http(request_id.clone())?; - println!("body: {:?}", body); let stats = serde_json::from_str(&body).map_err(|e| Error::Http { source: format!("Failed to parse index statistics: {}", e).into(), request_id, @@ -1026,7 +1078,7 @@ impl BaseTable for RemoteTable { "/v1/table/{}/index/{}/drop/", self.name, index_name )); - let (request_id, response) = self.client.send(request, true).await?; + let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) } @@ -1487,6 +1539,42 @@ mod tests { assert_eq!(&body, &expected_body); } + #[tokio::test] + async fn test_merge_insert_retries_on_409() { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + + // Default parameters + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/"); + + let params = request.url().query_pairs().collect::>(); + assert_eq!(params["on"], "some_col"); + assert_eq!(params["when_matched_update_all"], "false"); + assert_eq!(params["when_not_matched_insert_all"], "false"); + assert_eq!(params["when_not_matched_by_source_delete"], "false"); + assert!(!params.contains_key("when_matched_update_all_filt")); + assert!(!params.contains_key("when_not_matched_by_source_delete_filt")); + + http::Response::builder().status(409).body("").unwrap() + }); + + let e = table + .merge_insert(&["some_col"]) + .execute(data) + .await + .unwrap_err(); + assert!(e.to_string().contains("Hit retry limit")); + } + #[tokio::test] async fn test_delete() { let table = Table::new_with_handler("my_table", |request| {