mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-04 12:50:40 +00:00
feat(rust): support datafusion expressions for merge insert predicates (#3444)
### Description This PR exposes native DataFusion expression support in the Rust SDK's `MergeInsertBuilder` via two new builder methods: `when_matched_update_all_expr` and `when_not_matched_by_source_delete_expr`. For remote LanceDB tables (where operations are serialized over HTTP/JSON to the SaaS backend), native DataFusion expression trees cannot be executed directly. The SDK handles this gracefully by returning a `NotSupported` error. ### Key Changes - **`MergeFilter` Enum**: Introduced a helper enum to store either a SQL string or a native `datafusion_expr::Expr`. - **`MergeInsertBuilder`**: Updated `when_matched_update_all_filt` and `when_not_matched_by_source_delete_filt` fields to store the new enum, and added `when_matched_update_all_expr` and `when_not_matched_by_source_delete_expr` builder methods. - **Execution & Remote Dispatch**: Dispatched the filter variants during local execution, and rejected expression filters with a clean `NotSupported` error in remote table request conversion. - **Testing**: Added a `test_merge_insert_expr` unit test covering conditional updates and deletes with programmatically built DataFusion expressions. ### Verification - Added integration test `test_merge_insert_expr` which successfully compiles and passes. - Formatted and linted the code. Closes #3416
This commit is contained in:
@@ -23,6 +23,7 @@ use crate::table::DropColumnsResult;
|
||||
use crate::table::MergeResult;
|
||||
use crate::table::Tags;
|
||||
use crate::table::UpdateResult;
|
||||
use crate::table::merge::MergeFilter;
|
||||
use crate::table::query::create_multi_vector_plan;
|
||||
use crate::table::{AlterColumnsResult, FieldMetadataUpdate, UpdateFieldMetadataResult};
|
||||
use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics};
|
||||
@@ -2266,13 +2267,34 @@ impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
|
||||
}
|
||||
let on = value.on[0].clone();
|
||||
|
||||
let when_matched_update_all_filt = match value.when_matched_update_all_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let when_not_matched_by_source_delete_filt =
|
||||
match value.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
on,
|
||||
when_matched_update_all: value.when_matched_update_all,
|
||||
when_matched_update_all_filt: value.when_matched_update_all_filt,
|
||||
when_matched_update_all_filt,
|
||||
when_not_matched_insert_all: value.when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
|
||||
when_not_matched_by_source_delete_filt,
|
||||
// Only serialize use_index when it's false for backwards compatibility
|
||||
use_index: value.use_index,
|
||||
})
|
||||
|
||||
@@ -53,6 +53,12 @@ pub struct MergeResult {
|
||||
pub num_rows: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MergeFilter {
|
||||
Sql(String),
|
||||
Expr(datafusion_expr::Expr),
|
||||
}
|
||||
|
||||
/// A builder used to create and run a merge insert operation
|
||||
///
|
||||
/// See [`super::Table::merge_insert`] for more context
|
||||
@@ -61,10 +67,10 @@ pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn BaseTable>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
pub(crate) when_matched_update_all_filt: Option<MergeFilter>,
|
||||
pub(crate) when_not_matched_insert_all: bool,
|
||||
pub(crate) when_not_matched_by_source_delete: bool,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<MergeFilter>,
|
||||
pub(crate) timeout: Option<Duration>,
|
||||
pub(crate) use_index: bool,
|
||||
pub(crate) use_lsm_write: Option<bool>,
|
||||
@@ -110,7 +116,14 @@ impl MergeInsertBuilder {
|
||||
/// For example, "target.last_update < source.last_update"
|
||||
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = condition;
|
||||
self.when_matched_update_all_filt = condition.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_matched_update_all`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_matched_update_all_expr(&mut self, condition: datafusion_expr::Expr) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = Some(MergeFilter::Expr(condition));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -132,7 +145,17 @@ impl MergeInsertBuilder {
|
||||
/// limit what rows are deleted.
|
||||
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = filter;
|
||||
self.when_not_matched_by_source_delete_filt = filter.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_not_matched_by_source_delete`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_not_matched_by_source_delete_expr(
|
||||
&mut self,
|
||||
filter: datafusion_expr::Expr,
|
||||
) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = Some(MergeFilter::Expr(filter));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -234,7 +257,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
) {
|
||||
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||
(true, Some(MergeFilter::Sql(filt))) => {
|
||||
builder.when_matched(WhenMatched::update_if(&dataset, &filt)?)
|
||||
}
|
||||
(true, Some(MergeFilter::Expr(expr))) => {
|
||||
builder.when_matched(WhenMatched::update_if_expr(expr))
|
||||
}
|
||||
};
|
||||
if params.when_not_matched_insert_all {
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||
@@ -242,10 +270,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
|
||||
}
|
||||
if params.when_not_matched_by_source_delete {
|
||||
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
} else {
|
||||
WhenNotMatchedBySource::Delete
|
||||
let behavior = match params.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(filter)) => {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
}
|
||||
Some(MergeFilter::Expr(expr)) => WhenNotMatchedBySource::DeleteIf(expr),
|
||||
None => WhenNotMatchedBySource::Delete,
|
||||
};
|
||||
builder.when_not_matched_by_source(behavior);
|
||||
} else {
|
||||
@@ -386,6 +416,45 @@ mod tests {
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 25);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_expr() {
|
||||
use datafusion_expr::{col, lit};
|
||||
|
||||
let conn = connect("memory://").execute().await.unwrap();
|
||||
|
||||
// Create a dataset with i=0..10
|
||||
let batches = merge_insert_test_batches(0, 0);
|
||||
let table = conn
|
||||
.create_table("my_table_expr", batches)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Conditional update that only replaces the age=0 data
|
||||
let new_batches = merge_insert_test_batches(5, 3);
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// use expression: target.age = 0
|
||||
let expr = col("target.age").eq(lit(0));
|
||||
merge_insert_builder.when_matched_update_all_expr(expr);
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
|
||||
// Delete with expression
|
||||
// Create new batches with i=10..20 (so target rows i=0..9 are not matched by source)
|
||||
let new_batches = merge_insert_test_batches(10, 0); // won't insert or update since we don't enable matched/unmatched actions
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// delete if target.age = 3
|
||||
let delete_expr = col("target.age").eq(lit(3));
|
||||
merge_insert_builder.when_not_matched_by_source_delete_expr(delete_expr);
|
||||
let result = merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(result.num_deleted_rows, 5);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user