mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 03:42:57 +00:00
fix(node): createTable() should save embeddings, and mergeInsert should use them (#2065)
* `createTable()` now saves embeddings in the schema metadata. Previously, it would drop them. (`createEmptyTable()` was already tested and worked.) * `mergeInsert()` now uses embeddings. Fixes #2066
This commit is contained in:
@@ -83,6 +83,74 @@ describe("embedding functions", () => {
|
||||
expect(vector0).toEqual([1, 2, 3]);
|
||||
});
|
||||
|
||||
it("should be able to append and upsert using embedding function", async () => {
|
||||
@register()
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
embeddingDataType(): Float {
|
||||
return new Float32();
|
||||
}
|
||||
async computeQueryEmbeddings(_data: string) {
|
||||
return [1, 2, 3];
|
||||
}
|
||||
async computeSourceEmbeddings(data: string[]) {
|
||||
return Array.from({ length: data.length }).fill([
|
||||
1, 2, 3,
|
||||
]) as number[][];
|
||||
}
|
||||
}
|
||||
const func = new MockEmbeddingFunction();
|
||||
const db = await connect(tmpDir.name);
|
||||
const table = await db.createTable(
|
||||
"test",
|
||||
[
|
||||
{ id: 1, text: "hello" },
|
||||
{ id: 2, text: "world" },
|
||||
],
|
||||
{
|
||||
embeddingFunction: {
|
||||
function: func,
|
||||
sourceColumn: "text",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const schema = await table.schema();
|
||||
expect(schema.metadata.get("embedding_functions")).toBeDefined();
|
||||
|
||||
// Append some new data
|
||||
const data1 = [
|
||||
{ id: 3, text: "forest" },
|
||||
{ id: 4, text: "mountain" },
|
||||
];
|
||||
await table.add(data1);
|
||||
|
||||
// Upsert some data
|
||||
const data2 = [
|
||||
{ id: 5, text: "river" },
|
||||
{ id: 2, text: "canyon" },
|
||||
];
|
||||
await table
|
||||
.mergeInsert("id")
|
||||
.whenMatchedUpdateAll()
|
||||
.whenNotMatchedInsertAll()
|
||||
.execute(data2);
|
||||
|
||||
const rows = await table.query().toArray();
|
||||
rows.sort((a, b) => a.id - b.id);
|
||||
const texts = rows.map((row) => row.text);
|
||||
expect(texts).toEqual(["hello", "canyon", "forest", "mountain", "river"]);
|
||||
const vectorsDefined = rows.map(
|
||||
(row) => row.vector !== undefined && row.vector !== null,
|
||||
);
|
||||
expect(vectorsDefined).toEqual(new Array(5).fill(true));
|
||||
});
|
||||
|
||||
it("should be able to create an empty table with an embedding function", async () => {
|
||||
@register()
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
|
||||
@@ -609,6 +609,14 @@ async function applyEmbeddings<T>(
|
||||
return table;
|
||||
}
|
||||
|
||||
let schemaMetadata = schema?.metadata || new Map<string, string>();
|
||||
|
||||
if (!(embeddings == null || embeddings === undefined)) {
|
||||
const registry = getRegistry();
|
||||
const embeddingMetadata = registry.getTableMetadata([embeddings]);
|
||||
schemaMetadata = new Map([...schemaMetadata, ...embeddingMetadata]);
|
||||
}
|
||||
|
||||
// Convert from ArrowTable to Record<String, Vector>
|
||||
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||
const name = table.schema.fields[idx].name;
|
||||
@@ -677,15 +685,21 @@ async function applyEmbeddings<T>(
|
||||
newColumns[destColumn] = makeVector(vectors, destType);
|
||||
}
|
||||
|
||||
const newTable = new ArrowTable(newColumns);
|
||||
let newTable = new ArrowTable(newColumns);
|
||||
if (schema != null) {
|
||||
if (schema.fields.find((f) => f.name === destColumn) === undefined) {
|
||||
throw new Error(
|
||||
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
|
||||
);
|
||||
}
|
||||
return alignTable(newTable, schema as Schema);
|
||||
newTable = alignTable(newTable, schema as Schema);
|
||||
}
|
||||
|
||||
newTable = new ArrowTable(
|
||||
new Schema(newTable.schema.fields, schemaMetadata),
|
||||
newTable.batches,
|
||||
);
|
||||
|
||||
return newTable;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import { Data, fromDataToBuffer } from "./arrow";
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import { Data, Schema, fromDataToBuffer } from "./arrow";
|
||||
import { NativeMergeInsertBuilder } from "./native";
|
||||
|
||||
/** A builder used to create and run a merge insert operation */
|
||||
export class MergeInsertBuilder {
|
||||
#native: NativeMergeInsertBuilder;
|
||||
#schema: Schema | Promise<Schema>;
|
||||
|
||||
/** Construct a MergeInsertBuilder. __Internal use only.__ */
|
||||
constructor(native: NativeMergeInsertBuilder) {
|
||||
constructor(
|
||||
native: NativeMergeInsertBuilder,
|
||||
schema: Schema | Promise<Schema>,
|
||||
) {
|
||||
this.#native = native;
|
||||
this.#schema = schema;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -35,6 +42,7 @@ export class MergeInsertBuilder {
|
||||
whenMatchedUpdateAll(options?: { where: string }): MergeInsertBuilder {
|
||||
return new MergeInsertBuilder(
|
||||
this.#native.whenMatchedUpdateAll(options?.where),
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
@@ -42,7 +50,10 @@ export class MergeInsertBuilder {
|
||||
* be inserted into the target table.
|
||||
*/
|
||||
whenNotMatchedInsertAll(): MergeInsertBuilder {
|
||||
return new MergeInsertBuilder(this.#native.whenNotMatchedInsertAll());
|
||||
return new MergeInsertBuilder(
|
||||
this.#native.whenNotMatchedInsertAll(),
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
* Rows that exist only in the target table (old data) will be
|
||||
@@ -56,6 +67,7 @@ export class MergeInsertBuilder {
|
||||
}): MergeInsertBuilder {
|
||||
return new MergeInsertBuilder(
|
||||
this.#native.whenNotMatchedBySourceDelete(options?.where),
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
@@ -64,7 +76,14 @@ export class MergeInsertBuilder {
|
||||
* Nothing is returned but the `Table` is updated
|
||||
*/
|
||||
async execute(data: Data): Promise<void> {
|
||||
const buffer = await fromDataToBuffer(data);
|
||||
let schema: Schema;
|
||||
if (this.#schema instanceof Promise) {
|
||||
schema = await this.#schema;
|
||||
this.#schema = schema; // In case of future calls
|
||||
} else {
|
||||
schema = this.#schema;
|
||||
}
|
||||
const buffer = await fromDataToBuffer(data, undefined, schema);
|
||||
await this.#native.execute(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -520,14 +520,8 @@ export class LocalTable extends Table {
|
||||
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
||||
const mode = options?.mode ?? "append";
|
||||
const schema = await this.schema();
|
||||
const registry = getRegistry();
|
||||
const functions = await registry.parseFunctions(schema.metadata);
|
||||
|
||||
const buffer = await fromDataToBuffer(
|
||||
data,
|
||||
functions.values().next().value,
|
||||
schema,
|
||||
);
|
||||
const buffer = await fromDataToBuffer(data, undefined, schema);
|
||||
await this.inner.add(buffer, mode);
|
||||
}
|
||||
|
||||
@@ -733,7 +727,7 @@ export class LocalTable extends Table {
|
||||
}
|
||||
mergeInsert(on: string | string[]): MergeInsertBuilder {
|
||||
on = Array.isArray(on) ? on : [on];
|
||||
return new MergeInsertBuilder(this.inner.mergeInsert(on));
|
||||
return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user