diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 94d9d138..5e752f38 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -96,6 +96,50 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { expect(await table.countRows("id == 10")).toBe(1); }); + it("should let me update values with `values`", async () => { + await table.add([{ id: 1 }]); + expect(await table.countRows("id == 1")).toBe(1); + expect(await table.countRows("id == 7")).toBe(0); + await table.update({ values: { id: 7 } }); + expect(await table.countRows("id == 1")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + await table.add([{ id: 2 }]); + // Test Map as input + await table.update({ + values: { + id: "10", + }, + where: "id % 2 == 0", + }); + expect(await table.countRows("id == 2")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + expect(await table.countRows("id == 10")).toBe(1); + }); + + it("should let me update values with `valuesSql`", async () => { + await table.add([{ id: 1 }]); + expect(await table.countRows("id == 1")).toBe(1); + expect(await table.countRows("id == 7")).toBe(0); + await table.update({ + valuesSql: { + id: "7", + }, + }); + expect(await table.countRows("id == 1")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + await table.add([{ id: 2 }]); + // Test Map as input + await table.update({ + valuesSql: { + id: "10", + }, + where: "id % 2 == 0", + }); + expect(await table.countRows("id == 2")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + expect(await table.countRows("id == 10")).toBe(1); + }); + // https://github.com/lancedb/lancedb/issues/1293 test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])( "can create empty table with non default float type: %s", diff --git a/nodejs/__test__/util.test.ts b/nodejs/__test__/util.test.ts new file mode 100644 index 00000000..49654ad5 --- /dev/null +++ b/nodejs/__test__/util.test.ts @@ -0,0 +1,28 @@ +import { IntoSql, toSQL } from "../lancedb/util"; +test.each([ + ["string", "'string'"], + [123, "123"], + [1.11, "1.11"], + [true, "TRUE"], + [false, "FALSE"], + [null, "NULL"], + [new Date("2021-01-01T00:00:00.000Z"), "'2021-01-01T00:00:00.000Z'"], + [[1, 2, 3], "[1, 2, 3]"], + [new ArrayBuffer(8), "X'0000000000000000'"], + [Buffer.from("hello"), "X'68656c6c6f'"], + ["Hello 'world'", "'Hello ''world'''"], +])("toSQL(%p) === %p", (value, expected) => { + expect(toSQL(value)).toBe(expected); +}); + +test("toSQL({}) throws on unsupported value type", () => { + expect(() => toSQL({} as unknown as IntoSql)).toThrow( + "Unsupported value type: object value: ([object Object])", + ); +}); +test("toSQL() throws on unsupported value type", () => { + // biome-ignore lint/suspicious/noExplicitAny: + expect(() => (toSQL)()).toThrow( + "Unsupported value type: undefined value: (undefined)", + ); +}); diff --git a/nodejs/lancedb/remote/table.ts b/nodejs/lancedb/remote/table.ts index f06b0c84..c1712415 100644 --- a/nodejs/lancedb/remote/table.ts +++ b/nodejs/lancedb/remote/table.ts @@ -22,6 +22,7 @@ import { IndexOptions } from "../indices"; import { MergeInsertBuilder } from "../merge"; import { VectorQuery } from "../query"; import { AddDataOptions, Table, UpdateOptions } from "../table"; +import { IntoSql, toSQL } from "../util"; import { RestfulLanceDBClient } from "./client"; export class RemoteTable extends Table { @@ -84,12 +85,66 @@ export class RemoteTable extends Table { } async update( - updates: Map | Record, + optsOrUpdates: + | (Map | Record) + | ({ + values: Map | Record; + } & Partial) + | ({ + valuesSql: Map | Record; + } & Partial), options?: Partial, ): Promise { + const isValues = + "values" in optsOrUpdates && typeof optsOrUpdates.values !== "string"; + const isValuesSql = + "valuesSql" in optsOrUpdates && + typeof optsOrUpdates.valuesSql !== "string"; + const isMap = (obj: unknown): obj is Map => { + return obj instanceof Map; + }; + + let predicate; + let columns: [string, string][]; + switch (true) { + case isMap(optsOrUpdates): + columns = Array.from(optsOrUpdates.entries()); + predicate = options?.where; + break; + case isValues && isMap(optsOrUpdates.values): + columns = Array.from(optsOrUpdates.values.entries()).map(([k, v]) => [ + k, + toSQL(v), + ]); + predicate = optsOrUpdates.where; + break; + case isValues && !isMap(optsOrUpdates.values): + columns = Object.entries(optsOrUpdates.values).map(([k, v]) => [ + k, + toSQL(v), + ]); + predicate = optsOrUpdates.where; + break; + + case isValuesSql && isMap(optsOrUpdates.valuesSql): + columns = Array.from(optsOrUpdates.valuesSql.entries()); + predicate = optsOrUpdates.where; + break; + case isValuesSql && !isMap(optsOrUpdates.valuesSql): + columns = Object.entries(optsOrUpdates.valuesSql).map(([k, v]) => [ + k, + v, + ]); + predicate = optsOrUpdates.where; + break; + default: + columns = Object.entries(optsOrUpdates as Record); + predicate = options?.where; + } + await this.#client.post(`${this.#tablePrefix}/update/`, { - predicate: options?.where ?? null, - updates: Object.entries(updates).map(([key, value]) => [key, value]), + predicate: predicate ?? null, + updates: columns, }); } async countRows(filter?: unknown): Promise { diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index de8f2733..557c62c0 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -40,6 +40,7 @@ import { } from "./native"; import { Query, VectorQuery } from "./query"; import { sanitizeTable } from "./sanitize"; +import { IntoSql, toSQL } from "./util"; export { IndexConfig } from "./native"; /** @@ -123,6 +124,34 @@ export abstract class Table { * @param {Data} data Records to be inserted into the Table */ abstract add(data: Data, options?: Partial): Promise; + /** + * Update existing records in the Table + * @param opts.values The values to update. The keys are the column names and the values + * are the values to set. + * @example + * ```ts + * table.update({where:"x = 2", values:{"vector": [10, 10]}}) + * ``` + */ + abstract update( + opts: { + values: Map | Record; + } & Partial, + ): Promise; + /** + * Update existing records in the Table + * @param opts.valuesSql The values to update. The keys are the column names and the values + * are the values to set. The values are SQL expressions. + * @example + * ```ts + * table.update({where:"x = 2", valuesSql:{"x": "x + 1"}}) + * ``` + */ + abstract update( + opts: { + valuesSql: Map | Record; + } & Partial, + ): Promise; /** * Update existing records in the Table * @@ -152,6 +181,7 @@ export abstract class Table { updates: Map | Record, options?: Partial, ): Promise; + /** Count the total number of rows in the dataset. */ abstract countRows(filter?: string): Promise; /** Delete the rows that satisfy the predicate. */ @@ -471,17 +501,63 @@ export class LocalTable extends Table { } async update( - updates: Map | Record, + optsOrUpdates: + | (Map | Record) + | ({ + values: Map | Record; + } & Partial) + | ({ + valuesSql: Map | Record; + } & Partial), options?: Partial, ) { - const onlyIf = options?.where; + const isValues = + "values" in optsOrUpdates && typeof optsOrUpdates.values !== "string"; + const isValuesSql = + "valuesSql" in optsOrUpdates && + typeof optsOrUpdates.valuesSql !== "string"; + const isMap = (obj: unknown): obj is Map => { + return obj instanceof Map; + }; + + let predicate; let columns: [string, string][]; - if (updates instanceof Map) { - columns = Array.from(updates.entries()); - } else { - columns = Object.entries(updates); + switch (true) { + case isMap(optsOrUpdates): + columns = Array.from(optsOrUpdates.entries()); + predicate = options?.where; + break; + case isValues && isMap(optsOrUpdates.values): + columns = Array.from(optsOrUpdates.values.entries()).map(([k, v]) => [ + k, + toSQL(v), + ]); + predicate = optsOrUpdates.where; + break; + case isValues && !isMap(optsOrUpdates.values): + columns = Object.entries(optsOrUpdates.values).map(([k, v]) => [ + k, + toSQL(v), + ]); + predicate = optsOrUpdates.where; + break; + + case isValuesSql && isMap(optsOrUpdates.valuesSql): + columns = Array.from(optsOrUpdates.valuesSql.entries()); + predicate = optsOrUpdates.where; + break; + case isValuesSql && !isMap(optsOrUpdates.valuesSql): + columns = Object.entries(optsOrUpdates.valuesSql).map(([k, v]) => [ + k, + v, + ]); + predicate = optsOrUpdates.where; + break; + default: + columns = Object.entries(optsOrUpdates as Record); + predicate = options?.where; } - await this.inner.update(onlyIf, columns); + await this.inner.update(predicate, columns); } async countRows(filter?: string): Promise { diff --git a/nodejs/lancedb/util.ts b/nodejs/lancedb/util.ts index 6e9e696a..20b84ae1 100644 --- a/nodejs/lancedb/util.ts +++ b/nodejs/lancedb/util.ts @@ -1,3 +1,37 @@ +export type IntoSql = + | string + | number + | boolean + | null + | Date + | ArrayBufferLike + | Buffer + | IntoSql[]; + +export function toSQL(value: IntoSql): string { + if (typeof value === "string") { + return `'${value.replace(/'/g, "''")}'`; + } else if (typeof value === "number") { + return value.toString(); + } else if (typeof value === "boolean") { + return value ? "TRUE" : "FALSE"; + } else if (value === null) { + return "NULL"; + } else if (value instanceof Date) { + return `'${value.toISOString()}'`; + } else if (Array.isArray(value)) { + return `[${value.map(toSQL).join(", ")}]`; + } else if (Buffer.isBuffer(value)) { + return `X'${value.toString("hex")}'`; + } else if (value instanceof ArrayBuffer) { + return `X'${Buffer.from(value).toString("hex")}'`; + } else { + throw new Error( + `Unsupported value type: ${typeof value} value: (${value})`, + ); + } +} + export class TTLCache { // biome-ignore lint/suspicious/noExplicitAny: private readonly cache: Map;