diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 04602c49f..5ddc1d795 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -145,6 +145,7 @@ impl From for lancedb::remote::ClientConfig { id_delimiter: config.id_delimiter, tls_config: config.tls_config.map(Into::into), header_provider: None, // the header provider is set separately later + mem_wal_enabled: None, // mem_wal is set per-operation in merge_insert } } } diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index b2564740c..a0362a604 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -34,6 +34,7 @@ class LanceMergeInsertBuilder(object): self._when_not_matched_by_source_condition = None self._timeout = None self._use_index = True + self._mem_wal = False def when_matched_update_all( self, *, where: Optional[str] = None @@ -96,6 +97,47 @@ class LanceMergeInsertBuilder(object): self._use_index = use_index return self + def mem_wal(self, enabled: bool = True) -> LanceMergeInsertBuilder: + """ + Enable MemWAL (Memory Write-Ahead Log) mode for this merge insert operation. + + When enabled, the merge insert will route data through a memory node service + that buffers writes before flushing to storage. This is only supported for + remote (LanceDB Cloud) tables. + + **Important:** MemWAL only supports the upsert pattern. You must use: + - `when_matched_update_all()` (without a filter condition) + - `when_not_matched_insert_all()` + + MemWAL does NOT support: + - `when_matched_update_all(where=...)` with a filter condition + - `when_not_matched_by_source_delete()` + + Parameters + ---------- + enabled: bool + Whether to enable MemWAL mode. Defaults to `True`. + + Raises + ------ + NotImplementedError + If used on a native (local) table, as MemWAL is only supported for + remote tables. + ValueError + If the merge insert pattern is not supported by MemWAL. + + Examples + -------- + >>> # Correct usage with MemWAL + >>> table.merge_insert(["id"]) \\ + ... .when_matched_update_all() \\ + ... .when_not_matched_insert_all() \\ + ... .mem_wal() \\ + ... .execute(new_data) + """ + self._mem_wal = enabled + return self + def execute( self, new_data: DATA, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e19449cc8..c61f29c2c 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -4181,6 +4181,7 @@ class AsyncTable: when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition, timeout=merge._timeout, use_index=merge._use_index, + mem_wal=merge._mem_wal, ), ) diff --git a/python/src/connection.rs b/python/src/connection.rs index a8b218a8e..341054fd1 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -506,6 +506,7 @@ pub struct PyClientConfig { id_delimiter: Option, tls_config: Option, header_provider: Option>, + mem_wal_enabled: Option, } #[derive(FromPyObject)] @@ -590,6 +591,7 @@ impl From for lancedb::remote::ClientConfig { id_delimiter: value.id_delimiter, tls_config: value.tls_config.map(Into::into), header_provider, + mem_wal_enabled: value.mem_wal_enabled, } } } diff --git a/python/src/table.rs b/python/src/table.rs index e988cadb4..cf9a0a7a6 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -710,6 +710,9 @@ impl Table { if let Some(use_index) = parameters.use_index { builder.use_index(use_index); } + if let Some(mem_wal) = parameters.mem_wal { + builder.mem_wal(mem_wal); + } future_into_py(self_.py(), async move { let res = builder.execute(Box::new(batches)).await.infer_error()?; @@ -870,6 +873,7 @@ pub struct MergeInsertParams { when_not_matched_by_source_condition: Option, timeout: Option, use_index: Option, + mem_wal: Option, } #[pyclass] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index e745a921b..657938492 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -784,13 +784,19 @@ impl ConnectBuilder { message: "An api_key is required when connecting to LanceDb Cloud".to_string(), })?; + // Propagate mem_wal_enabled from options to client_config + let mut client_config = self.request.client_config; + if options.mem_wal_enabled.is_some() { + client_config.mem_wal_enabled = options.mem_wal_enabled; + } + let storage_options = StorageOptions(options.storage_options.clone()); let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( &self.request.uri, &api_key, ®ion, options.host_override, - self.request.client_config, + client_config, storage_options.into(), )?); Ok(Connection { diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index ac318c014..46bdc8619 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -14,6 +14,7 @@ use crate::remote::db::RemoteOptions; use crate::remote::retry::{ResolvedRetryConfig, RetryCounter}; const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); +const MEM_WAL_ENABLED_HEADER: HeaderName = HeaderName::from_static("x-lancedb-mem-wal-enabled"); /// Configuration for TLS/mTLS settings. #[derive(Clone, Debug, Default)] @@ -52,6 +53,10 @@ pub struct ClientConfig { pub tls_config: Option, /// Provider for custom headers to be added to each request pub header_provider: Option>, + /// Enable MemWAL write path for streaming writes. + /// When true, write operations will use the MemWAL architecture + /// for high-performance streaming writes. + pub mem_wal_enabled: Option, } impl std::fmt::Debug for ClientConfig { @@ -67,6 +72,7 @@ impl std::fmt::Debug for ClientConfig { "header_provider", &self.header_provider.as_ref().map(|_| "Some(...)"), ) + .field("mem_wal_enabled", &self.mem_wal_enabled) .finish() } } @@ -81,6 +87,7 @@ impl Default for ClientConfig { id_delimiter: None, tls_config: None, header_provider: None, + mem_wal_enabled: None, } } } @@ -477,6 +484,11 @@ impl RestfulLanceDbClient { ); } + // Add MemWAL header if enabled + if let Some(true) = config.mem_wal_enabled { + headers.insert(MEM_WAL_ENABLED_HEADER, HeaderValue::from_static("true")); + } + Ok(headers) } diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index b80c1cea1..79fefa61e 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -78,6 +78,7 @@ pub const OPT_REMOTE_PREFIX: &str = "remote_database_"; pub const OPT_REMOTE_API_KEY: &str = "remote_database_api_key"; pub const OPT_REMOTE_REGION: &str = "remote_database_region"; pub const OPT_REMOTE_HOST_OVERRIDE: &str = "remote_database_host_override"; +pub const OPT_REMOTE_MEM_WAL_ENABLED: &str = "remote_database_mem_wal_enabled"; // TODO: add support for configuring client config via key/value options #[derive(Clone, Debug, Default)] @@ -98,6 +99,12 @@ pub struct RemoteDatabaseOptions { /// These options are only used for LanceDB Enterprise and only a subset of options /// are supported. pub storage_options: HashMap, + /// Enable MemWAL write path for high-performance streaming writes. + /// + /// When enabled, write operations (insert, merge_insert, etc.) will use + /// the MemWAL architecture which buffers writes in memory and Write-Ahead Log + /// before asynchronously merging to the base table. + pub mem_wal_enabled: Option, } impl RemoteDatabaseOptions { @@ -109,6 +116,9 @@ impl RemoteDatabaseOptions { let api_key = map.get(OPT_REMOTE_API_KEY).cloned(); let region = map.get(OPT_REMOTE_REGION).cloned(); let host_override = map.get(OPT_REMOTE_HOST_OVERRIDE).cloned(); + let mem_wal_enabled = map + .get(OPT_REMOTE_MEM_WAL_ENABLED) + .map(|v| v.to_lowercase() == "true"); let storage_options = map .iter() .filter(|(key, _)| !key.starts_with(OPT_REMOTE_PREFIX)) @@ -119,6 +129,7 @@ impl RemoteDatabaseOptions { region, host_override, storage_options, + mem_wal_enabled, }) } } @@ -137,6 +148,12 @@ impl DatabaseOptions for RemoteDatabaseOptions { if let Some(host_override) = &self.host_override { map.insert(OPT_REMOTE_HOST_OVERRIDE.to_string(), host_override.clone()); } + if let Some(mem_wal_enabled) = &self.mem_wal_enabled { + map.insert( + OPT_REMOTE_MEM_WAL_ENABLED.to_string(), + mem_wal_enabled.to_string(), + ); + } } } @@ -181,6 +198,20 @@ impl RemoteDatabaseOptionsBuilder { self.options.host_override = Some(host_override); self } + + /// Enable MemWAL write path for high-performance streaming writes. + /// + /// When enabled, write operations will use the MemWAL architecture + /// which buffers writes in memory and Write-Ahead Log before + /// asynchronously merging to the base table. + /// + /// # Arguments + /// + /// * `enabled` - Whether to enable MemWAL writes + pub fn mem_wal_enabled(mut self, enabled: bool) -> Self { + self.options.mem_wal_enabled = Some(enabled); + self + } } #[derive(Debug)] diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index a633e00ab..0043c4424 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -62,6 +62,7 @@ use std::time::Duration; use tokio::sync::RwLock; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); +const MEM_WAL_ENABLED_HEADER: HeaderName = HeaderName::from_static("x-lancedb-mem-wal-enabled"); const METRIC_TYPE_KEY: &str = "metric_type"; const INDEX_TYPE_KEY: &str = "index_type"; const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(30); @@ -1359,6 +1360,7 @@ impl BaseTable for RemoteTable { self.check_mutable().await?; let timeout = params.timeout; + let mem_wal = params.mem_wal; let query = MergeInsertRequest::try_from(params)?; let mut request = self @@ -1374,6 +1376,10 @@ impl BaseTable for RemoteTable { } } + if mem_wal { + request = request.header(MEM_WAL_ENABLED_HEADER, "true"); + } + let (request_id, response) = self.send_streaming(request, new_data, true).await?; let response = self.check_table_response(&request_id, response).await?; diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index d8805acb8..31f9567ec 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -55,6 +55,7 @@ pub struct MergeInsertBuilder { pub(crate) when_not_matched_by_source_delete_filt: Option, pub(crate) timeout: Option, pub(crate) use_index: bool, + pub(crate) mem_wal: bool, } impl MergeInsertBuilder { @@ -69,6 +70,7 @@ impl MergeInsertBuilder { when_not_matched_by_source_delete_filt: None, timeout: None, use_index: true, + mem_wal: false, } } @@ -148,13 +150,65 @@ impl MergeInsertBuilder { self } + /// Enables MemWAL (Memory Write-Ahead Log) mode for this merge insert operation. + /// + /// When enabled, the merge insert will route data through a memory node service + /// that buffers writes before flushing to storage. This is only supported for + /// remote (LanceDB Cloud) tables. + /// + /// If not set, defaults to `false`. + pub fn mem_wal(&mut self, enabled: bool) -> &mut Self { + self.mem_wal = enabled; + self + } + /// Executes the merge insert operation /// /// Returns version and statistics about the merge operation including the number of rows /// inserted, updated, and deleted. pub async fn execute(self, new_data: Box) -> Result { + // Validate MemWAL constraints before execution + if self.mem_wal { + self.validate_mem_wal_pattern()?; + } self.table.clone().merge_insert(self, new_data).await } + + /// Validate that the merge insert pattern is supported by MemWAL. + /// + /// MemWAL only supports the upsert pattern: + /// - when_matched_update_all (without filter) + /// - when_not_matched_insert_all + /// - NO when_not_matched_by_source_delete + fn validate_mem_wal_pattern(&self) -> Result<()> { + // Must have when_matched_update_all without filter + if !self.when_matched_update_all { + return Err(Error::InvalidInput { + message: "MemWAL requires when_matched_update_all() to be set".to_string(), + }); + } + if self.when_matched_update_all_filt.is_some() { + return Err(Error::InvalidInput { + message: "MemWAL does not support conditional when_matched_update_all (no filter allowed)".to_string(), + }); + } + + // Must have when_not_matched_insert_all + if !self.when_not_matched_insert_all { + return Err(Error::InvalidInput { + message: "MemWAL requires when_not_matched_insert_all() to be set".to_string(), + }); + } + + // Must NOT have when_not_matched_by_source_delete + if self.when_not_matched_by_source_delete { + return Err(Error::InvalidInput { + message: "MemWAL does not support when_not_matched_by_source_delete()".to_string(), + }); + } + + Ok(()) + } } /// Internal implementation of the merge insert logic @@ -165,6 +219,14 @@ pub(crate) async fn execute_merge_insert( params: MergeInsertBuilder, new_data: Box, ) -> Result { + if params.mem_wal { + return Err(Error::NotSupported { + message: "MemWAL is not supported for native (local) tables. \ + MemWAL is only available for remote (LanceDB Cloud) tables." + .to_string(), + }); + } + let dataset = table.dataset.get().await?; let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( @@ -324,4 +386,139 @@ mod tests { merge_insert_builder.execute(new_batches).await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 25); } + + #[tokio::test] + async fn test_mem_wal_validation_valid_pattern() { + let conn = connect("memory://").execute().await.unwrap(); + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("mem_wal_test", batches) + .execute() + .await + .unwrap(); + + // Valid MemWAL pattern: when_matched_update_all + when_not_matched_insert_all + let new_batches = merge_insert_test_batches(5, 1); + let mut builder = table.merge_insert(&["i"]); + builder.when_matched_update_all(None); + builder.when_not_matched_insert_all(); + builder.mem_wal(true); + + // Should fail because native tables don't support MemWAL, but validation passes + let result = builder.execute(new_batches).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("MemWAL is not supported for native"), + "Expected native table error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_mem_wal_validation_missing_when_matched() { + let conn = connect("memory://").execute().await.unwrap(); + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("mem_wal_test2", batches) + .execute() + .await + .unwrap(); + + // Missing when_matched_update_all + let new_batches = merge_insert_test_batches(5, 1); + let mut builder = table.merge_insert(&["i"]); + builder.when_not_matched_insert_all(); + builder.mem_wal(true); + + let result = builder.execute(new_batches).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("requires when_matched_update_all"), + "Expected validation error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_mem_wal_validation_missing_when_not_matched() { + let conn = connect("memory://").execute().await.unwrap(); + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("mem_wal_test3", batches) + .execute() + .await + .unwrap(); + + // Missing when_not_matched_insert_all + let new_batches = merge_insert_test_batches(5, 1); + let mut builder = table.merge_insert(&["i"]); + builder.when_matched_update_all(None); + builder.mem_wal(true); + + let result = builder.execute(new_batches).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("requires when_not_matched_insert_all"), + "Expected validation error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_mem_wal_validation_with_filter() { + let conn = connect("memory://").execute().await.unwrap(); + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("mem_wal_test4", batches) + .execute() + .await + .unwrap(); + + // With conditional filter - not allowed + let new_batches = merge_insert_test_batches(5, 1); + let mut builder = table.merge_insert(&["i"]); + builder.when_matched_update_all(Some("target.age > 0".to_string())); + builder.when_not_matched_insert_all(); + builder.mem_wal(true); + + let result = builder.execute(new_batches).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("does not support conditional"), + "Expected filter validation error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_mem_wal_validation_with_delete() { + let conn = connect("memory://").execute().await.unwrap(); + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("mem_wal_test5", batches) + .execute() + .await + .unwrap(); + + // With when_not_matched_by_source_delete - not allowed + let new_batches = merge_insert_test_batches(5, 1); + let mut builder = table.merge_insert(&["i"]); + builder.when_matched_update_all(None); + builder.when_not_matched_insert_all(); + builder.when_not_matched_by_source_delete(None); + builder.mem_wal(true); + + let result = builder.execute(new_batches).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("does not support when_not_matched_by_source_delete"), + "Expected delete validation error, got: {}", + err + ); + } }