From 048f52c2aa042e5cbaebfce3400c345faa1875f2 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Fri, 29 May 2026 08:48:11 -0700 Subject: [PATCH] feat(table): route merge_insert through the MemWAL LSM write path (#3354) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary When an `LsmWriteSpec` is installed on a table (#3396), `merge_insert` upsert calls are dispatched through Lance's MemWAL `ShardWriter` (LSM-style append) instead of the standard merge path. - **`use_lsm_write`** — a `merge_insert` builder option, default `true`; set it `false` to use the standard path for a call even when a spec is set. - **`assume_pre_sharded`** — a `merge_insert` builder option, default `false`; skips the per-row shard check and routes by the first row only. - **`close_lsm_writers`** — drains and closes the table's cached MemWAL shard writers. - The `merge_insert` **`on`** columns default to, and are validated against, the table's unenforced primary key. - Shard writers are cached alongside the dataset (in `DatasetConsistencyWrapper`) and reused for the session. - `MergeResult` gains **`num_rows`** — on the LSM path the insert/update breakdown is unknown until compaction, so only the total is reported. Routing covers all three sharding strategies — bucket (murmur3, Iceberg-compatible), identity, and unsharded. Each `merge_insert` call targets a single shard; the whole input is collected and validated before a single atomic `ShardWriter::put`, so a validation failure leaves the MemWAL untouched. Bindings: Python (`merge_insert(...).use_lsm_write(...)` / `.assume_pre_sharded(...)`, `Table.close_lsm_writers`) and TypeScript (`mergeInsert(...).useLsmWrite(...)` / `.assumePreSharded(...)`, `Table.closeLsmWriters`). ## Context Reconstructed from the original #3354 branch onto current `main`: the branch predated the #3394 (unenforced primary key) / #3396 (`LsmWriteSpec`) split and has been rebuilt on that merged foundation. Depends on Lance `v7.0.0-beta.13`. The MemWAL read path (reading un-flushed shard data back into queries) and remote (LanceDB Cloud) LSM support are follow-ups. --------- Co-authored-by: Jack Ye --- Cargo.lock | 7 + docs/src/js/classes/MergeInsertBuilder.md | 51 + docs/src/js/classes/Table.md | 19 + docs/src/js/interfaces/LsmWriteSpec.md | 5 +- docs/src/js/interfaces/MergeResult.md | 8 + nodejs/__test__/table.test.ts | 94 ++ nodejs/lancedb/merge.ts | 35 + nodejs/lancedb/table.ts | 19 +- nodejs/src/merge.rs | 14 + nodejs/src/table.rs | 7 + python/python/lancedb/_lancedb.pyi | 2 + python/python/lancedb/merge.py | 42 + python/python/lancedb/remote/table.py | 4 + python/python/lancedb/table.py | 21 +- python/python/tests/docs/test_merge_insert.py | 10 +- python/python/tests/test_merge_insert_lsm.py | 196 ++++ python/src/table.rs | 28 +- rust/lancedb/Cargo.toml | 2 +- rust/lancedb/src/remote/client.rs | 14 + rust/lancedb/src/remote/table.rs | 1 + rust/lancedb/src/table.rs | 29 + rust/lancedb/src/table/dataset.rs | 11 + rust/lancedb/src/table/merge.rs | 423 ++++++++ rust/lancedb/src/table/merge/lsm.rs | 998 +++++++++++++++++- 24 files changed, 2020 insertions(+), 20 deletions(-) create mode 100644 python/python/tests/test_merge_insert_lsm.py diff --git a/Cargo.lock b/Cargo.lock index 1d1441d00..1777d9591 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8468,6 +8468,12 @@ dependencies = [ "digest 0.11.3", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -9691,6 +9697,7 @@ dependencies = [ "getrandom 0.4.2", "js-sys", "serde_core", + "sha1_smol", "wasm-bindgen", ] diff --git a/docs/src/js/classes/MergeInsertBuilder.md b/docs/src/js/classes/MergeInsertBuilder.md index ae601c9e2..ac0493bad 100644 --- a/docs/src/js/classes/MergeInsertBuilder.md +++ b/docs/src/js/classes/MergeInsertBuilder.md @@ -76,6 +76,57 @@ the query optimizer chooses a suboptimal path. *** +### useLsmWrite() + +```ts +useLsmWrite(useLsmWrite): MergeInsertBuilder +``` + +Controls whether the merge uses the MemWAL LSM write path. + +By default (unset), a `mergeInsert` on a table with an LSM write spec is +routed through Lance's MemWAL shard writer, and a table without one uses +the standard path. Pass `false` to force the standard path even when a +spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none +is installed. + +#### Parameters + +* **useLsmWrite**: `boolean` + Whether to use the LSM write path. + +#### Returns + +[`MergeInsertBuilder`](MergeInsertBuilder.md) + +*** + +### validateSingleShard() + +```ts +validateSingleShard(validateSingleShard): MergeInsertBuilder +``` + +Controls how an LSM merge checks that its input targets a single shard. + +When a table has an LSM write spec, every row in a `mergeInsert` call must +route to the same shard. When `true` (the default), every row is inspected +to verify this. When `false`, only the first row is inspected and the +shard it routes to is used for the whole input — a faster path for callers +that have already pre-sharded their input. Has no effect on tables without +an LSM write spec. + +#### Parameters + +* **validateSingleShard**: `boolean` + Whether to check every row routes to one shard. Defaults to `true`. + +#### Returns + +[`MergeInsertBuilder`](MergeInsertBuilder.md) + +*** + ### whenMatchedUpdateAll() ```ts diff --git a/docs/src/js/classes/Table.md b/docs/src/js/classes/Table.md index 45fa13362..62b962daf 100644 --- a/docs/src/js/classes/Table.md +++ b/docs/src/js/classes/Table.md @@ -187,6 +187,25 @@ Any attempt to use the table after it is closed will result in an error. *** +### closeLsmWriters() + +```ts +abstract closeLsmWriters(): Promise +``` + +Drain and close any cached MemWAL shard writers held for this table. + +When an [LsmWriteSpec](../interfaces/LsmWriteSpec.md) is installed, `mergeInsert` opens MemWAL +shard writers and caches them for reuse across calls. This closes them, +flushing pending data; writers reopen lazily on the next `mergeInsert`. +It is a no-op when no writers are cached. + +#### Returns + +`Promise`<`void`> + +*** + ### countRows() ```ts diff --git a/docs/src/js/interfaces/LsmWriteSpec.md b/docs/src/js/interfaces/LsmWriteSpec.md index 017e819dc..8a588df6a 100644 --- a/docs/src/js/interfaces/LsmWriteSpec.md +++ b/docs/src/js/interfaces/LsmWriteSpec.md @@ -11,7 +11,10 @@ Specification selecting Lance's MemWAL LSM-style write path for `specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`, `column` and `numBuckets` are required; for `"identity"`, `column` is -required. +required and must be a deterministic function of the unenforced primary +key (every row with a given primary key must always produce the same +`column` value, or upserts of that key can land in different shards and a +stale version can win). ## Properties diff --git a/docs/src/js/interfaces/MergeResult.md b/docs/src/js/interfaces/MergeResult.md index d59049cb8..6114fabfa 100644 --- a/docs/src/js/interfaces/MergeResult.md +++ b/docs/src/js/interfaces/MergeResult.md @@ -32,6 +32,14 @@ numInsertedRows: number; *** +### numRows + +```ts +numRows: number; +``` + +*** + ### numUpdatedRows ```ts diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 7d43ca351..3be56d3c7 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -2625,3 +2625,97 @@ describe("setLsmWriteSpec / unsetLsmWriteSpec", () => { ).rejects.toThrow(); }); }); + +describe("LSM merge insert", () => { + let tmpDir: tmp.DirResult; + + beforeEach(() => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + }); + afterEach(() => tmpDir.removeCallback()); + + async function bucketTable(conn: Connection): Promise { + // The primary key column must be non-nullable. + const table = await conn.createEmptyTable( + "t", + new arrow.Schema([ + new arrow.Field("id", new arrow.Utf8(), false), + new arrow.Field("value", new arrow.Float64(), true), + ]), + ); + await table.add([ + { id: "a", value: 1 }, + { id: "b", value: 2 }, + ]); + await table.setUnenforcedPrimaryKey("id"); + // numBuckets = 1: every row routes to the single bucket. + await table.setLsmWriteSpec({ + specType: "bucket", + column: "id", + numBuckets: 1, + }); + return table; + } + + it("routes merge_insert through the shard writer", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .execute([ + { id: "c", value: 3 }, + { id: "d", value: 4 }, + ]); + // LSM path: rows go to the MemWAL, so only numRows is populated. + expect(res.numRows).toBe(2); + expect(res.version).toBe(0); + expect(res.numInsertedRows).toBe(0); + + await table.closeLsmWriters(); + }); + + it("falls back to the standard path with useLsmWrite(false)", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenNotMatchedInsertAll() + .useLsmWrite(false) + .execute([ + { id: "b", value: 9 }, + { id: "e", value: 5 }, + ]); + // Standard path commits: id="e" inserted ("b" already exists). + expect(res.numInsertedRows).toBe(1); + expect(await table.countRows()).toBe(3); + }); + + it("supports validateSingleShard(false)", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .validateSingleShard(false) + .execute([{ id: "f", value: 6 }]); + expect(res.numRows).toBe(1); + }); + + it("rejects a non-upsert merge under an LSM spec", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + await expect( + table + .mergeInsert("id") + .whenNotMatchedInsertAll() + .execute([{ id: "g", value: 7 }]), + ).rejects.toThrow(); + }); +}); diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts index dc9144fdf..08321427f 100644 --- a/nodejs/lancedb/merge.ts +++ b/nodejs/lancedb/merge.ts @@ -87,6 +87,41 @@ export class MergeInsertBuilder { this.#schema, ); } + /** + * Controls whether the merge uses the MemWAL LSM write path. + * + * By default (unset), a `mergeInsert` on a table with an LSM write spec is + * routed through Lance's MemWAL shard writer, and a table without one uses + * the standard path. Pass `false` to force the standard path even when a + * spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none + * is installed. + * + * @param useLsmWrite - Whether to use the LSM write path. + */ + useLsmWrite(useLsmWrite: boolean): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.useLsmWrite(useLsmWrite), + this.#schema, + ); + } + /** + * Controls how an LSM merge checks that its input targets a single shard. + * + * When a table has an LSM write spec, every row in a `mergeInsert` call must + * route to the same shard. When `true` (the default), every row is inspected + * to verify this. When `false`, only the first row is inspected and the + * shard it routes to is used for the whole input — a faster path for callers + * that have already pre-sharded their input. Has no effect on tables without + * an LSM write spec. + * + * @param validateSingleShard - Whether to check every row routes to one shard. Defaults to `true`. + */ + validateSingleShard(validateSingleShard: boolean): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.validateSingleShard(validateSingleShard), + this.#schema, + ); + } /** * Executes the merge insert operation * diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index fe495392a..ae2e86995 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -161,7 +161,10 @@ export interface Version { * * `specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`, * `column` and `numBuckets` are required; for `"identity"`, `column` is - * required. + * required and must be a deterministic function of the unenforced primary + * key (every row with a given primary key must always produce the same + * `column` value, or upserts of that key can land in different shards and a + * stale version can win). */ export interface LsmWriteSpec { /** One of `"bucket"`, `"identity"`, or `"unsharded"`. */ @@ -567,6 +570,16 @@ export abstract class Table { * @returns {Promise} */ abstract unsetLsmWriteSpec(): Promise; + /** + * Drain and close any cached MemWAL shard writers held for this table. + * + * When an {@link LsmWriteSpec} is installed, `mergeInsert` opens MemWAL + * shard writers and caches them for reuse across calls. This closes them, + * flushing pending data; writers reopen lazily on the next `mergeInsert`. + * It is a no-op when no writers are cached. + * @returns {Promise} + */ + abstract closeLsmWriters(): Promise; /** Retrieve the version of the table */ abstract version(): Promise; @@ -1041,6 +1054,10 @@ export class LocalTable extends Table { return await this.inner.unsetLsmWriteSpec(); } + async closeLsmWriters(): Promise { + return await this.inner.closeLsmWriters(); + } + async version(): Promise { return await this.inner.version(); } diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs index 98d637fb3..5ba9846bc 100644 --- a/nodejs/src/merge.rs +++ b/nodejs/src/merge.rs @@ -50,6 +50,20 @@ impl NativeMergeInsertBuilder { this } + #[napi] + pub fn use_lsm_write(&self, use_lsm_write: bool) -> Self { + let mut this = self.clone(); + this.inner.use_lsm_write(use_lsm_write); + this + } + + #[napi] + pub fn validate_single_shard(&self, validate_single_shard: bool) -> Self { + let mut this = self.clone(); + this.inner.validate_single_shard(validate_single_shard); + this + } + #[napi(catch_unwind)] pub async fn execute(&self, buf: Buffer) -> napi::Result { let data = ipc_file_to_batches(buf.to_vec()) diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 4c5424bc9..16cde35d8 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -391,6 +391,11 @@ impl Table { .default_error() } + #[napi(catch_unwind)] + pub async fn close_lsm_writers(&self) -> napi::Result<()> { + self.inner_ref()?.close_lsm_writers().await.default_error() + } + #[napi(catch_unwind)] pub async fn version(&self) -> napi::Result { self.inner_ref()? @@ -940,6 +945,7 @@ pub struct MergeResult { pub num_updated_rows: i64, pub num_deleted_rows: i64, pub num_attempts: i64, + pub num_rows: i64, } impl From for MergeResult { @@ -950,6 +956,7 @@ impl From for MergeResult { num_updated_rows: value.num_updated_rows as i64, num_deleted_rows: value.num_deleted_rows as i64, num_attempts: value.num_attempts as i64, + num_rows: value.num_rows as i64, } } } diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index db28e0fc8..0148f6575 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -220,6 +220,7 @@ class Table: async def set_unenforced_primary_key(self, columns: List[str]) -> None: ... async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ... async def unset_lsm_write_spec(self) -> None: ... + async def close_lsm_writers(self) -> None: ... @property def tags(self) -> Tags: ... def query(self) -> Query: ... @@ -420,6 +421,7 @@ class MergeResult: num_inserted_rows: int num_deleted_rows: int num_attempts: int + num_rows: int class LsmWriteSpec: """Specification selecting Lance's MemWAL LSM-style write path for diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index b2564740c..6085f5a06 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -34,6 +34,8 @@ class LanceMergeInsertBuilder(object): self._when_not_matched_by_source_condition = None self._timeout = None self._use_index = True + self._use_lsm_write = None + self._validate_single_shard = None def when_matched_update_all( self, *, where: Optional[str] = None @@ -96,6 +98,46 @@ class LanceMergeInsertBuilder(object): self._use_index = use_index return self + def use_lsm_write(self, use_lsm_write: bool) -> LanceMergeInsertBuilder: + """ + Controls whether the merge uses the MemWAL LSM write path. + + By default (unset), a `merge_insert` on a table with an LSM write spec + is routed through Lance's MemWAL shard writer, and a table without one + uses the standard path. Pass `False` to force the standard path even + when a spec is set. Pass `True` to require a spec — `merge_insert` + raises an error if none is installed. + + Parameters + ---------- + use_lsm_write: bool + Whether to use the LSM write path. + """ + self._use_lsm_write = use_lsm_write + return self + + def validate_single_shard( + self, validate_single_shard: bool + ) -> LanceMergeInsertBuilder: + """ + Controls how an LSM merge checks that its input targets a single shard. + + When a table has an LSM write spec, every row in a `merge_insert` call + must route to the same shard. When `True` (the default), every row is + inspected to verify this. When `False`, only the first row is inspected + and the shard it routes to is used for the whole input — a faster path + for callers that have already pre-sharded their input. + + Has no effect on tables without an LSM write spec. + + Parameters + ---------- + validate_single_shard: bool + Whether to check every row routes to one shard. Defaults to `True`. + """ + self._validate_single_shard = validate_single_shard + return self + def execute( self, new_data: DATA, diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 3d4155269..73bdbb8b1 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -792,6 +792,10 @@ class RemoteTable(Table): """Not supported on LanceDB Cloud.""" return LOOP.run(self._table.unset_lsm_write_spec()) + def close_lsm_writers(self) -> None: + """No-op on LanceDB Cloud (no local shard writers).""" + return LOOP.run(self._table.close_lsm_writers()) + def drop_index(self, index_name: str): return LOOP.run(self._table.drop_index(index_name)) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 407709d17..2de369419 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1251,7 +1251,7 @@ class Table(ABC): ... .when_not_matched_insert_all() \\ ... .execute(new_data) >>> res - MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1) + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -3601,6 +3601,11 @@ class LanceTable(Table): [`AsyncTable.unset_lsm_write_spec`][lancedb.AsyncTable.unset_lsm_write_spec].""" return LOOP.run(self._table.unset_lsm_write_spec()) + def close_lsm_writers(self) -> None: + """Close cached MemWAL shard writers. See + [`AsyncTable.close_lsm_writers`][lancedb.AsyncTable.close_lsm_writers].""" + return LOOP.run(self._table.close_lsm_writers()) + def uses_v2_manifest_paths(self) -> bool: """ Check if the table is using the new v2 manifest paths. @@ -4209,6 +4214,16 @@ class AsyncTable: """ await self._inner.unset_lsm_write_spec() + async def close_lsm_writers(self) -> None: + """Drain and close any cached MemWAL shard writers for this table. + + When an LSM write spec is installed, `merge_insert` opens MemWAL shard + writers and caches them for reuse across calls. This closes them, + flushing pending data; writers reopen lazily on the next + `merge_insert`. It is a no-op when no writers are cached. + """ + await self._inner.close_lsm_writers() + @property def name(self) -> str: """The name of the table.""" @@ -4659,7 +4674,7 @@ class AsyncTable: ... .when_not_matched_insert_all() \\ ... .execute(new_data) >>> res - MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1) + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -5039,6 +5054,8 @@ class AsyncTable: when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition, timeout=merge._timeout, use_index=merge._use_index, + use_lsm_write=merge._use_lsm_write, + validate_single_shard=merge._validate_single_shard, ), ) diff --git a/python/python/tests/docs/test_merge_insert.py b/python/python/tests/docs/test_merge_insert.py index 228faa31b..adf812219 100644 --- a/python/python/tests/docs/test_merge_insert.py +++ b/python/python/tests/docs/test_merge_insert.py @@ -57,7 +57,7 @@ async def test_upsert_async(mem_db_async): await table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=2) # --8<-- [end:upsert_basic_async] assert await table.count_rows() == 3 assert res.version == 2 @@ -86,7 +86,7 @@ def test_insert_if_not_exists(mem_db): table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=0, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=1) # --8<-- [end:insert_if_not_exists] assert table.count_rows() == 3 assert res.version == 2 @@ -116,7 +116,7 @@ async def test_insert_if_not_exists_async(mem_db_async): await table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=0, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=1) # --8<-- [end:insert_if_not_exists] assert await table.count_rows() == 3 assert res.version == 2 @@ -150,7 +150,7 @@ def test_replace_range(mem_db): table.count_rows("doc_id = 1") # 1 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=0, num_deleted_rows=1) + # num_inserted_rows=0, num_deleted_rows=1, num_rows=1) # --8<-- [end:insert_if_not_exists] assert table.count_rows("doc_id = 1") == 1 assert res.version == 2 @@ -185,7 +185,7 @@ async def test_replace_range_async(mem_db_async): await table.count_rows("doc_id = 1") # 1 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=0, num_deleted_rows=1) + # num_inserted_rows=0, num_deleted_rows=1, num_rows=1) # --8<-- [end:insert_if_not_exists] assert await table.count_rows("doc_id = 1") == 1 assert res.version == 2 diff --git a/python/python/tests/test_merge_insert_lsm.py b/python/python/tests/test_merge_insert_lsm.py new file mode 100644 index 000000000..abdfb306d --- /dev/null +++ b/python/python/tests/test_merge_insert_lsm.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +"""Tests for the MemWAL LSM ``merge_insert`` dispatch.""" + +from datetime import timedelta + +import lancedb +import pyarrow as pa +import pytest +from lancedb._lancedb import LsmWriteSpec + +SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.int64(), nullable=False), + ] +) + +REGION_SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("region", pa.utf8(), nullable=False), + ] +) + + +def _reader(ids): + batch = pa.RecordBatch.from_arrays( + [ + pa.array(ids, type=pa.int64()), + pa.array(list(range(len(ids))), type=pa.int64()), + ], + schema=SCHEMA, + ) + return pa.RecordBatchReader.from_batches(SCHEMA, [batch]) + + +def _region_reader(rows): + batch = pa.RecordBatch.from_arrays( + [ + pa.array([row[0] for row in rows], type=pa.int64()), + pa.array([row[1] for row in rows], type=pa.utf8()), + ], + schema=REGION_SCHEMA, + ) + return pa.RecordBatchReader.from_batches(REGION_SCHEMA, [batch]) + + +def _bucket_table(tmp_path): + """A table with ``id`` as the primary key and a single-bucket LSM spec.""" + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + # num_buckets = 1: every row routes to the single bucket. + table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1)) + return table + + +def test_lsm_merge_insert_bucket(tmp_path): + table = _bucket_table(tmp_path) + # Empty `on` defaults to the primary key. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([3, 4, 5])) + ) + # LSM path: rows go to the MemWAL, so only num_rows is populated. + assert result.num_rows == 3 + assert result.version == 0 + assert result.num_inserted_rows == 0 + assert result.num_updated_rows == 0 + + +def test_lsm_merge_insert_unsharded(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + table.set_lsm_write_spec(LsmWriteSpec.unsharded()) + result = ( + table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([10, 11, 12, 13])) + ) + assert result.num_rows == 4 + + +def test_lsm_merge_insert_identity(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _region_reader([(1, "us"), (2, "us")])) + table.set_unenforced_primary_key("id") + table.set_lsm_write_spec(LsmWriteSpec.identity("region")) + # All rows share one identity value, so they route to one shard. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_region_reader([(3, "us"), (4, "us")])) + ) + assert result.num_rows == 2 + + +def test_lsm_merge_insert_use_lsm_write_false(tmp_path): + table = _bucket_table(tmp_path) # rows id = 1, 2, 3 + # use_lsm_write(False) opts out: the standard path runs and commits. + result = ( + table.merge_insert("id") + .when_not_matched_insert_all() + .use_lsm_write(False) + .execute(_reader([3, 4, 5])) + ) + assert result.num_inserted_rows == 2 + assert table.count_rows() == 5 + + +def test_lsm_merge_insert_validate_single_shard_off(tmp_path): + table = _bucket_table(tmp_path) + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .validate_single_shard(False) + .execute(_reader([6, 7, 8])) + ) + assert result.num_rows == 3 + + +def test_lsm_merge_insert_use_lsm_write_true_requires_spec(tmp_path): + # A table with a primary key but no LSM write spec installed. + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + with pytest.raises(Exception, match="use_lsm_write"): + ( + table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .use_lsm_write(True) + .execute(_reader([4])) + ) + + +def test_lsm_merge_insert_rejects_on_not_primary_key(tmp_path): + table = _bucket_table(tmp_path) + with pytest.raises(Exception, match="primary key"): + ( + table.merge_insert("value") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([1])) + ) + + +def test_lsm_merge_insert_rejects_non_upsert(tmp_path): + table = _bucket_table(tmp_path) + # Insert-only (no when_matched_update_all) is not the upsert shape. + with pytest.raises(Exception, match="upsert"): + table.merge_insert([]).when_not_matched_insert_all().execute(_reader([4])) + + +def test_lsm_close_writers(tmp_path): + table = _bucket_table(tmp_path) + ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([7, 8])) + ) + table.close_lsm_writers() + # The writer reopens lazily on the next merge_insert. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([9])) + ) + assert result.num_rows == 1 + + +@pytest.mark.asyncio +async def test_async_lsm_merge_insert(tmp_path): + db = await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) + table = await db.create_table("t", _reader([1, 2, 3])) + await table.set_unenforced_primary_key("id") + await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1)) + + builder = ( + table.merge_insert([]).when_matched_update_all().when_not_matched_insert_all() + ) + result = await builder.execute(_reader([3, 4, 5])) + assert result.num_rows == 3 + await table.close_lsm_writers() diff --git a/python/src/table.rs b/python/src/table.rs index 546bec555..302c2bb46 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -143,18 +143,20 @@ pub struct MergeResult { pub num_inserted_rows: u64, pub num_deleted_rows: u64, pub num_attempts: u32, + pub num_rows: u64, } #[pymethods] impl MergeResult { pub fn __repr__(&self) -> String { format!( - "MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={})", + "MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={}, num_rows={})", self.version, self.num_updated_rows, self.num_inserted_rows, self.num_deleted_rows, - self.num_attempts + self.num_attempts, + self.num_rows ) } } @@ -167,6 +169,7 @@ impl From for MergeResult { num_inserted_rows: result.num_inserted_rows, num_deleted_rows: result.num_deleted_rows, num_attempts: result.num_attempts, + num_rows: result.num_rows, } } } @@ -194,6 +197,12 @@ impl LsmWriteSpec { } /// Identity sharding — shard by the raw value of `column`. + /// + /// `column` must be a deterministic function of the unenforced primary + /// key: every row with a given primary key must always produce the same + /// `column` value, or upserts of that key can land in different shards + /// and a stale version can win. Typically `column` is the primary key + /// itself or a stable attribute of it. #[staticmethod] pub fn identity(column: String) -> Self { Self { @@ -933,6 +942,12 @@ impl Table { if let Some(use_index) = parameters.use_index { builder.use_index(use_index); } + if let Some(use_lsm_write) = parameters.use_lsm_write { + builder.use_lsm_write(use_lsm_write); + } + if let Some(validate_single_shard) = parameters.validate_single_shard { + builder.validate_single_shard(validate_single_shard); + } future_into_py(self_.py(), async move { let res = builder.execute(Box::new(batches)).await.infer_error()?; @@ -971,6 +986,13 @@ impl Table { }) } + pub fn close_lsm_writers(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.close_lsm_writers().await.infer_error() + }) + } + pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { @@ -1124,6 +1146,8 @@ pub struct MergeInsertParams { when_not_matched_by_source_condition: Option, timeout: Option, use_index: Option, + use_lsm_write: Option, + validate_single_shard: Option, } #[pyclass] diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index f83044324..b42d8d235 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -75,7 +75,7 @@ reqwest = { version = "0.12.0", default-features = false, features = [ "stream", ], optional = true } http = { version = "1", optional = true } # Matching what is in reqwest -uuid = { version = "1.7.0", features = ["v4"] } +uuid = { version = "1.7.0", features = ["v4", "v5"] } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [ diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 378698ca6..6a44f7f1c 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -908,6 +908,15 @@ mod tests { use serial_test::serial; use std::time::Duration; + // Serializes the env-var-mutating tests below: cargo test runs tests in + // parallel, but several of these tests read and write the same process- + // global env vars (`LANCEDB_USER_ID*`), so they would race without this. + static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + fn lock_env() -> std::sync::MutexGuard<'static, ()> { + ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()) + } + #[test] fn test_timeout_config_default() { let config = TimeoutConfig::default(); @@ -1166,6 +1175,7 @@ mod tests { #[test] #[serial(user_id_env)] fn test_resolve_user_id_none() { + let _guard = lock_env(); let config = ClientConfig::default(); // Clear env vars that might be set from other tests // SAFETY: This is only called in tests @@ -1179,6 +1189,7 @@ mod tests { #[test] #[serial(user_id_env)] fn test_resolve_user_id_from_env() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", "env-user-id"); @@ -1194,6 +1205,7 @@ mod tests { #[test] #[serial(user_id_env)] fn test_resolve_user_id_from_env_key() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::remove_var("LANCEDB_USER_ID"); @@ -1215,6 +1227,7 @@ mod tests { #[test] #[serial(user_id_env)] fn test_resolve_user_id_direct_takes_precedence() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", "env-user-id"); @@ -1233,6 +1246,7 @@ mod tests { #[test] #[serial(user_id_env)] fn test_resolve_user_id_empty_env_ignored() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", ""); diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 4f034ba43..dc16b61c6 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1805,6 +1805,7 @@ impl BaseTable for RemoteTable { num_inserted_rows: 0, num_updated_rows: 0, num_attempts: 0, + num_rows: 0, }); } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 96f0cffcf..ca34bbdf3 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -366,6 +366,14 @@ impl LsmWriteSpec { /// Construct an identity-sharding spec (shard by the raw value of /// `column`) with no maintained indexes. + /// + /// `column` must be a deterministic function of the unenforced primary + /// key: every row with a given primary key must always produce the same + /// `column` value. MemWAL dedups upserts by primary key but tracks + /// generations per shard, so if the same key is written with two + /// different `column` values its versions land in different shards and a + /// stale value can win. Typically `column` is the primary key itself, or + /// a stable attribute of it (e.g. a tenant id). pub fn identity(column: impl Into) -> Self { Self::Identity { column: column.into(), @@ -580,6 +588,13 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { message: "unset_lsm_write_spec is not supported on this table type".into(), }) } + /// Drain and close any cached MemWAL shard writers for this table. + /// + /// The default implementation is a no-op; table types that maintain + /// MemWAL shard writers override it. + async fn close_lsm_writers(&self) -> Result<()> { + Ok(()) + } /// Gets the table tag manager. async fn tags(&self) -> Result>; /// Optimize the dataset. @@ -1386,6 +1401,16 @@ impl Table { self.inner.unset_lsm_write_spec().await } + /// Drain and close any cached MemWAL shard writers held for this table. + /// + /// When an [`LsmWriteSpec`] is installed, `merge_insert` opens MemWAL shard + /// writers and caches them for reuse across calls. This closes them, + /// flushing pending data; writers reopen lazily on the next `merge_insert`. + /// It is a no-op when no writers are cached. + pub async fn close_lsm_writers(&self) -> Result<()> { + self.inner.close_lsm_writers().await + } + /// Retrieve the version of the table /// /// LanceDb supports versioning. Every operation that modifies the table increases @@ -2829,6 +2854,10 @@ impl BaseTable for NativeTable { merge::lsm::unset_lsm_write_spec(self).await } + async fn close_lsm_writers(&self) -> Result<()> { + merge::lsm::close_lsm_writers(self).await + } + /// Delete rows from the table async fn delete(&self, predicate: Predicate<'_>) -> Result { delete::execute_delete(self, predicate).await diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index 584d45a2f..b4673d876 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -8,6 +8,7 @@ use std::{ use lance::{Dataset, dataset::refs}; +use crate::table::merge::lsm::ShardWriterCache; use crate::{Error, error::Result, utils::background_cache::BackgroundCache}; /// A wrapper around a [Dataset] that provides consistency checks. @@ -18,6 +19,10 @@ use crate::{Error, error::Result, utils::background_cache::BackgroundCache}; pub struct DatasetConsistencyWrapper { state: Arc>, consistency: ConsistencyMode, + /// The single MemWAL `ShardWriter` for this dataset, co-located so it is + /// cached for the session and shares the dataset's lifecycle. A dataset + /// writes to one shard at a time. Shared by `Arc` across clones. + shard_writer: Arc, } /// The current dataset and whether it is pinned to a specific version. @@ -67,9 +72,15 @@ impl DatasetConsistencyWrapper { pinned_version: None, })), consistency, + shard_writer: Arc::new(ShardWriterCache::default()), } } + /// The MemWAL `ShardWriter` cache co-located with this dataset. + pub(crate) fn shard_writer(&self) -> &Arc { + &self.shard_writer + } + /// Get the current dataset. /// /// Behavior depends on the consistency mode: diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index def78aa4f..b3bda36af 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -41,6 +41,16 @@ pub struct MergeResult { /// A value of 1 means the operation succeeded on the first try. #[serde(default)] pub num_attempts: u32, + /// Total number of rows written. + /// + /// On the standard `merge_insert` path this equals + /// `num_inserted_rows + num_updated_rows`. On the MemWAL LSM write path the + /// insert/update breakdown is not known until compaction; in that mode + /// `num_inserted_rows`, `num_updated_rows`, `num_deleted_rows`, `version` + /// and `num_attempts` are all `0` and this field holds the total number of + /// rows written through the shard writer. + #[serde(default)] + pub num_rows: u64, } /// A builder used to create and run a merge insert operation @@ -57,6 +67,8 @@ pub struct MergeInsertBuilder { pub(crate) when_not_matched_by_source_delete_filt: Option, pub(crate) timeout: Option, pub(crate) use_index: bool, + pub(crate) use_lsm_write: Option, + pub(crate) validate_single_shard: bool, } impl MergeInsertBuilder { @@ -71,6 +83,8 @@ impl MergeInsertBuilder { when_not_matched_by_source_delete_filt: None, timeout: None, use_index: true, + use_lsm_write: None, + validate_single_shard: true, } } @@ -150,6 +164,34 @@ impl MergeInsertBuilder { self } + /// Controls whether `merge_insert` uses the MemWAL LSM write path. + /// + /// By default (unset), a `merge_insert` on a table with an + /// [`LsmWriteSpec`](super::LsmWriteSpec) installed is routed through + /// Lance's MemWAL shard writer, and a table without one uses the standard + /// path. Calling this with `false` forces the standard path even when a + /// spec is set. Calling it with `true` requires a spec — `merge_insert` + /// errors if none is installed. + pub fn use_lsm_write(&mut self, use_lsm_write: bool) -> &mut Self { + self.use_lsm_write = Some(use_lsm_write); + self + } + + /// Controls how an LSM `merge_insert` checks that its input targets a + /// single shard. + /// + /// When a table has an LSM write spec, every row in a `merge_insert` call + /// must route to the same shard. When `true` (the default), every row is + /// inspected to verify this. When `false`, only the first row is inspected + /// and the shard it routes to is used for the whole input — a faster path + /// for callers that have already pre-sharded their input. + /// + /// Has no effect on tables without an LSM write spec. + pub fn validate_single_shard(&mut self, validate_single_shard: bool) -> &mut Self { + self.validate_single_shard = validate_single_shard; + self + } + /// Executes the merge insert operation /// /// Returns version and statistics about the merge operation including the number of rows @@ -167,6 +209,23 @@ pub(crate) async fn execute_merge_insert( params: MergeInsertBuilder, new_data: Box, ) -> Result { + match lsm::lsm_dispatch_decision(table, ¶ms).await? { + lsm::LsmDispatch::Lsm(plan) => { + let future = + lsm::execute_lsm_merge_insert(table, plan, params.validate_single_shard, new_data); + return match params.timeout { + Some(timeout) => match tokio::time::timeout(timeout, future).await { + Ok(result) => result, + Err(_) => Err(Error::Runtime { + message: "merge insert timed out".to_string(), + }), + }, + None => future.await, + }; + } + lsm::LsmDispatch::Standard => {} + } + let dataset = table.dataset.get().await?; let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( @@ -219,6 +278,7 @@ pub(crate) async fn execute_merge_insert( num_inserted_rows: stats.num_inserted_rows, num_deleted_rows: stats.num_deleted_rows, num_attempts: stats.num_attempts, + num_rows: stats.num_inserted_rows + stats.num_updated_rows, }) } @@ -327,3 +387,366 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 25); } } + +#[cfg(test)] +mod lsm_tests { + use std::sync::Arc; + + use arrow_array::{ + Int64Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + }; + use arrow_schema::{DataType, Field, Schema}; + use tempfile::{TempDir, tempdir}; + + use crate::connect; + use crate::error::Error; + use crate::table::{LsmWriteSpec, Table}; + + /// A reader of `[id: Int64, value: Int64]` rows; `value` is `0..n`. + fn id_value_reader(ids: Vec) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Int64, false), + ])); + let n = ids.len() as i64; + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(Int64Array::from_iter_values(0..n)), + ], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) + } + + /// A reader of `[id: Int64, region: Utf8]` rows. + fn id_region_reader(rows: Vec<(i64, &str)>) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("region", DataType::Utf8, false), + ])); + let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); + let regions: Vec<&str> = rows.iter().map(|(_, region)| *region).collect(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(regions)), + ], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) + } + + /// A multi-batch reader of `[id: Int64, region: Utf8]` rows. + fn id_region_multi_reader(batches: Vec>) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("region", DataType::Utf8, false), + ])); + let records: Vec<_> = batches + .into_iter() + .map(|rows| { + let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); + let regions: Vec<&str> = rows.iter().map(|(_, region)| *region).collect(); + Ok(RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(regions)), + ], + ) + .unwrap()) + }) + .collect(); + Box::new(RecordBatchIterator::new(records, schema)) + } + + /// Create an `[id, value]` table with `id` as the unenforced primary key. + async fn id_value_table(dir: &TempDir) -> Table { + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_value_reader(vec![1, 2, 3])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + } + + #[tokio::test] + async fn lsm_merge_insert_bucket() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + // num_buckets = 1: every row routes to the single bucket. + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // Empty `on` defaults to the primary key. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_value_reader(vec![3, 4, 5])) + .await + .unwrap(); + + // LSM path: rows go to the MemWAL, the breakdown is unknown until + // compaction, so only `num_rows` is populated. + assert_eq!(result.num_rows, 3); + assert_eq!(result.version, 0); + assert_eq!(result.num_inserted_rows, 0); + assert_eq!(result.num_updated_rows, 0); + } + + #[tokio::test] + async fn lsm_merge_insert_unsharded() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::unsharded()) + .await + .unwrap(); + + let mut builder = table.merge_insert(&["id"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_value_reader(vec![10, 11, 12, 13])) + .await + .unwrap(); + assert_eq!(result.num_rows, 4); + } + + #[tokio::test] + async fn lsm_merge_insert_identity() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us"), (2, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // All rows share one identity value, so they route to one shard. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_region_reader(vec![(3, "us"), (4, "us")])) + .await + .unwrap(); + assert_eq!(result.num_rows, 2); + } + + #[tokio::test] + async fn lsm_merge_insert_use_lsm_write_false_falls_back() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // use_lsm_write(false) opts out: the standard path runs and commits. + let mut builder = table.merge_insert(&["id"]); + builder.when_not_matched_insert_all().use_lsm_write(false); + let result = builder + .execute(id_value_reader(vec![3, 4, 5])) + .await + .unwrap(); + + assert_eq!(result.num_inserted_rows, 2); + assert_eq!(table.count_rows(None).await.unwrap(), 5); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_on_not_primary_key() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + let mut builder = table.merge_insert(&["value"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder.execute(id_value_reader(vec![1])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_non_upsert() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // Insert-only (no when_matched_update_all) is not the upsert shape. + let mut builder = table.merge_insert(&[]); + builder.when_not_matched_insert_all(); + let err = builder.execute(id_value_reader(vec![4])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_close_writers_then_reopen() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder.execute(id_value_reader(vec![7, 8])).await.unwrap(); + + table.close_lsm_writers().await.unwrap(); + + // The writer reopens lazily on the next merge_insert. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder.execute(id_value_reader(vec![9])).await.unwrap(); + assert_eq!(result.num_rows, 1); + } + + #[tokio::test] + async fn lsm_merge_insert_multi_batch() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // Multiple batches that all route to one shard are written together. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_region_multi_reader(vec![ + vec![(2, "us"), (3, "us")], + vec![(4, "us")], + ])) + .await + .unwrap(); + assert_eq!(result.num_rows, 3); + + // Batches that route to different shards are rejected; the validation + // runs before any write, so no partial write is left behind. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder + .execute(id_region_multi_reader(vec![ + vec![(5, "us")], + vec![(6, "eu")], + ])) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_use_lsm_write_true_requires_spec() { + let dir = tempdir().unwrap(); + // id_value_table sets a primary key but no LSM write spec. + let table = id_value_table(&dir).await; + + let mut builder = table.merge_insert(&["id"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all() + .use_lsm_write(true); + let err = builder.execute(id_value_reader(vec![4])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_second_shard() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // The first merge_insert opens the single writer for shard "us". + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder + .execute(id_region_reader(vec![(2, "us")])) + .await + .unwrap(); + + // A merge_insert routing to a different shard is rejected. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder + .execute(id_region_reader(vec![(3, "eu")])) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + + // After closing the writer, a different shard can be written. + table.close_lsm_writers().await.unwrap(); + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder + .execute(id_region_reader(vec![(4, "eu")])) + .await + .unwrap(); + } +} diff --git a/rust/lancedb/src/table/merge/lsm.rs b/rust/lancedb/src/table/merge/lsm.rs index 51d04f5e0..80246d59f 100644 --- a/rust/lancedb/src/table/merge/lsm.rs +++ b/rust/lancedb/src/table/merge/lsm.rs @@ -1,26 +1,71 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -//! MemWAL LSM write-path spec management. +//! MemWAL LSM write path for `merge_insert`. //! -//! [`set_lsm_write_spec`] installs a [`super::super::LsmWriteSpec`] on a -//! table, which selects Lance's MemWAL LSM-style write path for future -//! `merge_insert` calls. [`unset_lsm_write_spec`] removes it. The actual -//! `merge_insert` dispatch and writer are a follow-up. +//! [`set_lsm_write_spec`] installs an [`LsmWriteSpec`] on a table by creating +//! Lance's MemWAL index; [`unset_lsm_write_spec`] removes it. Once a spec is +//! installed, `merge_insert` upsert calls are dispatched through Lance's MemWAL +//! `ShardWriter` (LSM-style append) instead of the standard merge path — see +//! [`lsm_dispatch_decision`] and [`execute_lsm_merge_insert`]. +//! +//! Each `merge_insert` call must target a single shard: every row must route +//! to the same shard under the installed sharding spec (bucket / identity / +//! unsharded). [`MergeInsertBuilder::validate_single_shard`] controls whether +//! every row is checked or only the first. A dataset writes to one shard at a +//! time; its writer is cached in the [`ShardWriterCache`] held alongside the +//! dataset, and [`close_lsm_writers`] closes it. -use lance::dataset::mem_wal::DatasetMemWalExt; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow_array::{Array, ArrayRef, Int32Array, RecordBatch, RecordBatchReader}; +use arrow_schema::{DataType, Schema as ArrowSchema, SchemaRef}; +use lance::Dataset; +use lance::dataset::mem_wal::{ + DatasetMemWalExt, ShardWriter, ShardWriterConfig, evaluate_sharding_spec, +}; use lance::index::DatasetIndexExt; +use lance_core::datatypes::Schema as LanceSchema; +use lance_index::mem_wal::{MemWalIndexDetails, ShardingSpec}; +use tokio::sync::RwLock; +use uuid::Uuid; use crate::error::{Error, Result}; +use crate::table::merge::{MergeInsertBuilder, MergeResult}; use crate::table::{LsmWriteSpec, NativeTable}; +/// Spec id of the sole sharding spec installed by [`set_lsm_write_spec`]. +/// Must match Lance's `InitializeMemWalBuilder` (`SHARDING_SPEC_ID`). +const SHARDING_SPEC_ID: u32 = 1; + +/// Transform name recorded by `bucket_sharding`. +const BUCKET_TRANSFORM: &str = "bucket"; +/// Transform name recorded by `identity_sharding`. +const IDENTITY_TRANSFORM: &str = "identity"; +/// Transform name recorded by `unsharded`. +const UNSHARDED_TRANSFORM: &str = "unsharded"; + +/// Parameter key holding the bucket count on the bucket transform. +const NUM_BUCKETS_PARAM: &str = "num_buckets"; + +/// Fixed namespace UUID for deriving deterministic shard ids. Hardcoded so +/// derivations stay stable across processes. +const SHARD_NAMESPACE: Uuid = Uuid::from_u128(0x4c53_4d57_5249_5445_5f53_4841_5244_3031); + // ============================================================================= // set_lsm_write_spec // ============================================================================= /// Install an [`LsmWriteSpec`] on the table. /// -/// The bucket / unsharded sharding spec is constructed and validated by Lance's +/// The bucket / identity / unsharded sharding spec is constructed and validated +/// by Lance's /// [`InitializeMemWalBuilder`](lance::dataset::mem_wal::InitializeMemWalBuilder). #[allow(clippy::redundant_pub_crate)] pub(crate) async fn set_lsm_write_spec(table: &NativeTable, spec: LsmWriteSpec) -> Result<()> { @@ -78,7 +123,8 @@ pub(crate) async fn set_lsm_write_spec(table: &NativeTable, spec: LsmWriteSpec) /// Remove the [`LsmWriteSpec`] from the table by dropping the MemWAL index. /// -/// Errors if no spec is currently set. +/// Any cached shard writers are drained and closed first. Errors if no spec is +/// currently set. #[allow(clippy::redundant_pub_crate)] pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { table.dataset.ensure_mutable()?; @@ -92,6 +138,8 @@ pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { } } + table.dataset.shard_writer().drain_and_close().await?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset .drop_index(lance_index::mem_wal::MEM_WAL_INDEX_NAME) @@ -99,3 +147,937 @@ pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { table.dataset.update(dataset); Ok(()) } + +// ============================================================================= +// close_lsm_writers +// ============================================================================= + +/// Drain and close every cached MemWAL shard writer for the table. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn close_lsm_writers(table: &NativeTable) -> Result<()> { + table.dataset.shard_writer().drain_and_close().await +} + +// ============================================================================= +// ShardWriter cache +// ============================================================================= + +/// Per-table cache holding the single open MemWAL `ShardWriter`. +/// +/// Held by [`DatasetConsistencyWrapper`](crate::table::dataset::DatasetConsistencyWrapper) +/// so the writer lives where the dataset lives — cached for the session and +/// reused across `merge_insert` calls. A dataset writes to one shard at a +/// time; routing a `merge_insert` to a different shard requires closing the +/// current writer first via [`close_lsm_writers`]. `ShardWriter::put` takes +/// `&self`, so concurrent puts on the cached writer are safe; `close` consumes +/// the writer, so the entry wraps it in `RwLock>`. +#[derive(Default)] +#[allow(clippy::redundant_pub_crate)] +pub(crate) struct ShardWriterCache { + /// `Some((shard_id, entry))` once a writer has been opened for the session. + slot: RwLock)>>, +} + +impl std::fmt::Debug for ShardWriterCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ShardWriterCache").finish_non_exhaustive() + } +} + +struct ShardWriterEntry { + inner: RwLock>, +} + +impl ShardWriterEntry { + fn new(writer: ShardWriter) -> Self { + Self { + inner: RwLock::new(Some(writer)), + } + } + + async fn put(&self, batches: Vec) -> Result<()> { + let guard = self.inner.read().await; + let writer = guard.as_ref().ok_or_else(|| Error::Runtime { + message: "merge_insert: shard writer was closed before this write".to_string(), + })?; + writer.put(batches).await.map_err(|e| Error::Runtime { + message: format!("merge_insert: shard writer put failed: {}", e), + })?; + Ok(()) + } + + async fn close(&self) -> Result<()> { + let writer = { self.inner.write().await.take() }; + if let Some(writer) = writer { + writer.close().await.map_err(|e| Error::Runtime { + message: format!("merge_insert: shard writer close failed: {}", e), + })?; + } + Ok(()) + } +} + +impl ShardWriterCache { + /// Return the cached writer, opening one for `shard_id` with `config` if + /// the slot is empty. Errors if a writer is already open for a *different* + /// shard — the caller must close it first. + async fn writer_for_shard( + &self, + dataset: &Dataset, + shard_id: Uuid, + config: ShardWriterConfig, + ) -> Result> { + { + let guard = self.slot.read().await; + if let Some((cached, entry)) = guard.as_ref() { + check_shard_match(*cached, shard_id)?; + return Ok(entry.clone()); + } + } + let mut guard = self.slot.write().await; + // Re-check: another caller may have opened the writer meanwhile. + if let Some((cached, entry)) = guard.as_ref() { + check_shard_match(*cached, shard_id)?; + return Ok(entry.clone()); + } + let writer = dataset + .mem_wal_writer(shard_id, config) + .await + .map_err(|e| Error::Runtime { + message: format!( + "merge_insert: failed to open MemWAL shard writer for shard {}: {}", + shard_id, e + ), + })?; + let entry = Arc::new(ShardWriterEntry::new(writer)); + *guard = Some((shard_id, entry.clone())); + Ok(entry) + } + + /// Close the cached writer, if any, and clear the slot. + #[allow(clippy::redundant_pub_crate)] + pub(crate) async fn drain_and_close(&self) -> Result<()> { + let cached = { self.slot.write().await.take() }; + if let Some((_, entry)) = cached { + entry.close().await?; + } + Ok(()) + } +} + +/// Error if a cached writer is open for a shard other than the one needed. +fn check_shard_match(cached: Uuid, wanted: Uuid) -> Result<()> { + if cached == wanted { + return Ok(()); + } + Err(Error::InvalidInput { + message: format!( + "merge_insert: a shard writer is already open for shard {} but this input routes to shard {}; call close_lsm_writers before writing to a different shard", + cached, wanted + ), + }) +} + +// ============================================================================= +// merge_insert LSM dispatch +// ============================================================================= + +/// How the installed sharding spec routes rows to shards. +#[derive(Debug, Clone)] +enum LsmMode { + /// Hash-bucket the routing column into `num_buckets` shards. + Bucket { spec: ShardingSpec }, + /// Shard by the raw value of the routing column. + Identity { spec: ShardingSpec }, + /// Route every row to a single shard. + Unsharded, +} + +/// Resolved plan for routing a `merge_insert` through the MemWAL write path. +#[derive(Debug)] +#[allow(clippy::redundant_pub_crate)] +pub(crate) struct LsmPlan { + mode: LsmMode, + writer_config_defaults: HashMap, +} + +/// Outcome of [`lsm_dispatch_decision`]. +#[allow(clippy::redundant_pub_crate)] +pub(crate) enum LsmDispatch { + /// No LSM write spec applies; use the standard `merge_insert` path. + Standard, + /// Route the `merge_insert` through the MemWAL shard writer. + Lsm(LsmPlan), +} + +/// Decide whether a `merge_insert` should be routed through the MemWAL write +/// path, validating the builder against the installed spec. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn lsm_dispatch_decision( + table: &NativeTable, + params: &MergeInsertBuilder, +) -> Result { + // `Some(false)` is an explicit opt-out: use the standard path. + if params.use_lsm_write == Some(false) { + return Ok(LsmDispatch::Standard); + } + + let dataset = table.dataset.get().await?; + let Some(details) = dataset.mem_wal_index_details().await? else { + // No LSM write spec installed. `Some(true)` explicitly asked for the + // LSM path, which is meaningless without a spec; `None` (the default) + // just falls back to the standard path. + if params.use_lsm_write == Some(true) { + return Err(Error::InvalidInput { + message: "merge_insert: use_lsm_write(true) requires an LSM write spec on the table; call set_lsm_write_spec first".to_string(), + }); + } + return Ok(LsmDispatch::Standard); + }; + + let pk_cols: Vec = dataset + .schema() + .unenforced_primary_key() + .iter() + .map(|f| f.name.clone()) + .collect(); + if pk_cols.is_empty() { + return Err(Error::Runtime { + message: "merge_insert: table has a MemWAL index but no unenforced primary key" + .to_string(), + }); + } + if !params.on.is_empty() && params.on != pk_cols { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: `on` columns {:?} must match the table's unenforced primary key {:?} when an LSM write spec is set; pass an empty `on` to default to the primary key", + params.on, pk_cols + ), + }); + } + + if !is_upsert_only(params) { + return Err(Error::InvalidInput { + message: "merge_insert: when an LSM write spec is set, only the upsert form (when_matched_update_all without a filter + when_not_matched_insert_all, no by-source delete) is supported; call use_lsm_write(false) to use the standard merge_insert path".to_string(), + }); + } + + let mode = resolve_lsm_mode(&details)?; + Ok(LsmDispatch::Lsm(LsmPlan { + mode, + writer_config_defaults: details.writer_config_defaults, + })) +} + +/// Returns true if the builder requests the upsert-only shape the LSM write +/// path can honor. +fn is_upsert_only(params: &MergeInsertBuilder) -> bool { + params.when_matched_update_all + && params.when_matched_update_all_filt.is_none() + && params.when_not_matched_insert_all + && !params.when_not_matched_by_source_delete + && params.when_not_matched_by_source_delete_filt.is_none() +} + +/// Read the sharding mode from the MemWAL index details. +fn resolve_lsm_mode(details: &MemWalIndexDetails) -> Result { + let spec = details + .sharding_specs + .first() + .cloned() + .ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL index has no sharding spec".to_string(), + })?; + let field = spec.fields.first().ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL index has an empty sharding spec".to_string(), + })?; + match field.transform.as_deref() { + Some(BUCKET_TRANSFORM) => { + field + .parameters + .get(NUM_BUCKETS_PARAM) + .and_then(|s| s.parse::().ok()) + .filter(|n| *n > 0) + .ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL bucket spec has a missing or invalid num_buckets parameter".to_string(), + })?; + Ok(LsmMode::Bucket { spec }) + } + Some(IDENTITY_TRANSFORM) => Ok(LsmMode::Identity { spec }), + Some(UNSHARDED_TRANSFORM) => Ok(LsmMode::Unsharded), + other => Err(Error::Runtime { + message: format!( + "merge_insert: MemWAL index has an unsupported sharding transform {:?}", + other + ), + }), + } +} + +// ============================================================================= +// LSM merge_insert execution +// ============================================================================= + +/// Execute a `merge_insert` through the MemWAL shard writer cache. +/// +/// The entire input is collected, schema-aligned, and shard-validated before +/// anything is written, then issued as a single atomic `ShardWriter::put` — so +/// a validation failure (e.g. input spanning shards) never leaves a partial +/// write behind. When `validate_single_shard` is set, every row is checked to +/// route to one shard; when disabled, only the first row of the whole input is. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn execute_lsm_merge_insert( + table: &NativeTable, + plan: LsmPlan, + validate_single_shard: bool, + new_data: Box, +) -> Result { + let dataset = table.dataset.get().await?; + let target_schema: SchemaRef = Arc::new(ArrowSchema::from(dataset.schema())); + + // Collect, align and shard-validate the whole input before writing + // anything. `ShardWriter::put` is atomic over the batch vector, so any + // failure raised here leaves the MemWAL untouched. + let mut batches: Vec = Vec::new(); + let mut total_rows: u64 = 0; + + for batch in new_data { + let batch = batch.map_err(|e| Error::Arrow { source: e })?; + if batch.num_rows() == 0 { + continue; + } + let batch = align_batch_schema(batch, &target_schema)?; + total_rows += batch.num_rows() as u64; + batches.push(batch); + } + + // Empty input (or only empty batches): nothing to write. + let Some(shard_id) = resolve_input_shard( + &plan.mode, + dataset.schema(), + &batches, + validate_single_shard, + )? + else { + return Ok(lsm_merge_result(0)); + }; + + let config = shard_writer_config_from_defaults(&plan.writer_config_defaults); + let writer = table + .dataset + .shard_writer() + .writer_for_shard(dataset.as_ref(), shard_id, config) + .await?; + writer.put(batches).await?; + + Ok(lsm_merge_result(total_rows)) +} + +/// Resolve the target shard for a collected input. +fn resolve_input_shard( + mode: &LsmMode, + schema: &LanceSchema, + batches: &[RecordBatch], + validate_single_shard: bool, +) -> Result> { + let mut shard_id: Option = None; + for batch in batches { + if batch.num_rows() == 0 { + continue; + } + if !validate_single_shard && shard_id.is_some() { + continue; + } + let batch_shard = resolve_batch_shard(mode, schema, batch, validate_single_shard)?; + match shard_id { + Some(seen) if seen != batch_shard => { + return Err(Error::InvalidInput { + message: "merge_insert: input batches route to multiple shards; each merge_insert call must target a single shard".to_string(), + }); + } + _ => shard_id = Some(batch_shard), + } + } + Ok(shard_id) +} + +/// Compute the target shard id for a non-empty batch. When +/// `validate_single_shard` is set, every row is checked to route to the same +/// shard; otherwise only the first row is inspected. +fn resolve_batch_shard( + mode: &LsmMode, + schema: &LanceSchema, + batch: &RecordBatch, + validate_single_shard: bool, +) -> Result { + let routing_batch = if validate_single_shard { + batch.clone() + } else { + batch.slice(0, 1) + }; + match mode { + LsmMode::Unsharded => Ok(unsharded_shard_id()), + LsmMode::Bucket { spec } => { + let values = evaluate_lsm_shard_values(&routing_batch, spec, schema)?; + let buckets = values + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::Runtime { + message: format!( + "merge_insert: MemWAL bucket evaluator returned {:?}; expected Int32", + values.data_type() + ), + })?; + let first = buckets.value(0); + if validate_single_shard { + for row in 1..routing_batch.num_rows() { + let bucket = buckets.value(row); + if bucket != first { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: input row 0 hashes to bucket {} but row {} hashes to bucket {}; each merge_insert call must target a single bucket (pre-shard the input, or set validate_single_shard(false) to route by the first row only)", + first, row, bucket + ), + }); + } + } + } + Ok(bucket_shard_id(u32::try_from(first).map_err(|_| { + Error::Runtime { + message: format!( + "merge_insert: MemWAL bucket evaluator returned negative bucket {}", + first + ), + } + })?)) + } + LsmMode::Identity { spec } => { + let values = evaluate_lsm_shard_values(&routing_batch, spec, schema)?; + let first = encode_scalar(values.as_ref(), 0)?; + if validate_single_shard { + for row in 1..routing_batch.num_rows() { + if encode_scalar(values.as_ref(), row)? != first { + return Err(Error::InvalidInput { + message: "merge_insert: input rows have differing values for identity-sharding column; each merge_insert call must target a single shard (pre-shard the input, or set validate_single_shard(false) to route by the first row only)".to_string(), + }); + } + } + } + Ok(identity_shard_id(&first)) + } + } +} + +fn evaluate_lsm_shard_values( + batch: &RecordBatch, + spec: &ShardingSpec, + schema: &LanceSchema, +) -> Result { + let values = evaluate_sharding_spec(batch, spec, schema)?; + if values.num_columns() != 1 { + return Err(Error::Runtime { + message: format!( + "merge_insert: MemWAL sharding spec evaluated to {} fields; expected exactly one", + values.num_columns() + ), + }); + } + Ok(values.column(0).clone()) +} + +/// Encode one cell of an identity-sharding column to comparable bytes. +fn encode_scalar(array: &dyn Array, row: usize) -> Result> { + if array.is_null(row) { + return Err(Error::InvalidInput { + message: "merge_insert: identity sharding does not support null routing values" + .to_string(), + }); + } + Ok(match array.data_type() { + DataType::Int8 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int16 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int32 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int64 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt8 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt16 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt32 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt64 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Utf8 => array.as_string::().value(row).as_bytes().to_vec(), + DataType::LargeUtf8 => array.as_string::().value(row).as_bytes().to_vec(), + DataType::Boolean => vec![u8::from(array.as_boolean().value(row))], + other => { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: identity sharding does not support column dtype {:?}", + other + ), + }); + } + }) +} + +/// Deterministic shard id for a bucket index. +fn bucket_shard_id(bucket: u32) -> Uuid { + Uuid::new_v5(&SHARD_NAMESPACE, format!("bucket-{}", bucket).as_bytes()) +} + +/// Deterministic shard id for an identity value. +fn identity_shard_id(value: &[u8]) -> Uuid { + let mut name = b"identity-".to_vec(); + name.extend_from_slice(value); + Uuid::new_v5(&SHARD_NAMESPACE, &name) +} + +/// Deterministic shard id for the single unsharded shard. +fn unsharded_shard_id() -> Uuid { + Uuid::new_v5(&SHARD_NAMESPACE, b"unsharded") +} + +/// Build a [`ShardWriterConfig`] from the persisted `writer_config_defaults`. +/// +/// Unknown or unparseable keys are ignored; absent keys keep the +/// [`ShardWriterConfig`] default. The shard id is set by `mem_wal_writer`. +fn shard_writer_config_from_defaults(defaults: &HashMap) -> ShardWriterConfig { + let mut config = ShardWriterConfig::default().with_shard_spec_id(SHARDING_SPEC_ID); + let bool_of = |key: &str| defaults.get(key).and_then(|s| s.parse::().ok()); + let usize_of = |key: &str| defaults.get(key).and_then(|s| s.parse::().ok()); + let millis_of = |key: &str| { + defaults + .get(key) + .and_then(|s| s.parse::().ok()) + .map(Duration::from_millis) + }; + + if let Some(v) = bool_of("durable_write") { + config = config.with_durable_write(v); + } + if let Some(v) = bool_of("sync_indexed_write") { + config = config.with_sync_indexed_write(v); + } + if let Some(v) = usize_of("max_wal_buffer_size") { + config = config.with_max_wal_buffer_size(v); + } + if let Some(v) = usize_of("max_memtable_size") { + config = config.with_max_memtable_size(v); + } + if let Some(v) = usize_of("max_memtable_rows") { + config = config.with_max_memtable_rows(v); + } + if let Some(v) = usize_of("max_memtable_batches") { + config = config.with_max_memtable_batches(v); + } + if let Some(v) = usize_of("manifest_scan_batch_size") { + config = config.with_manifest_scan_batch_size(v); + } + if let Some(v) = usize_of("max_unflushed_memtable_bytes") { + config = config.with_max_unflushed_memtable_bytes(v); + } + if let Some(v) = millis_of("backpressure_log_interval_ms") { + config = config.with_backpressure_log_interval(v); + } + if let Some(v) = usize_of("async_index_buffer_rows") { + config = config.with_async_index_buffer_rows(v); + } + if let Some(v) = millis_of("async_index_interval_ms") { + config = config.with_async_index_interval(v); + } + if let Some(v) = bool_of("enable_memtable") { + config = config.with_enable_memtable(v); + } + if let Some(v) = millis_of("max_wal_flush_interval_ms") { + config = config.with_max_wal_flush_interval(v); + } + if let Some(v) = millis_of("stats_log_interval_ms") { + config = config.with_stats_log_interval(Some(v)); + } + config +} + +/// Re-attach the dataset's Arrow schema (including field metadata) to a +/// user-supplied input batch. The MemWAL `ShardWriter` checks batch schemas +/// against the dataset schema by exact equality, so input readers built +/// without the primary-key metadata must be rewrapped before being put. +/// +/// Columns are matched by name; column order in the input is irrelevant. +fn align_batch_schema(batch: RecordBatch, target: &SchemaRef) -> Result { + if batch.schema() == *target { + return Ok(batch); + } + let mut columns = Vec::with_capacity(target.fields().len()); + for field in target.fields() { + let column = batch + .column_by_name(field.name()) + .ok_or_else(|| Error::InvalidInput { + message: format!( + "merge_insert: input is missing column '{}' required by the table schema", + field.name() + ), + })?; + if column.data_type() != field.data_type() { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: input column '{}' has dtype {:?}, expected {:?}", + field.name(), + column.data_type(), + field.data_type() + ), + }); + } + columns.push(column.clone()); + } + RecordBatch::try_new(target.clone(), columns).map_err(|e| Error::Arrow { source: e }) +} + +/// Build the [`MergeResult`] for an LSM-path `merge_insert`. +/// +/// The insert/update breakdown is not known until LSM compaction, so only the +/// total row count is reported. +fn lsm_merge_result(num_rows: u64) -> MergeResult { + MergeResult { + version: 0, + num_inserted_rows: 0, + num_updated_rows: 0, + num_deleted_rows: 0, + num_attempts: 0, + num_rows, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, BooleanArray, Int32Array, Int64Array, StringArray, UInt64Array}; + use arrow_schema::Field; + use lance_index::mem_wal::ShardingField; + + fn lance_schema(batch: &RecordBatch) -> LanceSchema { + LanceSchema::try_from(batch.schema().as_ref()).unwrap() + } + + fn single_field_spec(field: ShardingField) -> ShardingSpec { + ShardingSpec { + spec_id: SHARDING_SPEC_ID, + fields: vec![field], + } + } + + fn bucket_mode(source_id: i32, num_buckets: u32) -> LsmMode { + LsmMode::Bucket { + spec: single_field_spec(ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![source_id], + transform: Some(BUCKET_TRANSFORM.to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([( + NUM_BUCKETS_PARAM.to_string(), + num_buckets.to_string(), + )]), + }), + } + } + + fn identity_mode(source_id: i32) -> LsmMode { + LsmMode::Identity { + spec: single_field_spec(ShardingField { + field_id: "identity".to_string(), + source_ids: vec![source_id], + transform: Some(IDENTITY_TRANSFORM.to_string()), + expression: None, + result_type: "utf8".to_string(), + parameters: HashMap::new(), + }), + } + } + + fn bucket_values(batch: &RecordBatch, num_buckets: u32) -> Vec { + let LsmMode::Bucket { spec } = bucket_mode(0, num_buckets) else { + unreachable!(); + }; + let values = evaluate_lsm_shard_values(batch, &spec, &lance_schema(batch)).unwrap(); + values.as_primitive::().values().to_vec() + } + + #[test] + fn bucket_assignments_are_pinned() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + )]) + .unwrap(); + assert_eq!(bucket_values(&batch, 8), vec![1, 5, 0]); + } + + #[test] + fn bucket_int32_uses_lance_evaluator() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(3)])) as ArrayRef, + )]) + .unwrap(); + assert_eq!(bucket_values(&batch, 8), vec![2, 7, 0, 1]); + } + + #[test] + fn bucket_accepts_lance_supported_scalar_types() { + let bool_batch = RecordBatch::try_from_iter([( + "id", + Arc::new(BooleanArray::from(vec![true])) as ArrayRef, + )]) + .unwrap(); + assert!( + resolve_batch_shard( + &bucket_mode(0, 8), + &lance_schema(&bool_batch), + &bool_batch, + true + ) + .is_ok() + ); + + let u64_batch = RecordBatch::try_from_iter([( + "id", + Arc::new(UInt64Array::from(vec![1_u64])) as ArrayRef, + )]) + .unwrap(); + assert!( + resolve_batch_shard( + &bucket_mode(0, 8), + &lance_schema(&u64_batch), + &u64_batch, + true + ) + .is_ok() + ); + } + + #[test] + fn shard_ids_are_deterministic_and_distinct() { + assert_eq!(bucket_shard_id(3), bucket_shard_id(3)); + assert_ne!(bucket_shard_id(3), bucket_shard_id(4)); + assert_ne!(bucket_shard_id(0), unsharded_shard_id()); + assert_eq!( + identity_shard_id(b"tenant-a"), + identity_shard_id(b"tenant-a") + ); + assert_ne!( + identity_shard_id(b"tenant-a"), + identity_shard_id(b"tenant-b") + ); + } + + #[test] + fn encode_scalar_distinguishes_values() { + let ints = Int64Array::from(vec![1, 2]); + assert_ne!( + encode_scalar(&ints, 0).unwrap(), + encode_scalar(&ints, 1).unwrap() + ); + let strs = StringArray::from(vec!["x", "y"]); + assert_ne!( + encode_scalar(&strs, 0).unwrap(), + encode_scalar(&strs, 1).unwrap() + ); + } + + #[test] + fn writer_config_from_defaults_parses_known_keys() { + let defaults = HashMap::from([ + ("durable_write".to_string(), "false".to_string()), + ("max_memtable_rows".to_string(), "4096".to_string()), + ("async_index_interval_ms".to_string(), "250".to_string()), + ("unknown_key".to_string(), "ignored".to_string()), + ]); + let config = shard_writer_config_from_defaults(&defaults); + assert!(!config.durable_write); + assert_eq!(config.max_memtable_rows, 4096); + assert_eq!(config.async_index_interval, Duration::from_millis(250)); + assert_eq!(config.shard_spec_id, SHARDING_SPEC_ID); + } + + #[test] + fn align_batch_schema_reorders_columns() { + let target: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let source = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![ + Field::new("v", DataType::Int64, false), + Field::new("id", DataType::Int64, false), + ])), + vec![ + Arc::new(Int64Array::from(vec![10, 20])), + Arc::new(Int64Array::from(vec![1, 2])), + ], + ) + .unwrap(); + let aligned = align_batch_schema(source, &target).unwrap(); + assert_eq!(aligned.schema(), target); + assert_eq!( + aligned.column(0).as_primitive::().values(), + &[1, 2] + ); + } + + #[test] + fn align_batch_schema_rejects_missing_column() { + let target: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let source = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int64, + false, + )])), + vec![Arc::new(Int64Array::from(vec![1, 2]))], + ) + .unwrap(); + assert!(matches!( + align_batch_schema(source, &target), + Err(Error::InvalidInput { .. }) + )); + } + + fn utf8_batch(col: &str, values: Vec<&str>) -> RecordBatch { + RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + col, + DataType::Utf8, + true, + )])), + vec![Arc::new(StringArray::from(values))], + ) + .unwrap() + } + + #[test] + fn resolve_batch_shard_bucket_same_bucket() { + let mode = bucket_mode(0, 8); + let batch = utf8_batch("id", vec!["a", "a"]); + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true).unwrap(), + bucket_shard_id(1) + ); + } + + #[test] + fn resolve_batch_shard_bucket_rejects_mixed() { + let mode = bucket_mode(0, 8); + let batch = utf8_batch("id", vec!["a", "b"]); + // validate_single_shard rejects a batch that spans buckets. + assert!(matches!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true), + Err(Error::InvalidInput { .. }) + )); + // With validation off, only row 0 is inspected, so it is accepted. + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, false).unwrap(), + bucket_shard_id(1) + ); + } + + #[test] + fn resolve_batch_shard_bucket_routes_nulls_to_zero() { + let mode = bucket_mode(0, 8); + let batch = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int64, + true, + )])), + vec![Arc::new(Int64Array::from(vec![None, None]))], + ) + .unwrap(); + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true).unwrap(), + bucket_shard_id(0) + ); + } + + #[test] + fn resolve_batch_shard_rejects_missing_routing_column() { + let mode = bucket_mode(0, 8); + let schema = LanceSchema::try_from(&ArrowSchema::new(vec![Field::new( + "id", + DataType::Utf8, + true, + )])) + .unwrap(); + let batch = utf8_batch("other", vec!["a"]); + assert!(resolve_batch_shard(&mode, &schema, &batch, true).is_err()); + } + + #[test] + fn resolve_batch_shard_identity_groups_by_value() { + let mode = identity_mode(0); + let same = utf8_batch("region", vec!["us", "us"]); + let mixed = utf8_batch("region", vec!["us", "eu"]); + assert!(resolve_batch_shard(&mode, &lance_schema(&same), &same, true).is_ok()); + assert!(matches!( + resolve_batch_shard(&mode, &lance_schema(&mixed), &mixed, true), + Err(Error::InvalidInput { .. }) + )); + // With validation off, the mixed batch is accepted (row 0 only). + assert!(resolve_batch_shard(&mode, &lance_schema(&mixed), &mixed, false).is_ok()); + } + + #[test] + fn resolve_input_shard_validation_off_only_uses_first_input_row() { + let mode = bucket_mode(0, 8); + let first = utf8_batch("id", vec!["a"]); + let second = utf8_batch("id", vec!["b"]); + let schema = lance_schema(&first); + assert_eq!( + resolve_input_shard(&mode, &schema, &[first.clone(), second.clone()], false).unwrap(), + Some(bucket_shard_id(1)) + ); + assert!(matches!( + resolve_input_shard(&mode, &schema, &[first, second], true), + Err(Error::InvalidInput { .. }) + )); + } + + #[test] + fn resolve_batch_shard_unsharded_is_constant() { + let batch = utf8_batch("anything", vec!["a", "b", "c"]); + assert_eq!( + resolve_batch_shard(&LsmMode::Unsharded, &lance_schema(&batch), &batch, true).unwrap(), + unsharded_shard_id() + ); + } +}