feat: add support for filter during merge insert when matched (#948)

Closes #940
This commit is contained in:
Weston Pace
2024-02-09 10:26:14 -08:00
committed by GitHub
parent 48d55bf952
commit a9727eb318
12 changed files with 150 additions and 51 deletions

View File

@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
self._table = table
self._on = on
self._when_matched_update_all = False
self._when_matched_update_all_condition = None
self._when_not_matched_insert_all = False
self._when_not_matched_by_source_delete = False
self._when_not_matched_by_source_condition = None
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
def when_matched_update_all(
self, *, where: Optional[str] = None
) -> LanceMergeInsertBuilder:
"""
Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
but that behavior is subject to change.
"""
self._when_matched_update_all = True
self._when_matched_update_all_condition = where
return self
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:

View File

@@ -298,6 +298,10 @@ class RemoteTable(Table):
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params[
"when_matched_update_all_filt"
] = merge._when_matched_update_all_condition
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()

View File

@@ -1467,7 +1467,7 @@ class LanceTable(Table):
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
builder.when_matched_update_all()
builder.when_matched_update_all(merge._when_matched_update_all_condition)
if merge._when_not_matched_insert_all:
builder.when_not_matched_insert_all()
if merge._when_not_matched_by_source_delete:

View File

@@ -3,7 +3,7 @@ name = "lancedb"
version = "0.5.3"
dependencies = [
"deprecation",
"pylance==0.9.12",
"pylance==0.9.14",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",

View File

@@ -513,8 +513,15 @@ def test_merge_insert(db):
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
# These `sort_by` calls can be removed once lance#1892
# is merged (it fixes the ordering)
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# conditional update
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
new_data
)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)