From a9727eb31864940264d069d1dca84ba293a49e43 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 9 Feb 2024 10:26:14 -0800 Subject: [PATCH] feat: add support for filter during merge insert when matched (#948) Closes #940 --- Cargo.toml | 8 ++-- node/src/index.ts | 23 +++++++++++- node/src/remote/index.ts | 5 ++- node/src/test/test.ts | 20 +++++++--- python/lancedb/merge.py | 6 ++- python/lancedb/remote/table.py | 4 ++ python/lancedb/table.py | 2 +- python/pyproject.toml | 2 +- python/tests/test_table.py | 11 +++++- rust/ffi/node/src/table.rs | 38 +++++++++++-------- rust/vectordb/src/table.rs | 64 +++++++++++++++++++++++--------- rust/vectordb/src/table/merge.rs | 18 ++++++++- 12 files changed, 150 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 972aa328..a13c53c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"] categories = ["database-implementations"] [workspace.dependencies] -lance = { "version" = "=0.9.12", "features" = ["dynamodb"] } -lance-index = { "version" = "=0.9.12" } -lance-linalg = { "version" = "=0.9.12" } -lance-testing = { "version" = "=0.9.12" } +lance = { "version" = "=0.9.14", "features" = ["dynamodb"] } +lance-index = { "version" = "=0.9.14" } +lance-linalg = { "version" = "=0.9.14" } +lance-testing = { "version" = "=0.9.14" } # Note that this one does not include pyarrow arrow = { version = "50.0", optional = false } arrow-array = "50.0" diff --git a/node/src/index.ts b/node/src/index.ts index 2607cc2d..012e7305 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -525,8 +525,19 @@ export interface MergeInsertArgs { * If there are multiple matches then the behavior is undefined. * Currently this causes multiple copies of the row to be created * but that behavior is subject to change. + * + * Optionally, a filter can be specified. This should be an SQL + * filter where fields with the prefix "target." refer to fields + * in the target table (old data) and fields with the prefix + * "source." refer to fields in the source table (new data). For + * example, the filter "target.lastUpdated < source.lastUpdated" will + * only update matched rows when the incoming `lastUpdated` value is + * newer. + * + * Rows that do not match the filter will not be updated. Rows that + * do not match the filter do become "not matched" rows. */ - whenMatchedUpdateAll?: boolean + whenMatchedUpdateAll?: string | boolean /** * If true then rows that exist only in the source table (new data) * will be inserted into the target table. @@ -885,7 +896,14 @@ export class LocalTable implements Table { } async mergeInsert (on: string, data: Array> | ArrowTable, args: MergeInsertArgs): Promise { - const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false + let whenMatchedUpdateAll = false + let whenMatchedUpdateAllFilt = null + if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) { + whenMatchedUpdateAll = true + if (args.whenMatchedUpdateAll !== true) { + whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll + } + } const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false let whenNotMatchedBySourceDelete = false let whenNotMatchedBySourceDeleteFilt = null @@ -909,6 +927,7 @@ export class LocalTable implements Table { this._tbl, on, whenMatchedUpdateAll, + whenMatchedUpdateAllFilt, whenNotMatchedInsertAll, whenNotMatchedBySourceDelete, whenNotMatchedBySourceDeleteFilt, diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 28f43581..9255f31f 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -286,8 +286,11 @@ export class RemoteTable implements Table { const queryParams: any = { on } - if (args.whenMatchedUpdateAll ?? false) { + if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) { queryParams.when_matched_update_all = 'true' + if (typeof args.whenMatchedUpdateAll === 'string') { + queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll + } } else { queryParams.when_matched_update_all = 'false' } diff --git a/node/src/test/test.ts b/node/src/test/test.ts index cda140fd..20b05087 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -540,26 +540,36 @@ describe('LanceDB client', function () { const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }] const table = await con.createTable('my_table', data) + // insert if not exists let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }] await table.mergeInsert('id', newData, { whenNotMatchedInsertAll: true }) assert.equal(await table.countRows(), 3) - assert.equal((await table.filter('age = 2').execute()).length, 1) + assert.equal(await table.countRows('age = 2'), 1) - newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }] + // conditional update + newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }] + await table.mergeInsert('id', newData, { + whenMatchedUpdateAll: 'target.age = 1' + }) + assert.equal(await table.countRows(), 3) + assert.equal(await table.countRows('age = 1'), 1) + assert.equal(await table.countRows('age = 3'), 1) + + newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }] await table.mergeInsert('id', newData, { whenNotMatchedInsertAll: true, whenMatchedUpdateAll: true }) assert.equal(await table.countRows(), 4) - assert.equal((await table.filter('age = 3').execute()).length, 2) + assert.equal((await table.filter('age = 4').execute()).length, 2) - newData = [{ id: 5, age: 4 }] + newData = [{ id: 5, age: 5 }] await table.mergeInsert('id', newData, { whenNotMatchedInsertAll: true, whenMatchedUpdateAll: true, - whenNotMatchedBySourceDelete: 'age < 3' + whenNotMatchedBySourceDelete: 'age < 4' }) assert.equal(await table.countRows(), 3) diff --git a/python/lancedb/merge.py b/python/lancedb/merge.py index 9180635a..69671c5e 100644 --- a/python/lancedb/merge.py +++ b/python/lancedb/merge.py @@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object): self._table = table self._on = on self._when_matched_update_all = False + self._when_matched_update_all_condition = None self._when_not_matched_insert_all = False self._when_not_matched_by_source_delete = False self._when_not_matched_by_source_condition = None - def when_matched_update_all(self) -> LanceMergeInsertBuilder: + def when_matched_update_all( + self, *, where: Optional[str] = None + ) -> LanceMergeInsertBuilder: """ Rows that exist in both the source table (new data) and the target table (old data) will be updated, replacing @@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object): but that behavior is subject to change. """ self._when_matched_update_all = True + self._when_matched_update_all_condition = where return self def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder: diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 341837b3..925690c3 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -298,6 +298,10 @@ class RemoteTable(Table): ) params["on"] = merge._on[0] params["when_matched_update_all"] = str(merge._when_matched_update_all).lower() + if merge._when_matched_update_all_condition is not None: + params[ + "when_matched_update_all_filt" + ] = merge._when_matched_update_all_condition params["when_not_matched_insert_all"] = str( merge._when_not_matched_insert_all ).lower() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index edcd9132..fb790268 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -1467,7 +1467,7 @@ class LanceTable(Table): ds = self.to_lance() builder = ds.merge_insert(merge._on) if merge._when_matched_update_all: - builder.when_matched_update_all() + builder.when_matched_update_all(merge._when_matched_update_all_condition) if merge._when_not_matched_insert_all: builder.when_not_matched_insert_all() if merge._when_not_matched_by_source_delete: diff --git a/python/pyproject.toml b/python/pyproject.toml index 68f14491..09209506 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lancedb" version = "0.5.3" dependencies = [ "deprecation", - "pylance==0.9.12", + "pylance==0.9.14", "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", diff --git a/python/tests/test_table.py b/python/tests/test_table.py index b7b8e00a..02e8b1f0 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -513,8 +513,15 @@ def test_merge_insert(db): ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) - # These `sort_by` calls can be removed once lance#1892 - # is merged (it fixes the ordering) + assert table.to_arrow().sort_by("a") == expected + + table.restore(version) + + # conditional update + table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute( + new_data + ) + expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index ac5690f5..bb6bbfea 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -191,28 +191,34 @@ impl JsTable { let key = cx.argument::(0)?.value(&mut cx); let mut builder = table.merge_insert(&[&key]); if cx.argument::(1)?.value(&mut cx) { - builder.when_matched_update_all(); - } - if cx.argument::(2)?.value(&mut cx) { - builder.when_not_matched_insert_all(); + let filter = cx.argument_opt(2).unwrap(); + if filter.is_a::(&mut cx) { + builder.when_matched_update_all(None); + } else { + let filter = filter + .downcast_or_throw::(&mut cx)? + .deref() + .value(&mut cx); + builder.when_matched_update_all(Some(filter)); + } } if cx.argument::(3)?.value(&mut cx) { - if let Some(filter) = cx.argument_opt(4) { - if filter.is_a::(&mut cx) { - builder.when_not_matched_by_source_delete(None); - } else { - let filter = filter - .downcast_or_throw::(&mut cx)? - .deref() - .value(&mut cx); - builder.when_not_matched_by_source_delete(Some(filter)); - } - } else { + builder.when_not_matched_insert_all(); + } + if cx.argument::(4)?.value(&mut cx) { + let filter = cx.argument_opt(5).unwrap(); + if filter.is_a::(&mut cx) { builder.when_not_matched_by_source_delete(None); + } else { + let filter = filter + .downcast_or_throw::(&mut cx)? + .deref() + .value(&mut cx); + builder.when_not_matched_by_source_delete(Some(filter)); } } - let buffer = cx.argument::(5)?; + let buffer = cx.argument::(6)?; let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?; diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 5f889060..aa9cd129 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -27,7 +27,7 @@ use lance::dataset::optimize::{ compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, }; pub use lance::dataset::ReadParams; -use lance::dataset::{Dataset, UpdateBuilder, WriteParams}; +use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; @@ -238,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync { /// schema.clone()); /// // Perform an upsert operation /// let mut merge_insert = tbl.merge_insert(&["id"]); - /// merge_insert.when_matched_update_all() + /// merge_insert.when_matched_update_all(None) /// .when_not_matched_insert_all(); /// merge_insert.execute(Box::new(new_data)).await.unwrap(); /// # }); @@ -677,11 +677,14 @@ impl MergeInsert for NativeTable { ) -> Result<()> { let dataset = Arc::new(self.clone_inner_dataset()); let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; - if params.when_matched_update_all { - builder.when_matched(lance::dataset::WhenMatched::UpdateAll); - } else { - builder.when_matched(lance::dataset::WhenMatched::DoNothing); - } + match ( + params.when_matched_update_all, + params.when_matched_update_all_filt, + ) { + (false, _) => builder.when_matched(WhenMatched::DoNothing), + (true, None) => builder.when_matched(WhenMatched::UpdateAll), + (true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?), + }; if params.when_not_matched_insert_all { builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll); } else { @@ -824,6 +827,7 @@ impl Table for NativeTable { #[cfg(test)] mod tests { + use std::iter; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -947,14 +951,14 @@ mod tests { let uri = tmp_dir.path().to_str().unwrap(); // Create a dataset with i=0..10 - let batches = make_test_batches_with_offset(0); + let batches = merge_insert_test_batches(0, 0); let table = NativeTable::create(&uri, "test", batches, None, None) .await .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); // Create new data with i=5..15 - let new_batches = Box::new(make_test_batches_with_offset(5)); + let new_batches = Box::new(merge_insert_test_batches(5, 1)); // Perform a "insert if not exists" let mut merge_insert_builder = table.merge_insert(&["i"]); @@ -964,13 +968,27 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 15); // Create new data with i=15..25 (no id matches) - let new_batches = Box::new(make_test_batches_with_offset(15)); + let new_batches = Box::new(merge_insert_test_batches(15, 2)); // Perform a "bulk update" (should not affect anything) let mut merge_insert_builder = table.merge_insert(&["i"]); - merge_insert_builder.when_matched_update_all(); + merge_insert_builder.when_matched_update_all(None); merge_insert_builder.execute(new_batches).await.unwrap(); // No new rows should have been inserted assert_eq!(table.count_rows(None).await.unwrap(), 15); + assert_eq!( + table.count_rows(Some("age = 2".to_string())).await.unwrap(), + 0 + ); + + // Conditional update that only replaces the age=0 data + let new_batches = Box::new(merge_insert_test_batches(5, 3)); + let mut merge_insert_builder = table.merge_insert(&["i"]); + merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string())); + merge_insert_builder.execute(new_batches).await.unwrap(); + assert_eq!( + table.count_rows(Some("age = 3".to_string())).await.unwrap(), + 5 + ); } #[tokio::test] @@ -1319,23 +1337,35 @@ mod tests { assert!(wrapper.called()); } - fn make_test_batches_with_offset( + fn merge_insert_test_batches( offset: i32, + age: i32, ) -> impl RecordBatchReader + Send + Sync + 'static { - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("age", DataType::Int32, false), + ])); RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), - vec![Arc::new(Int32Array::from_iter_values( - offset..(offset + 10), - ))], + vec![ + Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), + Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))), + ], )], schema, ) } fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { - make_test_batches_with_offset(0) + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], + )], + schema, + ) } #[tokio::test] diff --git a/rust/vectordb/src/table/merge.rs b/rust/vectordb/src/table/merge.rs index 26caa235..38a8fa13 100644 --- a/rust/vectordb/src/table/merge.rs +++ b/rust/vectordb/src/table/merge.rs @@ -35,6 +35,7 @@ pub struct MergeInsertBuilder { table: Arc, pub(super) on: Vec, pub(super) when_matched_update_all: bool, + pub(super) when_matched_update_all_filt: Option, pub(super) when_not_matched_insert_all: bool, pub(super) when_not_matched_by_source_delete: bool, pub(super) when_not_matched_by_source_delete_filt: Option, @@ -46,6 +47,7 @@ impl MergeInsertBuilder { table, on, when_matched_update_all: false, + when_matched_update_all_filt: None, when_not_matched_insert_all: false, when_not_matched_by_source_delete: false, when_not_matched_by_source_delete_filt: None, @@ -59,8 +61,22 @@ impl MergeInsertBuilder { /// If there are multiple matches then the behavior is undefined. /// Currently this causes multiple copies of the row to be created /// but that behavior is subject to change. - pub fn when_matched_update_all(&mut self) -> &mut Self { + /// + /// An optional condition may be specified. If it is, then only + /// matched rows that satisfy the condtion will be updated. Any + /// rows that do not satisfy the condition will be left as they + /// are. Failing to satisfy the condition does not cause a + /// "matched row" to become a "not matched" row. + /// + /// The condition should be an SQL string. Use the prefix + /// target. to refer to rows in the target table (old data) + /// and the prefix source. to refer to rows in the source + /// table (new data). + /// + /// For example, "target.last_update < source.last_update" + pub fn when_matched_update_all(&mut self, condition: Option) -> &mut Self { self.when_matched_update_all = true; + self.when_matched_update_all_filt = condition; self }