diff --git a/docs/src/js/classes/MergeInsertBuilder.md b/docs/src/js/classes/MergeInsertBuilder.md index 5d5b6e81..5d407d95 100644 --- a/docs/src/js/classes/MergeInsertBuilder.md +++ b/docs/src/js/classes/MergeInsertBuilder.md @@ -33,7 +33,7 @@ Construct a MergeInsertBuilder. __Internal use only.__ ### execute() ```ts -execute(data): Promise +execute(data): Promise ``` Executes the merge insert operation @@ -44,9 +44,9 @@ Executes the merge insert operation #### Returns -`Promise`<[`MergeStats`](../interfaces/MergeStats.md)> +`Promise`<[`MergeResult`](../interfaces/MergeResult.md)> -Statistics about the merge operation: counts of inserted, updated, and deleted rows +the merge result *** diff --git a/docs/src/js/classes/Table.md b/docs/src/js/classes/Table.md index 0c1d0a3b..0bad38a5 100644 --- a/docs/src/js/classes/Table.md +++ b/docs/src/js/classes/Table.md @@ -40,7 +40,7 @@ Returns the name of the table ### add() ```ts -abstract add(data, options?): Promise +abstract add(data, options?): Promise ``` Insert records into this Table. @@ -54,14 +54,17 @@ Insert records into this Table. #### Returns -`Promise`<`void`> +`Promise`<[`AddResult`](../interfaces/AddResult.md)> + +A promise that resolves to an object +containing the new version number of the table *** ### addColumns() ```ts -abstract addColumns(newColumnTransforms): Promise +abstract addColumns(newColumnTransforms): Promise ``` Add new columns with defined values. @@ -76,14 +79,17 @@ Add new columns with defined values. #### Returns -`Promise`<`void`> +`Promise`<[`AddColumnsResult`](../interfaces/AddColumnsResult.md)> + +A promise that resolves to an object +containing the new version number of the table after adding the columns. *** ### alterColumns() ```ts -abstract alterColumns(columnAlterations): Promise +abstract alterColumns(columnAlterations): Promise ``` Alter the name or nullability of columns. @@ -96,7 +102,10 @@ Alter the name or nullability of columns. #### Returns -`Promise`<`void`> +`Promise`<[`AlterColumnsResult`](../interfaces/AlterColumnsResult.md)> + +A promise that resolves to an object +containing the new version number of the table after altering the columns. *** @@ -252,7 +261,7 @@ await table.createIndex("my_float_col"); ### delete() ```ts -abstract delete(predicate): Promise +abstract delete(predicate): Promise ``` Delete the rows that satisfy the predicate. @@ -263,7 +272,10 @@ Delete the rows that satisfy the predicate. #### Returns -`Promise`<`void`> +`Promise`<[`DeleteResult`](../interfaces/DeleteResult.md)> + +A promise that resolves to an object +containing the new version number of the table *** @@ -284,7 +296,7 @@ Return a brief description of the table ### dropColumns() ```ts -abstract dropColumns(columnNames): Promise +abstract dropColumns(columnNames): Promise ``` Drop one or more columns from the dataset @@ -303,7 +315,10 @@ then call ``cleanup_files`` to remove the old files. #### Returns -`Promise`<`void`> +`Promise`<[`DropColumnsResult`](../interfaces/DropColumnsResult.md)> + +A promise that resolves to an object +containing the new version number of the table after dropping the columns. *** @@ -678,7 +693,7 @@ Return the table as an arrow table #### update(opts) ```ts -abstract update(opts): Promise +abstract update(opts): Promise ``` Update existing records in the Table @@ -689,7 +704,10 @@ Update existing records in the Table ##### Returns -`Promise`<`void`> +`Promise`<[`UpdateResult`](../interfaces/UpdateResult.md)> + +A promise that resolves to an object containing +the number of rows updated and the new version number ##### Example @@ -700,7 +718,7 @@ table.update({where:"x = 2", values:{"vector": [10, 10]}}) #### update(opts) ```ts -abstract update(opts): Promise +abstract update(opts): Promise ``` Update existing records in the Table @@ -711,7 +729,10 @@ Update existing records in the Table ##### Returns -`Promise`<`void`> +`Promise`<[`UpdateResult`](../interfaces/UpdateResult.md)> + +A promise that resolves to an object containing +the number of rows updated and the new version number ##### Example @@ -722,7 +743,7 @@ table.update({where:"x = 2", valuesSql:{"x": "x + 1"}}) #### update(updates, options) ```ts -abstract update(updates, options?): Promise +abstract update(updates, options?): Promise ``` Update existing records in the Table @@ -745,10 +766,6 @@ repeatedly calilng this method. * **updates**: `Record`<`string`, `string`> \| `Map`<`string`, `string`> the columns to update - Keys in the map should specify the name of the column to update. - Values in the map provide the new value of the column. These can - be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions - based on the row being updated (e.g. "my_col + 1") * **options?**: `Partial`<[`UpdateOptions`](../interfaces/UpdateOptions.md)> additional options to control @@ -756,7 +773,15 @@ repeatedly calilng this method. ##### Returns -`Promise`<`void`> +`Promise`<[`UpdateResult`](../interfaces/UpdateResult.md)> + +A promise that resolves to an object +containing the number of rows updated and the new version number + +Keys in the map should specify the name of the column to update. +Values in the map provide the new value of the column. These can +be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions +based on the row being updated (e.g. "my_col + 1") *** diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 962f07e2..c9779d79 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -34,13 +34,18 @@ ## Interfaces +- [AddColumnsResult](interfaces/AddColumnsResult.md) - [AddColumnsSql](interfaces/AddColumnsSql.md) - [AddDataOptions](interfaces/AddDataOptions.md) +- [AddResult](interfaces/AddResult.md) +- [AlterColumnsResult](interfaces/AlterColumnsResult.md) - [ClientConfig](interfaces/ClientConfig.md) - [ColumnAlteration](interfaces/ColumnAlteration.md) - [CompactionStats](interfaces/CompactionStats.md) - [ConnectionOptions](interfaces/ConnectionOptions.md) - [CreateTableOptions](interfaces/CreateTableOptions.md) +- [DeleteResult](interfaces/DeleteResult.md) +- [DropColumnsResult](interfaces/DropColumnsResult.md) - [ExecutableQuery](interfaces/ExecutableQuery.md) - [FragmentStatistics](interfaces/FragmentStatistics.md) - [FragmentSummaryStats](interfaces/FragmentSummaryStats.md) @@ -54,7 +59,7 @@ - [IndexStatistics](interfaces/IndexStatistics.md) - [IvfFlatOptions](interfaces/IvfFlatOptions.md) - [IvfPqOptions](interfaces/IvfPqOptions.md) -- [MergeStats](interfaces/MergeStats.md) +- [MergeResult](interfaces/MergeResult.md) - [OpenTableOptions](interfaces/OpenTableOptions.md) - [OptimizeOptions](interfaces/OptimizeOptions.md) - [OptimizeStats](interfaces/OptimizeStats.md) @@ -65,6 +70,7 @@ - [TableStatistics](interfaces/TableStatistics.md) - [TimeoutConfig](interfaces/TimeoutConfig.md) - [UpdateOptions](interfaces/UpdateOptions.md) +- [UpdateResult](interfaces/UpdateResult.md) - [Version](interfaces/Version.md) ## Type Aliases diff --git a/docs/src/js/interfaces/AddColumnsResult.md b/docs/src/js/interfaces/AddColumnsResult.md new file mode 100644 index 00000000..fbc9b0b5 --- /dev/null +++ b/docs/src/js/interfaces/AddColumnsResult.md @@ -0,0 +1,15 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / AddColumnsResult + +# Interface: AddColumnsResult + +## Properties + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/AddResult.md b/docs/src/js/interfaces/AddResult.md new file mode 100644 index 00000000..7a90e03a --- /dev/null +++ b/docs/src/js/interfaces/AddResult.md @@ -0,0 +1,15 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / AddResult + +# Interface: AddResult + +## Properties + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/AlterColumnsResult.md b/docs/src/js/interfaces/AlterColumnsResult.md new file mode 100644 index 00000000..73ab22eb --- /dev/null +++ b/docs/src/js/interfaces/AlterColumnsResult.md @@ -0,0 +1,15 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / AlterColumnsResult + +# Interface: AlterColumnsResult + +## Properties + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/DeleteResult.md b/docs/src/js/interfaces/DeleteResult.md new file mode 100644 index 00000000..f2b18633 --- /dev/null +++ b/docs/src/js/interfaces/DeleteResult.md @@ -0,0 +1,15 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / DeleteResult + +# Interface: DeleteResult + +## Properties + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/DropColumnsResult.md b/docs/src/js/interfaces/DropColumnsResult.md new file mode 100644 index 00000000..8e2440b2 --- /dev/null +++ b/docs/src/js/interfaces/DropColumnsResult.md @@ -0,0 +1,15 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / DropColumnsResult + +# Interface: DropColumnsResult + +## Properties + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/MergeResult.md b/docs/src/js/interfaces/MergeResult.md new file mode 100644 index 00000000..9874fd5d --- /dev/null +++ b/docs/src/js/interfaces/MergeResult.md @@ -0,0 +1,39 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / MergeResult + +# Interface: MergeResult + +## Properties + +### numDeletedRows + +```ts +numDeletedRows: number; +``` + +*** + +### numInsertedRows + +```ts +numInsertedRows: number; +``` + +*** + +### numUpdatedRows + +```ts +numUpdatedRows: number; +``` + +*** + +### version + +```ts +version: number; +``` diff --git a/docs/src/js/interfaces/MergeStats.md b/docs/src/js/interfaces/MergeStats.md deleted file mode 100644 index cb8f05f9..00000000 --- a/docs/src/js/interfaces/MergeStats.md +++ /dev/null @@ -1,31 +0,0 @@ -[**@lancedb/lancedb**](../README.md) • **Docs** - -*** - -[@lancedb/lancedb](../globals.md) / MergeStats - -# Interface: MergeStats - -## Properties - -### numDeletedRows - -```ts -numDeletedRows: bigint; -``` - -*** - -### numInsertedRows - -```ts -numInsertedRows: bigint; -``` - -*** - -### numUpdatedRows - -```ts -numUpdatedRows: bigint; -``` diff --git a/docs/src/js/interfaces/UpdateResult.md b/docs/src/js/interfaces/UpdateResult.md new file mode 100644 index 00000000..3dd8d812 --- /dev/null +++ b/docs/src/js/interfaces/UpdateResult.md @@ -0,0 +1,23 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / UpdateResult + +# Interface: UpdateResult + +## Properties + +### rowsUpdated + +```ts +rowsUpdated: number; +``` + +*** + +### version + +```ts +version: number; +``` diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index f067f305..dcc385b3 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -34,6 +34,7 @@ import { } from "../lancedb/embedding"; import { Index } from "../lancedb/indices"; import { instanceOfFullTextQuery } from "../lancedb/query"; +import exp = require("constants"); describe.each([arrow15, arrow16, arrow17, arrow18])( "Given a table", @@ -95,7 +96,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( }); it("should overwrite data if asked", async () => { - await table.add([{ id: 1 }, { id: 2 }]); + const addRes = await table.add([{ id: 1 }, { id: 2 }]); + expect(addRes).toHaveProperty("version"); + expect(addRes.version).toBe(2); await table.add([{ id: 1 }], { mode: "overwrite" }); await expect(table.countRows()).resolves.toBe(1); }); @@ -111,7 +114,11 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( await table.add([{ id: 1 }]); expect(await table.countRows("id == 1")).toBe(1); expect(await table.countRows("id == 7")).toBe(0); - await table.update({ id: "7" }); + const updateRes = await table.update({ id: "7" }); + expect(updateRes).toHaveProperty("version"); + expect(updateRes.version).toBe(3); + expect(updateRes).toHaveProperty("rowsUpdated"); + expect(updateRes.rowsUpdated).toBe(1); expect(await table.countRows("id == 1")).toBe(0); expect(await table.countRows("id == 7")).toBe(1); await table.add([{ id: 2 }]); @@ -338,15 +345,16 @@ describe("merge insert", () => { { a: 3, b: "y" }, { a: 4, b: "z" }, ]; - const stats = await table + const mergeInsertRes = await table .mergeInsert("a") .whenMatchedUpdateAll() .whenNotMatchedInsertAll() .execute(newData); - - expect(stats.numInsertedRows).toBe(1n); - expect(stats.numUpdatedRows).toBe(2n); - expect(stats.numDeletedRows).toBe(0n); + expect(mergeInsertRes).toHaveProperty("version"); + expect(mergeInsertRes.version).toBe(2); + expect(mergeInsertRes.numInsertedRows).toBe(1); + expect(mergeInsertRes.numUpdatedRows).toBe(2); + expect(mergeInsertRes.numDeletedRows).toBe(0); const expected = [ { a: 1, b: "a" }, @@ -365,10 +373,12 @@ describe("merge insert", () => { { a: 3, b: "y" }, { a: 4, b: "z" }, ]; - await table + const mergeInsertRes = await table .mergeInsert("a") .whenMatchedUpdateAll({ where: "target.b = 'b'" }) .execute(newData); + expect(mergeInsertRes).toHaveProperty("version"); + expect(mergeInsertRes.version).toBe(2); const expected = [ { a: 1, b: "a" }, @@ -1028,15 +1038,19 @@ describe("schema evolution", function () { { id: 1n, vector: [0.1, 0.2] }, ]); // Can create a non-nullable column only through addColumns at the moment. - await table.addColumns([ + const addColumnsRes = await table.addColumns([ { name: "price", valueSql: "cast(10.0 as double)" }, ]); + expect(addColumnsRes).toHaveProperty("version"); + expect(addColumnsRes.version).toBe(2); expect(await table.schema()).toEqual(schema); - await table.alterColumns([ + const alterColumnsRes = await table.alterColumns([ { path: "id", rename: "new_id" }, { path: "price", nullable: true }, ]); + expect(alterColumnsRes).toHaveProperty("version"); + expect(alterColumnsRes.version).toBe(3); const expectedSchema = new Schema([ new Field("new_id", new Int64(), true), @@ -1154,7 +1168,9 @@ describe("schema evolution", function () { const table = await con.createTable("vectors", [ { id: 1n, vector: [0.1, 0.2] }, ]); - await table.dropColumns(["vector"]); + const dropColumnsRes = await table.dropColumns(["vector"]); + expect(dropColumnsRes).toHaveProperty("version"); + expect(dropColumnsRes.version).toBe(2); const expectedSchema = new Schema([new Field("id", new Int64(), true)]); expect(await table.schema()).toEqual(expectedSchema); diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 4f3e8106..6f548cb5 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -28,7 +28,13 @@ export { FragmentSummaryStats, Tags, TagContents, - MergeStats, + MergeResult, + AddResult, + AddColumnsResult, + AlterColumnsResult, + DeleteResult, + DropColumnsResult, + UpdateResult, } from "./native.js"; export { diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts index 19d03cb3..781ca177 100644 --- a/nodejs/lancedb/merge.ts +++ b/nodejs/lancedb/merge.ts @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors import { Data, Schema, fromDataToBuffer } from "./arrow"; -import { MergeStats, NativeMergeInsertBuilder } from "./native"; +import { MergeResult, NativeMergeInsertBuilder } from "./native"; /** A builder used to create and run a merge insert operation */ export class MergeInsertBuilder { @@ -73,9 +73,9 @@ export class MergeInsertBuilder { /** * Executes the merge insert operation * - * @returns Statistics about the merge operation: counts of inserted, updated, and deleted rows + * @returns {Promise} the merge result */ - async execute(data: Data): Promise { + async execute(data: Data): Promise { let schema: Schema; if (this.#schema instanceof Promise) { schema = await this.#schema; diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 7da220d9..e344a7f5 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -16,12 +16,18 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { IndexOptions } from "./indices"; import { MergeInsertBuilder } from "./merge"; import { + AddColumnsResult, AddColumnsSql, + AddResult, + AlterColumnsResult, + DeleteResult, + DropColumnsResult, IndexConfig, IndexStatistics, OptimizeStats, TableStatistics, Tags, + UpdateResult, Table as _NativeTable, } from "./native"; import { @@ -126,12 +132,19 @@ export abstract class Table { /** * Insert records into this Table. * @param {Data} data Records to be inserted into the Table + * @returns {Promise} A promise that resolves to an object + * containing the new version number of the table */ - abstract add(data: Data, options?: Partial): Promise; + abstract add( + data: Data, + options?: Partial, + ): Promise; /** * Update existing records in the Table * @param opts.values The values to update. The keys are the column names and the values * are the values to set. + * @returns {Promise} A promise that resolves to an object containing + * the number of rows updated and the new version number * @example * ```ts * table.update({where:"x = 2", values:{"vector": [10, 10]}}) @@ -141,11 +154,13 @@ export abstract class Table { opts: { values: Map | Record; } & Partial, - ): Promise; + ): Promise; /** * Update existing records in the Table * @param opts.valuesSql The values to update. The keys are the column names and the values * are the values to set. The values are SQL expressions. + * @returns {Promise} A promise that resolves to an object containing + * the number of rows updated and the new version number * @example * ```ts * table.update({where:"x = 2", valuesSql:{"x": "x + 1"}}) @@ -155,7 +170,7 @@ export abstract class Table { opts: { valuesSql: Map | Record; } & Partial, - ): Promise; + ): Promise; /** * Update existing records in the Table * @@ -173,6 +188,8 @@ export abstract class Table { * repeatedly calilng this method. * @param {Map | Record} updates - the * columns to update + * @returns {Promise} A promise that resolves to an object + * containing the number of rows updated and the new version number * * Keys in the map should specify the name of the column to update. * Values in the map provide the new value of the column. These can @@ -184,12 +201,16 @@ export abstract class Table { abstract update( updates: Map | Record, options?: Partial, - ): Promise; + ): Promise; /** Count the total number of rows in the dataset. */ abstract countRows(filter?: string): Promise; - /** Delete the rows that satisfy the predicate. */ - abstract delete(predicate: string): Promise; + /** + * Delete the rows that satisfy the predicate. + * @returns {Promise} A promise that resolves to an object + * containing the new version number of the table + */ + abstract delete(predicate: string): Promise; /** * Create an index to speed up queries. * @@ -343,15 +364,23 @@ export abstract class Table { * the SQL expression to use to calculate the value of the new column. These * expressions will be evaluated for each row in the table, and can * reference existing columns in the table. + * @returns {Promise} A promise that resolves to an object + * containing the new version number of the table after adding the columns. */ - abstract addColumns(newColumnTransforms: AddColumnsSql[]): Promise; + abstract addColumns( + newColumnTransforms: AddColumnsSql[], + ): Promise; /** * Alter the name or nullability of columns. * @param {ColumnAlteration[]} columnAlterations One or more alterations to * apply to columns. + * @returns {Promise} A promise that resolves to an object + * containing the new version number of the table after altering the columns. */ - abstract alterColumns(columnAlterations: ColumnAlteration[]): Promise; + abstract alterColumns( + columnAlterations: ColumnAlteration[], + ): Promise; /** * Drop one or more columns from the dataset * @@ -362,8 +391,10 @@ export abstract class Table { * @param {string[]} columnNames The names of the columns to drop. These can * be nested column references (e.g. "a.b.c") or top-level column names * (e.g. "a"). + * @returns {Promise} A promise that resolves to an object + * containing the new version number of the table after dropping the columns. */ - abstract dropColumns(columnNames: string[]): Promise; + abstract dropColumns(columnNames: string[]): Promise; /** Retrieve the version of the table */ abstract version(): Promise; @@ -529,12 +560,12 @@ export class LocalTable extends Table { return tbl.schema; } - async add(data: Data, options?: Partial): Promise { + async add(data: Data, options?: Partial): Promise { const mode = options?.mode ?? "append"; const schema = await this.schema(); const buffer = await fromDataToBuffer(data, undefined, schema); - await this.inner.add(buffer, mode); + return await this.inner.add(buffer, mode); } async update( @@ -547,7 +578,7 @@ export class LocalTable extends Table { valuesSql: Map | Record; } & Partial), options?: Partial, - ) { + ): Promise { const isValues = "values" in optsOrUpdates && typeof optsOrUpdates.values !== "string"; const isValuesSql = @@ -594,15 +625,15 @@ export class LocalTable extends Table { columns = Object.entries(optsOrUpdates as Record); predicate = options?.where; } - await this.inner.update(predicate, columns); + return await this.inner.update(predicate, columns); } async countRows(filter?: string): Promise { return await this.inner.countRows(filter); } - async delete(predicate: string): Promise { - await this.inner.delete(predicate); + async delete(predicate: string): Promise { + return await this.inner.delete(predicate); } async createIndex(column: string, options?: Partial) { @@ -690,11 +721,15 @@ export class LocalTable extends Table { // TODO: Support BatchUDF - async addColumns(newColumnTransforms: AddColumnsSql[]): Promise { - await this.inner.addColumns(newColumnTransforms); + async addColumns( + newColumnTransforms: AddColumnsSql[], + ): Promise { + return await this.inner.addColumns(newColumnTransforms); } - async alterColumns(columnAlterations: ColumnAlteration[]): Promise { + async alterColumns( + columnAlterations: ColumnAlteration[], + ): Promise { const processedAlterations = columnAlterations.map((alteration) => { if (typeof alteration.dataType === "string") { return { @@ -715,11 +750,11 @@ export class LocalTable extends Table { } }); - await this.inner.alterColumns(processedAlterations); + return await this.inner.alterColumns(processedAlterations); } - async dropColumns(columnNames: string[]): Promise { - await this.inner.dropColumns(columnNames); + async dropColumns(columnNames: string[]): Promise { + return await this.inner.dropColumns(columnNames); } async version(): Promise { diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs index 4f824034..38eb4883 100644 --- a/nodejs/src/merge.rs +++ b/nodejs/src/merge.rs @@ -5,7 +5,7 @@ use lancedb::{arrow::IntoArrow, ipc::ipc_file_to_batches, table::merge::MergeIns use napi::bindgen_prelude::*; use napi_derive::napi; -use crate::error::convert_error; +use crate::{error::convert_error, table::MergeResult}; #[napi] #[derive(Clone)] @@ -37,7 +37,7 @@ impl NativeMergeInsertBuilder { } #[napi(catch_unwind)] - pub async fn execute(&self, buf: Buffer) -> napi::Result { + pub async fn execute(&self, buf: Buffer) -> napi::Result { let data = ipc_file_to_batches(buf.to_vec()) .and_then(IntoArrow::into_arrow) .map_err(|e| { @@ -46,14 +46,13 @@ impl NativeMergeInsertBuilder { let this = self.clone(); - let stats = this.inner.execute(data).await.map_err(|e| { + let res = this.inner.execute(data).await.map_err(|e| { napi::Error::from_reason(format!( "Failed to execute merge insert: {}", convert_error(&e) )) })?; - - Ok(stats.into()) + Ok(res.into()) } } @@ -62,20 +61,3 @@ impl From for NativeMergeInsertBuilder { Self { inner } } } - -#[napi(object)] -pub struct MergeStats { - pub num_inserted_rows: BigInt, - pub num_updated_rows: BigInt, - pub num_deleted_rows: BigInt, -} - -impl From for MergeStats { - fn from(stats: lancedb::table::MergeStats) -> Self { - Self { - num_inserted_rows: stats.num_inserted_rows.into(), - num_updated_rows: stats.num_updated_rows.into(), - num_deleted_rows: stats.num_deleted_rows.into(), - } - } -} diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 76ab724c..afc8203b 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -75,7 +75,7 @@ impl Table { } #[napi(catch_unwind)] - pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> { + pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; let mut op = self.inner_ref()?.add(batches); @@ -88,7 +88,8 @@ impl Table { return Err(napi::Error::from_reason(format!("Invalid mode: {}", mode))); }; - op.execute().await.default_error() + let res = op.execute().await.default_error()?; + Ok(res.into()) } #[napi(catch_unwind)] @@ -101,8 +102,9 @@ impl Table { } #[napi(catch_unwind)] - pub async fn delete(&self, predicate: String) -> napi::Result<()> { - self.inner_ref()?.delete(&predicate).await.default_error() + pub async fn delete(&self, predicate: String) -> napi::Result { + let res = self.inner_ref()?.delete(&predicate).await.default_error()?; + Ok(res.into()) } #[napi(catch_unwind)] @@ -168,7 +170,7 @@ impl Table { &self, only_if: Option, columns: Vec<(String, String)>, - ) -> napi::Result { + ) -> napi::Result { let mut op = self.inner_ref()?.update(); if let Some(only_if) = only_if { op = op.only_if(only_if); @@ -176,7 +178,8 @@ impl Table { for (column_name, value) in columns { op = op.column(column_name, value); } - op.execute().await.default_error() + let res = op.execute().await.default_error()?; + Ok(res.into()) } #[napi(catch_unwind)] @@ -190,21 +193,28 @@ impl Table { } #[napi(catch_unwind)] - pub async fn add_columns(&self, transforms: Vec) -> napi::Result<()> { + pub async fn add_columns( + &self, + transforms: Vec, + ) -> napi::Result { let transforms = transforms .into_iter() .map(|sql| (sql.name, sql.value_sql)) .collect::>(); let transforms = NewColumnTransform::SqlExpressions(transforms); - self.inner_ref()? + let res = self + .inner_ref()? .add_columns(transforms, None) .await .default_error()?; - Ok(()) + Ok(res.into()) } #[napi(catch_unwind)] - pub async fn alter_columns(&self, alterations: Vec) -> napi::Result<()> { + pub async fn alter_columns( + &self, + alterations: Vec, + ) -> napi::Result { for alteration in &alterations { if alteration.rename.is_none() && alteration.nullable.is_none() @@ -221,21 +231,23 @@ impl Table { .collect::, String>>() .map_err(napi::Error::from_reason)?; - self.inner_ref()? + let res = self + .inner_ref()? .alter_columns(&alterations) .await .default_error()?; - Ok(()) + Ok(res.into()) } #[napi(catch_unwind)] - pub async fn drop_columns(&self, columns: Vec) -> napi::Result<()> { + pub async fn drop_columns(&self, columns: Vec) -> napi::Result { let col_refs = columns.iter().map(String::as_str).collect::>(); - self.inner_ref()? + let res = self + .inner_ref()? .drop_columns(&col_refs) .await .default_error()?; - Ok(()) + Ok(res.into()) } #[napi(catch_unwind)] @@ -642,6 +654,105 @@ pub struct Version { pub metadata: HashMap, } +#[napi(object)] +pub struct UpdateResult { + pub rows_updated: i64, + pub version: i64, +} + +impl From for UpdateResult { + fn from(value: lancedb::table::UpdateResult) -> Self { + Self { + rows_updated: value.rows_updated as i64, + version: value.version as i64, + } + } +} + +#[napi(object)] +pub struct AddResult { + pub version: i64, +} + +impl From for AddResult { + fn from(value: lancedb::table::AddResult) -> Self { + Self { + version: value.version as i64, + } + } +} + +#[napi(object)] +pub struct DeleteResult { + pub version: i64, +} + +impl From for DeleteResult { + fn from(value: lancedb::table::DeleteResult) -> Self { + Self { + version: value.version as i64, + } + } +} + +#[napi(object)] +pub struct MergeResult { + pub version: i64, + pub num_inserted_rows: i64, + pub num_updated_rows: i64, + pub num_deleted_rows: i64, +} + +impl From for MergeResult { + fn from(value: lancedb::table::MergeResult) -> Self { + Self { + version: value.version as i64, + num_inserted_rows: value.num_inserted_rows as i64, + num_updated_rows: value.num_updated_rows as i64, + num_deleted_rows: value.num_deleted_rows as i64, + } + } +} + +#[napi(object)] +pub struct AddColumnsResult { + pub version: i64, +} + +impl From for AddColumnsResult { + fn from(value: lancedb::table::AddColumnsResult) -> Self { + Self { + version: value.version as i64, + } + } +} + +#[napi(object)] +pub struct AlterColumnsResult { + pub version: i64, +} + +impl From for AlterColumnsResult { + fn from(value: lancedb::table::AlterColumnsResult) -> Self { + Self { + version: value.version as i64, + } + } +} + +#[napi(object)] +pub struct DropColumnsResult { + pub version: i64, +} + +impl From for DropColumnsResult { + fn from(value: lancedb::table::DropColumnsResult) -> Self { + Self { + version: value.version as i64, + } + } +} + #[napi] pub struct TagContents { pub version: i64, diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index ee744f60..7e9934aa 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -36,8 +36,10 @@ class Table: async def schema(self) -> pa.Schema: ... async def add( self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"] - ) -> None: ... - async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... + ) -> AddResult: ... + async def update( + self, updates: Dict[str, str], where: Optional[str] + ) -> UpdateResult: ... async def count_rows(self, filter: Optional[str]) -> int: ... async def create_index( self, @@ -51,10 +53,12 @@ class Table: async def checkout_latest(self): ... async def restore(self, version: Optional[int] = None): ... async def list_indices(self) -> list[IndexConfig]: ... - async def delete(self, filter: str): ... - async def add_columns(self, columns: list[tuple[str, str]]) -> None: ... - async def add_columns_with_schema(self, schema: pa.Schema) -> None: ... - async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ... + async def delete(self, filter: str) -> DeleteResult: ... + async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ... + async def add_columns_with_schema(self, schema: pa.Schema) -> AddColumnsResult: ... + async def alter_columns( + self, columns: list[dict[str, Any]] + ) -> AlterColumnsResult: ... async def optimize( self, *, @@ -208,3 +212,28 @@ class OptimizeStats: class Tag(TypedDict): version: int manifest_size: int + +class AddResult: + version: int + +class DeleteResult: + version: int + +class UpdateResult: + rows_updated: int + version: int + +class MergeResult: + version: int + num_updated_rows: int + num_inserted_rows: int + num_deleted_rows: int + +class AddColumnsResult: + version: int + +class AlterColumnsResult: + version: int + +class DropColumnsResult: + version: int diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index 575f1dba..419877bb 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from .common import DATA + from ._lancedb import ( + MergeInsertResult, + ) class LanceMergeInsertBuilder(object): @@ -78,7 +81,7 @@ class LanceMergeInsertBuilder(object): new_data: DATA, on_bad_vectors: str = "error", fill_value: float = 0.0, - ): + ) -> MergeInsertResult: """ Executes the merge insert operation @@ -95,5 +98,10 @@ class LanceMergeInsertBuilder(object): One of "error", "drop", "fill". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + + Returns + ------- + MergeInsertResult + version: the new version number of the table after doing merge insert. """ return self._table._do_merge(self, new_data, on_bad_vectors, fill_value) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index f68d0163..79f00517 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -415,6 +415,7 @@ class LanceModel(pydantic.BaseModel): >>> table.add([ ... TestModel(name="test", vector=[1.0, 2.0]) ... ]) + AddResult(version=2) >>> table.search([0., 0.]).limit(1).to_pydantic(TestModel) [TestModel(name='test', vector=FixedSizeList(dim=2))] """ diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index b70d05e9..ed6d14ea 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -7,7 +7,16 @@ from functools import cached_property from typing import Dict, Iterable, List, Optional, Union, Literal import warnings -from lancedb._lancedb import IndexConfig +from lancedb._lancedb import ( + AddColumnsResult, + AddResult, + AlterColumnsResult, + DeleteResult, + DropColumnsResult, + IndexConfig, + MergeResult, + UpdateResult, +) from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList from lancedb.remote.db import LOOP @@ -263,7 +272,7 @@ class RemoteTable(Table): mode: str = "append", on_bad_vectors: str = "error", fill_value: float = 0.0, - ) -> int: + ) -> AddResult: """Add more data to the [Table](Table). It has the same API signature as the OSS version. @@ -286,8 +295,12 @@ class RemoteTable(Table): fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + Returns + ------- + AddResult + An object containing the new version number of the table after adding data. """ - LOOP.run( + return LOOP.run( self._table.add( data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) @@ -413,10 +426,12 @@ class RemoteTable(Table): new_data: DATA, on_bad_vectors: str, fill_value: float, - ): - LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) + ) -> MergeResult: + return LOOP.run( + self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) + ) - def delete(self, predicate: str): + def delete(self, predicate: str) -> DeleteResult: """Delete rows from the table. This can be used to delete a single row, many rows, all rows, or @@ -431,6 +446,11 @@ class RemoteTable(Table): The filter must not be empty, or it will error. + Returns + ------- + DeleteResult + An object containing the new version number of the table after deletion. + Examples -------- >>> import lancedb @@ -463,7 +483,7 @@ class RemoteTable(Table): x vector _distance # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ - LOOP.run(self._table.delete(predicate)) + return LOOP.run(self._table.delete(predicate)) def update( self, @@ -471,7 +491,7 @@ class RemoteTable(Table): values: Optional[dict] = None, *, values_sql: Optional[Dict[str, str]] = None, - ): + ) -> UpdateResult: """ This can be used to update zero to all rows depending on how many rows match the where clause. @@ -489,6 +509,12 @@ class RemoteTable(Table): reference existing columns. For example, {"x": "x + 1"} will increment the x column by 1. + Returns + ------- + UpdateResult + - rows_updated: The number of rows that were updated + - version: The new version number of the table after the update + Examples -------- >>> import lancedb @@ -513,7 +539,7 @@ class RemoteTable(Table): 2 2 [10.0, 10.0] # doctest: +SKIP """ - LOOP.run( + return LOOP.run( self._table.update(where=where, updates=values, updates_sql=values_sql) ) @@ -561,13 +587,15 @@ class RemoteTable(Table): def count_rows(self, filter: Optional[str] = None) -> int: return LOOP.run(self._table.count_rows(filter)) - def add_columns(self, transforms: Dict[str, str]): + def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult: return LOOP.run(self._table.add_columns(transforms)) - def alter_columns(self, *alterations: Iterable[Dict[str, str]]): + def alter_columns( + self, *alterations: Iterable[Dict[str, str]] + ) -> AlterColumnsResult: return LOOP.run(self._table.alter_columns(*alterations)) - def drop_columns(self, columns: Iterable[str]): + def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult: return LOOP.run(self._table.drop_columns(columns)) def drop_index(self, index_name: str): diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index df130d3c..91a8fea5 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -78,6 +78,13 @@ if TYPE_CHECKING: CleanupStats, CompactionStats, Tag, + AddColumnsResult, + AddResult, + AlterColumnsResult, + DeleteResult, + DropColumnsResult, + MergeResult, + UpdateResult, ) from .db import LanceDBConnection from .index import IndexConfig @@ -550,6 +557,7 @@ class Table(ABC): Can append new data with [Table.add()][lancedb.table.Table.add]. >>> table.add([{"vector": [0.5, 1.3], "b": 4}]) + AddResult(version=2) Can query the table with [Table.search][lancedb.table.Table.search]. @@ -894,7 +902,7 @@ class Table(ABC): mode: AddMode = "append", on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, - ): + ) -> AddResult: """Add more data to the [Table](Table). Parameters @@ -916,6 +924,10 @@ class Table(ABC): fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + Returns + ------- + AddResult + An object containing the new version number of the table after adding data. """ raise NotImplementedError @@ -962,12 +974,12 @@ class Table(ABC): >>> table = db.create_table("my_table", data) >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) >>> # Perform a "upsert" operation - >>> stats = table.merge_insert("a") \\ + >>> res = table.merge_insert("a") \\ ... .when_matched_update_all() \\ ... .when_not_matched_insert_all() \\ ... .execute(new_data) - >>> stats - {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} + >>> res + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -976,7 +988,7 @@ class Table(ABC): 1 2 x 2 3 y 3 4 z - """ + """ # noqa: E501 on = [on] if isinstance(on, str) else list(iter(on)) return LanceMergeInsertBuilder(self, on) @@ -1091,10 +1103,10 @@ class Table(ABC): new_data: DATA, on_bad_vectors: OnBadVectorsType, fill_value: float, - ): ... + ) -> MergeResult: ... @abstractmethod - def delete(self, where: str): + def delete(self, where: str) -> DeleteResult: """Delete rows from the table. This can be used to delete a single row, many rows, all rows, or @@ -1109,6 +1121,11 @@ class Table(ABC): The filter must not be empty, or it will error. + Returns + ------- + DeleteResult + An object containing the new version number of the table after deletion. + Examples -------- >>> import lancedb @@ -1125,6 +1142,7 @@ class Table(ABC): 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] >>> table.delete("x = 2") + DeleteResult(version=2) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] @@ -1138,6 +1156,7 @@ class Table(ABC): >>> to_remove '1, 5' >>> table.delete(f"x IN ({to_remove})") + DeleteResult(version=3) >>> table.to_pandas() x vector 0 3 [5.0, 6.0] @@ -1151,7 +1170,7 @@ class Table(ABC): values: Optional[dict] = None, *, values_sql: Optional[Dict[str, str]] = None, - ): + ) -> UpdateResult: """ This can be used to update zero to all rows depending on how many rows match the where clause. If no where clause is provided, then @@ -1173,6 +1192,12 @@ class Table(ABC): reference existing columns. For example, {"x": "x + 1"} will increment the x column by 1. + Returns + ------- + UpdateResult + - rows_updated: The number of rows that were updated + - version: The new version number of the table after the update + Examples -------- >>> import lancedb @@ -1186,12 +1211,14 @@ class Table(ABC): 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] >>> table.update(where="x = 2", values={"vector": [10.0, 10]}) + UpdateResult(rows_updated=1, version=2) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] 1 3 [5.0, 6.0] 2 2 [10.0, 10.0] >>> table.update(values_sql={"x": "x + 1"}) + UpdateResult(rows_updated=3, version=3) >>> table.to_pandas() x vector 0 2 [1.0, 2.0] @@ -1354,6 +1381,11 @@ class Table(ABC): Alternatively, a pyarrow Field or Schema can be provided to add new columns with the specified data types. The new columns will be initialized with null values. + + Returns + ------- + AddColumnsResult + version: the new version number of the table after adding columns. """ @abstractmethod @@ -1379,10 +1411,15 @@ class Table(ABC): nullability is not changed. Only non-nullable columns can be changed to nullable. Currently, you cannot change a nullable column to non-nullable. + + Returns + ------- + AlterColumnsResult + version: the new version number of the table after the alteration. """ @abstractmethod - def drop_columns(self, columns: Iterable[str]): + def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult: """ Drop columns from the table. @@ -1390,6 +1427,11 @@ class Table(ABC): ---------- columns : Iterable[str] The names of the columns to drop. + + Returns + ------- + DropColumnsResult + version: the new version number of the table dropping the columns. """ @abstractmethod @@ -1611,6 +1653,7 @@ class LanceTable(Table): ... [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table.tags.create("v1", table.version) >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + AddResult(version=2) >>> tags = table.tags.list() >>> print(tags["v1"]["version"]) 1 @@ -1649,6 +1692,7 @@ class LanceTable(Table): vector type 0 [1.1, 0.9] vector >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + AddResult(version=2) >>> table.version 2 >>> table.checkout(1) @@ -1691,6 +1735,7 @@ class LanceTable(Table): vector type 0 [1.1, 0.9] vector >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + AddResult(version=2) >>> table.version 2 >>> table.restore(1) @@ -2055,7 +2100,7 @@ class LanceTable(Table): mode: AddMode = "append", on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, - ): + ) -> AddResult: """Add data to the table. If vector columns are missing and the table has embedding functions, then the vector columns @@ -2079,7 +2124,7 @@ class LanceTable(Table): int The number of vectors in the table. """ - LOOP.run( + return LOOP.run( self._table.add( data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) @@ -2409,8 +2454,8 @@ class LanceTable(Table): ) return self - def delete(self, where: str): - LOOP.run(self._table.delete(where)) + def delete(self, where: str) -> DeleteResult: + return LOOP.run(self._table.delete(where)) def update( self, @@ -2418,7 +2463,7 @@ class LanceTable(Table): values: Optional[dict] = None, *, values_sql: Optional[Dict[str, str]] = None, - ): + ) -> UpdateResult: """ This can be used to update zero to all rows depending on how many rows match the where clause. @@ -2436,6 +2481,12 @@ class LanceTable(Table): reference existing columns. For example, {"x": "x + 1"} will increment the x column by 1. + Returns + ------- + UpdateResult + - rows_updated: The number of rows that were updated + - version: The new version number of the table after the update + Examples -------- >>> import lancedb @@ -2449,6 +2500,7 @@ class LanceTable(Table): 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] >>> table.update(where="x = 2", values={"vector": [10.0, 10]}) + UpdateResult(rows_updated=1, version=2) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] @@ -2456,7 +2508,7 @@ class LanceTable(Table): 2 2 [10.0, 10.0] """ - LOOP.run(self._table.update(values, where=where, updates_sql=values_sql)) + return LOOP.run(self._table.update(values, where=where, updates_sql=values_sql)) def _execute_query( self, @@ -2490,7 +2542,7 @@ class LanceTable(Table): new_data: DATA, on_bad_vectors: OnBadVectorsType, fill_value: float, - ): + ) -> MergeResult: return LOOP.run( self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) ) @@ -2635,14 +2687,16 @@ class LanceTable(Table): def add_columns( self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema - ): - LOOP.run(self._table.add_columns(transforms)) + ) -> AddColumnsResult: + return LOOP.run(self._table.add_columns(transforms)) - def alter_columns(self, *alterations: Iterable[Dict[str, str]]): - LOOP.run(self._table.alter_columns(*alterations)) + def alter_columns( + self, *alterations: Iterable[Dict[str, str]] + ) -> AlterColumnsResult: + return LOOP.run(self._table.alter_columns(*alterations)) - def drop_columns(self, columns: Iterable[str]): - LOOP.run(self._table.drop_columns(columns)) + def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult: + return LOOP.run(self._table.drop_columns(columns)) def uses_v2_manifest_paths(self) -> bool: """ @@ -3197,7 +3251,7 @@ class AsyncTable: mode: Optional[Literal["append", "overwrite"]] = "append", on_bad_vectors: Optional[OnBadVectorsType] = None, fill_value: Optional[float] = None, - ): + ) -> AddResult: """Add more data to the [Table](Table). Parameters @@ -3236,7 +3290,7 @@ class AsyncTable: if isinstance(data, pa.Table): data = data.to_reader() - await self._inner.add(data, mode or "append") + return await self._inner.add(data, mode or "append") def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """ @@ -3281,12 +3335,12 @@ class AsyncTable: >>> table = db.create_table("my_table", data) >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) >>> # Perform a "upsert" operation - >>> stats = table.merge_insert("a") \\ + >>> res = table.merge_insert("a") \\ ... .when_matched_update_all() \\ ... .when_not_matched_insert_all() \\ ... .execute(new_data) - >>> stats - {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} + >>> res + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -3295,7 +3349,7 @@ class AsyncTable: 1 2 x 2 3 y 3 4 z - """ + """ # noqa: E501 on = [on] if isinstance(on, str) else list(iter(on)) return LanceMergeInsertBuilder(self, on) @@ -3626,7 +3680,7 @@ class AsyncTable: new_data: DATA, on_bad_vectors: OnBadVectorsType, fill_value: float, - ): + ) -> MergeResult: schema = await self.schema() if on_bad_vectors is None: on_bad_vectors = "error" @@ -3654,7 +3708,7 @@ class AsyncTable: ), ) - async def delete(self, where: str): + async def delete(self, where: str) -> DeleteResult: """Delete rows from the table. This can be used to delete a single row, many rows, all rows, or @@ -3685,6 +3739,7 @@ class AsyncTable: 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] >>> table.delete("x = 2") + DeleteResult(version=2) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] @@ -3698,6 +3753,7 @@ class AsyncTable: >>> to_remove '1, 5' >>> table.delete(f"x IN ({to_remove})") + DeleteResult(version=3) >>> table.to_pandas() x vector 0 3 [5.0, 6.0] @@ -3710,7 +3766,7 @@ class AsyncTable: *, where: Optional[str] = None, updates_sql: Optional[Dict[str, str]] = None, - ): + ) -> UpdateResult: """ This can be used to update zero to all rows in the table. @@ -3732,6 +3788,13 @@ class AsyncTable: literals (e.g. "7" or "'foo'") or they can be expressions based on the previous value of the row (e.g. "x + 1" to increment the x column by 1) + Returns + ------- + UpdateResult + An object containing: + - rows_updated: The number of rows that were updated + - version: The new version number of the table after the update + Examples -------- >>> import asyncio @@ -3760,7 +3823,7 @@ class AsyncTable: async def add_columns( self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema - ): + ) -> AddColumnsResult: """ Add new columns with defined values. @@ -3772,6 +3835,12 @@ class AsyncTable: each row in the table, and can reference existing columns. Alternatively, you can pass a pyarrow field or schema to add new columns with NULLs. + + Returns + ------- + AddColumnsResult + version: the new version number of the table after adding columns. + """ if isinstance(transforms, pa.Field): transforms = [transforms] @@ -3780,11 +3849,13 @@ class AsyncTable: ): transforms = pa.schema(transforms) if isinstance(transforms, pa.Schema): - await self._inner.add_columns_with_schema(transforms) + return await self._inner.add_columns_with_schema(transforms) else: - await self._inner.add_columns(list(transforms.items())) + return await self._inner.add_columns(list(transforms.items())) - async def alter_columns(self, *alterations: Iterable[dict[str, Any]]): + async def alter_columns( + self, *alterations: Iterable[dict[str, Any]] + ) -> AlterColumnsResult: """ Alter column names and nullability. @@ -3804,8 +3875,13 @@ class AsyncTable: nullability is not changed. Only non-nullable columns can be changed to nullable. Currently, you cannot change a nullable column to non-nullable. + + Returns + ------- + AlterColumnsResult + version: the new version number of the table after the alteration. """ - await self._inner.alter_columns(alterations) + return await self._inner.alter_columns(alterations) async def drop_columns(self, columns: Iterable[str]): """ @@ -3816,7 +3892,7 @@ class AsyncTable: columns : Iterable[str] The names of the columns to drop. """ - await self._inner.drop_columns(columns) + return await self._inner.drop_columns(columns) async def version(self) -> int: """ diff --git a/python/python/tests/docs/test_merge_insert.py b/python/python/tests/docs/test_merge_insert.py index 72e4ce4d..228faa31 100644 --- a/python/python/tests/docs/test_merge_insert.py +++ b/python/python/tests/docs/test_merge_insert.py @@ -18,19 +18,19 @@ def test_upsert(mem_db): {"id": 1, "name": "Bobby"}, {"id": 2, "name": "Charlie"}, ] - stats = ( + res = ( table.merge_insert("id") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_users) ) table.count_rows() # 3 - stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0} + res # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0} # --8<-- [end:upsert_basic] assert table.count_rows() == 3 - assert stats["num_inserted_rows"] == 1 - assert stats["num_updated_rows"] == 1 - assert stats["num_deleted_rows"] == 0 + assert res.num_inserted_rows == 1 + assert res.num_deleted_rows == 0 + assert res.num_updated_rows == 1 @pytest.mark.asyncio @@ -48,19 +48,22 @@ async def test_upsert_async(mem_db_async): {"id": 1, "name": "Bobby"}, {"id": 2, "name": "Charlie"}, ] - stats = await ( + res = await ( table.merge_insert("id") .when_matched_update_all() .when_not_matched_insert_all() .execute(new_users) ) await table.count_rows() # 3 - stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0} + res + # MergeResult(version=2, num_updated_rows=1, + # num_inserted_rows=1, num_deleted_rows=0) # --8<-- [end:upsert_basic_async] assert await table.count_rows() == 3 - assert stats["num_inserted_rows"] == 1 - assert stats["num_updated_rows"] == 1 - assert stats["num_deleted_rows"] == 0 + assert res.version == 2 + assert res.num_inserted_rows == 1 + assert res.num_deleted_rows == 0 + assert res.num_updated_rows == 1 def test_insert_if_not_exists(mem_db): @@ -77,16 +80,19 @@ def test_insert_if_not_exists(mem_db): {"domain": "google.com", "name": "Google"}, {"domain": "facebook.com", "name": "Facebook"}, ] - stats = ( + res = ( table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains) ) table.count_rows() # 3 - stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0} + res + # MergeResult(version=2, num_updated_rows=0, + # num_inserted_rows=1, num_deleted_rows=0) # --8<-- [end:insert_if_not_exists] assert table.count_rows() == 3 - assert stats["num_inserted_rows"] == 1 - assert stats["num_updated_rows"] == 0 - assert stats["num_deleted_rows"] == 0 + assert res.version == 2 + assert res.num_inserted_rows == 1 + assert res.num_deleted_rows == 0 + assert res.num_updated_rows == 0 @pytest.mark.asyncio @@ -104,16 +110,19 @@ async def test_insert_if_not_exists_async(mem_db_async): {"domain": "google.com", "name": "Google"}, {"domain": "facebook.com", "name": "Facebook"}, ] - stats = await ( + res = await ( table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains) ) await table.count_rows() # 3 - stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0} - # --8<-- [end:insert_if_not_exists_async] + res + # MergeResult(version=2, num_updated_rows=0, + # num_inserted_rows=1, num_deleted_rows=0) + # --8<-- [end:insert_if_not_exists] assert await table.count_rows() == 3 - assert stats["num_inserted_rows"] == 1 - assert stats["num_updated_rows"] == 0 - assert stats["num_deleted_rows"] == 0 + assert res.version == 2 + assert res.num_inserted_rows == 1 + assert res.num_deleted_rows == 0 + assert res.num_updated_rows == 0 def test_replace_range(mem_db): @@ -131,7 +140,7 @@ def test_replace_range(mem_db): new_chunks = [ {"doc_id": 1, "chunk_id": 0, "text": "Baz"}, ] - stats = ( + res = ( table.merge_insert(["doc_id", "chunk_id"]) .when_matched_update_all() .when_not_matched_insert_all() @@ -139,12 +148,15 @@ def test_replace_range(mem_db): .execute(new_chunks) ) table.count_rows("doc_id = 1") # 1 - stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1} - # --8<-- [end:replace_range] + res + # MergeResult(version=2, num_updated_rows=1, + # num_inserted_rows=0, num_deleted_rows=1) + # --8<-- [end:insert_if_not_exists] assert table.count_rows("doc_id = 1") == 1 - assert stats["num_inserted_rows"] == 0 - assert stats["num_updated_rows"] == 1 - assert stats["num_deleted_rows"] == 1 + assert res.version == 2 + assert res.num_inserted_rows == 0 + assert res.num_deleted_rows == 1 + assert res.num_updated_rows == 1 @pytest.mark.asyncio @@ -163,7 +175,7 @@ async def test_replace_range_async(mem_db_async): new_chunks = [ {"doc_id": 1, "chunk_id": 0, "text": "Baz"}, ] - stats = await ( + res = await ( table.merge_insert(["doc_id", "chunk_id"]) .when_matched_update_all() .when_not_matched_insert_all() @@ -171,9 +183,12 @@ async def test_replace_range_async(mem_db_async): .execute(new_chunks) ) await table.count_rows("doc_id = 1") # 1 - stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1} - # --8<-- [end:replace_range_async] + res + # MergeResult(version=2, num_updated_rows=1, + # num_inserted_rows=0, num_deleted_rows=1) + # --8<-- [end:insert_if_not_exists] assert await table.count_rows("doc_id = 1") == 1 - assert stats["num_inserted_rows"] == 0 - assert stats["num_updated_rows"] == 1 - assert stats["num_deleted_rows"] == 1 + assert res.version == 2 + assert res.num_inserted_rows == 0 + assert res.num_deleted_rows == 1 + assert res.num_updated_rows == 1 diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 594373cf..af412c5e 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -106,15 +106,22 @@ async def test_update_async(mem_db_async: AsyncConnection): table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) assert await table.count_rows("id == 0") == 1 assert await table.count_rows("id == 7") == 0 - await table.update({"id": 7}) + update_res = await table.update({"id": 7}) + assert update_res.rows_updated == 1 + assert update_res.version == 2 assert await table.count_rows("id == 7") == 1 assert await table.count_rows("id == 0") == 0 - await table.add([{"id": 2}]) - await table.update(where="id % 2 == 0", updates_sql={"id": "5"}) + add_res = await table.add([{"id": 2}]) + assert add_res.version == 3 + update_res = await table.update(where="id % 2 == 0", updates_sql={"id": "5"}) + assert update_res.rows_updated == 1 + assert update_res.version == 4 assert await table.count_rows("id == 7") == 1 assert await table.count_rows("id == 2") == 0 assert await table.count_rows("id == 5") == 1 - await table.update({"id": 10}, where="id == 5") + update_res = await table.update({"id": 10}, where="id == 5") + assert update_res.rows_updated == 1 + assert update_res.version == 5 assert await table.count_rows("id == 10") == 1 @@ -437,7 +444,8 @@ def test_add_pydantic_model(mem_db: DBConnection): content="foo", meta=Metadata(source="bar", timestamp=datetime.now()) ), ) - tbl.add([expected]) + add_res = tbl.add([expected]) + assert add_res.version == 2 result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0] assert result == expected @@ -459,11 +467,12 @@ async def test_add_async(mem_db_async: AsyncConnection): ], ) assert await table.count_rows() == 2 - await table.add( + add_res = await table.add( data=[ {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, ], ) + assert add_res.version == 2 assert await table.count_rows() == 3 @@ -795,7 +804,8 @@ def test_delete(mem_db: DBConnection): ) assert len(table) == 2 assert len(table.list_versions()) == 1 - table.delete("id=0") + delete_res = table.delete("id=0") + assert delete_res.version == 2 assert len(table.list_versions()) == 2 assert table.version == 2 assert len(table) == 1 @@ -809,7 +819,9 @@ def test_update(mem_db: DBConnection): ) assert len(table) == 2 assert len(table.list_versions()) == 1 - table.update(where="id=0", values={"vector": [1.1, 1.1]}) + update_res = table.update(where="id=0", values={"vector": [1.1, 1.1]}) + assert update_res.version == 2 + assert update_res.rows_updated == 1 assert len(table.list_versions()) == 2 assert table.version == 2 assert len(table) == 2 @@ -898,9 +910,16 @@ def test_merge_insert(mem_db: DBConnection): new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) # upsert - table.merge_insert( - "a" - ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) + merge_insert_res = ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(new_data) + ) + assert merge_insert_res.version == 2 + assert merge_insert_res.num_inserted_rows == 1 + assert merge_insert_res.num_updated_rows == 2 + assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) assert table.to_arrow().sort_by("a") == expected @@ -908,17 +927,28 @@ def test_merge_insert(mem_db: DBConnection): table.restore(version) # conditional update - table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute( - new_data + merge_insert_res = ( + table.merge_insert("a") + .when_matched_update_all(where="target.b = 'b'") + .execute(new_data) ) + assert merge_insert_res.version == 4 + assert merge_insert_res.num_inserted_rows == 0 + assert merge_insert_res.num_updated_rows == 1 + assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]}) assert table.to_arrow().sort_by("a") == expected table.restore(version) # insert-if-not-exists - table.merge_insert("a").when_not_matched_insert_all().execute(new_data) - + merge_insert_res = ( + table.merge_insert("a").when_not_matched_insert_all().execute(new_data) + ) + assert merge_insert_res.version == 6 + assert merge_insert_res.num_inserted_rows == 1 + assert merge_insert_res.num_updated_rows == 0 + assert merge_insert_res.num_deleted_rows == 0 expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]}) assert table.to_arrow().sort_by("a") == expected @@ -927,13 +957,17 @@ def test_merge_insert(mem_db: DBConnection): new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) # replace-range - ( + merge_insert_res = ( table.merge_insert("a") .when_matched_update_all() .when_not_matched_insert_all() .when_not_matched_by_source_delete("a > 2") .execute(new_data) ) + assert merge_insert_res.version == 8 + assert merge_insert_res.num_inserted_rows == 1 + assert merge_insert_res.num_updated_rows == 1 + assert merge_insert_res.num_deleted_rows == 1 expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert table.to_arrow().sort_by("a") == expected @@ -941,11 +975,17 @@ def test_merge_insert(mem_db: DBConnection): table.restore(version) # replace-range no condition - table.merge_insert( - "a" - ).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute( - new_data + merge_insert_res = ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete() + .execute(new_data) ) + assert merge_insert_res.version == 10 + assert merge_insert_res.num_inserted_rows == 1 + assert merge_insert_res.num_updated_rows == 1 + assert merge_insert_res.num_deleted_rows == 2 expected = pa.table({"a": [2, 4], "b": ["x", "z"]}) assert table.to_arrow().sort_by("a") == expected @@ -1478,11 +1518,13 @@ def test_restore_consistency(tmp_path): def test_add_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) table = LanceTable.create(mem_db, "my_table", data=data) - table.add_columns({"new_col": "id + 2"}) + add_columns_res = table.add_columns({"new_col": "id + 2"}) + assert add_columns_res.version == 2 assert table.to_arrow().column_names == ["id", "new_col"] assert table.to_arrow()["new_col"].to_pylist() == [2, 3] - table.add_columns({"null_int": "cast(null as bigint)"}) + add_columns_res = table.add_columns({"null_int": "cast(null as bigint)"}) + assert add_columns_res.version == 3 assert table.schema.field("null_int").type == pa.int64() @@ -1490,7 +1532,8 @@ def test_add_columns(mem_db: DBConnection): async def test_add_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) - await table.add_columns({"new_col": "id + 2"}) + add_columns_res = await table.add_columns({"new_col": "id + 2"}) + assert add_columns_res.version == 2 data = await table.to_arrow() assert data.column_names == ["id", "new_col"] assert data["new_col"].to_pylist() == [2, 3] @@ -1500,9 +1543,10 @@ async def test_add_columns_async(mem_db_async: AsyncConnection): async def test_add_columns_with_schema(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) - await table.add_columns( + add_columns_res = await table.add_columns( [pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))] ) + assert add_columns_res.version == 2 assert await table.schema() == pa.schema( [ @@ -1513,11 +1557,12 @@ async def test_add_columns_with_schema(mem_db_async: AsyncConnection): ) table = await mem_db_async.create_table("table2", data=data) - await table.add_columns( + add_columns_res = await table.add_columns( pa.schema( [pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))] ) ) + assert add_columns_res.version == 2 assert await table.schema() == pa.schema( [ pa.field("id", pa.int64()), @@ -1530,7 +1575,8 @@ async def test_add_columns_with_schema(mem_db_async: AsyncConnection): def test_alter_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) table = mem_db.create_table("my_table", data=data) - table.alter_columns({"path": "id", "rename": "new_id"}) + alter_columns_res = table.alter_columns({"path": "id", "rename": "new_id"}) + assert alter_columns_res.version == 2 assert table.to_arrow().column_names == ["new_id"] @@ -1538,9 +1584,13 @@ def test_alter_columns(mem_db: DBConnection): async def test_alter_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) table = await mem_db_async.create_table("my_table", data=data) - await table.alter_columns({"path": "id", "rename": "new_id"}) + alter_columns_res = await table.alter_columns({"path": "id", "rename": "new_id"}) + assert alter_columns_res.version == 2 assert (await table.to_arrow()).column_names == ["new_id"] - await table.alter_columns(dict(path="new_id", data_type=pa.int16(), nullable=True)) + alter_columns_res = await table.alter_columns( + dict(path="new_id", data_type=pa.int16(), nullable=True) + ) + assert alter_columns_res.version == 3 data = await table.to_arrow() assert data.column(0).type == pa.int16() assert data.schema.field(0).nullable @@ -1549,7 +1599,8 @@ async def test_alter_columns_async(mem_db_async: AsyncConnection): def test_drop_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) table = mem_db.create_table("my_table", data=data) - table.drop_columns(["category"]) + drop_columns_res = table.drop_columns(["category"]) + assert drop_columns_res.version == 2 assert table.to_arrow().column_names == ["id"] @@ -1557,7 +1608,8 @@ def test_drop_columns(mem_db: DBConnection): async def test_drop_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) table = await mem_db_async.create_table("my_table", data=data) - await table.drop_columns(["category"]) + drop_columns_res = await table.drop_columns(["category"]) + assert drop_columns_res.version == 2 assert (await table.to_arrow()).column_names == ["id"] diff --git a/python/src/lib.rs b/python/src/lib.rs index a1ed0d08..bd1b8d02 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -11,7 +11,10 @@ use pyo3::{ wrap_pyfunction, Bound, PyResult, Python, }; use query::{FTSQuery, HybridQuery, Query, VectorQuery}; -use table::Table; +use table::{ + AddColumnsResult, AddResult, AlterColumnsResult, DeleteResult, DropColumnsResult, MergeResult, + Table, UpdateResult, +}; pub mod arrow; pub mod connection; @@ -35,6 +38,13 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/python/src/table.rs b/python/src/table.rs index c8073e05..820b22bd 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -58,6 +58,170 @@ pub struct OptimizeStats { pub prune: RemovalStats, } +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct UpdateResult { + pub rows_updated: u64, + pub version: u64, +} + +#[pymethods] +impl UpdateResult { + pub fn __repr__(&self) -> String { + format!( + "UpdateResult(rows_updated={}, version={})", + self.rows_updated, self.version + ) + } +} + +impl From for UpdateResult { + fn from(result: lancedb::table::UpdateResult) -> Self { + Self { + rows_updated: result.rows_updated, + version: result.version, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct AddResult { + pub version: u64, +} + +#[pymethods] +impl AddResult { + pub fn __repr__(&self) -> String { + format!("AddResult(version={})", self.version) + } +} + +impl From for AddResult { + fn from(result: lancedb::table::AddResult) -> Self { + Self { + version: result.version, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct DeleteResult { + pub version: u64, +} + +#[pymethods] +impl DeleteResult { + pub fn __repr__(&self) -> String { + format!("DeleteResult(version={})", self.version) + } +} + +impl From for DeleteResult { + fn from(result: lancedb::table::DeleteResult) -> Self { + Self { + version: result.version, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct MergeResult { + pub version: u64, + pub num_updated_rows: u64, + pub num_inserted_rows: u64, + pub num_deleted_rows: u64, +} + +#[pymethods] +impl MergeResult { + pub fn __repr__(&self) -> String { + format!( + "MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={})", + self.version, + self.num_updated_rows, + self.num_inserted_rows, + self.num_deleted_rows + ) + } +} + +impl From for MergeResult { + fn from(result: lancedb::table::MergeResult) -> Self { + Self { + version: result.version, + num_updated_rows: result.num_updated_rows, + num_inserted_rows: result.num_inserted_rows, + num_deleted_rows: result.num_deleted_rows, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct AddColumnsResult { + pub version: u64, +} + +#[pymethods] +impl AddColumnsResult { + pub fn __repr__(&self) -> String { + format!("AddColumnsResult(version={})", self.version) + } +} + +impl From for AddColumnsResult { + fn from(result: lancedb::table::AddColumnsResult) -> Self { + Self { + version: result.version, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct AlterColumnsResult { + pub version: u64, +} + +#[pymethods] +impl AlterColumnsResult { + pub fn __repr__(&self) -> String { + format!("AlterColumnsResult(version={})", self.version) + } +} + +impl From for AlterColumnsResult { + fn from(result: lancedb::table::AlterColumnsResult) -> Self { + Self { + version: result.version, + } + } +} + +#[pyclass(get_all)] +#[derive(Clone, Debug)] +pub struct DropColumnsResult { + pub version: u64, +} + +#[pymethods] +impl DropColumnsResult { + pub fn __repr__(&self) -> String { + format!("DropColumnsResult(version={})", self.version) + } +} + +impl From for DropColumnsResult { + fn from(result: lancedb::table::DropColumnsResult) -> Self { + Self { + version: result.version, + } + } +} + #[pyclass] pub struct Table { // We keep a copy of the name to use if the inner table is dropped @@ -132,15 +296,16 @@ impl Table { } future_into_py(self_.py(), async move { - op.execute().await.infer_error()?; - Ok(()) + let result = op.execute().await.infer_error()?; + Ok(AddResult::from(result)) }) } pub fn delete(self_: PyRef<'_, Self>, condition: String) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { - inner.delete(&condition).await.infer_error() + let result = inner.delete(&condition).await.infer_error()?; + Ok(DeleteResult::from(result)) }) } @@ -160,8 +325,8 @@ impl Table { op = op.column(column_name, value); } future_into_py(self_.py(), async move { - op.execute().await.infer_error()?; - Ok(()) + let result = op.execute().await.infer_error()?; + Ok(UpdateResult::from(result)) }) } @@ -489,14 +654,8 @@ impl Table { } future_into_py(self_.py(), async move { - let stats = builder.execute(Box::new(batches)).await.infer_error()?; - Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item("num_inserted_rows", stats.num_inserted_rows)?; - dict.set_item("num_updated_rows", stats.num_updated_rows)?; - dict.set_item("num_deleted_rows", stats.num_deleted_rows)?; - Ok(dict.unbind()) - }) + let res = builder.execute(Box::new(batches)).await.infer_error()?; + Ok(MergeResult::from(res)) }) } @@ -532,8 +691,8 @@ impl Table { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { - inner.add_columns(definitions, None).await.infer_error()?; - Ok(()) + let result = inner.add_columns(definitions, None).await.infer_error()?; + Ok(AddColumnsResult::from(result)) }) } @@ -546,8 +705,8 @@ impl Table { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { - inner.add_columns(transform, None).await.infer_error()?; - Ok(()) + let result = inner.add_columns(transform, None).await.infer_error()?; + Ok(AddColumnsResult::from(result)) }) } @@ -590,8 +749,8 @@ impl Table { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { - inner.alter_columns(&alterations).await.infer_error()?; - Ok(()) + let result = inner.alter_columns(&alterations).await.infer_error()?; + Ok(AlterColumnsResult::from(result)) }) } @@ -599,8 +758,8 @@ impl Table { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { let column_refs = columns.iter().map(String::as_str).collect::>(); - inner.drop_columns(&column_refs).await.infer_error()?; - Ok(()) + let result = inner.drop_columns(&column_refs).await.infer_error()?; + Ok(DropColumnsResult::from(result)) }) } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 2e089cdd..11344964 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -4,7 +4,14 @@ use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; +use crate::table::AddColumnsResult; +use crate::table::AddResult; +use crate::table::AlterColumnsResult; +use crate::table::DeleteResult; +use crate::table::DropColumnsResult; +use crate::table::MergeResult; use crate::table::Tags; +use crate::table::UpdateResult; use crate::table::{AddDataMode, AnyQuery, Filter, TableStatistics}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error, Table}; @@ -47,7 +54,6 @@ use crate::{ TableDefinition, UpdateBuilder, }, }; -use lance::dataset::MergeStats; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); @@ -735,7 +741,7 @@ impl BaseTable for RemoteTable { &self, add: AddDataBuilder, data: Box, - ) -> Result<()> { + ) -> Result { self.check_mutable().await?; let mut request = self .client @@ -750,9 +756,21 @@ impl BaseTable for RemoteTable { } let (request_id, response) = self.send_streaming(request, data, true).await?; - self.check_table_response(&request_id, response).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; - Ok(()) + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(AddResult { version }); + } + + let add_response: AddResult = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse add response: {}", e).into(), + request_id, + status_code: None, + })?; + Ok(add_response) } async fn create_plan( @@ -885,7 +903,7 @@ impl BaseTable for RemoteTable { Ok(final_analyze) } - async fn update(&self, update: UpdateBuilder) -> Result { + async fn update(&self, update: UpdateBuilder) -> Result { self.check_mutable().await?; let request = self .client @@ -902,13 +920,29 @@ impl BaseTable for RemoteTable { })); let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; - self.check_table_response(&request_id, response).await?; + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(UpdateResult { + rows_updated: 0, + version, + }); + } - Ok(0) // TODO: support returning number of modified rows once supported in SaaS. + let update_response: UpdateResult = + serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse update response: {}", e).into(), + request_id, + status_code: None, + })?; + + Ok(update_response) } - async fn delete(&self, predicate: &str) -> Result<()> { + async fn delete(&self, predicate: &str) -> Result { self.check_mutable().await?; let body = serde_json::json!({ "predicate": predicate }); let request = self @@ -916,8 +950,22 @@ impl BaseTable for RemoteTable { .post(&format!("/v1/table/{}/delete/", self.name)) .json(&body); let (request_id, response) = self.send(request, true).await?; - self.check_table_response(&request_id, response).await?; - Ok(()) + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + if body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(DeleteResult { version }); + } + + let delete_response: DeleteResult = + serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse delete response: {}", e).into(), + request_id, + status_code: None, + })?; + Ok(delete_response) } async fn create_index(&self, mut index: IndexBuilder) -> Result<()> { @@ -1023,7 +1071,7 @@ impl BaseTable for RemoteTable { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result { + ) -> Result { self.check_mutable().await?; let query = MergeInsertRequest::try_from(params)?; @@ -1035,11 +1083,28 @@ impl BaseTable for RemoteTable { let (request_id, response) = self.send_streaming(request, new_data, true).await?; - // TODO: server can response with these stats in response body. - // We should test that we can handle both empty response from old server - // and response with stats from new server. - self.check_table_response(&request_id, response).await?; - Ok(MergeStats::default()) + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(MergeResult { + version, + num_deleted_rows: 0, + num_inserted_rows: 0, + num_updated_rows: 0, + }); + } + + let merge_insert_response: MergeResult = + serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse merge_insert response: {}", e).into(), + request_id, + status_code: None, + })?; + + Ok(merge_insert_response) } async fn tags(&self) -> Result> { @@ -1062,7 +1127,7 @@ impl BaseTable for RemoteTable { &self, transforms: NewColumnTransform, _read_columns: Option>, - ) -> Result<()> { + ) -> Result { self.check_mutable().await?; match transforms { NewColumnTransform::SqlExpressions(expressions) => { @@ -1080,9 +1145,24 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/add_columns/", self.name)) .json(&body); - let (request_id, response) = self.send(request, true).await?; // todo: - self.check_table_response(&request_id, response).await?; - Ok(()) + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(AddColumnsResult { version }); + } + + let result: AddColumnsResult = + serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse add_columns response: {}", e).into(), + request_id, + status_code: None, + })?; + + Ok(result) } _ => { return Err(Error::NotSupported { @@ -1092,7 +1172,7 @@ impl BaseTable for RemoteTable { } } - async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()> { + async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result { self.check_mutable().await?; let body = alterations .iter() @@ -1120,11 +1200,25 @@ impl BaseTable for RemoteTable { .post(&format!("/v1/table/{}/alter_columns/", self.name)) .json(&body); let (request_id, response) = self.send(request, true).await?; - self.check_table_response(&request_id, response).await?; - Ok(()) + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(AlterColumnsResult { version }); + } + + let result: AlterColumnsResult = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse alter_columns response: {}", e).into(), + request_id, + status_code: None, + })?; + + Ok(result) } - async fn drop_columns(&self, columns: &[&str]) -> Result<()> { + async fn drop_columns(&self, columns: &[&str]) -> Result { self.check_mutable().await?; let body = serde_json::json!({ "columns": columns }); let request = self @@ -1132,8 +1226,22 @@ impl BaseTable for RemoteTable { .post(&format!("/v1/table/{}/drop_columns/", self.name)) .json(&body); let (request_id, response) = self.send(request, true).await?; - self.check_table_response(&request_id, response).await?; - Ok(()) + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + if body.trim().is_empty() || body == "{}" { + // Backward compatible with old servers + let version = self.version().await?; + return Ok(DropColumnsResult { version }); + } + + let result: DropColumnsResult = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse drop_columns response: {}", e).into(), + request_id, + status_code: None, + })?; + + Ok(result) } async fn list_indices(&self) -> Result> { @@ -1357,16 +1465,20 @@ mod tests { .execute(example_data()) .map_ok(|_| ()), ), - Box::pin(table.delete("false")), - Box::pin(table.add_columns( - NewColumnTransform::SqlExpressions(vec![("x".into(), "y".into())]), - None, - )), + Box::pin(table.delete("false").map_ok(|_| ())), + Box::pin( + table + .add_columns( + NewColumnTransform::SqlExpressions(vec![("x".into(), "y".into())]), + None, + ) + .map_ok(|_| ()), + ), Box::pin(async { let alterations = vec![ColumnAlteration::new("x".into()).rename("y".into())]; - table.alter_columns(&alterations).await + table.alter_columns(&alterations).await.map(|_| ()) }), - Box::pin(table.drop_columns(&["a"])), + Box::pin(table.drop_columns(&["a"]).map_ok(|_| ())), // TODO: other endpoints. ]; @@ -1497,6 +1609,60 @@ mod tests { body } + #[tokio::test] + async fn test_add_append_old_server() { + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::channel(); + let table = Table::new_with_handler("my_table", move |mut request| { + if request.url().path() == "/v1/table/my_table/insert/" { + assert_eq!(request.method(), "POST"); + assert!(request + .url() + .query_pairs() + .filter(|(k, _)| k == "mode") + .all(|(_, v)| v == "append")); + + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + // Return empty JSON object for old server behavior + http::Response::builder().status(200).body("").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + // Handle describe call for backward compatibility + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 42); + + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let expected_body = write_ipc_stream(&data); + assert_eq!(&body, &expected_body); + } + #[tokio::test] async fn test_add_append() { let data = RecordBatch::try_new( @@ -1526,15 +1692,80 @@ mod tests { std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); sender.send(body_out).unwrap(); - http::Response::builder().status(200).body("").unwrap() + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() }); - table + let result = table .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) .execute() .await .unwrap(); + assert_eq!(result.version, 43); + + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let expected_body = write_ipc_stream(&data); + assert_eq!(&body, &expected_body); + } + + #[tokio::test] + async fn test_add_overwrite_old_server() { + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::channel(); + let table = Table::new_with_handler("my_table", move |mut request| { + if request.url().path() == "/v1/table/my_table/insert/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "mode") + .map(|kv| kv.1) + .as_deref(), + Some("overwrite"), + "Expected mode=overwrite" + ); + + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + // Return empty JSON object for old server behavior + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + // Handle describe call for backward compatibility + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .mode(AddDataMode::Overwrite) + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 42); + let body = receiver.recv().unwrap(); let body = collect_body(body).await; let expected_body = write_ipc_stream(&data); @@ -1573,22 +1804,83 @@ mod tests { std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); sender.send(body_out).unwrap(); - http::Response::builder().status(200).body("").unwrap() + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() }); - table + let result = table .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) .mode(AddDataMode::Overwrite) .execute() .await .unwrap(); + assert_eq!(result.version, 43); + let body = receiver.recv().unwrap(); let body = collect_body(body).await; let expected_body = write_ipc_stream(&data); assert_eq!(&body, &expected_body); } + #[tokio::test] + async fn test_update_old_server() { + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/update/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + if let Some(body) = request.body().unwrap().as_bytes() { + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let updates = value.get("updates").unwrap().as_array().unwrap(); + assert!(updates.len() == 2); + + let col_name = updates[0][0].as_str().unwrap(); + let expression = updates[0][1].as_str().unwrap(); + assert_eq!(col_name, "a"); + assert_eq!(expression, "a + 1"); + + let col_name = updates[1][0].as_str().unwrap(); + let expression = updates[1][1].as_str().unwrap(); + assert_eq!(col_name, "b"); + assert_eq!(expression, "b - 1"); + + let only_if = value.get("predicate").unwrap().as_str().unwrap(); + assert_eq!(only_if, "b > 10"); + } + + // Return empty JSON object (old server behavior) + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + // Handle the describe request for version lookup + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .update() + .column("a", "a + 1") + .column("b", "b - 1") + .only_if("b > 10") + .execute() + .await + .unwrap(); + + assert_eq!(result.version, 42); + assert_eq!(result.rows_updated, 0); + } + #[tokio::test] async fn test_update() { let table = Table::new_with_handler("my_table", |request| { @@ -1619,10 +1911,14 @@ mod tests { assert_eq!(only_if, "b > 10"); } - http::Response::builder().status(200).body("{}").unwrap() + // Return structured response (new server behavior) + http::Response::builder() + .status(200) + .body(r#"{"rows_updated": 5, "version": 43}"#) + .unwrap() }); - table + let result = table .update() .column("a", "a + 1") .column("b", "b - 1") @@ -1630,6 +1926,157 @@ mod tests { .execute() .await .unwrap(); + + // Verify result for new behavior + assert_eq!(result.rows_updated, 5); // From structured response + assert_eq!(result.version, 43); // From structured response + } + + #[tokio::test] + async fn test_alter_columns_old_server() { + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/alter_columns/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let alterations = value.get("alterations").unwrap().as_array().unwrap(); + assert!(alterations.len() == 2); + + let path = alterations[0]["path"].as_str().unwrap(); + let data_type = alterations[0]["data_type"]["type"].as_str().unwrap(); + assert_eq!(path, "b.c"); + assert_eq!(data_type, "int32"); + + let path = alterations[1]["path"].as_str().unwrap(); + let nullable = alterations[1]["nullable"].as_bool().unwrap(); + let rename = alterations[1]["rename"].as_str().unwrap(); + assert_eq!(path, "x"); + assert!(nullable); + assert_eq!(rename, "y"); + + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .alter_columns(&[ + ColumnAlteration::new("b.c".into()).cast_to(DataType::Int32), + ColumnAlteration::new("x".into()) + .rename("y".into()) + .set_nullable(true), + ]) + .await + .unwrap(); + + assert_eq!(result.version, 42); + } + + #[tokio::test] + async fn test_alter_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/alter_columns/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let alterations = value.get("alterations").unwrap().as_array().unwrap(); + assert!(alterations.len() == 2); + + let path = alterations[0]["path"].as_str().unwrap(); + let data_type = alterations[0]["data_type"]["type"].as_str().unwrap(); + assert_eq!(path, "b.c"); + assert_eq!(data_type, "int32"); + + let path = alterations[1]["path"].as_str().unwrap(); + let nullable = alterations[1]["nullable"].as_bool().unwrap(); + let rename = alterations[1]["rename"].as_str().unwrap(); + assert_eq!(path, "x"); + assert!(nullable); + assert_eq!(rename, "y"); + + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() + }); + + let result = table + .alter_columns(&[ + ColumnAlteration::new("b.c".into()).cast_to(DataType::Int32), + ColumnAlteration::new("x".into()) + .rename("y".into()) + .set_nullable(true), + ]) + .await + .unwrap(); + + assert_eq!(result.version, 43); + } + + #[tokio::test] + async fn test_merge_insert_old_server() { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + + // Default parameters + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/merge_insert/" { + assert_eq!(request.method(), "POST"); + + let params = request.url().query_pairs().collect::>(); + assert_eq!(params["on"], "some_col"); + assert_eq!(params["when_matched_update_all"], "false"); + assert_eq!(params["when_not_matched_insert_all"], "false"); + assert_eq!(params["when_not_matched_by_source_delete"], "false"); + assert!(!params.contains_key("when_matched_update_all_filt")); + assert!(!params.contains_key("when_not_matched_by_source_delete_filt")); + + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .merge_insert(&["some_col"]) + .execute(data) + .await + .unwrap(); + + assert_eq!(result.version, 42); + assert_eq!(result.num_deleted_rows, 0); + assert_eq!(result.num_inserted_rows, 0); + assert_eq!(result.num_updated_rows, 0); } #[tokio::test] @@ -1644,7 +2091,7 @@ mod tests { batch.schema(), )); - // Default parameters + // Default parameters with new server behavior let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/"); @@ -1657,53 +2104,22 @@ mod tests { assert!(!params.contains_key("when_matched_update_all_filt")); assert!(!params.contains_key("when_not_matched_by_source_delete_filt")); - http::Response::builder().status(200).body("").unwrap() + http::Response::builder() + .status(200) + .body(r#"{"version": 43, "num_deleted_rows": 0, "num_inserted_rows": 3, "num_updated_rows": 0}"#) + .unwrap() }); - table + let result = table .merge_insert(&["some_col"]) .execute(data) .await .unwrap(); - // All parameters specified - let (sender, receiver) = std::sync::mpsc::channel(); - let table = Table::new_with_handler("my_table", move |mut request| { - assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/"); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - ARROW_STREAM_CONTENT_TYPE - ); - - let params = request.url().query_pairs().collect::>(); - assert_eq!(params["on"], "some_col"); - assert_eq!(params["when_matched_update_all"], "true"); - assert_eq!(params["when_not_matched_insert_all"], "false"); - assert_eq!(params["when_not_matched_by_source_delete"], "true"); - assert_eq!(params["when_matched_update_all_filt"], "a = 1"); - assert_eq!(params["when_not_matched_by_source_delete_filt"], "b = 2"); - - let mut body_out = reqwest::Body::from(Vec::new()); - std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); - sender.send(body_out).unwrap(); - - http::Response::builder().status(200).body("").unwrap() - }); - let mut builder = table.merge_insert(&["some_col"]); - builder - .when_matched_update_all(Some("a = 1".into())) - .when_not_matched_by_source_delete(Some("b = 2".into())); - let data = Box::new(RecordBatchIterator::new( - [Ok(batch.clone())], - batch.schema(), - )); - builder.execute(data).await.unwrap(); - - let body = receiver.recv().unwrap(); - let body = collect_body(body).await; - let expected_body = write_ipc_stream(&batch); - assert_eq!(&body, &expected_body); + assert_eq!(result.version, 43); + assert_eq!(result.num_deleted_rows, 0); + assert_eq!(result.num_inserted_rows, 3); + assert_eq!(result.num_updated_rows, 0); } #[tokio::test] @@ -1742,6 +2158,36 @@ mod tests { assert!(e.to_string().contains("Hit retry limit")); } + #[tokio::test] + async fn test_delete_old_server() { + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/delete/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let predicate = body.get("predicate").unwrap().as_str().unwrap(); + assert_eq!(predicate, "id in (1, 2, 3)"); + + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table.delete("id in (1, 2, 3)").await.unwrap(); + assert_eq!(result.version, 42); + } + #[tokio::test] async fn test_delete() { let table = Table::new_with_handler("my_table", |request| { @@ -1757,12 +2203,82 @@ mod tests { let predicate = body.get("predicate").unwrap().as_str().unwrap(); assert_eq!(predicate, "id in (1, 2, 3)"); - http::Response::builder().status(200).body("").unwrap() + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() }); - table.delete("id in (1, 2, 3)").await.unwrap(); + let result = table.delete("id in (1, 2, 3)").await.unwrap(); + assert_eq!(result.version, 43); } + #[tokio::test] + async fn test_drop_columns_old_server() { + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/drop_columns/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let columns = value.get("columns").unwrap().as_array().unwrap(); + assert!(columns.len() == 2); + + let col1 = columns[0].as_str().unwrap(); + let col2 = columns[1].as_str().unwrap(); + assert_eq!(col1, "a"); + assert_eq!(col2, "b"); + + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table.drop_columns(&["a", "b"]).await.unwrap(); + assert_eq!(result.version, 42); + } + + #[tokio::test] + async fn test_drop_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/drop_columns/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let columns = value.get("columns").unwrap().as_array().unwrap(); + assert!(columns.len() == 2); + + let col1 = columns[0].as_str().unwrap(); + let col2 = columns[1].as_str().unwrap(); + assert_eq!(col1, "a"); + assert_eq!(col2, "b"); + + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() + }); + + let result = table.drop_columns(&["a", "b"]).await.unwrap(); + assert_eq!(result.version, 43); + } #[tokio::test] async fn test_query_plain() { let expected_data = RecordBatch::try_new( @@ -2577,6 +3093,59 @@ mod tests { assert!(matches!(res, Err(Error::NotSupported { .. }))); } + #[tokio::test] + async fn test_add_columns_old_server() { + let table = Table::new_with_handler("my_table", |request| { + if request.url().path() == "/v1/table/my_table/add_columns/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let new_columns = value.get("new_columns").unwrap().as_array().unwrap(); + assert!(new_columns.len() == 2); + + let col_name = new_columns[0]["name"].as_str().unwrap(); + let expression = new_columns[0]["expression"].as_str().unwrap(); + assert_eq!(col_name, "b"); + assert_eq!(expression, "a + 1"); + + let col_name = new_columns[1]["name"].as_str().unwrap(); + let expression = new_columns[1]["expression"].as_str().unwrap(); + assert_eq!(col_name, "x"); + assert_eq!(expression, "cast(NULL as int32)"); + + // Return empty JSON object for old server behavior + http::Response::builder().status(200).body("{}").unwrap() + } else if request.url().path() == "/v1/table/my_table/describe/" { + // Handle describe call for backward compatibility + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let result = table + .add_columns( + NewColumnTransform::SqlExpressions(vec![ + ("b".into(), "a + 1".into()), + ("x".into(), "cast(NULL as int32)".into()), + ]), + None, + ) + .await + .unwrap(); + + assert_eq!(result.version, 42); + } + #[tokio::test] async fn test_add_columns() { let table = Table::new_with_handler("my_table", |request| { @@ -2603,10 +3172,13 @@ mod tests { assert_eq!(col_name, "x"); assert_eq!(expression, "cast(NULL as int32)"); - http::Response::builder().status(200).body("{}").unwrap() + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#) + .unwrap() }); - table + let result = table .add_columns( NewColumnTransform::SqlExpressions(vec![ ("b".into(), "a + 1".into()), @@ -2616,75 +3188,8 @@ mod tests { ) .await .unwrap(); - } - #[tokio::test] - async fn test_alter_columns() { - let table = Table::new_with_handler("my_table", |request| { - assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/v1/table/my_table/alter_columns/"); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - JSON_CONTENT_TYPE - ); - - let body = request.body().unwrap().as_bytes().unwrap(); - let body = std::str::from_utf8(body).unwrap(); - let value: serde_json::Value = serde_json::from_str(body).unwrap(); - let alterations = value.get("alterations").unwrap().as_array().unwrap(); - assert!(alterations.len() == 2); - - let path = alterations[0]["path"].as_str().unwrap(); - let data_type = alterations[0]["data_type"]["type"].as_str().unwrap(); - assert_eq!(path, "b.c"); - assert_eq!(data_type, "int32"); - - let path = alterations[1]["path"].as_str().unwrap(); - let nullable = alterations[1]["nullable"].as_bool().unwrap(); - let rename = alterations[1]["rename"].as_str().unwrap(); - assert_eq!(path, "x"); - assert!(nullable); - assert_eq!(rename, "y"); - - http::Response::builder().status(200).body("{}").unwrap() - }); - - table - .alter_columns(&[ - ColumnAlteration::new("b.c".into()).cast_to(DataType::Int32), - ColumnAlteration::new("x".into()) - .rename("y".into()) - .set_nullable(true), - ]) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_drop_columns() { - let table = Table::new_with_handler("my_table", |request| { - assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/v1/table/my_table/drop_columns/"); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - JSON_CONTENT_TYPE - ); - - let body = request.body().unwrap().as_bytes().unwrap(); - let body = std::str::from_utf8(body).unwrap(); - let value: serde_json::Value = serde_json::from_str(body).unwrap(); - let columns = value.get("columns").unwrap().as_array().unwrap(); - assert!(columns.len() == 2); - - let col1 = columns[0].as_str().unwrap(); - let col2 = columns[1].as_str().unwrap(); - assert_eq!(col1, "a"); - assert_eq!(col2, "b"); - - http::Response::builder().status(200).body("{}").unwrap() - }); - - table.drop_columns(&["a", "b"]).await.unwrap(); + assert_eq!(result.version, 43); } #[tokio::test] diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index bb404e1c..704c48f1 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -20,7 +20,6 @@ use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; use lance::dataset::scanner::Scanner; pub use lance::dataset::ColumnAlteration; -pub use lance::dataset::MergeStats; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; pub use lance::dataset::Version; @@ -312,7 +311,7 @@ impl AddDataBuilder { self } - pub async fn execute(self) -> Result<()> { + pub async fn execute(self) -> Result { let parent = self.parent.clone(); let data = self.data.into_arrow()?; let without_data = AddDataBuilder:: { @@ -380,8 +379,8 @@ impl UpdateBuilder { } /// Executes the update operation. - /// Returns the number of rows that were updated. - pub async fn execute(self) -> Result { + /// Returns the update result + pub async fn execute(self) -> Result { if self.columns.is_empty() { Err(Error::InvalidInput { message: "at least one column must be specified in an update operation".to_string(), @@ -424,6 +423,50 @@ pub trait Tags: Send + Sync { async fn update(&mut self, tag: &str, version: u64) -> Result<()>; } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct UpdateResult { + pub rows_updated: u64, + pub version: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AddResult { + pub version: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DeleteResult { + pub version: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct MergeResult { + pub version: u64, + /// Number of inserted rows (for user statistics) + pub num_inserted_rows: u64, + /// Number of updated rows (for user statistics) + pub num_updated_rows: u64, + /// Number of deleted rows (for user statistics) + /// Note: This is different from internal references to 'deleted_rows', since we technically "delete" updated rows during processing. + /// However those rows are not shared with the user. + pub num_deleted_rows: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AddColumnsResult { + pub version: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AlterColumnsResult { + pub version: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DropColumnsResult { + pub version: u64, +} + /// A trait for anything "table-like". This is used for both native tables (which target /// Lance datasets) and remote tables (which target LanceDB cloud) /// @@ -468,11 +511,11 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { &self, add: AddDataBuilder, data: Box, - ) -> Result<()>; + ) -> Result; /// Delete rows from the table. - async fn delete(&self, predicate: &str) -> Result<()>; + async fn delete(&self, predicate: &str) -> Result; /// Update rows in the table. - async fn update(&self, update: UpdateBuilder) -> Result; + async fn update(&self, update: UpdateBuilder) -> Result; /// Create an index on the provided column(s). async fn create_index(&self, index: IndexBuilder) -> Result<()>; /// List the indices on the table. @@ -488,7 +531,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result; + ) -> Result; /// Gets the table tag manager. async fn tags(&self) -> Result>; /// Optimize the dataset. @@ -498,11 +541,11 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { &self, transforms: NewColumnTransform, read_columns: Option>, - ) -> Result<()>; + ) -> Result; /// Alter columns in the table. - async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>; + async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result; /// Drop columns from the table. - async fn drop_columns(&self, columns: &[&str]) -> Result<()>; + async fn drop_columns(&self, columns: &[&str]) -> Result; /// Get the version of the table. async fn version(&self) -> Result; /// Checkout a specific version of the table. @@ -731,7 +774,7 @@ impl Table { /// tbl.delete("id > 5").await.unwrap(); /// # }); /// ``` - pub async fn delete(&self, predicate: &str) -> Result<()> { + pub async fn delete(&self, predicate: &str) -> Result { self.inner.delete(predicate).await } @@ -1046,17 +1089,20 @@ impl Table { &self, transforms: NewColumnTransform, read_columns: Option>, - ) -> Result<()> { + ) -> Result { self.inner.add_columns(transforms, read_columns).await } /// Change a column's name or nullability. - pub async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()> { + pub async fn alter_columns( + &self, + alterations: &[ColumnAlteration], + ) -> Result { self.inner.alter_columns(alterations).await } /// Remove columns from the table. - pub async fn drop_columns(&self, columns: &[&str]) -> Result<()> { + pub async fn drop_columns(&self, columns: &[&str]) -> Result { self.inner.drop_columns(columns).await } @@ -2089,7 +2135,7 @@ impl BaseTable for NativeTable { &self, add: AddDataBuilder, data: Box, - ) -> Result<()> { + ) -> Result { let data = Box::new(MaybeEmbedded::try_new( data, self.table_definition().await?, @@ -2112,9 +2158,9 @@ impl BaseTable for NativeTable { .execute_stream(data) .await? }; - + let version = dataset.manifest().version; self.dataset.set_latest(dataset).await; - Ok(()) + Ok(AddResult { version }) } async fn create_index(&self, opts: IndexBuilder) -> Result<()> { @@ -2160,7 +2206,7 @@ impl BaseTable for NativeTable { Ok(dataset.prewarm_index(index_name).await?) } - async fn update(&self, update: UpdateBuilder) -> Result { + async fn update(&self, update: UpdateBuilder) -> Result { let dataset = self.dataset.get().await?.clone(); let mut builder = LanceUpdateBuilder::new(Arc::new(dataset)); if let Some(predicate) = update.filter { @@ -2176,7 +2222,10 @@ impl BaseTable for NativeTable { self.dataset .set_latest(res.new_dataset.as_ref().clone()) .await; - Ok(res.rows_updated) + Ok(UpdateResult { + rows_updated: res.rows_updated, + version: res.new_dataset.version().version, + }) } async fn create_plan( @@ -2368,7 +2417,7 @@ impl BaseTable for NativeTable { &self, params: MergeInsertBuilder, new_data: Box, - ) -> Result { + ) -> Result { let dataset = Arc::new(self.dataset.get().await?.clone()); let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( @@ -2396,14 +2445,23 @@ impl BaseTable for NativeTable { } let job = builder.try_build()?; let (new_dataset, stats) = job.execute_reader(new_data).await?; + let version = new_dataset.manifest().version; self.dataset.set_latest(new_dataset.as_ref().clone()).await; - Ok(stats) + Ok(MergeResult { + version, + num_updated_rows: stats.num_updated_rows, + num_inserted_rows: stats.num_inserted_rows, + num_deleted_rows: stats.num_deleted_rows, + }) } /// Delete rows from the table - async fn delete(&self, predicate: &str) -> Result<()> { - self.dataset.get_mut().await?.delete(predicate).await?; - Ok(()) + async fn delete(&self, predicate: &str) -> Result { + let mut dataset = self.dataset.get_mut().await?; + dataset.delete(predicate).await?; + Ok(DeleteResult { + version: dataset.version().version, + }) } async fn tags(&self) -> Result> { @@ -2470,27 +2528,28 @@ impl BaseTable for NativeTable { &self, transforms: NewColumnTransform, read_columns: Option>, - ) -> Result<()> { - self.dataset - .get_mut() - .await? - .add_columns(transforms, read_columns, None) - .await?; - Ok(()) + ) -> Result { + let mut dataset = self.dataset.get_mut().await?; + dataset.add_columns(transforms, read_columns, None).await?; + Ok(AddColumnsResult { + version: dataset.version().version, + }) } - async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()> { - self.dataset - .get_mut() - .await? - .alter_columns(alterations) - .await?; - Ok(()) + async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result { + let mut dataset = self.dataset.get_mut().await?; + dataset.alter_columns(alterations).await?; + Ok(AlterColumnsResult { + version: dataset.version().version, + }) } - async fn drop_columns(&self, columns: &[&str]) -> Result<()> { - self.dataset.get_mut().await?.drop_columns(columns).await?; - Ok(()) + async fn drop_columns(&self, columns: &[&str]) -> Result { + let mut dataset = self.dataset.get_mut().await?; + dataset.drop_columns(columns).await?; + Ok(DropColumnsResult { + version: dataset.version().version, + }) } async fn list_indices(&self) -> Result> { diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index f63b0218..130800fc 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -4,11 +4,10 @@ use std::sync::Arc; use arrow_array::RecordBatchReader; -use lance::dataset::MergeStats; use crate::Result; -use super::BaseTable; +use super::{BaseTable, MergeResult}; /// A builder used to create and run a merge insert operation /// @@ -87,9 +86,9 @@ impl MergeInsertBuilder { /// Executes the merge insert operation /// - /// Returns statistics about the merge operation including the number of rows + /// Returns version and statistics about the merge operation including the number of rows /// inserted, updated, and deleted. - pub async fn execute(self, new_data: Box) -> Result { + pub async fn execute(self, new_data: Box) -> Result { self.table.clone().merge_insert(self, new_data).await } }