mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
feat: add support for filter during merge insert when matched (#948)
Closes #940
This commit is contained in:
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.12" }
|
||||
lance-linalg = { "version" = "=0.9.12" }
|
||||
lance-testing = { "version" = "=0.9.12" }
|
||||
lance = { "version" = "=0.9.14", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.14" }
|
||||
lance-linalg = { "version" = "=0.9.14" }
|
||||
lance-testing = { "version" = "=0.9.14" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -525,8 +525,19 @@ export interface MergeInsertArgs {
|
||||
* If there are multiple matches then the behavior is undefined.
|
||||
* Currently this causes multiple copies of the row to be created
|
||||
* but that behavior is subject to change.
|
||||
*
|
||||
* Optionally, a filter can be specified. This should be an SQL
|
||||
* filter where fields with the prefix "target." refer to fields
|
||||
* in the target table (old data) and fields with the prefix
|
||||
* "source." refer to fields in the source table (new data). For
|
||||
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||
* only update matched rows when the incoming `lastUpdated` value is
|
||||
* newer.
|
||||
*
|
||||
* Rows that do not match the filter will not be updated. Rows that
|
||||
* do not match the filter do become "not matched" rows.
|
||||
*/
|
||||
whenMatchedUpdateAll?: boolean
|
||||
whenMatchedUpdateAll?: string | boolean
|
||||
/**
|
||||
* If true then rows that exist only in the source table (new data)
|
||||
* will be inserted into the target table.
|
||||
@@ -885,7 +896,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
|
||||
let whenMatchedUpdateAll = false
|
||||
let whenMatchedUpdateAllFilt = null
|
||||
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||
whenMatchedUpdateAll = true
|
||||
if (args.whenMatchedUpdateAll !== true) {
|
||||
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||
}
|
||||
}
|
||||
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||
let whenNotMatchedBySourceDelete = false
|
||||
let whenNotMatchedBySourceDeleteFilt = null
|
||||
@@ -909,6 +927,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
this._tbl,
|
||||
on,
|
||||
whenMatchedUpdateAll,
|
||||
whenMatchedUpdateAllFilt,
|
||||
whenNotMatchedInsertAll,
|
||||
whenNotMatchedBySourceDelete,
|
||||
whenNotMatchedBySourceDeleteFilt,
|
||||
|
||||
@@ -286,8 +286,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
const queryParams: any = {
|
||||
on
|
||||
}
|
||||
if (args.whenMatchedUpdateAll ?? false) {
|
||||
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||
queryParams.when_matched_update_all = 'true'
|
||||
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||
}
|
||||
} else {
|
||||
queryParams.when_matched_update_all = 'false'
|
||||
}
|
||||
|
||||
@@ -540,26 +540,36 @@ describe('LanceDB client', function () {
|
||||
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||
const table = await con.createTable('my_table', data)
|
||||
|
||||
// insert if not exists
|
||||
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal((await table.filter('age = 2').execute()).length, 1)
|
||||
assert.equal(await table.countRows('age = 2'), 1)
|
||||
|
||||
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
|
||||
// conditional update
|
||||
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenMatchedUpdateAll: 'target.age = 1'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal(await table.countRows('age = 1'), 1)
|
||||
assert.equal(await table.countRows('age = 3'), 1)
|
||||
|
||||
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 4)
|
||||
assert.equal((await table.filter('age = 3').execute()).length, 2)
|
||||
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||
|
||||
newData = [{ id: 5, age: 4 }]
|
||||
newData = [{ id: 5, age: 5 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true,
|
||||
whenNotMatchedBySourceDelete: 'age < 3'
|
||||
whenNotMatchedBySourceDelete: 'age < 4'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
|
||||
|
||||
@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
|
||||
self._table = table
|
||||
self._on = on
|
||||
self._when_matched_update_all = False
|
||||
self._when_matched_update_all_condition = None
|
||||
self._when_not_matched_insert_all = False
|
||||
self._when_not_matched_by_source_delete = False
|
||||
self._when_not_matched_by_source_condition = None
|
||||
|
||||
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist in both the source table (new data) and
|
||||
the target table (old data) will be updated, replacing
|
||||
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
|
||||
but that behavior is subject to change.
|
||||
"""
|
||||
self._when_matched_update_all = True
|
||||
self._when_matched_update_all_condition = where
|
||||
return self
|
||||
|
||||
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||
|
||||
@@ -298,6 +298,10 @@ class RemoteTable(Table):
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params[
|
||||
"when_matched_update_all_filt"
|
||||
] = merge._when_matched_update_all_condition
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
|
||||
@@ -1459,7 +1459,7 @@ class LanceTable(Table):
|
||||
ds = self.to_lance()
|
||||
builder = ds.merge_insert(merge._on)
|
||||
if merge._when_matched_update_all:
|
||||
builder.when_matched_update_all()
|
||||
builder.when_matched_update_all(merge._when_matched_update_all_condition)
|
||||
if merge._when_not_matched_insert_all:
|
||||
builder.when_not_matched_insert_all()
|
||||
if merge._when_not_matched_by_source_delete:
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
||||
version = "0.5.3"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.12",
|
||||
"pylance==0.9.14",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
|
||||
@@ -513,8 +513,15 @@ def test_merge_insert(db):
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
# These `sort_by` calls can be removed once lance#1892
|
||||
# is merged (it fixes the ordering)
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# conditional update
|
||||
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||
new_data
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
@@ -191,28 +191,34 @@ impl JsTable {
|
||||
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let mut builder = table.merge_insert(&[&key]);
|
||||
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||
builder.when_matched_update_all();
|
||||
}
|
||||
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
|
||||
builder.when_not_matched_insert_all();
|
||||
let filter = cx.argument_opt(2).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_matched_update_all(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_matched_update_all(Some(filter));
|
||||
}
|
||||
}
|
||||
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||
if let Some(filter) = cx.argument_opt(4) {
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_not_matched_by_source_delete(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_not_matched_by_source_delete(Some(filter));
|
||||
}
|
||||
} else {
|
||||
builder.when_not_matched_insert_all();
|
||||
}
|
||||
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
|
||||
let filter = cx.argument_opt(5).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_not_matched_by_source_delete(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_not_matched_by_source_delete(Some(filter));
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = cx.argument::<JsBuffer>(5)?;
|
||||
let buffer = cx.argument::<JsBuffer>(6)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
pub use lance::dataset::ReadParams;
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
@@ -238,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// schema.clone());
|
||||
/// // Perform an upsert operation
|
||||
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||
/// merge_insert.when_matched_update_all()
|
||||
/// merge_insert.when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||
/// # });
|
||||
@@ -677,11 +677,14 @@ impl MergeInsert for NativeTable {
|
||||
) -> Result<()> {
|
||||
let dataset = Arc::new(self.clone_inner_dataset());
|
||||
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
|
||||
if params.when_matched_update_all {
|
||||
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
|
||||
} else {
|
||||
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
|
||||
}
|
||||
match (
|
||||
params.when_matched_update_all,
|
||||
params.when_matched_update_all_filt,
|
||||
) {
|
||||
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||
};
|
||||
if params.when_not_matched_insert_all {
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||
} else {
|
||||
@@ -824,6 +827,7 @@ impl Table for NativeTable {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -947,14 +951,14 @@ mod tests {
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
// Create a dataset with i=0..10
|
||||
let batches = make_test_batches_with_offset(0);
|
||||
let batches = merge_insert_test_batches(0, 0);
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Create new data with i=5..15
|
||||
let new_batches = Box::new(make_test_batches_with_offset(5));
|
||||
let new_batches = Box::new(merge_insert_test_batches(5, 1));
|
||||
|
||||
// Perform a "insert if not exists"
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
@@ -964,13 +968,27 @@ mod tests {
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
|
||||
// Create new data with i=15..25 (no id matches)
|
||||
let new_batches = Box::new(make_test_batches_with_offset(15));
|
||||
let new_batches = Box::new(merge_insert_test_batches(15, 2));
|
||||
// Perform a "bulk update" (should not affect anything)
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
merge_insert_builder.when_matched_update_all();
|
||||
merge_insert_builder.when_matched_update_all(None);
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
// No new rows should have been inserted
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
|
||||
0
|
||||
);
|
||||
|
||||
// Conditional update that only replaces the age=0 data
|
||||
let new_batches = Box::new(merge_insert_test_batches(5, 3));
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1319,23 +1337,35 @@ mod tests {
|
||||
assert!(wrapper.called());
|
||||
}
|
||||
|
||||
fn make_test_batches_with_offset(
|
||||
fn merge_insert_test_batches(
|
||||
offset: i32,
|
||||
age: i32,
|
||||
) -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("i", DataType::Int32, false),
|
||||
Field::new("age", DataType::Int32, false),
|
||||
]));
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(
|
||||
offset..(offset + 10),
|
||||
))],
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
|
||||
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
|
||||
],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
make_test_batches_with_offset(0)
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -35,6 +35,7 @@ pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn MergeInsert>,
|
||||
pub(super) on: Vec<String>,
|
||||
pub(super) when_matched_update_all: bool,
|
||||
pub(super) when_matched_update_all_filt: Option<String>,
|
||||
pub(super) when_not_matched_insert_all: bool,
|
||||
pub(super) when_not_matched_by_source_delete: bool,
|
||||
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
@@ -46,6 +47,7 @@ impl MergeInsertBuilder {
|
||||
table,
|
||||
on,
|
||||
when_matched_update_all: false,
|
||||
when_matched_update_all_filt: None,
|
||||
when_not_matched_insert_all: false,
|
||||
when_not_matched_by_source_delete: false,
|
||||
when_not_matched_by_source_delete_filt: None,
|
||||
@@ -59,8 +61,22 @@ impl MergeInsertBuilder {
|
||||
/// If there are multiple matches then the behavior is undefined.
|
||||
/// Currently this causes multiple copies of the row to be created
|
||||
/// but that behavior is subject to change.
|
||||
pub fn when_matched_update_all(&mut self) -> &mut Self {
|
||||
///
|
||||
/// An optional condition may be specified. If it is, then only
|
||||
/// matched rows that satisfy the condtion will be updated. Any
|
||||
/// rows that do not satisfy the condition will be left as they
|
||||
/// are. Failing to satisfy the condition does not cause a
|
||||
/// "matched row" to become a "not matched" row.
|
||||
///
|
||||
/// The condition should be an SQL string. Use the prefix
|
||||
/// target. to refer to rows in the target table (old data)
|
||||
/// and the prefix source. to refer to rows in the source
|
||||
/// table (new data).
|
||||
///
|
||||
/// For example, "target.last_update < source.last_update"
|
||||
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = condition;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user