fix(node): createTable() should save embeddings, and mergeInsert should use them (#2065)

* `createTable()` now saves embeddings in the schema metadata.
Previously, it would drop them. (`createEmptyTable()` was already tested
and worked.)
* `mergeInsert()` now uses embeddings.

Fixes #2066
This commit is contained in:
Will Jones
2025-01-28 12:38:50 -08:00
committed by GitHub
parent d999d72c8d
commit 0a9e1eab75
4 changed files with 109 additions and 14 deletions

View File

@@ -83,6 +83,74 @@ describe("embedding functions", () => {
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to append and upsert using embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
const schema = await table.schema();
expect(schema.metadata.get("embedding_functions")).toBeDefined();
// Append some new data
const data1 = [
{ id: 3, text: "forest" },
{ id: 4, text: "mountain" },
];
await table.add(data1);
// Upsert some data
const data2 = [
{ id: 5, text: "river" },
{ id: 2, text: "canyon" },
];
await table
.mergeInsert("id")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute(data2);
const rows = await table.query().toArray();
rows.sort((a, b) => a.id - b.id);
const texts = rows.map((row) => row.text);
expect(texts).toEqual(["hello", "canyon", "forest", "mountain", "river"]);
const vectorsDefined = rows.map(
(row) => row.vector !== undefined && row.vector !== null,
);
expect(vectorsDefined).toEqual(new Array(5).fill(true));
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {

View File

@@ -609,6 +609,14 @@ async function applyEmbeddings<T>(
return table;
}
let schemaMetadata = schema?.metadata || new Map<string, string>();
if (!(embeddings == null || embeddings === undefined)) {
const registry = getRegistry();
const embeddingMetadata = registry.getTableMetadata([embeddings]);
schemaMetadata = new Map([...schemaMetadata, ...embeddingMetadata]);
}
// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
const name = table.schema.fields[idx].name;
@@ -677,15 +685,21 @@ async function applyEmbeddings<T>(
newColumns[destColumn] = makeVector(vectors, destType);
}
const newTable = new ArrowTable(newColumns);
let newTable = new ArrowTable(newColumns);
if (schema != null) {
if (schema.fields.find((f) => f.name === destColumn) === undefined) {
throw new Error(
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
);
}
return alignTable(newTable, schema as Schema);
newTable = alignTable(newTable, schema as Schema);
}
newTable = new ArrowTable(
new Schema(newTable.schema.fields, schemaMetadata),
newTable.batches,
);
return newTable;
}

View File

@@ -1,13 +1,20 @@
import { Data, fromDataToBuffer } from "./arrow";
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { Data, Schema, fromDataToBuffer } from "./arrow";
import { NativeMergeInsertBuilder } from "./native";
/** A builder used to create and run a merge insert operation */
export class MergeInsertBuilder {
#native: NativeMergeInsertBuilder;
#schema: Schema | Promise<Schema>;
/** Construct a MergeInsertBuilder. __Internal use only.__ */
constructor(native: NativeMergeInsertBuilder) {
constructor(
native: NativeMergeInsertBuilder,
schema: Schema | Promise<Schema>,
) {
this.#native = native;
this.#schema = schema;
}
/**
@@ -35,6 +42,7 @@ export class MergeInsertBuilder {
whenMatchedUpdateAll(options?: { where: string }): MergeInsertBuilder {
return new MergeInsertBuilder(
this.#native.whenMatchedUpdateAll(options?.where),
this.#schema,
);
}
/**
@@ -42,7 +50,10 @@ export class MergeInsertBuilder {
* be inserted into the target table.
*/
whenNotMatchedInsertAll(): MergeInsertBuilder {
return new MergeInsertBuilder(this.#native.whenNotMatchedInsertAll());
return new MergeInsertBuilder(
this.#native.whenNotMatchedInsertAll(),
this.#schema,
);
}
/**
* Rows that exist only in the target table (old data) will be
@@ -56,6 +67,7 @@ export class MergeInsertBuilder {
}): MergeInsertBuilder {
return new MergeInsertBuilder(
this.#native.whenNotMatchedBySourceDelete(options?.where),
this.#schema,
);
}
/**
@@ -64,7 +76,14 @@ export class MergeInsertBuilder {
* Nothing is returned but the `Table` is updated
*/
async execute(data: Data): Promise<void> {
const buffer = await fromDataToBuffer(data);
let schema: Schema;
if (this.#schema instanceof Promise) {
schema = await this.#schema;
this.#schema = schema; // In case of future calls
} else {
schema = this.#schema;
}
const buffer = await fromDataToBuffer(data, undefined, schema);
await this.#native.execute(buffer);
}
}

View File

@@ -520,14 +520,8 @@ export class LocalTable extends Table {
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = await registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
const buffer = await fromDataToBuffer(data, undefined, schema);
await this.inner.add(buffer, mode);
}
@@ -733,7 +727,7 @@ export class LocalTable extends Table {
}
mergeInsert(on: string | string[]): MergeInsertBuilder {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on));
return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema());
}
/**