diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 909de4312..d84d7feb8 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -72,6 +72,10 @@ impl ServerVersion { pub fn support_structural_fts(&self) -> bool { self.0 >= semver::Version::new(0, 3, 0) } + + pub fn support_multipart_write(&self) -> bool { + self.0 >= semver::Version::new(0, 4, 0) + } } pub const OPT_REMOTE_PREFIX: &str = "remote_database_"; diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 13edae94f..244952a17 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -10,6 +10,7 @@ use super::ARROW_STREAM_CONTENT_TYPE; use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; +use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions}; use crate::index::Index; use crate::index::IndexStatistics; use crate::index::waiter::wait_for_index; @@ -23,7 +24,7 @@ use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; use crate::table::query::create_multi_vector_plan; -use crate::table::{AnyQuery, Filter, TableStatistics}; +use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics}; use crate::utils::background_cache::BackgroundCache; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error}; @@ -43,7 +44,7 @@ use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use http::header::CONTENT_TYPE; use http::{HeaderName, StatusCode}; use lance::arrow::json::{JsonDataType, JsonSchema}; @@ -614,6 +615,66 @@ impl RemoteTable { Ok(bodies) } + async fn create_multipart_write(&self) -> Result { + let request = self.client.post(&format!( + "/v1/table/{}/multipart_write/create", + self.identifier + )); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse multipart create response: {}", e).into(), + request_id, + status_code: None, + })?; + parsed["upload_id"] + .as_str() + .map(|s| s.to_string()) + .ok_or_else(|| Error::Http { + source: "Missing upload_id in multipart create response".into(), + request_id: String::new(), + status_code: None, + }) + } + + async fn complete_multipart_write(&self, upload_id: &str) -> Result { + let request = self + .client + .post(&format!( + "/v1/table/{}/multipart_write/complete", + self.identifier + )) + .query(&[("upload_id", upload_id)]); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse multipart complete response: {}", e).into(), + request_id, + status_code: None, + })?; + let version = parsed["version"].as_u64().ok_or_else(|| Error::Http { + source: "Missing version in multipart complete response".into(), + request_id: String::new(), + status_code: None, + })?; + Ok(AddResult { version }) + } + + async fn abort_multipart_write(&self, upload_id: &str) -> Result<()> { + let request = self + .client + .post(&format!( + "/v1/table/{}/multipart_write/abort", + self.identifier + )) + .query(&[("upload_id", upload_id)]); + let (request_id, response) = self.send(request, true).await?; + self.check_table_response(&request_id, response).await?; + Ok(()) + } + async fn check_mutable(&self) -> Result<()> { let read_guard = self.version.read().await; match *read_guard { @@ -817,6 +878,19 @@ mod test_utils { } pub fn new_mock_with_config(name: String, handler: F, config: ClientConfig) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + Self::new_mock_with_version_and_config(name, handler, None, config) + } + + pub fn new_mock_with_version_and_config( + name: String, + handler: F, + version: Option, + config: ClientConfig, + ) -> Self where F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, T: Into, @@ -827,7 +901,7 @@ mod test_utils { name: name.clone(), namespace: vec![], identifier: name, - server_version: ServerVersion::default(), + server_version: version.map(ServerVersion).unwrap_or_default(), version: RwLock::new(None), location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), @@ -836,6 +910,185 @@ mod test_utils { } } +impl RemoteTable { + fn is_retryable_write_error(&self, err: &Error) -> bool { + match err { + Error::Http { + source, + status_code, + .. + } => { + // Don't retry read errors (is_body/is_decode): the + // server may have committed the write already, and + // without an idempotency key we'd duplicate data. + source + .downcast_ref::() + .is_some_and(|e| e.is_connect()) + || status_code.is_some_and(|s| self.client.retry_config.statuses.contains(&s)) + } + // send_with_retry exhausted its internal retries on a retryable + // status. The outer loop can still retry the whole operation with + // a fresh session. + Error::Retry { status_code, .. } => { + status_code.is_some_and(|s| self.client.retry_config.statuses.contains(&s)) + } + _ => false, + } + } + + async fn add_single_partition(&self, output: PreprocessingOutput) -> Result { + use crate::remote::retry::RetryCounter; + + let mut insert: Arc = Arc::new(RemoteInsertExec::new( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + output.plan, + output.overwrite, + )); + + let mut retry_counter = + RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); + + loop { + let stream = execute_plan(insert.clone(), Default::default())?; + let result: Result> = stream.try_collect().await.map_err(Error::from); + + match result { + Ok(_) => { + let add_result = insert + .as_any() + .downcast_ref::>() + .and_then(|i| i.add_result()) + .unwrap_or(AddResult { version: 0 }); + + if output.overwrite { + self.invalidate_schema_cache(); + } + + return Ok(add_result); + } + Err(err) if output.rescannable && self.is_retryable_write_error(&err) => { + retry_counter.increment_from_error(err)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + insert = insert.reset_state()?; + continue; + } + Err(err) => return Err(err), + } + } + } + + async fn add_multipart( + &self, + output: PreprocessingOutput, + num_partitions: usize, + ) -> Result { + use crate::remote::retry::RetryCounter; + + let mut retry_counter = + RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); + + loop { + let upload_id = self.create_multipart_write().await?; + + let result = self + .execute_multipart_inserts(&upload_id, &output, num_partitions) + .await; + + match result { + Ok(()) => match self.complete_multipart_write(&upload_id).await { + Ok(result) => { + if output.overwrite { + self.invalidate_schema_cache(); + } + return Ok(result); + } + Err(e) => { + if let Err(abort_err) = self.abort_multipart_write(&upload_id).await { + log::warn!( + "Failed to abort multipart write {}: {}", + upload_id, + abort_err + ); + } + if output.rescannable && self.is_retryable_write_error(&e) { + retry_counter.increment_from_error(e)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + return Err(e); + } + }, + Err(e) => { + if let Err(abort_err) = self.abort_multipart_write(&upload_id).await { + log::warn!( + "Failed to abort multipart write {}: {}", + upload_id, + abort_err + ); + } + if output.rescannable && self.is_retryable_write_error(&e) { + retry_counter.increment_from_error(e)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + return Err(e); + } + } + } + } + + async fn execute_multipart_inserts( + &self, + upload_id: &str, + output: &PreprocessingOutput, + num_partitions: usize, + ) -> Result<()> { + let plan = Arc::new( + datafusion_physical_plan::repartition::RepartitionExec::try_new( + output.plan.clone(), + datafusion_physical_plan::Partitioning::RoundRobinBatch(num_partitions), + )?, + ) as Arc; + + let insert = Arc::new(RemoteInsertExec::new_multipart( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + plan, + output.overwrite, + upload_id.to_string(), + )); + + let task_ctx = Arc::new(datafusion_execution::TaskContext::default()); + let mut join_set = tokio::task::JoinSet::new(); + for partition in 0..num_partitions { + let exec = insert.clone(); + let ctx = task_ctx.clone(); + join_set.spawn(async move { + let mut stream = exec + .execute(partition, ctx) + .map_err(|e| -> Error { e.into() })?; + while let Some(batch) = stream.next().await { + batch.map_err(|e| -> Error { e.into() })?; + } + Ok::<_, Error>(()) + }); + } + + // JoinSet aborts all remaining tasks when dropped, so if we return + // early on error the orphaned tasks are automatically cancelled. + while let Some(result) = join_set.join_next().await { + result.map_err(|e| Error::Runtime { + message: format!("Insert task panicked: {}", e), + })??; + } + + Ok(()) + } +} + #[async_trait] impl BaseTable for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { @@ -986,74 +1239,44 @@ impl BaseTable for RemoteTable { status_code: None, }) } - async fn add(&self, add: AddDataBuilder) -> Result { - use crate::remote::retry::RetryCounter; - + async fn add(&self, mut add: AddDataBuilder) -> Result { self.check_mutable().await?; let table_schema = self.schema().await?; let table_def = TableDefinition::try_from_rich_schema(table_schema.clone())?; + + let num_partitions = if let Some(parallelism) = add.write_parallelism { + if parallelism > 1 && self.server_version.support_multipart_write() { + parallelism + } else { + 1 + } + } else if self.server_version.support_multipart_write() { + // Peek at the first batch to estimate write partitions, same as NativeTable. + let mut peeked = PeekedScannable::new(add.data); + let n = if let Some(first_batch) = peeked.peek().await { + let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); + estimate_write_partitions( + first_batch.get_array_memory_size(), + first_batch.num_rows(), + peeked.num_rows(), + max_partitions, + ) + } else { + 1 + }; + add.data = Box::new(peeked); + n + } else { + 1 + }; + let output = add.into_plan(&table_schema, &table_def)?; - let mut insert: Arc = Arc::new(RemoteInsertExec::new( - self.name.clone(), - self.identifier.clone(), - self.client.clone(), - output.plan, - output.overwrite, - )); - - let mut retry_counter = - RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); - - loop { - let stream = execute_plan(insert.clone(), Default::default())?; - let result: Result> = stream.try_collect().await.map_err(Error::from); - - match result { - Ok(_) => { - let add_result = insert - .as_any() - .downcast_ref::>() - .and_then(|i| i.add_result()) - .unwrap_or(AddResult { version: 0 }); - - if output.overwrite { - self.invalidate_schema_cache(); - } - - return Ok(add_result); - } - Err(err) if output.rescannable => { - let retryable = match &err { - Error::Http { - source, - status_code, - .. - } => { - // Don't retry read errors (is_body/is_decode): the - // server may have committed the write already, and - // without an idempotency key we'd duplicate data. - source - .downcast_ref::() - .is_some_and(|e| e.is_connect()) - || status_code - .is_some_and(|s| self.client.retry_config.statuses.contains(&s)) - } - _ => false, - }; - - if retryable { - retry_counter.increment_from_error(err)?; - tokio::time::sleep(retry_counter.next_sleep_time()).await; - insert = insert.reset_state()?; - continue; - } - - return Err(err); - } - Err(err) => return Err(err), - } + if num_partitions > 1 { + self.add_multipart(output, num_partitions).await + } else { + self.add_single_partition(output).await } } @@ -1811,6 +2034,7 @@ mod tests { use super::*; + use crate::remote::client::{ClientConfig, RetryConfig}; use crate::table::AddDataMode; use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; @@ -4831,4 +5055,516 @@ mod tests { assert_eq!(data.len(), 1); assert_eq!(data[0].as_ref().unwrap(), &expected_data); } + + fn schema_json() -> &'static str { + r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"# + } + + fn simple_describe_response() -> http::Response { + http::Response::builder() + .status(200) + .body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json())) + .unwrap() + } + + #[tokio::test] + async fn test_multipart_write_happy_path() { + use std::sync::Mutex; + + let create_count = Arc::new(AtomicUsize::new(0)); + let insert_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + let upload_ids = Arc::new(Mutex::new(Vec::::new())); + + let create_count_c = create_count.clone(); + let insert_count_c = insert_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + let upload_ids_c = upload_ids.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + create_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-123"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + insert_count_c.fetch_add(1, Ordering::SeqCst); + let uid = url::form_urlencoded::parse(query.as_bytes()) + .find(|(k, _)| k == "upload_id") + .map(|(_, v)| v.to_string()); + upload_ids_c + .lock() + .unwrap() + .push(uid.expect("missing upload_id on insert")); + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + let uid = url::form_urlencoded::parse(query.as_bytes()) + .find(|(k, _)| k == "upload_id") + .map(|(_, v)| v.to_string()); + upload_ids_c + .lock() + .unwrap() + .push(uid.expect("missing upload_id on complete")); + return http::Response::builder() + .status(200) + .body(r#"{"version": 5}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 5); + assert_eq!(create_count.load(Ordering::SeqCst), 1); + assert!( + insert_count.load(Ordering::SeqCst) > 1, + "Expected multiple insert calls, got {}", + insert_count.load(Ordering::SeqCst) + ); + assert_eq!(complete_count.load(Ordering::SeqCst), 1); + assert_eq!(abort_count.load(Ordering::SeqCst), 0); + + let ids = upload_ids.lock().unwrap(); + assert!( + ids.iter().all(|id| id == "test-upload-123"), + "All requests should use the same upload_id, got: {:?}", + *ids + ); + } + + #[tokio::test] + async fn test_multipart_write_fallback_old_server() { + let insert_count = Arc::new(AtomicUsize::new(0)); + let create_count = Arc::new(AtomicUsize::new(0)); + + let insert_count_c = insert_count.clone(); + let create_count_c = create_count.clone(); + + // Server version 0.3.0 does not support multipart writes + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 3, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path.contains("multipart_write") { + create_count_c.fetch_add(1, Ordering::SeqCst); + panic!("Should not call multipart write endpoints on old server"); + } + + if path == "/v1/table/my_table/insert/" { + let query = request.url().query().unwrap_or(""); + assert!( + !query.contains("upload_id"), + "Should not have upload_id for old server" + ); + insert_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 2); + assert_eq!(create_count.load(Ordering::SeqCst), 0); + assert_eq!(insert_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_small_data_single_partition() { + let insert_count = Arc::new(AtomicUsize::new(0)); + let create_count = Arc::new(AtomicUsize::new(0)); + + let insert_count_c = insert_count.clone(); + let create_count_c = create_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path.contains("multipart_write") { + create_count_c.fetch_add(1, Ordering::SeqCst); + panic!("Should not call multipart write endpoints for small data"); + } + + if path == "/v1/table/my_table/insert/" { + let query = request.url().query().unwrap_or(""); + assert!( + !query.contains("upload_id"), + "Should not have upload_id for small data" + ); + insert_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + // Small data: only 3 rows + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).execute().await.unwrap(); + + assert_eq!(result.version, 2); + assert_eq!(create_count.load(Ordering::SeqCst), 0); + assert_eq!(insert_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_abort_on_insert_failure() { + let create_count = Arc::new(AtomicUsize::new(0)); + let insert_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let insert_count_c = insert_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + create_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-456"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + let count = insert_count_c.fetch_add(1, Ordering::SeqCst); + // Fail on the first insert with non-retryable status + if count == 0 { + return http::Response::builder() + .status(400) + .body("Bad Request".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 5}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).write_parallelism(2).execute().await; + + assert!(result.is_err()); + assert_eq!(create_count.load(Ordering::SeqCst), 1); + assert_eq!(complete_count.load(Ordering::SeqCst), 0); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_abort_on_complete_failure() { + let abort_count = Arc::new(AtomicUsize::new(0)); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "test-upload-789"}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + return http::Response::builder() + .status(400) + .body("Bad Request".to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table.add(vec![batch]).write_parallelism(2).execute().await; + + assert!(result.is_err()); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } + + fn retry_config_no_backoff() -> ClientConfig { + ClientConfig { + retry_config: RetryConfig { + retries: Some(3), + connect_retries: Some(3), + read_retries: Some(3), + backoff_factor: Some(0.0), + backoff_jitter: Some(0.0), + statuses: Some(vec![502, 503]), + }, + ..Default::default() + } + } + + #[tokio::test] + async fn test_multipart_write_retry_on_partition_failure() { + // All inserts for the first upload session return 503 (retryable). + // After exhausting internal retries, the outer loop retries with a + // new session and succeeds. + let create_count = Arc::new(AtomicUsize::new(0)); + let complete_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let complete_count_c = complete_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version_and_config( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + let n = create_count_c.fetch_add(1, Ordering::SeqCst); + let body = format!(r#"{{"upload_id": "upload-{}"}}"#, n + 1); + return http::Response::builder().status(200).body(body).unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + // Fail all inserts for the first session + if query.contains("upload_id=upload-1") { + return http::Response::builder() + .status(503) + .body("Service Unavailable".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + complete_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(r#"{"version": 7}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + retry_config_no_backoff(), + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 7); + assert_eq!(create_count.load(Ordering::SeqCst), 2); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + assert_eq!(complete_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_multipart_write_retry_on_complete_failure() { + // Complete returns 503 for the first session, succeeds for the second. + let create_count = Arc::new(AtomicUsize::new(0)); + let abort_count = Arc::new(AtomicUsize::new(0)); + + let create_count_c = create_count.clone(); + let abort_count_c = abort_count.clone(); + + let table = Table::new_with_handler_version_and_config( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + let query = request.url().query().unwrap_or(""); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + + if path == "/v1/table/my_table/multipart_write/create" { + let n = create_count_c.fetch_add(1, Ordering::SeqCst); + let body = format!(r#"{{"upload_id": "upload-{}"}}"#, n + 1); + return http::Response::builder().status(200).body(body).unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/complete" { + // Fail complete for first session + if query.contains("upload_id=upload-1") { + return http::Response::builder() + .status(503) + .body("Service Unavailable".to_string()) + .unwrap(); + } + return http::Response::builder() + .status(200) + .body(r#"{"version": 9}"#.to_string()) + .unwrap(); + } + + if path == "/v1/table/my_table/multipart_write/abort" { + abort_count_c.fetch_add(1, Ordering::SeqCst); + return http::Response::builder() + .status(200) + .body(String::new()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }, + retry_config_no_backoff(), + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let result = table + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 9); + assert_eq!(create_count.load(Ordering::SeqCst), 2); + assert_eq!(abort_count.load(Ordering::SeqCst), 1); + } } diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index c8637281e..bc13010c4 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -12,7 +12,9 @@ use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; use futures::StreamExt; use http::header::CONTENT_TYPE; @@ -25,10 +27,12 @@ use crate::table::datafusion::insert::COUNT_SCHEMA; /// ExecutionPlan for inserting data into a remote LanceDB table. /// -/// This plan: -/// 1. Requires single partition (no parallel remote inserts yet) -/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint -/// 3. Stores AddResult for retrieval after execution +/// Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint. +/// +/// When `upload_id` is set, inserts are staged as part of a multipart write +/// session and the plan supports multiple partitions for parallel uploads. +/// Without `upload_id`, the plan requires a single partition and commits +/// immediately. #[derive(Debug)] pub struct RemoteInsertExec { table_name: String, @@ -38,10 +42,11 @@ pub struct RemoteInsertExec { overwrite: bool, properties: PlanProperties, add_result: Arc>>, + upload_id: Option, } impl RemoteInsertExec { - /// Create a new RemoteInsertExec. + /// Create a new single-partition RemoteInsertExec. pub fn new( table_name: String, identifier: String, @@ -49,10 +54,49 @@ impl RemoteInsertExec { input: Arc, overwrite: bool, ) -> Self { + Self::new_inner(table_name, identifier, client, input, overwrite, None) + } + + /// Create a multi-partition RemoteInsertExec for use with multipart writes. + /// + /// Each partition's insert is staged under the given `upload_id` without + /// committing. The caller is responsible for calling the complete (or abort) + /// endpoint after all partitions finish. + pub fn new_multipart( + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + upload_id: String, + ) -> Self { + Self::new_inner( + table_name, + identifier, + client, + input, + overwrite, + Some(upload_id), + ) + } + + fn new_inner( + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + upload_id: Option, + ) -> Self { + let num_partitions = if upload_id.is_some() { + input.output_partitioning().partition_count() + } else { + 1 + }; let schema = COUNT_SCHEMA.clone(); let properties = PlanProperties::new( EquivalenceProperties::new(schema), - datafusion_physical_plan::Partitioning::UnknownPartitioning(1), + datafusion_physical_plan::Partitioning::UnknownPartitioning(num_partitions), datafusion_physical_plan::execution_plan::EmissionType::Final, datafusion_physical_plan::execution_plan::Boundedness::Bounded, ); @@ -65,6 +109,7 @@ impl RemoteInsertExec { overwrite, properties, add_result: Arc::new(Mutex::new(None)), + upload_id, } } @@ -174,8 +219,11 @@ impl ExecutionPlan for RemoteInsertExec { } fn required_input_distribution(&self) -> Vec { - // Until we have a separate commit endpoint, we need to do all inserts in a single partition - vec![datafusion_physical_plan::Distribution::SinglePartition] + if self.upload_id.is_some() { + vec![datafusion_physical_plan::Distribution::UnspecifiedDistribution] + } else { + vec![datafusion_physical_plan::Distribution::SinglePartition] + } } fn benefits_from_input_partitioning(&self) -> Vec { @@ -191,12 +239,13 @@ impl ExecutionPlan for RemoteInsertExec { "RemoteInsertExec requires exactly one child".to_string(), )); } - Ok(Arc::new(Self::new( + Ok(Arc::new(Self::new_inner( self.table_name.clone(), self.identifier.clone(), self.client.clone(), children[0].clone(), self.overwrite, + self.upload_id.clone(), ))) } @@ -205,18 +254,20 @@ impl ExecutionPlan for RemoteInsertExec { partition: usize, context: Arc, ) -> DataFusionResult { - if partition != 0 { + if self.upload_id.is_none() && partition != 0 { return Err(DataFusionError::Internal( - "RemoteInsertExec only supports single partition execution".to_string(), + "RemoteInsertExec only supports single partition execution without upload_id" + .to_string(), )); } - let input_stream = self.input.execute(0, context)?; + let input_stream = self.input.execute(partition, context)?; let client = self.client.clone(); let identifier = self.identifier.clone(); let overwrite = self.overwrite; let add_result = self.add_result.clone(); let table_name = self.table_name.clone(); + let upload_id = self.upload_id.clone(); let stream = futures::stream::once(async move { let mut request = client @@ -226,6 +277,9 @@ impl ExecutionPlan for RemoteInsertExec { if overwrite { request = request.query(&[("mode", "overwrite")]); } + if let Some(ref uid) = upload_id { + request = request.query(&[("upload_id", uid.as_str())]); + } let (error_tx, mut error_rx) = tokio::sync::oneshot::channel(); let body = Self::stream_as_http_body(input_stream, error_tx)?; @@ -262,28 +316,30 @@ impl ExecutionPlan for RemoteInsertExec { let (request_id, response) = result?; - let body_text = response.text().await.map_err(|e| { - DataFusionError::External(Box::new(Error::Http { - source: Box::new(e), - request_id: request_id.clone(), - status_code: None, - })) - })?; - - let parsed_result = if body_text.trim().is_empty() { - // Backward compatible with old servers - AddResult { version: 0 } - } else { - serde_json::from_str(&body_text).map_err(|e| { + // For multipart writes, the staging response is not the final + // version. Only parse AddResult for non-multipart inserts. + if upload_id.is_none() { + let body_text = response.text().await.map_err(|e| { DataFusionError::External(Box::new(Error::Http { - source: format!("Failed to parse add response: {}", e).into(), + source: Box::new(e), request_id: request_id.clone(), status_code: None, })) - })? - }; + })?; + + let parsed_result = if body_text.trim().is_empty() { + // Backward compatible with old servers + AddResult { version: 0 } + } else { + serde_json::from_str(&body_text).map_err(|e| { + DataFusionError::External(Box::new(Error::Http { + source: format!("Failed to parse add response: {}", e).into(), + request_id: request_id.clone(), + status_code: None, + })) + })? + }; - { let mut res_lock = add_result.lock().map_err(|_| { DataFusionError::Execution("Failed to acquire lock for add_result".to_string()) })?; diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index db0636a1c..da9c8283b 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -75,6 +75,7 @@ pub mod query; pub mod schema_evolution; pub mod update; use crate::index::waiter::wait_for_index; +pub(crate) use add_data::PreprocessingOutput; pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; @@ -440,6 +441,34 @@ mod test_utils { embedding_registry: Arc::new(MemoryRegistry::new()), } } + + pub fn new_with_handler_version_and_config( + name: impl Into, + version: semver::Version, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + config: crate::remote::ClientConfig, + ) -> Self + where + T: Into, + { + let inner = Arc::new( + crate::remote::table::RemoteTable::new_mock_with_version_and_config( + name.into(), + handler.clone(), + Some(version), + config.clone(), + ), + ); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config( + handler, config, + )); + Self { + inner, + database: Some(database), + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } } } @@ -2198,21 +2227,26 @@ impl BaseTable for NativeTable { let table_schema = Schema::from(&ds.schema().clone()); - // Peek at the first batch to estimate a good partition count for - // write parallelism. - let mut peeked = PeekedScannable::new(add.data); - let num_partitions = if let Some(first_batch) = peeked.peek().await { - let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); - estimate_write_partitions( - first_batch.get_array_memory_size(), - first_batch.num_rows(), - peeked.num_rows(), - max_partitions, - ) + let num_partitions = if let Some(parallelism) = add.write_parallelism { + parallelism } else { - 1 + // Peek at the first batch to estimate a good partition count for + // write parallelism. + let mut peeked = PeekedScannable::new(add.data); + let n = if let Some(first_batch) = peeked.peek().await { + let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus(); + estimate_write_partitions( + first_batch.get_array_memory_size(), + first_batch.num_rows(), + peeked.num_rows(), + max_partitions, + ) + } else { + 1 + }; + add.data = Box::new(peeked); + n }; - add.data = Box::new(peeked); let output = add.into_plan(&table_schema, &table_def)?; diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index 5921c54ea..bbafd6ce2 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -52,6 +52,7 @@ pub struct AddDataBuilder { pub(crate) write_options: WriteOptions, pub(crate) on_nan_vectors: NaNVectorBehavior, pub(crate) embedding_registry: Option>, + pub(crate) write_parallelism: Option, } impl std::fmt::Debug for AddDataBuilder { @@ -77,6 +78,7 @@ impl AddDataBuilder { write_options: WriteOptions::default(), on_nan_vectors: NaNVectorBehavior::default(), embedding_registry, + write_parallelism: None, } } @@ -101,7 +103,22 @@ impl AddDataBuilder { self } + /// Set the number of parallel write streams. + /// + /// By default, the number of streams is estimated from the data size. + /// Setting this to `1` disables parallel writes. + pub fn write_parallelism(mut self, parallelism: usize) -> Self { + self.write_parallelism = Some(parallelism); + self + } + pub async fn execute(self) -> Result { + if self.write_parallelism.map(|p| p == 0).unwrap_or(false) { + return Err(Error::InvalidInput { + message: "write_parallelism must be greater than 0".to_string(), + }); + } + self.parent.clone().add(self).await }