Files
lancedb/python/python/tests/docs/test_merge_insert.py
Heng Ge 048f52c2aa feat(table): route merge_insert through the MemWAL LSM write path (#3354)
## Summary

When an `LsmWriteSpec` is installed on a table (#3396), `merge_insert`
upsert
calls are dispatched through Lance's MemWAL `ShardWriter` (LSM-style
append)
instead of the standard merge path.

- **`use_lsm_write`** — a `merge_insert` builder option, default `true`;
set it
  `false` to use the standard path for a call even when a spec is set.
- **`assume_pre_sharded`** — a `merge_insert` builder option, default
`false`;
  skips the per-row shard check and routes by the first row only.
- **`close_lsm_writers`** — drains and closes the table's cached MemWAL
shard
  writers.
- The `merge_insert` **`on`** columns default to, and are validated
against,
  the table's unenforced primary key.
- Shard writers are cached alongside the dataset (in
  `DatasetConsistencyWrapper`) and reused for the session.
- `MergeResult` gains **`num_rows`** — on the LSM path the insert/update
  breakdown is unknown until compaction, so only the total is reported.

Routing covers all three sharding strategies — bucket (murmur3,
Iceberg-compatible), identity, and unsharded. Each `merge_insert` call
targets
a single shard; the whole input is collected and validated before a
single
atomic `ShardWriter::put`, so a validation failure leaves the MemWAL
untouched.

Bindings: Python (`merge_insert(...).use_lsm_write(...)` /
`.assume_pre_sharded(...)`, `Table.close_lsm_writers`) and TypeScript
(`mergeInsert(...).useLsmWrite(...)` / `.assumePreSharded(...)`,
`Table.closeLsmWriters`).

## Context

Reconstructed from the original #3354 branch onto current `main`: the
branch
predated the #3394 (unenforced primary key) / #3396 (`LsmWriteSpec`)
split and
has been rebuilt on that merged foundation. Depends on Lance
`v7.0.0-beta.13`.

The MemWAL read path (reading un-flushed shard data back into queries)
and
remote (LanceDB Cloud) LSM support are follow-ups.

---------

Co-authored-by: Jack Ye <yezhaoqin@gmail.com>
2026-05-29 08:48:11 -07:00

195 lines
5.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pytest
def test_upsert(mem_db):
db = mem_db
# --8<-- [start:upsert_basic]
table = db.create_table(
"users",
[
{"id": 0, "name": "Alice"},
{"id": 1, "name": "Bob"},
],
)
new_users = [
{"id": 1, "name": "Bobby"},
{"id": 2, "name": "Charlie"},
]
res = (
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_users)
)
table.count_rows() # 3
res # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
# --8<-- [end:upsert_basic]
assert table.count_rows() == 3
assert res.num_inserted_rows == 1
assert res.num_deleted_rows == 0
assert res.num_updated_rows == 1
@pytest.mark.asyncio
async def test_upsert_async(mem_db_async):
db = mem_db_async
# --8<-- [start:upsert_basic_async]
table = await db.create_table(
"users",
[
{"id": 0, "name": "Alice"},
{"id": 1, "name": "Bob"},
],
)
new_users = [
{"id": 1, "name": "Bobby"},
{"id": 2, "name": "Charlie"},
]
res = await (
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_users)
)
await table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=1, num_deleted_rows=0, num_rows=2)
# --8<-- [end:upsert_basic_async]
assert await table.count_rows() == 3
assert res.version == 2
assert res.num_inserted_rows == 1
assert res.num_deleted_rows == 0
assert res.num_updated_rows == 1
def test_insert_if_not_exists(mem_db):
db = mem_db
# --8<-- [start:insert_if_not_exists]
table = db.create_table(
"domains",
[
{"domain": "google.com", "name": "Google"},
{"domain": "github.com", "name": "GitHub"},
],
)
new_domains = [
{"domain": "google.com", "name": "Google"},
{"domain": "facebook.com", "name": "Facebook"},
]
res = (
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
)
table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=0,
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert table.count_rows() == 3
assert res.version == 2
assert res.num_inserted_rows == 1
assert res.num_deleted_rows == 0
assert res.num_updated_rows == 0
@pytest.mark.asyncio
async def test_insert_if_not_exists_async(mem_db_async):
db = mem_db_async
# --8<-- [start:insert_if_not_exists_async]
table = await db.create_table(
"domains",
[
{"domain": "google.com", "name": "Google"},
{"domain": "github.com", "name": "GitHub"},
],
)
new_domains = [
{"domain": "google.com", "name": "Google"},
{"domain": "facebook.com", "name": "Facebook"},
]
res = await (
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
)
await table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=0,
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert await table.count_rows() == 3
assert res.version == 2
assert res.num_inserted_rows == 1
assert res.num_deleted_rows == 0
assert res.num_updated_rows == 0
def test_replace_range(mem_db):
db = mem_db
# --8<-- [start:replace_range]
table = db.create_table(
"chunks",
[
{"doc_id": 0, "chunk_id": 0, "text": "Hello"},
{"doc_id": 0, "chunk_id": 1, "text": "World"},
{"doc_id": 1, "chunk_id": 0, "text": "Foo"},
{"doc_id": 1, "chunk_id": 1, "text": "Bar"},
],
)
new_chunks = [
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
]
res = (
table.merge_insert(["doc_id", "chunk_id"])
.when_matched_update_all()
.when_not_matched_insert_all()
.when_not_matched_by_source_delete("doc_id = 1")
.execute(new_chunks)
)
table.count_rows("doc_id = 1") # 1
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert table.count_rows("doc_id = 1") == 1
assert res.version == 2
assert res.num_inserted_rows == 0
assert res.num_deleted_rows == 1
assert res.num_updated_rows == 1
@pytest.mark.asyncio
async def test_replace_range_async(mem_db_async):
db = mem_db_async
# --8<-- [start:replace_range_async]
table = await db.create_table(
"chunks",
[
{"doc_id": 0, "chunk_id": 0, "text": "Hello"},
{"doc_id": 0, "chunk_id": 1, "text": "World"},
{"doc_id": 1, "chunk_id": 0, "text": "Foo"},
{"doc_id": 1, "chunk_id": 1, "text": "Bar"},
],
)
new_chunks = [
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
]
res = await (
table.merge_insert(["doc_id", "chunk_id"])
.when_matched_update_all()
.when_not_matched_insert_all()
.when_not_matched_by_source_delete("doc_id = 1")
.execute(new_chunks)
)
await table.count_rows("doc_id = 1") # 1
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert await table.count_rows("doc_id = 1") == 1
assert res.version == 2
assert res.num_inserted_rows == 0
assert res.num_deleted_rows == 1
assert res.num_updated_rows == 1