diff --git a/docs/src/js/classes/MergeInsertBuilder.md b/docs/src/js/classes/MergeInsertBuilder.md index d72ea2ea..5d5b6e81 100644 --- a/docs/src/js/classes/MergeInsertBuilder.md +++ b/docs/src/js/classes/MergeInsertBuilder.md @@ -33,20 +33,20 @@ Construct a MergeInsertBuilder. __Internal use only.__ ### execute() ```ts -execute(data): Promise +execute(data): Promise ``` Executes the merge insert operation -Nothing is returned but the `Table` is updated - #### Parameters * **data**: [`Data`](../type-aliases/Data.md) #### Returns -`Promise`<`void`> +`Promise`<[`MergeStats`](../interfaces/MergeStats.md)> + +Statistics about the merge operation: counts of inserted, updated, and deleted rows *** diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 7f41ba5e..962f07e2 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -54,6 +54,7 @@ - [IndexStatistics](interfaces/IndexStatistics.md) - [IvfFlatOptions](interfaces/IvfFlatOptions.md) - [IvfPqOptions](interfaces/IvfPqOptions.md) +- [MergeStats](interfaces/MergeStats.md) - [OpenTableOptions](interfaces/OpenTableOptions.md) - [OptimizeOptions](interfaces/OptimizeOptions.md) - [OptimizeStats](interfaces/OptimizeStats.md) diff --git a/docs/src/js/interfaces/MergeStats.md b/docs/src/js/interfaces/MergeStats.md new file mode 100644 index 00000000..cb8f05f9 --- /dev/null +++ b/docs/src/js/interfaces/MergeStats.md @@ -0,0 +1,31 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / MergeStats + +# Interface: MergeStats + +## Properties + +### numDeletedRows + +```ts +numDeletedRows: bigint; +``` + +*** + +### numInsertedRows + +```ts +numInsertedRows: bigint; +``` + +*** + +### numUpdatedRows + +```ts +numUpdatedRows: bigint; +``` diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index a248174a..f067f305 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -338,11 +338,16 @@ describe("merge insert", () => { { a: 3, b: "y" }, { a: 4, b: "z" }, ]; - await table + const stats = await table .mergeInsert("a") .whenMatchedUpdateAll() .whenNotMatchedInsertAll() .execute(newData); + + expect(stats.numInsertedRows).toBe(1n); + expect(stats.numUpdatedRows).toBe(2n); + expect(stats.numDeletedRows).toBe(0n); + const expected = [ { a: 1, b: "a" }, { a: 2, b: "x" }, diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index c0604cce..4f3e8106 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -28,6 +28,7 @@ export { FragmentSummaryStats, Tags, TagContents, + MergeStats, } from "./native.js"; export { diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts index 407dca94..19d03cb3 100644 --- a/nodejs/lancedb/merge.ts +++ b/nodejs/lancedb/merge.ts @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors import { Data, Schema, fromDataToBuffer } from "./arrow"; -import { NativeMergeInsertBuilder } from "./native"; +import { MergeStats, NativeMergeInsertBuilder } from "./native"; /** A builder used to create and run a merge insert operation */ export class MergeInsertBuilder { @@ -73,9 +73,9 @@ export class MergeInsertBuilder { /** * Executes the merge insert operation * - * Nothing is returned but the `Table` is updated + * @returns Statistics about the merge operation: counts of inserted, updated, and deleted rows */ - async execute(data: Data): Promise { + async execute(data: Data): Promise { let schema: Schema; if (this.#schema instanceof Promise) { schema = await this.#schema; @@ -84,6 +84,6 @@ export class MergeInsertBuilder { schema = this.#schema; } const buffer = await fromDataToBuffer(data, undefined, schema); - await this.#native.execute(buffer); + return await this.#native.execute(buffer); } } diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs index cbeb3890..4f824034 100644 --- a/nodejs/src/merge.rs +++ b/nodejs/src/merge.rs @@ -37,7 +37,7 @@ impl NativeMergeInsertBuilder { } #[napi(catch_unwind)] - pub async fn execute(&self, buf: Buffer) -> napi::Result<()> { + pub async fn execute(&self, buf: Buffer) -> napi::Result { let data = ipc_file_to_batches(buf.to_vec()) .and_then(IntoArrow::into_arrow) .map_err(|e| { @@ -46,12 +46,14 @@ impl NativeMergeInsertBuilder { let this = self.clone(); - this.inner.execute(data).await.map_err(|e| { + let stats = this.inner.execute(data).await.map_err(|e| { napi::Error::from_reason(format!( "Failed to execute merge insert: {}", convert_error(&e) )) - }) + })?; + + Ok(stats.into()) } } @@ -60,3 +62,20 @@ impl From for NativeMergeInsertBuilder { Self { inner } } } + +#[napi(object)] +pub struct MergeStats { + pub num_inserted_rows: BigInt, + pub num_updated_rows: BigInt, + pub num_deleted_rows: BigInt, +} + +impl From for MergeStats { + fn from(stats: lancedb::table::MergeStats) -> Self { + Self { + num_inserted_rows: stats.num_inserted_rows.into(), + num_updated_rows: stats.num_updated_rows.into(), + num_deleted_rows: stats.num_deleted_rows.into(), + } + } +} diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 12da3cea..df130d3c 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -962,10 +962,12 @@ class Table(ABC): >>> table = db.create_table("my_table", data) >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) >>> # Perform a "upsert" operation - >>> table.merge_insert("a") \\ + >>> stats = table.merge_insert("a") \\ ... .when_matched_update_all() \\ ... .when_not_matched_insert_all() \\ ... .execute(new_data) + >>> stats + {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -2489,7 +2491,9 @@ class LanceTable(Table): on_bad_vectors: OnBadVectorsType, fill_value: float, ): - LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) + return LOOP.run( + self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) + ) @deprecation.deprecated( deprecated_in="0.21.0", @@ -3277,10 +3281,12 @@ class AsyncTable: >>> table = db.create_table("my_table", data) >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) >>> # Perform a "upsert" operation - >>> table.merge_insert("a") \\ + >>> stats = table.merge_insert("a") \\ ... .when_matched_update_all() \\ ... .when_not_matched_insert_all() \\ ... .execute(new_data) + >>> stats + {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -3636,7 +3642,7 @@ class AsyncTable: ) if isinstance(data, pa.Table): data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) - await self._inner.execute_merge_insert( + return await self._inner.execute_merge_insert( data, dict( on=merge._on, diff --git a/python/python/tests/docs/test_merge_insert.py b/python/python/tests/docs/test_merge_insert.py index 6ec67d09..72e4ce4d 100644 --- a/python/python/tests/docs/test_merge_insert.py +++ b/python/python/tests/docs/test_merge_insert.py @@ -18,15 +18,19 @@ def test_upsert(mem_db): {"id": 1, "name": "Bobby"}, {"id": 2, "name": "Charlie"}, ] - ( + stats = ( table.merge_insert("id") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_users) ) table.count_rows() # 3 + stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0} # --8<-- [end:upsert_basic] assert table.count_rows() == 3 + assert stats["num_inserted_rows"] == 1 + assert stats["num_updated_rows"] == 1 + assert stats["num_deleted_rows"] == 0 @pytest.mark.asyncio @@ -44,15 +48,19 @@ async def test_upsert_async(mem_db_async): {"id": 1, "name": "Bobby"}, {"id": 2, "name": "Charlie"}, ] - await ( + stats = await ( table.merge_insert("id") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_users) ) await table.count_rows() # 3 + stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0} # --8<-- [end:upsert_basic_async] assert await table.count_rows() == 3 + assert stats["num_inserted_rows"] == 1 + assert stats["num_updated_rows"] == 1 + assert stats["num_deleted_rows"] == 0 def test_insert_if_not_exists(mem_db): @@ -69,10 +77,16 @@ def test_insert_if_not_exists(mem_db): {"domain": "google.com", "name": "Google"}, {"domain": "facebook.com", "name": "Facebook"}, ] - (table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)) + stats = ( + table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains) + ) table.count_rows() # 3 + stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0} # --8<-- [end:insert_if_not_exists] assert table.count_rows() == 3 + assert stats["num_inserted_rows"] == 1 + assert stats["num_updated_rows"] == 0 + assert stats["num_deleted_rows"] == 0 @pytest.mark.asyncio @@ -90,12 +104,16 @@ async def test_insert_if_not_exists_async(mem_db_async): {"domain": "google.com", "name": "Google"}, {"domain": "facebook.com", "name": "Facebook"}, ] - await ( + stats = await ( table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains) ) await table.count_rows() # 3 + stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0} # --8<-- [end:insert_if_not_exists_async] assert await table.count_rows() == 3 + assert stats["num_inserted_rows"] == 1 + assert stats["num_updated_rows"] == 0 + assert stats["num_deleted_rows"] == 0 def test_replace_range(mem_db): @@ -113,7 +131,7 @@ def test_replace_range(mem_db): new_chunks = [ {"doc_id": 1, "chunk_id": 0, "text": "Baz"}, ] - ( + stats = ( table.merge_insert(["doc_id", "chunk_id"]) .when_matched_update_all() .when_not_matched_insert_all() @@ -121,8 +139,12 @@ def test_replace_range(mem_db): .execute(new_chunks) ) table.count_rows("doc_id = 1") # 1 + stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1} # --8<-- [end:replace_range] assert table.count_rows("doc_id = 1") == 1 + assert stats["num_inserted_rows"] == 0 + assert stats["num_updated_rows"] == 1 + assert stats["num_deleted_rows"] == 1 @pytest.mark.asyncio @@ -141,7 +163,7 @@ async def test_replace_range_async(mem_db_async): new_chunks = [ {"doc_id": 1, "chunk_id": 0, "text": "Baz"}, ] - await ( + stats = await ( table.merge_insert(["doc_id", "chunk_id"]) .when_matched_update_all() .when_not_matched_insert_all() @@ -149,5 +171,9 @@ async def test_replace_range_async(mem_db_async): .execute(new_chunks) ) await table.count_rows("doc_id = 1") # 1 + stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1} # --8<-- [end:replace_range_async] assert await table.count_rows("doc_id = 1") == 1 + assert stats["num_inserted_rows"] == 0 + assert stats["num_updated_rows"] == 1 + assert stats["num_deleted_rows"] == 1 diff --git a/python/src/table.rs b/python/src/table.rs index 2e267cb8..c8073e05 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -489,8 +489,14 @@ impl Table { } future_into_py(self_.py(), async move { - builder.execute(Box::new(batches)).await.infer_error()?; - Ok(()) + let stats = builder.execute(Box::new(batches)).await.infer_error()?; + Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item("num_inserted_rows", stats.num_inserted_rows)?; + dict.set_item("num_updated_rows", stats.num_updated_rows)?; + dict.set_item("num_deleted_rows", stats.num_deleted_rows)?; + Ok(dict.unbind()) + }) }) } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 11ee0041..2e089cdd 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -47,6 +47,7 @@ use crate::{ TableDefinition, UpdateBuilder, }, }; +use lance::dataset::MergeStats; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); @@ -1022,7 +1023,7 @@ impl BaseTable for RemoteTable { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result<()> { + ) -> Result { self.check_mutable().await?; let query = MergeInsertRequest::try_from(params)?; @@ -1034,9 +1035,11 @@ impl BaseTable for RemoteTable { let (request_id, response) = self.send_streaming(request, new_data, true).await?; + // TODO: server can response with these stats in response body. + // We should test that we can handle both empty response from old server + // and response with stats from new server. self.check_table_response(&request_id, response).await?; - - Ok(()) + Ok(MergeStats::default()) } async fn tags(&self) -> Result> { @@ -1348,7 +1351,12 @@ mod tests { Box::pin(table.count_rows(None).map_ok(|_| ())), Box::pin(table.update().column("a", "a + 1").execute().map_ok(|_| ())), Box::pin(table.add(example_data()).execute().map_ok(|_| ())), - Box::pin(table.merge_insert(&["test"]).execute(example_data())), + Box::pin( + table + .merge_insert(&["test"]) + .execute(example_data()) + .map_ok(|_| ()), + ), Box::pin(table.delete("false")), Box::pin(table.add_columns( NewColumnTransform::SqlExpressions(vec![("x".into(), "y".into())]), diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 9ea3380f..bb404e1c 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -20,6 +20,7 @@ use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; use lance::dataset::scanner::Scanner; pub use lance::dataset::ColumnAlteration; +pub use lance::dataset::MergeStats; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; pub use lance::dataset::Version; @@ -487,7 +488,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result<()>; + ) -> Result; /// Gets the table tag manager. async fn tags(&self) -> Result>; /// Optimize the dataset. @@ -2367,7 +2368,7 @@ impl BaseTable for NativeTable { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result<()> { + ) -> Result { let dataset = Arc::new(self.dataset.get().await?.clone()); let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( @@ -2394,9 +2395,9 @@ impl BaseTable for NativeTable { builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep); } let job = builder.try_build()?; - let (new_dataset, _stats) = job.execute_reader(new_data).await?; + let (new_dataset, stats) = job.execute_reader(new_data).await?; self.dataset.set_latest(new_dataset.as_ref().clone()).await; - Ok(()) + Ok(stats) } /// Delete rows from the table diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index ea2999a2..f63b0218 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use arrow_array::RecordBatchReader; +use lance::dataset::MergeStats; use crate::Result; @@ -86,8 +87,9 @@ impl MergeInsertBuilder { /// Executes the merge insert operation /// - /// Nothing is returned but the [`super::Table`] is updated - pub async fn execute(self, new_data: Box) -> Result<()> { + /// Returns statistics about the merge operation including the number of rows + /// inserted, updated, and deleted. + pub async fn execute(self, new_data: Box) -> Result { self.table.clone().merge_insert(self, new_data).await } }