From bc19a75f657bb3d6120ce2c367f51fed4a404e7c Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 11 Jun 2024 15:05:15 -0500 Subject: [PATCH] feat(nodejs): merge insert (#1351) closes https://github.com/lancedb/lancedb/issues/1349 --- nodejs/__test__/table.test.ts | 134 ++++++++++++++++++++++++++++++++ nodejs/lancedb/merge.ts | 70 +++++++++++++++++ nodejs/lancedb/table.ts | 5 ++ nodejs/package-lock.json | 4 +- nodejs/src/lib.rs | 1 + nodejs/src/merge.rs | 53 +++++++++++++ nodejs/src/table.rs | 7 ++ rust/lancedb/src/table/merge.rs | 1 + 8 files changed, 273 insertions(+), 2 deletions(-) create mode 100644 nodejs/lancedb/merge.ts create mode 100644 nodejs/src/merge.rs diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index aa0552a2..5948ea77 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -132,6 +132,140 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { }); }); +describe("merge insert", () => { + let tmpDir: tmp.DirResult; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const conn = await connect(tmpDir.name); + + table = await conn.createTable("some_table", [ + { a: 1, b: "a" }, + { a: 2, b: "b" }, + { a: 3, b: "c" }, + ]); + }); + afterEach(() => tmpDir.removeCallback()); + + test("upsert", async () => { + const newData = [ + { a: 2, b: "x" }, + { a: 3, b: "y" }, + { a: 4, b: "z" }, + ]; + await table + .mergeInsert("a") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .execute(newData); + const expected = [ + { a: 1, b: "a" }, + { a: 2, b: "x" }, + { a: 3, b: "y" }, + { a: 4, b: "z" }, + ]; + + expect( + JSON.parse(JSON.stringify((await table.toArrow()).toArray())), + ).toEqual(expected); + }); + test("conditional update", async () => { + const newData = [ + { a: 2, b: "x" }, + { a: 3, b: "y" }, + { a: 4, b: "z" }, + ]; + await table + .mergeInsert("a") + .whenMatchedUpdateAll({ where: "target.b = 'b'" }) + .execute(newData); + + const expected = [ + { a: 1, b: "a" }, + { a: 2, b: "x" }, + { a: 3, b: "c" }, + ]; + // round trip to arrow and back to json to avoid comparing arrow objects to js object + // biome-ignore lint/suspicious/noExplicitAny: test + let res: any[] = JSON.parse( + JSON.stringify((await table.toArrow()).toArray()), + ); + res = res.sort((a, b) => a.a - b.a); + + expect(res).toEqual(expected); + }); + + test("insert if not exists", async () => { + const newData = [ + { a: 2, b: "x" }, + { a: 3, b: "y" }, + { a: 4, b: "z" }, + ]; + await table.mergeInsert("a").whenNotMatchedInsertAll().execute(newData); + const expected = [ + { a: 1, b: "a" }, + { a: 2, b: "b" }, + { a: 3, b: "c" }, + { a: 4, b: "z" }, + ]; + // biome-ignore lint/suspicious/noExplicitAny: + let res: any[] = JSON.parse( + JSON.stringify((await table.toArrow()).toArray()), + ); + res = res.sort((a, b) => a.a - b.a); + expect(res).toEqual(expected); + }); + test("replace range", async () => { + const newData = [ + { a: 2, b: "x" }, + { a: 4, b: "z" }, + ]; + await table + .mergeInsert("a") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .whenNotMatchedBySourceDelete({ where: "a > 2" }) + .execute(newData); + + const expected = [ + { a: 1, b: "a" }, + { a: 2, b: "x" }, + { a: 4, b: "z" }, + ]; + // biome-ignore lint/suspicious/noExplicitAny: + let res: any[] = JSON.parse( + JSON.stringify((await table.toArrow()).toArray()), + ); + res = res.sort((a, b) => a.a - b.a); + expect(res).toEqual(expected); + }); + test("replace range no condition", async () => { + const newData = [ + { a: 2, b: "x" }, + { a: 4, b: "z" }, + ]; + await table + .mergeInsert("a") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .whenNotMatchedBySourceDelete() + .execute(newData); + + const expected = [ + { a: 2, b: "x" }, + { a: 4, b: "z" }, + ]; + + // biome-ignore lint/suspicious/noExplicitAny: test + let res: any[] = JSON.parse( + JSON.stringify((await table.toArrow()).toArray()), + ); + res = res.sort((a, b) => a.a - b.a); + expect(res).toEqual(expected); + }); +}); + describe("When creating an index", () => { let tmpDir: tmp.DirResult; const schema = new Schema([ diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts new file mode 100644 index 00000000..83ca92b9 --- /dev/null +++ b/nodejs/lancedb/merge.ts @@ -0,0 +1,70 @@ +import { Data, fromDataToBuffer } from "./arrow"; +import { NativeMergeInsertBuilder } from "./native"; + +/** A builder used to create and run a merge insert operation */ +export class MergeInsertBuilder { + #native: NativeMergeInsertBuilder; + + /** Construct a MergeInsertBuilder. __Internal use only.__ */ + constructor(native: NativeMergeInsertBuilder) { + this.#native = native; + } + + /** + * Rows that exist in both the source table (new data) and + * the target table (old data) will be updated, replacing + * the old row with the corresponding matching row. + * + * If there are multiple matches then the behavior is undefined. + * Currently this causes multiple copies of the row to be created + * but that behavior is subject to change. + * + * An optional condition may be specified. If it is, then only + * matched rows that satisfy the condtion will be updated. Any + * rows that do not satisfy the condition will be left as they + * are. Failing to satisfy the condition does not cause a + * "matched row" to become a "not matched" row. + * + * The condition should be an SQL string. Use the prefix + * target. to refer to rows in the target table (old data) + * and the prefix source. to refer to rows in the source + * table (new data). + * + * For example, "target.last_update < source.last_update" + */ + whenMatchedUpdateAll(options?: { where: string }): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.whenMatchedUpdateAll(options?.where), + ); + } + /** + * Rows that exist only in the source table (new data) should + * be inserted into the target table. + */ + whenNotMatchedInsertAll(): MergeInsertBuilder { + return new MergeInsertBuilder(this.#native.whenNotMatchedInsertAll()); + } + /** + * Rows that exist only in the target table (old data) will be + * deleted. An optional condition can be provided to limit what + * data is deleted. + * + * @param options.where - An optional condition to limit what data is deleted + */ + whenNotMatchedBySourceDelete(options?: { + where: string; + }): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.whenNotMatchedBySourceDelete(options?.where), + ); + } + /** + * Executes the merge insert operation + * + * Nothing is returned but the `Table` is updated + */ + async execute(data: Data): Promise { + const buffer = await fromDataToBuffer(data); + await this.#native.execute(buffer); + } +} diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 9b1363aa..22e460d9 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -23,6 +23,7 @@ import { import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { IndexOptions } from "./indices"; +import { MergeInsertBuilder } from "./merge"; import { AddColumnsSql, ColumnAlteration, @@ -478,4 +479,8 @@ export class Table { async toArrow(): Promise { return await this.query().toArrow(); } + mergeInsert(on: string | string[]): MergeInsertBuilder { + on = Array.isArray(on) ? on : [on]; + return new MergeInsertBuilder(this.inner.mergeInsert(on)); + } } diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index e5bf7470..cc55fc28 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,12 +1,12 @@ { "name": "@lancedb/lancedb", - "version": "0.5.0", + "version": "0.5.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@lancedb/lancedb", - "version": "0.5.0", + "version": "0.5.1", "cpu": [ "x64", "arm64" diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index 423893f8..fda89660 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -20,6 +20,7 @@ mod connection; mod error; mod index; mod iterator; +pub mod merge; mod query; mod table; mod util; diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs new file mode 100644 index 00000000..18c6a2ec --- /dev/null +++ b/nodejs/src/merge.rs @@ -0,0 +1,53 @@ +use lancedb::{arrow::IntoArrow, ipc::ipc_file_to_batches, table::merge::MergeInsertBuilder}; +use napi::bindgen_prelude::*; +use napi_derive::napi; + +#[napi] +#[derive(Clone)] +/// A builder used to create and run a merge insert operation +pub struct NativeMergeInsertBuilder { + pub(crate) inner: MergeInsertBuilder, +} + +#[napi] +impl NativeMergeInsertBuilder { + #[napi] + pub fn when_matched_update_all(&self, condition: Option) -> Self { + let mut this = self.clone(); + this.inner.when_matched_update_all(condition); + this + } + + #[napi] + pub fn when_not_matched_insert_all(&self) -> Self { + let mut this = self.clone(); + this.inner.when_not_matched_insert_all(); + this + } + #[napi] + pub fn when_not_matched_by_source_delete(&self, filter: Option) -> Self { + let mut this = self.clone(); + this.inner.when_not_matched_by_source_delete(filter); + this + } + + #[napi] + pub async fn execute(&self, buf: Buffer) -> napi::Result<()> { + let data = ipc_file_to_batches(buf.to_vec()) + .and_then(IntoArrow::into_arrow) + .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; + + let this = self.clone(); + + this.inner + .execute(data) + .await + .map_err(|e| napi::Error::from_reason(format!("Failed to execute merge insert: {}", e))) + } +} + +impl From for NativeMergeInsertBuilder { + fn from(inner: MergeInsertBuilder) -> Self { + Self { inner } + } +} diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 124594d3..084cb576 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -23,6 +23,7 @@ use napi_derive::napi; use crate::error::NapiErrorExt; use crate::index::Index; +use crate::merge::NativeMergeInsertBuilder; use crate::query::{Query, VectorQuery}; #[napi] @@ -328,6 +329,12 @@ impl Table { .map(IndexConfig::from) .collect::>()) } + + #[napi] + pub fn merge_insert(&self, on: Vec) -> napi::Result { + let on: Vec<_> = on.iter().map(String::as_str).collect(); + Ok(self.inner_ref()?.merge_insert(on.as_slice()).into()) + } } #[napi(object)] diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index 1633160d..5c422b9d 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -23,6 +23,7 @@ use super::TableInternal; /// A builder used to create and run a merge insert operation /// /// See [`super::Table::merge_insert`] for more context +#[derive(Debug, Clone)] pub struct MergeInsertBuilder { table: Arc, pub(super) on: Vec,