From 415d199c151f1bde8ea4e918dfc41c9875c6efb1 Mon Sep 17 00:00:00 2001 From: Armaan Sandhu <74664101+Ar-maan05@users.noreply.github.com> Date: Thu, 4 Jun 2026 04:17:51 +0530 Subject: [PATCH] feat(rust): support datafusion expressions for merge insert predicates (#3444) ### Description This PR exposes native DataFusion expression support in the Rust SDK's `MergeInsertBuilder` via two new builder methods: `when_matched_update_all_expr` and `when_not_matched_by_source_delete_expr`. For remote LanceDB tables (where operations are serialized over HTTP/JSON to the SaaS backend), native DataFusion expression trees cannot be executed directly. The SDK handles this gracefully by returning a `NotSupported` error. ### Key Changes - **`MergeFilter` Enum**: Introduced a helper enum to store either a SQL string or a native `datafusion_expr::Expr`. - **`MergeInsertBuilder`**: Updated `when_matched_update_all_filt` and `when_not_matched_by_source_delete_filt` fields to store the new enum, and added `when_matched_update_all_expr` and `when_not_matched_by_source_delete_expr` builder methods. - **Execution & Remote Dispatch**: Dispatched the filter variants during local execution, and rejected expression filters with a clean `NotSupported` error in remote table request conversion. - **Testing**: Added a `test_merge_insert_expr` unit test covering conditional updates and deletes with programmatically built DataFusion expressions. ### Verification - Added integration test `test_merge_insert_expr` which successfully compiles and passes. - Formatted and linted the code. Closes #3416 --- rust/lancedb/src/remote/table.rs | 26 +++++++++- rust/lancedb/src/table/merge.rs | 87 ++++++++++++++++++++++++++++---- 2 files changed, 102 insertions(+), 11 deletions(-) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 6dab22590..6d5ea8785 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -23,6 +23,7 @@ use crate::table::DropColumnsResult; use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; +use crate::table::merge::MergeFilter; use crate::table::query::create_multi_vector_plan; use crate::table::{AlterColumnsResult, FieldMetadataUpdate, UpdateFieldMetadataResult}; use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics}; @@ -2266,13 +2267,34 @@ impl TryFrom for MergeInsertRequest { } let on = value.on[0].clone(); + let when_matched_update_all_filt = match value.when_matched_update_all_filt { + Some(MergeFilter::Sql(sql)) => Some(sql), + Some(MergeFilter::Expr(_)) => { + return Err(Error::NotSupported { + message: "DataFusion expressions are not supported on remote tables".into(), + }); + } + None => None, + }; + + let when_not_matched_by_source_delete_filt = + match value.when_not_matched_by_source_delete_filt { + Some(MergeFilter::Sql(sql)) => Some(sql), + Some(MergeFilter::Expr(_)) => { + return Err(Error::NotSupported { + message: "DataFusion expressions are not supported on remote tables".into(), + }); + } + None => None, + }; + Ok(Self { on, when_matched_update_all: value.when_matched_update_all, - when_matched_update_all_filt: value.when_matched_update_all_filt, + when_matched_update_all_filt, when_not_matched_insert_all: value.when_not_matched_insert_all, when_not_matched_by_source_delete: value.when_not_matched_by_source_delete, - when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt, + when_not_matched_by_source_delete_filt, // Only serialize use_index when it's false for backwards compatibility use_index: value.use_index, }) diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index b3bda36af..a122ba2f2 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -53,6 +53,12 @@ pub struct MergeResult { pub num_rows: u64, } +#[derive(Debug, Clone)] +pub enum MergeFilter { + Sql(String), + Expr(datafusion_expr::Expr), +} + /// A builder used to create and run a merge insert operation /// /// See [`super::Table::merge_insert`] for more context @@ -61,10 +67,10 @@ pub struct MergeInsertBuilder { table: Arc, pub(crate) on: Vec, pub(crate) when_matched_update_all: bool, - pub(crate) when_matched_update_all_filt: Option, + pub(crate) when_matched_update_all_filt: Option, pub(crate) when_not_matched_insert_all: bool, pub(crate) when_not_matched_by_source_delete: bool, - pub(crate) when_not_matched_by_source_delete_filt: Option, + pub(crate) when_not_matched_by_source_delete_filt: Option, pub(crate) timeout: Option, pub(crate) use_index: bool, pub(crate) use_lsm_write: Option, @@ -110,7 +116,14 @@ impl MergeInsertBuilder { /// For example, "target.last_update < source.last_update" pub fn when_matched_update_all(&mut self, condition: Option) -> &mut Self { self.when_matched_update_all = true; - self.when_matched_update_all_filt = condition; + self.when_matched_update_all_filt = condition.map(MergeFilter::Sql); + self + } + + /// Similar to [`Self::when_matched_update_all`] but accepts a DataFusion logical expression directly. + pub fn when_matched_update_all_expr(&mut self, condition: datafusion_expr::Expr) -> &mut Self { + self.when_matched_update_all = true; + self.when_matched_update_all_filt = Some(MergeFilter::Expr(condition)); self } @@ -132,7 +145,17 @@ impl MergeInsertBuilder { /// limit what rows are deleted. pub fn when_not_matched_by_source_delete(&mut self, filter: Option) -> &mut Self { self.when_not_matched_by_source_delete = true; - self.when_not_matched_by_source_delete_filt = filter; + self.when_not_matched_by_source_delete_filt = filter.map(MergeFilter::Sql); + self + } + + /// Similar to [`Self::when_not_matched_by_source_delete`] but accepts a DataFusion logical expression directly. + pub fn when_not_matched_by_source_delete_expr( + &mut self, + filter: datafusion_expr::Expr, + ) -> &mut Self { + self.when_not_matched_by_source_delete = true; + self.when_not_matched_by_source_delete_filt = Some(MergeFilter::Expr(filter)); self } @@ -234,7 +257,12 @@ pub(crate) async fn execute_merge_insert( ) { (false, _) => builder.when_matched(WhenMatched::DoNothing), (true, None) => builder.when_matched(WhenMatched::UpdateAll), - (true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?), + (true, Some(MergeFilter::Sql(filt))) => { + builder.when_matched(WhenMatched::update_if(&dataset, &filt)?) + } + (true, Some(MergeFilter::Expr(expr))) => { + builder.when_matched(WhenMatched::update_if_expr(expr)) + } }; if params.when_not_matched_insert_all { builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll); @@ -242,10 +270,12 @@ pub(crate) async fn execute_merge_insert( builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing); } if params.when_not_matched_by_source_delete { - let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt { - WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)? - } else { - WhenNotMatchedBySource::Delete + let behavior = match params.when_not_matched_by_source_delete_filt { + Some(MergeFilter::Sql(filter)) => { + WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)? + } + Some(MergeFilter::Expr(expr)) => WhenNotMatchedBySource::DeleteIf(expr), + None => WhenNotMatchedBySource::Delete, }; builder.when_not_matched_by_source(behavior); } else { @@ -386,6 +416,45 @@ mod tests { merge_insert_builder.execute(new_batches).await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 25); } + + #[tokio::test] + async fn test_merge_insert_expr() { + use datafusion_expr::{col, lit}; + + let conn = connect("memory://").execute().await.unwrap(); + + // Create a dataset with i=0..10 + let batches = merge_insert_test_batches(0, 0); + let table = conn + .create_table("my_table_expr", batches) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 10); + + // Conditional update that only replaces the age=0 data + let new_batches = merge_insert_test_batches(5, 3); + let mut merge_insert_builder = table.merge_insert(&["i"]); + // use expression: target.age = 0 + let expr = col("target.age").eq(lit(0)); + merge_insert_builder.when_matched_update_all_expr(expr); + merge_insert_builder.execute(new_batches).await.unwrap(); + assert_eq!( + table.count_rows(Some("age = 3".to_string())).await.unwrap(), + 5 + ); + + // Delete with expression + // Create new batches with i=10..20 (so target rows i=0..9 are not matched by source) + let new_batches = merge_insert_test_batches(10, 0); // won't insert or update since we don't enable matched/unmatched actions + let mut merge_insert_builder = table.merge_insert(&["i"]); + // delete if target.age = 3 + let delete_expr = col("target.age").eq(lit(3)); + merge_insert_builder.when_not_matched_by_source_delete_expr(delete_expr); + let result = merge_insert_builder.execute(new_batches).await.unwrap(); + assert_eq!(result.num_deleted_rows, 5); + assert_eq!(table.count_rows(None).await.unwrap(), 5); + } } #[cfg(test)]