feat: add support for filter during merge insert when matched (#948)

Closes #940
This commit is contained in:
Weston Pace
2024-02-09 10:26:14 -08:00
parent 069ad267bd
commit 41ccb48160
12 changed files with 150 additions and 51 deletions

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
lance = { "version" = "=0.9.14", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.14" }
lance-linalg = { "version" = "=0.9.14" }
lance-testing = { "version" = "=0.9.14" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -525,8 +525,19 @@ export interface MergeInsertArgs {
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*
* Optionally, a filter can be specified. This should be an SQL
* filter where fields with the prefix "target." refer to fields
* in the target table (old data) and fields with the prefix
* "source." refer to fields in the source table (new data). For
* example, the filter "target.lastUpdated < source.lastUpdated" will
* only update matched rows when the incoming `lastUpdated` value is
* newer.
*
* Rows that do not match the filter will not be updated. Rows that
* do not match the filter do become "not matched" rows.
*/
whenMatchedUpdateAll?: boolean
whenMatchedUpdateAll?: string | boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
@@ -885,7 +896,14 @@ export class LocalTable<T = number[]> implements Table<T> {
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
let whenMatchedUpdateAll = false
let whenMatchedUpdateAllFilt = null
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
whenMatchedUpdateAll = true
if (args.whenMatchedUpdateAll !== true) {
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
}
}
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
@@ -909,6 +927,7 @@ export class LocalTable<T = number[]> implements Table<T> {
this._tbl,
on,
whenMatchedUpdateAll,
whenMatchedUpdateAllFilt,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,

View File

@@ -286,8 +286,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
queryParams.when_matched_update_all = 'true'
if (typeof args.whenMatchedUpdateAll === 'string') {
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
}
} else {
queryParams.when_matched_update_all = 'false'
}

View File

@@ -540,26 +540,36 @@ describe('LanceDB client', function () {
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
// insert if not exists
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)
assert.equal(await table.countRows('age = 2'), 1)
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
// conditional update
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
await table.mergeInsert('id', newData, {
whenMatchedUpdateAll: 'target.age = 1'
})
assert.equal(await table.countRows(), 3)
assert.equal(await table.countRows('age = 1'), 1)
assert.equal(await table.countRows('age = 3'), 1)
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)
assert.equal((await table.filter('age = 4').execute()).length, 2)
newData = [{ id: 5, age: 4 }]
newData = [{ id: 5, age: 5 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
whenNotMatchedBySourceDelete: 'age < 4'
})
assert.equal(await table.countRows(), 3)

View File

@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
self._table = table
self._on = on
self._when_matched_update_all = False
self._when_matched_update_all_condition = None
self._when_not_matched_insert_all = False
self._when_not_matched_by_source_delete = False
self._when_not_matched_by_source_condition = None
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
def when_matched_update_all(
self, *, where: Optional[str] = None
) -> LanceMergeInsertBuilder:
"""
Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
but that behavior is subject to change.
"""
self._when_matched_update_all = True
self._when_matched_update_all_condition = where
return self
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:

View File

@@ -298,6 +298,10 @@ class RemoteTable(Table):
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params[
"when_matched_update_all_filt"
] = merge._when_matched_update_all_condition
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()

View File

@@ -1459,7 +1459,7 @@ class LanceTable(Table):
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
builder.when_matched_update_all()
builder.when_matched_update_all(merge._when_matched_update_all_condition)
if merge._when_not_matched_insert_all:
builder.when_not_matched_insert_all()
if merge._when_not_matched_by_source_delete:

View File

@@ -3,7 +3,7 @@ name = "lancedb"
version = "0.5.3"
dependencies = [
"deprecation",
"pylance==0.9.12",
"pylance==0.9.14",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",

View File

@@ -513,8 +513,15 @@ def test_merge_insert(db):
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
# These `sort_by` calls can be removed once lance#1892
# is merged (it fixes the ordering)
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# conditional update
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
new_data
)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)

View File

@@ -191,28 +191,34 @@ impl JsTable {
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
builder.when_matched_update_all();
}
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
builder.when_not_matched_insert_all();
let filter = cx.argument_opt(2).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_matched_update_all(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_matched_update_all(Some(filter));
}
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
if let Some(filter) = cx.argument_opt(4) {
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
} else {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
let filter = cx.argument_opt(5).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
}
let buffer = cx.argument::<JsBuffer>(5)?;
let buffer = cx.argument::<JsBuffer>(6)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;

View File

@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
@@ -238,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all()
/// merge_insert.when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
@@ -677,11 +677,14 @@ impl MergeInsert for NativeTable {
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
if params.when_matched_update_all {
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
} else {
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
}
match (
params.when_matched_update_all,
params.when_matched_update_all_filt,
) {
(false, _) => builder.when_matched(WhenMatched::DoNothing),
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
};
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
@@ -824,6 +827,7 @@ impl Table for NativeTable {
#[cfg(test)]
mod tests {
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@@ -947,14 +951,14 @@ mod tests {
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = make_test_batches_with_offset(0);
let batches = merge_insert_test_batches(0, 0);
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
let new_batches = Box::new(merge_insert_test_batches(5, 1));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
@@ -964,13 +968,27 @@ mod tests {
assert_eq!(table.count_rows(None).await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
let new_batches = Box::new(merge_insert_test_batches(15, 2));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all();
merge_insert_builder.when_matched_update_all(None);
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows(None).await.unwrap(), 15);
assert_eq!(
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
0
);
// Conditional update that only replaces the age=0 data
let new_batches = Box::new(merge_insert_test_batches(5, 3));
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
5
);
}
#[tokio::test]
@@ -1319,23 +1337,35 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches_with_offset(
fn merge_insert_test_batches(
offset: i32,
age: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("age", DataType::Int32, false),
]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(
offset..(offset + 10),
))],
vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
],
)],
schema,
)
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
make_test_batches_with_offset(0)
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
#[tokio::test]

View File

@@ -35,6 +35,7 @@ pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_matched_update_all_filt: Option<String>,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
@@ -46,6 +47,7 @@ impl MergeInsertBuilder {
table,
on,
when_matched_update_all: false,
when_matched_update_all_filt: None,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
@@ -59,8 +61,22 @@ impl MergeInsertBuilder {
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
pub fn when_matched_update_all(&mut self) -> &mut Self {
///
/// An optional condition may be specified. If it is, then only
/// matched rows that satisfy the condtion will be updated. Any
/// rows that do not satisfy the condition will be left as they
/// are. Failing to satisfy the condition does not cause a
/// "matched row" to become a "not matched" row.
///
/// The condition should be an SQL string. Use the prefix
/// target. to refer to rows in the target table (old data)
/// and the prefix source. to refer to rows in the source
/// table (new data).
///
/// For example, "target.last_update < source.last_update"
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
self.when_matched_update_all = true;
self.when_matched_update_all_filt = condition;
self
}