From d9fb6457e1375567b832e0e55ed02486f878f184 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 4 Jun 2024 13:41:21 -0500 Subject: [PATCH] fix(nodejs): better support for f16 and f64 (#1343) closes https://github.com/lancedb/lancedb/issues/1292 closes https://github.com/lancedb/lancedb/issues/1293 --- nodejs/__test__/embedding.test.ts | 314 ++++++++++++++++++ nodejs/__test__/table.test.ts | 192 ++--------- nodejs/lancedb/arrow.ts | 12 +- .../lancedb/embedding/embedding_function.ts | 48 ++- nodejs/lancedb/query.ts | 4 +- nodejs/lancedb/table.ts | 1 + 6 files changed, 393 insertions(+), 178 deletions(-) create mode 100644 nodejs/__test__/embedding.test.ts diff --git a/nodejs/__test__/embedding.test.ts b/nodejs/__test__/embedding.test.ts new file mode 100644 index 00000000..bc03bc1c --- /dev/null +++ b/nodejs/__test__/embedding.test.ts @@ -0,0 +1,314 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import * as tmp from "tmp"; + +import { connect } from "../lancedb"; +import { + Field, + FixedSizeList, + Float, + Float16, + Float32, + Float64, + Schema, + Utf8, +} from "../lancedb/arrow"; +import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; +import { getRegistry, register } from "../lancedb/embedding/registry"; + +describe("embedding functions", () => { + let tmpDir: tmp.DirResult; + beforeEach(() => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + }); + afterEach(() => { + tmpDir.removeCallback(); + getRegistry().reset(); + }); + + it("should be able to create a table with an embedding function", async () => { + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return new Float32(); + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const func = new MockEmbeddingFunction(); + const db = await connect(tmpDir.name); + const table = await db.createTable( + "test", + [ + { id: 1, text: "hello" }, + { id: 2, text: "world" }, + ], + { + embeddingFunction: { + function: func, + sourceColumn: "text", + }, + }, + ); + // biome-ignore lint/suspicious/noExplicitAny: test + const arr = (await table.query().toArray()) as any; + expect(arr[0].vector).toBeDefined(); + + // we round trip through JSON to make sure the vector properly gets converted to an array + // otherwise it'll be a TypedArray or Vector + const vector0 = JSON.parse(JSON.stringify(arr[0].vector)); + expect(vector0).toEqual([1, 2, 3]); + }); + + it("should be able to create an empty table with an embedding function", async () => { + @register() + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return new Float32(); + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const schema = new Schema([ + new Field("text", new Utf8(), true), + new Field( + "vector", + new FixedSizeList(3, new Field("item", new Float32(), true)), + true, + ), + ]); + + const func = new MockEmbeddingFunction(); + const db = await connect(tmpDir.name); + const table = await db.createEmptyTable("test", schema, { + embeddingFunction: { + function: func, + sourceColumn: "text", + }, + }); + const outSchema = await table.schema(); + expect(outSchema.metadata.get("embedding_functions")).toBeDefined(); + await table.add([{ text: "hello world" }]); + + // biome-ignore lint/suspicious/noExplicitAny: test + const arr = (await table.query().toArray()) as any; + expect(arr[0].vector).toBeDefined(); + + // we round trip through JSON to make sure the vector properly gets converted to an array + // otherwise it'll be a TypedArray or Vector + const vector0 = JSON.parse(JSON.stringify(arr[0].vector)); + expect(vector0).toEqual([1, 2, 3]); + }); + it("should error when appending to a table with an unregistered embedding function", async () => { + @register("mock") + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return new Float32(); + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const func = getRegistry().get("mock")!.create(); + + const schema = LanceSchema({ + id: new Float64(), + text: func.sourceField(new Utf8()), + vector: func.vectorField(), + }); + + const db = await connect(tmpDir.name); + await db.createTable( + "test", + [ + { id: 1, text: "hello" }, + { id: 2, text: "world" }, + ], + { + schema, + }, + ); + + getRegistry().reset(); + const db2 = await connect(tmpDir.name); + + const tbl = await db2.openTable("test"); + + expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow( + `Function "mock" not found in registry`, + ); + }); + test.each([new Float16(), new Float32(), new Float64()])( + "should be able to provide manual embeddings with multiple float datatype", + async (floatType) => { + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return floatType; + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const data = [{ text: "hello" }, { text: "hello world" }]; + + const schema = new Schema([ + new Field("vector", new FixedSizeList(3, new Field("item", floatType))), + new Field("text", new Utf8()), + ]); + const func = new MockEmbeddingFunction(); + + const name = "test"; + const db = await connect(tmpDir.name); + + const table = await db.createTable(name, data, { + schema, + embeddingFunction: { + sourceColumn: "text", + function: func, + }, + }); + const res = await table.query().toArray(); + + expect([...res[0].vector]).toEqual([1, 2, 3]); + }, + ); + + test.only.each([new Float16(), new Float32(), new Float64()])( + "should be able to provide auto embeddings with multiple float datatypes", + async (floatType) => { + @register("test1") + class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction { + toJSON(): object { + return {}; + } + + embeddingDataType(): Float { + return floatType; + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + @register("test") + class MockEmbeddingFunction extends EmbeddingFunction { + toJSON(): object { + return {}; + } + ndims() { + return 3; + } + embeddingDataType(): Float { + return floatType; + } + async computeQueryEmbeddings(_data: string) { + return [1, 2, 3]; + } + async computeSourceEmbeddings(data: string[]) { + return Array.from({ length: data.length }).fill([ + 1, 2, 3, + ]) as number[][]; + } + } + const func = getRegistry().get("test")!.create(); + const func2 = getRegistry() + .get("test1")! + .create(); + + const schema = LanceSchema({ + text: func.sourceField(new Utf8()), + vector: func.vectorField(floatType), + }); + + const schema2 = LanceSchema({ + text: func2.sourceField(new Utf8()), + vector: func2.vectorField({ datatype: floatType, dims: 3 }), + }); + const schema3 = LanceSchema({ + text: func2.sourceField(new Utf8()), + vector: func.vectorField({ + datatype: new FixedSizeList(3, new Field("item", floatType, true)), + dims: 3, + }), + }); + + const expectedSchema = new Schema([ + new Field("text", new Utf8(), true), + new Field( + "vector", + new FixedSizeList(3, new Field("item", floatType, true)), + true, + ), + ]); + const stringSchema = JSON.stringify(schema, null, 2); + const stringSchema2 = JSON.stringify(schema2, null, 2); + const stringSchema3 = JSON.stringify(schema3, null, 2); + const stringExpectedSchema = JSON.stringify(expectedSchema, null, 2); + + expect(stringSchema).toEqual(stringExpectedSchema); + expect(stringSchema2).toEqual(stringExpectedSchema); + expect(stringSchema3).toEqual(stringExpectedSchema); + }, + ); +}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index f6e168b9..47e33de0 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -24,17 +24,13 @@ import { Table as ArrowTable, Field, FixedSizeList, - Float, Float32, Float64, Int32, Int64, Schema, - Utf8, makeArrowTable, } from "../lancedb/arrow"; -import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; -import { getRegistry, register } from "../lancedb/embedding/registry"; import { Index } from "../lancedb/indices"; // biome-ignore lint/suspicious/noExplicitAny: @@ -45,6 +41,7 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { const schema = new arrow.Schema([ new arrow.Field("id", new arrow.Float64(), true), ]); + beforeEach(async () => { tmpDir = tmp.dirSync({ unsafeCleanup: true }); const conn = await connect(tmpDir.name); @@ -96,6 +93,38 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { expect(await table.countRows("id == 10")).toBe(1); }); + // https://github.com/lancedb/lancedb/issues/1293 + test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])( + "can create empty table with non default float type: %s", + async (floatType) => { + const db = await connect(tmpDir.name); + + const data = [ + { text: "hello", vector: Array(512).fill(1.0) }, + { text: "hello world", vector: Array(512).fill(1.0) }, + ]; + const f64Schema = new arrow.Schema([ + new arrow.Field("text", new arrow.Utf8(), true), + new arrow.Field( + "vector", + new arrow.FixedSizeList(512, new arrow.Field("item", floatType)), + true, + ), + ]); + + const f64Table = await db.createEmptyTable("f64", f64Schema, { + mode: "overwrite", + }); + try { + await f64Table.add(data); + const res = await f64Table.query().toArray(); + expect(res.length).toBe(2); + } catch (e) { + expect(e).toBeUndefined(); + } + }, + ); + it("should return the table as an instance of an arrow table", async () => { const arrowTbl = await table.toArrow(); expect(arrowTbl).toBeInstanceOf(ArrowTable); @@ -437,161 +466,6 @@ describe("when dealing with versioning", () => { }); }); -describe("embedding functions", () => { - let tmpDir: tmp.DirResult; - beforeEach(() => { - tmpDir = tmp.dirSync({ unsafeCleanup: true }); - }); - afterEach(() => tmpDir.removeCallback()); - - it("should be able to create a table with an embedding function", async () => { - class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } - ndims() { - return 3; - } - embeddingDataType(): Float { - return new Float32(); - } - async computeQueryEmbeddings(_data: string) { - return [1, 2, 3]; - } - async computeSourceEmbeddings(data: string[]) { - return Array.from({ length: data.length }).fill([ - 1, 2, 3, - ]) as number[][]; - } - } - const func = new MockEmbeddingFunction(); - const db = await connect(tmpDir.name); - const table = await db.createTable( - "test", - [ - { id: 1, text: "hello" }, - { id: 2, text: "world" }, - ], - { - embeddingFunction: { - function: func, - sourceColumn: "text", - }, - }, - ); - // biome-ignore lint/suspicious/noExplicitAny: test - const arr = (await table.query().toArray()) as any; - expect(arr[0].vector).toBeDefined(); - - // we round trip through JSON to make sure the vector properly gets converted to an array - // otherwise it'll be a TypedArray or Vector - const vector0 = JSON.parse(JSON.stringify(arr[0].vector)); - expect(vector0).toEqual([1, 2, 3]); - }); - - it("should be able to create an empty table with an embedding function", async () => { - @register() - class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } - ndims() { - return 3; - } - embeddingDataType(): Float { - return new Float32(); - } - async computeQueryEmbeddings(_data: string) { - return [1, 2, 3]; - } - async computeSourceEmbeddings(data: string[]) { - return Array.from({ length: data.length }).fill([ - 1, 2, 3, - ]) as number[][]; - } - } - const schema = new Schema([ - new Field("text", new Utf8(), true), - new Field( - "vector", - new FixedSizeList(3, new Field("item", new Float32(), true)), - true, - ), - ]); - - const func = new MockEmbeddingFunction(); - const db = await connect(tmpDir.name); - const table = await db.createEmptyTable("test", schema, { - embeddingFunction: { - function: func, - sourceColumn: "text", - }, - }); - const outSchema = await table.schema(); - expect(outSchema.metadata.get("embedding_functions")).toBeDefined(); - await table.add([{ text: "hello world" }]); - - // biome-ignore lint/suspicious/noExplicitAny: test - const arr = (await table.query().toArray()) as any; - expect(arr[0].vector).toBeDefined(); - - // we round trip through JSON to make sure the vector properly gets converted to an array - // otherwise it'll be a TypedArray or Vector - const vector0 = JSON.parse(JSON.stringify(arr[0].vector)); - expect(vector0).toEqual([1, 2, 3]); - }); - it("should error when appending to a table with an unregistered embedding function", async () => { - @register("mock") - class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } - ndims() { - return 3; - } - embeddingDataType(): Float { - return new Float32(); - } - async computeQueryEmbeddings(_data: string) { - return [1, 2, 3]; - } - async computeSourceEmbeddings(data: string[]) { - return Array.from({ length: data.length }).fill([ - 1, 2, 3, - ]) as number[][]; - } - } - const func = getRegistry().get("mock")!.create(); - - const schema = LanceSchema({ - id: new arrow.Float64(), - text: func.sourceField(new Utf8()), - vector: func.vectorField(), - }); - - const db = await connect(tmpDir.name); - await db.createTable( - "test", - [ - { id: 1, text: "hello" }, - { id: 2, text: "world" }, - ], - { - schema, - }, - ); - - getRegistry().reset(); - const db2 = await connect(tmpDir.name); - - const tbl = await db2.openTable("test"); - - expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow( - `Function "mock" not found in registry`, - ); - }); -}); - describe("when optimizing a dataset", () => { let tmpDir: tmp.DirResult; let table: Table; diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 57175ca5..1836dc30 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -31,7 +31,7 @@ import { Schema, Struct, Utf8, - type Vector, + Vector, makeBuilder, makeData, type makeTable, @@ -182,6 +182,7 @@ export class MakeArrowTableOptions { vector: new VectorColumnOptions(), }; embeddings?: EmbeddingFunction; + embeddingFunction?: EmbeddingFunctionConfig; /** * If true then string columns will be encoded with dictionary encoding @@ -306,7 +307,11 @@ export function makeArrowTable( const opt = new MakeArrowTableOptions(options !== undefined ? options : {}); if (opt.schema !== undefined && opt.schema !== null) { opt.schema = sanitizeSchema(opt.schema); - opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings); + opt.schema = validateSchemaEmbeddings( + opt.schema, + data, + options?.embeddingFunction, + ); } const columns: Record = {}; // TODO: sample dataset to find missing columns @@ -545,7 +550,6 @@ async function applyEmbeddingsFromMetadata( dtype, ); } - const vector = makeVector(vectors, destType); columns[destColumn] = vector; } @@ -835,7 +839,7 @@ export function createEmptyTable(schema: Schema): ArrowTable { function validateSchemaEmbeddings( schema: Schema, data: Array>, - embeddings: EmbeddingFunction | undefined, + embeddings: EmbeddingFunctionConfig | undefined, ) { const fields = []; const missingEmbeddingFields = []; diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 4342139d..8e752a8f 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -100,33 +100,55 @@ export abstract class EmbeddingFunction< * @see {@link lancedb.LanceSchema} */ vectorField( - options?: Partial, + optionsOrDatatype?: Partial | DataType, ): [DataType, Map] { - let dtype: DataType; - const dims = this.ndims() ?? options?.dims; - if (!options?.datatype) { - if (dims === undefined) { - throw new Error("ndims is required for vector field"); - } - dtype = new FixedSizeList(dims, new Field("item", new Float32(), true)); + let dtype: DataType | undefined; + let vectorType: DataType; + let dims: number | undefined = this.ndims(); + + // `func.vectorField(new Float32())` + if (isDataType(optionsOrDatatype)) { + dtype = optionsOrDatatype; } else { - if (isFixedSizeList(options.datatype)) { - dtype = options.datatype; - } else if (isFloat(options.datatype)) { + // `func.vectorField({ + // datatype: new Float32(), + // dims: 10 + // })` + dims = dims ?? optionsOrDatatype?.dims; + dtype = optionsOrDatatype?.datatype; + } + + if (dtype !== undefined) { + // `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))` + // or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})` + if (isFixedSizeList(dtype)) { + vectorType = dtype; + // `func.vectorField(new Float32())` + // or `func.vectorField({datatype: new Float32()})` + } else if (isFloat(dtype)) { + // No `ndims` impl and no `{dims: n}` provided; if (dims === undefined) { throw new Error("ndims is required for vector field"); } - dtype = newVectorType(dims, options.datatype); + vectorType = newVectorType(dims, dtype); } else { throw new Error( "Expected FixedSizeList or Float as datatype for vector field", ); } + } else { + if (dims === undefined) { + throw new Error("ndims is required for vector field"); + } + vectorType = new FixedSizeList( + dims, + new Field("item", new Float32(), true), + ); } const metadata = new Map(); metadata.set("vector_column_for", this); - return [dtype, metadata]; + return [vectorType, metadata]; } /** The number of dimensions of the embeddings */ diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index d3566959..0ac40378 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -168,10 +168,10 @@ export class QueryBase< } /** Collect the results as an array of objects. */ - async toArray(): Promise { + // biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[] + async toArray(): Promise { const tbl = await this.toArrow(); - // eslint-disable-next-line @typescript-eslint/no-unsafe-return return tbl.toArray(); } } diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index eda6c6da..9d0f8adf 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -135,6 +135,7 @@ export class Table { const buffer = await fromDataToBuffer( data, functions.values().next().value, + schema, ); await this.inner.add(buffer, mode); }