From e6fd8d071ed089e72cc74e4be5c75a78bea30b9b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 20 Mar 2026 13:19:07 -0700 Subject: [PATCH] feat(rust): parallel inserts for remote tables via multipart write (#3071) Similar to https://github.com/lancedb/lancedb/pull/3062, we can write in parallel to remote tables if the input data source is large enough. We take advantage of new endpoints coming in server version 0.4.0, which allow writing data in multiple requests, and the committing at the end in a single request. To make testing easier, I also introduce a `write_parallelism` parameter. In the future, we can expose that in Python and NodeJS so users can manually specify the parallelism they get. Closes #2861 --------- Co-authored-by: Claude Opus 4.6 --- rust/lancedb/src/remote/db.rs | 4 + rust/lancedb/src/remote/table.rs | 866 ++++++++++++++++++++++-- rust/lancedb/src/remote/table/insert.rs | 116 +++- rust/lancedb/src/table.rs | 60 +- rust/lancedb/src/table/add_data.rs | 17 + 5 files changed, 955 insertions(+), 108 deletions(-) diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 909de4312..d84d7feb8 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -72,6 +72,10 @@ impl ServerVersion { pub fn support_structural_fts(&self) -> bool { self.0 >= semver::Version::new(0, 3, 0) } + + pub fn support_multipart_write(&self) -> bool { + self.0 >= semver::Version::new(0, 4, 0) + } } pub const OPT_REMOTE_PREFIX: &str = "remote_database_"; diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 13edae94f..244952a17 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -10,6 +10,7 @@ use super::ARROW_STREAM_CONTENT_TYPE; use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; +use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions}; use crate::index::Index; use crate::index::IndexStatistics; use crate::index::waiter::wait_for_index; @@ -23,7 +24,7 @@ use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; use crate::table::query::create_multi_vector_plan; -use crate::table::{AnyQuery, Filter, TableStatistics}; +use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics}; use crate::utils::background_cache::BackgroundCache; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error}; @@ -43,7 +44,7 @@ use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use http::header::CONTENT_TYPE; use http::{HeaderName, StatusCode}; use lance::arrow::json::{JsonDataType, JsonSchema}; @@ -614,6 +615,66 @@ impl RemoteTable { Ok(bodies) } + async fn create_multipart_write(&self) -> Result { + let request = self.client.post(&format!( + "/v1/table/{}/multipart_write/create", + self.identifier + )); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse multipart create response: {}", e).into(), + request_id, + status_code: None, + })?; + parsed["upload_id"] + .as_str() + .map(|s| s.to_string()) + .ok_or_else(|| Error::Http { + source: "Missing upload_id in multipart create response".into(), + request_id: String::new(), + status_code: None, + }) + } + + async fn complete_multipart_write(&self, upload_id: &str) -> Result { + let request = self + .client + .post(&format!( + "/v1/table/{}/multipart_write/complete", + self.identifier + )) + .query(&[("upload_id", upload_id)]); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse multipart complete response: {}", e).into(), + request_id, + status_code: None, + })?; + let version = parsed["version"].as_u64().ok_or_else(|| Error::Http { + source: "Missing version in multipart complete response".into(), + request_id: String::new(), + status_code: None, + })?; + Ok(AddResult { version }) + } + + async fn abort_multipart_write(&self, upload_id: &str) -> Result<()> { + let request = self + .client + .post(&format!( + "/v1/table/{}/multipart_write/abort", + self.identifier + )) + .query(&[("upload_id", upload_id)]); + let (request_id, response) = self.send(request, true).await?; + self.check_table_response(&request_id, response).await?; + Ok(()) + } + async fn check_mutable(&self) -> Result<()> { let read_guard = self.version.read().await; match *read_guard { @@ -817,6 +878,19 @@ mod test_utils { } pub fn new_mock_with_config(name: String, handler: F, config: ClientConfig) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + Self::new_mock_with_version_and_config(name, handler, None, config) + } + + pub fn new_mock_with_version_and_config( + name: String, + handler: F, + version: Option, + config: ClientConfig, + ) -> Self where F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, T: Into, @@ -827,7 +901,7 @@ mod test_utils { name: name.clone(), namespace: vec![], identifier: name, - server_version: ServerVersion::default(), + server_version: version.map(ServerVersion).unwrap_or_default(), version: RwLock::new(None), location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), @@ -836,6 +910,185 @@ mod test_utils { } } +impl RemoteTable { + fn is_retryable_write_error(&self, err: &Error) -> bool { + match err { + Error::Http { + source, + status_code, + .. + } => { + // Don't retry read errors (is_body/is_decode): the + // server may have committed the write already, and + // without an idempotency key we'd duplicate data. + source + .downcast_ref::() + .is_some_and(|e| e.is_connect()) + || status_code.is_some_and(|s| self.client.retry_config.statuses.contains(&s)) + } + // send_with_retry exhausted its internal retries on a retryable + // status. The outer loop can still retry the whole operation with + // a fresh session. + Error::Retry { status_code, .. } => { + status_code.is_some_and(|s| self.client.retry_config.statuses.contains(&s)) + } + _ => false, + } + } + + async fn add_single_partition(&self, output: PreprocessingOutput) -> Result { + use crate::remote::retry::RetryCounter; + + let mut insert: Arc = Arc::new(RemoteInsertExec::new( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + output.plan, + output.overwrite, + )); + + let mut retry_counter = + RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); + + loop { + let stream = execute_plan(insert.clone(), Default::default())?; + let result: Result> = stream.try_collect().await.map_err(Error::from); + + match result { + Ok(_) => { + let add_result = insert + .as_any() + .downcast_ref::>() + .and_then(|i| i.add_result()) + .unwrap_or(AddResult { version: 0 }); + + if output.overwrite { + self.invalidate_schema_cache(); + } + + return Ok(add_result); + } + Err(err) if output.rescannable && self.is_retryable_write_error(&err) => { + retry_counter.increment_from_error(err)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + insert = insert.reset_state()?; + continue; + } + Err(err) => return Err(err), + } + } + } + + async fn add_multipart( + &self, + output: PreprocessingOutput, + num_partitions: usize, + ) -> Result { + use crate::remote::retry::RetryCounter; + + let mut retry_counter = + RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); + + loop { + let upload_id = self.create_multipart_write().await?; + + let result = self + .execute_multipart_inserts(&upload_id, &output, num_partitions) + .await; + + match result { + Ok(()) => match self.complete_multipart_write(&upload_id).await { + Ok(result) => { + if output.overwrite { + self.invalidate_schema_cache(); + } + return Ok(result); + } + Err(e) => { + if let Err(abort_err) = self.abort_multipart_write(&upload_id).await { + log::warn!( + "Failed to abort multipart write {}: {}", + upload_id, + abort_err + ); + } + if output.rescannable && self.is_retryable_write_error(&e) { + retry_counter.increment_from_error(e)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + return Err(e); + } + }, + Err(e) => { + if let Err(abort_err) = self.abort_multipart_write(&upload_id).await { + log::warn!( + "Failed to abort multipart write {}: {}", + upload_id, + abort_err + ); + } + if output.rescannable && self.is_retryable_write_error(&e) { + retry_counter.increment_from_error(e)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + return Err(e); + } + } + } + } + + async fn execute_multipart_inserts( + &self, + upload_id: &str, + output: &PreprocessingOutput, + num_partitions: usize, + ) -> Result<()> { + let plan = Arc::new( + datafusion_physical_plan::repartition::RepartitionExec::try_new( + output.plan.clone(), + datafusion_physical_plan::Partitioning::RoundRobinBatch(num_partitions), + )?, + ) as Arc; + + let insert = Arc::new(RemoteInsertExec::new_multipart( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + plan, + output.overwrite, + upload_id.to_string(), + )); + + let task_ctx = Arc::new(datafusion_execution::TaskContext::default()); + let mut join_set = tokio::task::JoinSet::new(); + for partition in 0..num_partitions { + let exec = insert.clone(); + let ctx = task_ctx.clone(); + join_set.spawn(async move { + let mut stream = exec + .execute(partition, ctx) + .map_err(|e| -> Error { e.into() })?; + while let Some(batch) = stream.next().await { + batch.map_err(|e| -> Error { e.into() })?; + } + Ok::<_, Error>(()) + }); + } + + // JoinSet aborts all remaining tasks when dropped, so if we return + // early on error the orphaned tasks are automatically cancelled. + while let Some(result) = join_set.join_next().await { + result.map_err(|e| Error::Runtime { + message: format!("Insert task panicked: {}", e), + })??; + } + + Ok(()) + } +} + #[async_trait] impl BaseTable for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { @@ -986,74 +1239,44 @@ impl BaseTable for RemoteTable { status_code: None, }) } - async fn add(&self, add: AddDataBuilder) -> Result { - use crate::remote::retry::RetryCounter; - + async fn add(&self, mut add: AddDataBuilder) -> Result { self.check_mutable().await?; let table_schema = self.schema().await?; let table_def = TableDefinition::try_from_rich_schema(table_schema.clone())?; + + let num_partitions = if let Some(parallelism) = add.write_parallelism { + if parallelism > 1 && self.server_version.support_multipart_write() { + parallelism + } else { + 1 + } + } else if self.server_version.support_multipart_write() { + // Peek at the first batch to estimate write partitions, same as NativeTable. + let mut peeked = PeekedScannable::new(add.data); + let n = if let Some(first_batch) = peeked.peek().await { + let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); + estimate_write_partitions( + first_batch.get_array_memory_size(), + first_batch.num_rows(), + peeked.num_rows(), + max_partitions, + ) + } else { + 1 + }; + add.data = Box::new(peeked); + n + } else { + 1 + }; + let output = add.into_plan(&table_schema, &table_def)?; - let mut insert: Arc = Arc::new(RemoteInsertExec::new( - self.name.clone(), - self.identifier.clone(), - self.client.clone(), - output.plan, - output.overwrite, - )); - - let mut retry_counter = - RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); - - loop { - let stream = execute_plan(insert.clone(), Default::default())?; - let result: Result> = stream.try_collect().await.map_err(Error::from); - - match result { - Ok(_) => { - let add_result = insert - .as_any() - .downcast_ref::>() - .and_then(|i| i.add_result()) - .unwrap_or(AddResult { version: 0 }); - - if output.overwrite { - self.invalidate_schema_cache(); - } - - return Ok(add_result); - } - Err(err) if output.rescannable => { - let retryable = match &err { - Error::Http { - source, - status_code, - .. - } => { - // Don't retry read errors (is_body/is_decode): the - // server may have committed the write already, and - // without an idempotency key we'd duplicate data. - source - .downcast_ref::() - .is_some_and(|e| e.is_connect()) - || status_code - .is_some_and(|s| self.client.retry_config.statuses.contains(&s)) - } - _ => false, - }; - - if retryable { - retry_counter.increment_from_error(err)?; - tokio::time::sleep(retry_counter.next_sleep_time()).await; - insert = insert.reset_state()?; - continue; - } - - return Err(err); - } - Err(err) => return Err(err), - } + if num_partitions > 1 { + self.add_multipart(output, num_partitions).await + } else { + self.add_single_partition(output).await } } @@ -1811,6 +2034,7 @@ mod tests { use super::*; + use crate::remote::client::{ClientConfig, RetryConfig}; use crate::table::AddDataMode; use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; @@ -4831,4 +5055,516 @@ mod tests { assert_eq!(data.len(), 1); assert_eq!(data[0].as_ref().unwrap(), &expected_data); } + + fn schema_json() -> &'static str { + r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"# + } + + fn simple_describe_response() -> http::Response { + http::Response::builder() + .status(200) + .body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json())) + .unwrap() + } + + #[tokio::test] + async fn test_multipart_write_happy_path() { + use std::sync::Mutex; + + let create_count = Arc::new(AtomicUsize::new(0)); + let insert_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + let upload_ids = Arc::new(Mutex::new(Vec::::new())); + + let create_count_c = create_count.clone(); + let insert_count_c = insert_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + let upload_ids_c = upload_ids.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + create_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-123"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + insert_count_c.fetch_add(1, Ordering::SeqCst); + let uid = url::form_urlencoded::parse(query.as_bytes()) + .find(|(k, _)| k == "upload_id") + .map(|(_, v)| v.to_string()); + upload_ids_c + .lock() + .unwrap() + .push(uid.expect("missing upload_id on insert")); + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + let uid = url::form_urlencoded::parse(query.as_bytes()) + .find(|(k, _)| k == "upload_id") + .map(|(_, v)| v.to_string()); + upload_ids_c + .lock() + .unwrap() + .push(uid.expect("missing upload_id on complete")); + return http::Response::builder() + .status(200) + .body(r#"{"version": 5}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 5); + assert_eq!(create_count.load(Ordering::SeqCst), 1); + assert!( + insert_count.load(Ordering::SeqCst) > 1, + "Expected multiple insert calls, got {}", + insert_count.load(Ordering::SeqCst) + ); + assert_eq!(complete_count.load(Ordering::SeqCst), 1); + assert_eq!(abort_count.load(Ordering::SeqCst), 0); + + let ids = upload_ids.lock().unwrap(); + assert!( + ids.iter().all(|id| id == "test-upload-123"), + "All requests should use the same upload_id, got: {:?}", + *ids + ); + } + + #[tokio::test] + async fn test_multipart_write_fallback_old_server() { + let insert_count = Arc::new(AtomicUsize::new(0)); + let create_count = Arc::new(AtomicUsize::new(0)); + + let insert_count_c = insert_count.clone(); + let create_count_c = create_count.clone(); + + // Server version 0.3.0 does not support multipart writes + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 3, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path.contains("multipart_write") { + create_count_c.fetch_add(1, Ordering::SeqCst); + panic!("Should not call multipart write endpoints on old server"); + } + + if path == "/v1/table/my_table/insert/" { + let query = request.url().query().unwrap_or(""); + assert!( + !query.contains("upload_id"), + "Should not have upload_id for old server" + ); + insert_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 2); + assert_eq!(create_count.load(Ordering::SeqCst), 0); + assert_eq!(insert_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_small_data_single_partition() { + let insert_count = Arc::new(AtomicUsize::new(0)); + let create_count = Arc::new(AtomicUsize::new(0)); + + let insert_count_c = insert_count.clone(); + let create_count_c = create_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path.contains("multipart_write") { + create_count_c.fetch_add(1, Ordering::SeqCst); + panic!("Should not call multipart write endpoints for small data"); + } + + if path == "/v1/table/my_table/insert/" { + let query = request.url().query().unwrap_or(""); + assert!( + !query.contains("upload_id"), + "Should not have upload_id for small data" + ); + insert_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + // Small data: only 3 rows + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).execute().await.unwrap(); + + assert_eq!(result.version, 2); + assert_eq!(create_count.load(Ordering::SeqCst), 0); + assert_eq!(insert_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_abort_on_insert_failure() { + let create_count = Arc::new(AtomicUsize::new(0)); + let insert_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let insert_count_c = insert_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + create_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-456"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + let count = insert_count_c.fetch_add(1, Ordering::SeqCst); + // Fail on the first insert with non-retryable status + if count == 0 { + return http::Response::builder() + .status(400) + .body("Bad Request".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 5}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).write_parallelism(2).execute().await; + + assert!(result.is_err()); + assert_eq!(create_count.load(Ordering::SeqCst), 1); + assert_eq!(complete_count.load(Ordering::SeqCst), 0); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_abort_on_complete_failure() { + let abort_count = Arc::new(AtomicUsize::new(0)); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-789"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + return http::Response::builder() + .status(400) + .body("Bad Request".to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).write_parallelism(2).execute().await; + + assert!(result.is_err()); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } + + fn retry_config_no_backoff() -> ClientConfig { + ClientConfig { + retry_config: RetryConfig { + retries: Some(3), + connect_retries: Some(3), + read_retries: Some(3), + backoff_factor: Some(0.0), + backoff_jitter: Some(0.0), + statuses: Some(vec![502, 503]), + }, + ..Default::default() + } + } + + #[tokio::test] + async fn test_multipart_write_retry_on_partition_failure() { + // All inserts for the first upload session return 503 (retryable). + // After exhausting internal retries, the outer loop retries with a + // new session and succeeds. + let create_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version_and_config( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + let n = create_count_c.fetch_add(1, Ordering::SeqCst); + let body = format!(r#"{{"upload_id": "upload-{}"}}"#, n + 1); + return http::Response::builder().status(200).body(body).unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + // Fail all inserts for the first session + if query.contains("upload_id=upload-1") { + return http::Response::builder() + .status(503) + .body("Service Unavailable".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 7}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + retry_config_no_backoff(), + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 7); + assert_eq!(create_count.load(Ordering::SeqCst), 2); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + assert_eq!(complete_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_retry_on_complete_failure() { + // Complete returns 503 for the first session, succeeds for the second. + let create_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version_and_config( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + let n = create_count_c.fetch_add(1, Ordering::SeqCst); + let body = format!(r#"{{"upload_id": "upload-{}"}}"#, n + 1); + return http::Response::builder().status(200).body(body).unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + // Fail complete for first session + if query.contains("upload_id=upload-1") { + return http::Response::builder() + .status(503) + .body("Service Unavailable".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 9}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + retry_config_no_backoff(), + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 9); + assert_eq!(create_count.load(Ordering::SeqCst), 2); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } } diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index c8637281e..bc13010c4 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -12,7 +12,9 @@ use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; use futures::StreamExt; use http::header::CONTENT_TYPE; @@ -25,10 +27,12 @@ use crate::table::datafusion::insert::COUNT_SCHEMA; /// ExecutionPlan for inserting data into a remote LanceDB table. /// -/// This plan: -/// 1. Requires single partition (no parallel remote inserts yet) -/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint -/// 3. Stores AddResult for retrieval after execution +/// Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint. +/// +/// When `upload_id` is set, inserts are staged as part of a multipart write +/// session and the plan supports multiple partitions for parallel uploads. +/// Without `upload_id`, the plan requires a single partition and commits +/// immediately. #[derive(Debug)] pub struct RemoteInsertExec { table_name: String, @@ -38,10 +42,11 @@ pub struct RemoteInsertExec { overwrite: bool, properties: PlanProperties, add_result: Arc>>, + upload_id: Option, } impl RemoteInsertExec { - /// Create a new RemoteInsertExec. + /// Create a new single-partition RemoteInsertExec. pub fn new( table_name: String, identifier: String, @@ -49,10 +54,49 @@ impl RemoteInsertExec { input: Arc, overwrite: bool, ) -> Self { + Self::new_inner(table_name, identifier, client, input, overwrite, None) + } + + /// Create a multi-partition RemoteInsertExec for use with multipart writes. + /// + /// Each partition's insert is staged under the given `upload_id` without + /// committing. The caller is responsible for calling the complete (or abort) + /// endpoint after all partitions finish. + pub fn new_multipart( + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + upload_id: String, + ) -> Self { + Self::new_inner( + table_name, + identifier, + client, + input, + overwrite, + Some(upload_id), + ) + } + + fn new_inner( + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + upload_id: Option, + ) -> Self { + let num_partitions = if upload_id.is_some() { + input.output_partitioning().partition_count() + } else { + 1 + }; let schema = COUNT_SCHEMA.clone(); let properties = PlanProperties::new( EquivalenceProperties::new(schema), - datafusion_physical_plan::Partitioning::UnknownPartitioning(1), + datafusion_physical_plan::Partitioning::UnknownPartitioning(num_partitions), datafusion_physical_plan::execution_plan::EmissionType::Final, datafusion_physical_plan::execution_plan::Boundedness::Bounded, ); @@ -65,6 +109,7 @@ impl RemoteInsertExec { overwrite, properties, add_result: Arc::new(Mutex::new(None)), + upload_id, } } @@ -174,8 +219,11 @@ impl ExecutionPlan for RemoteInsertExec { } fn required_input_distribution(&self) -> Vec { - // Until we have a separate commit endpoint, we need to do all inserts in a single partition - vec![datafusion_physical_plan::Distribution::SinglePartition] + if self.upload_id.is_some() { + vec![datafusion_physical_plan::Distribution::UnspecifiedDistribution] + } else { + vec![datafusion_physical_plan::Distribution::SinglePartition] + } } fn benefits_from_input_partitioning(&self) -> Vec { @@ -191,12 +239,13 @@ impl ExecutionPlan for RemoteInsertExec { "RemoteInsertExec requires exactly one child".to_string(), )); } - Ok(Arc::new(Self::new( + Ok(Arc::new(Self::new_inner( self.table_name.clone(), self.identifier.clone(), self.client.clone(), children[0].clone(), self.overwrite, + self.upload_id.clone(), ))) } @@ -205,18 +254,20 @@ impl ExecutionPlan for RemoteInsertExec { partition: usize, context: Arc, ) -> DataFusionResult { - if partition != 0 { + if self.upload_id.is_none() && partition != 0 { return Err(DataFusionError::Internal( - "RemoteInsertExec only supports single partition execution".to_string(), + "RemoteInsertExec only supports single partition execution without upload_id" + .to_string(), )); } - let input_stream = self.input.execute(0, context)?; + let input_stream = self.input.execute(partition, context)?; let client = self.client.clone(); let identifier = self.identifier.clone(); let overwrite = self.overwrite; let add_result = self.add_result.clone(); let table_name = self.table_name.clone(); + let upload_id = self.upload_id.clone(); let stream = futures::stream::once(async move { let mut request = client @@ -226,6 +277,9 @@ impl ExecutionPlan for RemoteInsertExec { if overwrite { request = request.query(&[("mode", "overwrite")]); } + if let Some(ref uid) = upload_id { + request = request.query(&[("upload_id", uid.as_str())]); + } let (error_tx, mut error_rx) = tokio::sync::oneshot::channel(); let body = Self::stream_as_http_body(input_stream, error_tx)?; @@ -262,28 +316,30 @@ impl ExecutionPlan for RemoteInsertExec { let (request_id, response) = result?; - let body_text = response.text().await.map_err(|e| { - DataFusionError::External(Box::new(Error::Http { - source: Box::new(e), - request_id: request_id.clone(), - status_code: None, - })) - })?; - - let parsed_result = if body_text.trim().is_empty() { - // Backward compatible with old servers - AddResult { version: 0 } - } else { - serde_json::from_str(&body_text).map_err(|e| { + // For multipart writes, the staging response is not the final + // version. Only parse AddResult for non-multipart inserts. + if upload_id.is_none() { + let body_text = response.text().await.map_err(|e| { DataFusionError::External(Box::new(Error::Http { - source: format!("Failed to parse add response: {}", e).into(), + source: Box::new(e), request_id: request_id.clone(), status_code: None, })) - })? - }; + })?; + + let parsed_result = if body_text.trim().is_empty() { + // Backward compatible with old servers + AddResult { version: 0 } + } else { + serde_json::from_str(&body_text).map_err(|e| { + DataFusionError::External(Box::new(Error::Http { + source: format!("Failed to parse add response: {}", e).into(), + request_id: request_id.clone(), + status_code: None, + })) + })? + }; - { let mut res_lock = add_result.lock().map_err(|_| { DataFusionError::Execution("Failed to acquire lock for add_result".to_string()) })?; diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index db0636a1c..da9c8283b 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -75,6 +75,7 @@ pub mod query; pub mod schema_evolution; pub mod update; use crate::index::waiter::wait_for_index; +pub(crate) use add_data::PreprocessingOutput; pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; @@ -440,6 +441,34 @@ mod test_utils { embedding_registry: Arc::new(MemoryRegistry::new()), } } + + pub fn new_with_handler_version_and_config( + name: impl Into, + version: semver::Version, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + config: crate::remote::ClientConfig, + ) -> Self + where + T: Into, + { + let inner = Arc::new( + crate::remote::table::RemoteTable::new_mock_with_version_and_config( + name.into(), + handler.clone(), + Some(version), + config.clone(), + ), + ); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config( + handler, config, + )); + Self { + inner, + database: Some(database), + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } } } @@ -2198,21 +2227,26 @@ impl BaseTable for NativeTable { let table_schema = Schema::from(&ds.schema().clone()); - // Peek at the first batch to estimate a good partition count for - // write parallelism. - let mut peeked = PeekedScannable::new(add.data); - let num_partitions = if let Some(first_batch) = peeked.peek().await { - let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); - estimate_write_partitions( - first_batch.get_array_memory_size(), - first_batch.num_rows(), - peeked.num_rows(), - max_partitions, - ) + let num_partitions = if let Some(parallelism) = add.write_parallelism { + parallelism } else { - 1 + // Peek at the first batch to estimate a good partition count for + // write parallelism. + let mut peeked = PeekedScannable::new(add.data); + let n = if let Some(first_batch) = peeked.peek().await { + let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); + estimate_write_partitions( + first_batch.get_array_memory_size(), + first_batch.num_rows(), + peeked.num_rows(), + max_partitions, + ) + } else { + 1 + }; + add.data = Box::new(peeked); + n }; - add.data = Box::new(peeked); let output = add.into_plan(&table_schema, &table_def)?; diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index 5921c54ea..bbafd6ce2 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -52,6 +52,7 @@ pub struct AddDataBuilder { pub(crate) write_options: WriteOptions, pub(crate) on_nan_vectors: NaNVectorBehavior, pub(crate) embedding_registry: Option>, + pub(crate) write_parallelism: Option, } impl std::fmt::Debug for AddDataBuilder { @@ -77,6 +78,7 @@ impl AddDataBuilder { write_options: WriteOptions::default(), on_nan_vectors: NaNVectorBehavior::default(), embedding_registry, + write_parallelism: None, } } @@ -101,7 +103,22 @@ impl AddDataBuilder { self } + /// Set the number of parallel write streams. + /// + /// By default, the number of streams is estimated from the data size. + /// Setting this to `1` disables parallel writes. + pub fn write_parallelism(mut self, parallelism: usize) -> Self { + self.write_parallelism = Some(parallelism); + self + } + pub async fn execute(self) -> Result { + if self.write_parallelism.map(|p| p == 0).unwrap_or(false) { + return Err(Error::InvalidInput { + message: "write_parallelism must be greater than 0".to_string(), + }); + } + self.parent.clone().add(self).await }