mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-12 23:02:59 +00:00
@@ -104,4 +104,4 @@ class LanceMergeInsertBuilder(object):
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
return self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
|
||||
@@ -2464,7 +2464,31 @@ class AsyncTable:
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
schema = await self.schema()
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.execute_merge_insert(
|
||||
data,
|
||||
dict(
|
||||
on=merge._on,
|
||||
when_matched_update_all=merge._when_matched_update_all,
|
||||
when_matched_update_all_condition=merge._when_matched_update_all_condition,
|
||||
when_not_matched_insert_all=merge._when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete=merge._when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
|
||||
),
|
||||
)
|
||||
|
||||
async def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
@@ -636,11 +636,13 @@ def test_merge_insert(db):
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||
"a > 2"
|
||||
).execute(new_data)
|
||||
(
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -658,6 +660,75 @@ def test_merge_insert(db):
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_insert_async(db_async: AsyncConnection):
|
||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||
table = await db_async.create_table("some_table", data=data)
|
||||
assert await table.count_rows() == 3
|
||||
version = await table.version()
|
||||
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# conditional update
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all(where="target.b = 'b'")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# insert-if-not-exists
|
||||
await table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range no condition
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
|
||||
Reference in New Issue
Block a user