From 0a9e1eab75640b468ffae7a539a940579d1f5410 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 28 Jan 2025 12:38:50 -0800 Subject: [PATCH] fix(node): `createTable()` should save embeddings, and `mergeInsert` should use them (#2065) * `createTable()` now saves embeddings in the schema metadata. Previously, it would drop them. (`createEmptyTable()` was already tested and worked.) * `mergeInsert()` now uses embeddings. Fixes #2066 --- nodejs/__test__/embedding.test.ts | 68 +++++++++++++++++++++++++++++++ nodejs/lancedb/arrow.ts | 18 +++++++- nodejs/lancedb/merge.ts | 27 ++++++++++-- nodejs/lancedb/table.ts | 10 +---- 4 files changed, 109 insertions(+), 14 deletions(-) diff --git a/nodejs/__test__/embedding.test.ts b/nodejs/__test__/embedding.test.ts index 2200aed9..e1904cf7 100644 --- a/nodejs/__test__/embedding.test.ts +++ b/nodejs/__test__/embedding.test.ts @@ -83,6 +83,74 @@ describe("embedding functions", () => { expect(vector0).toEqual([1, 2, 3]); }); + it("should be able to append and upsert using embedding function", async () => { + @register() + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return new Float32(); + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const func = new MockEmbeddingFunction(); + const db = await connect(tmpDir.name); + const table = await db.createTable( + "test", + [ + { id: 1, text: "hello" }, + { id: 2, text: "world" }, + ], + { + embeddingFunction: { + function: func, + sourceColumn: "text", + }, + }, + ); + + const schema = await table.schema(); + expect(schema.metadata.get("embedding_functions")).toBeDefined(); + + // Append some new data + const data1 = [ + { id: 3, text: "forest" }, + { id: 4, text: "mountain" }, + ]; + await table.add(data1); + + // Upsert some data + const data2 = [ + { id: 5, text: "river" }, + { id: 2, text: "canyon" }, + ]; + await table + .mergeInsert("id") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .execute(data2); + + const rows = await table.query().toArray(); + rows.sort((a, b) => a.id - b.id); + const texts = rows.map((row) => row.text); + expect(texts).toEqual(["hello", "canyon", "forest", "mountain", "river"]); + const vectorsDefined = rows.map( + (row) => row.vector !== undefined && row.vector !== null, + ); + expect(vectorsDefined).toEqual(new Array(5).fill(true)); + }); + it("should be able to create an empty table with an embedding function", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 6de51ca8..258f77a4 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -609,6 +609,14 @@ async function applyEmbeddings( return table; } + let schemaMetadata = schema?.metadata || new Map(); + + if (!(embeddings == null || embeddings === undefined)) { + const registry = getRegistry(); + const embeddingMetadata = registry.getTableMetadata([embeddings]); + schemaMetadata = new Map([...schemaMetadata, ...embeddingMetadata]); + } + // Convert from ArrowTable to Record const colEntries = [...Array(table.numCols).keys()].map((_, idx) => { const name = table.schema.fields[idx].name; @@ -677,15 +685,21 @@ async function applyEmbeddings( newColumns[destColumn] = makeVector(vectors, destType); } - const newTable = new ArrowTable(newColumns); + let newTable = new ArrowTable(newColumns); if (schema != null) { if (schema.fields.find((f) => f.name === destColumn) === undefined) { throw new Error( `When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`, ); } - return alignTable(newTable, schema as Schema); + newTable = alignTable(newTable, schema as Schema); } + + newTable = new ArrowTable( + new Schema(newTable.schema.fields, schemaMetadata), + newTable.batches, + ); + return newTable; } diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts index 83ca92b9..407dca94 100644 --- a/nodejs/lancedb/merge.ts +++ b/nodejs/lancedb/merge.ts @@ -1,13 +1,20 @@ -import { Data, fromDataToBuffer } from "./arrow"; +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors +import { Data, Schema, fromDataToBuffer } from "./arrow"; import { NativeMergeInsertBuilder } from "./native"; /** A builder used to create and run a merge insert operation */ export class MergeInsertBuilder { #native: NativeMergeInsertBuilder; + #schema: Schema | Promise; /** Construct a MergeInsertBuilder. __Internal use only.__ */ - constructor(native: NativeMergeInsertBuilder) { + constructor( + native: NativeMergeInsertBuilder, + schema: Schema | Promise, + ) { this.#native = native; + this.#schema = schema; } /** @@ -35,6 +42,7 @@ export class MergeInsertBuilder { whenMatchedUpdateAll(options?: { where: string }): MergeInsertBuilder { return new MergeInsertBuilder( this.#native.whenMatchedUpdateAll(options?.where), + this.#schema, ); } /** @@ -42,7 +50,10 @@ export class MergeInsertBuilder { * be inserted into the target table. */ whenNotMatchedInsertAll(): MergeInsertBuilder { - return new MergeInsertBuilder(this.#native.whenNotMatchedInsertAll()); + return new MergeInsertBuilder( + this.#native.whenNotMatchedInsertAll(), + this.#schema, + ); } /** * Rows that exist only in the target table (old data) will be @@ -56,6 +67,7 @@ export class MergeInsertBuilder { }): MergeInsertBuilder { return new MergeInsertBuilder( this.#native.whenNotMatchedBySourceDelete(options?.where), + this.#schema, ); } /** @@ -64,7 +76,14 @@ export class MergeInsertBuilder { * Nothing is returned but the `Table` is updated */ async execute(data: Data): Promise { - const buffer = await fromDataToBuffer(data); + let schema: Schema; + if (this.#schema instanceof Promise) { + schema = await this.#schema; + this.#schema = schema; // In case of future calls + } else { + schema = this.#schema; + } + const buffer = await fromDataToBuffer(data, undefined, schema); await this.#native.execute(buffer); } } diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index b581ea30..3d413207 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -520,14 +520,8 @@ export class LocalTable extends Table { async add(data: Data, options?: Partial): Promise { const mode = options?.mode ?? "append"; const schema = await this.schema(); - const registry = getRegistry(); - const functions = await registry.parseFunctions(schema.metadata); - const buffer = await fromDataToBuffer( - data, - functions.values().next().value, - schema, - ); + const buffer = await fromDataToBuffer(data, undefined, schema); await this.inner.add(buffer, mode); } @@ -733,7 +727,7 @@ export class LocalTable extends Table { } mergeInsert(on: string | string[]): MergeInsertBuilder { on = Array.isArray(on) ? on : [on]; - return new MergeInsertBuilder(this.inner.mergeInsert(on)); + return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema()); } /**