fix(nodejs): better support for f16 and f64 (#1343)

closes https://github.com/lancedb/lancedb/issues/1292
closes https://github.com/lancedb/lancedb/issues/1293
This commit is contained in:
Cory Grinstead
2024-06-04 13:41:21 -05:00
committed by GitHub
parent 56b4fd2bd9
commit d9fb6457e1
6 changed files with 393 additions and 178 deletions

View File

@@ -0,0 +1,314 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as tmp from "tmp";
import { connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float,
Float16,
Float32,
Float64,
Schema,
Utf8,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide manual embeddings with multiple float datatype",
async (floatType) => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const data = [{ text: "hello" }, { text: "hello world" }];
const schema = new Schema([
new Field("vector", new FixedSizeList(3, new Field("item", floatType))),
new Field("text", new Utf8()),
]);
const func = new MockEmbeddingFunction();
const name = "test";
const db = await connect(tmpDir.name);
const table = await db.createTable(name, data, {
schema,
embeddingFunction: {
sourceColumn: "text",
function: func,
},
});
const res = await table.query().toArray();
expect([...res[0].vector]).toEqual([1, 2, 3]);
},
);
test.only.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide auto embeddings with multiple float datatypes",
async (floatType) => {
@register("test1")
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
@register("test")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("test")!.create();
const func2 = getRegistry()
.get<MockEmbeddingFunctionWithoutNDims>("test1")!
.create();
const schema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(floatType),
});
const schema2 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func2.vectorField({ datatype: floatType, dims: 3 }),
});
const schema3 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func.vectorField({
datatype: new FixedSizeList(3, new Field("item", floatType, true)),
dims: 3,
}),
});
const expectedSchema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", floatType, true)),
true,
),
]);
const stringSchema = JSON.stringify(schema, null, 2);
const stringSchema2 = JSON.stringify(schema2, null, 2);
const stringSchema3 = JSON.stringify(schema3, null, 2);
const stringExpectedSchema = JSON.stringify(expectedSchema, null, 2);
expect(stringSchema).toEqual(stringExpectedSchema);
expect(stringSchema2).toEqual(stringExpectedSchema);
expect(stringSchema3).toEqual(stringExpectedSchema);
},
);
});

View File

@@ -24,17 +24,13 @@ import {
Table as ArrowTable,
Field,
FixedSizeList,
Float,
Float32,
Float64,
Int32,
Int64,
Schema,
Utf8,
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices";
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -45,6 +41,7 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
@@ -96,6 +93,38 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
expect(await table.countRows("id == 10")).toBe(1);
});
// https://github.com/lancedb/lancedb/issues/1293
test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])(
"can create empty table with non default float type: %s",
async (floatType) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello", vector: Array(512).fill(1.0) },
{ text: "hello world", vector: Array(512).fill(1.0) },
];
const f64Schema = new arrow.Schema([
new arrow.Field("text", new arrow.Utf8(), true),
new arrow.Field(
"vector",
new arrow.FixedSizeList(512, new arrow.Field("item", floatType)),
true,
),
]);
const f64Table = await db.createEmptyTable("f64", f64Schema, {
mode: "overwrite",
});
try {
await f64Table.add(data);
const res = await f64Table.query().toArray();
expect(res.length).toBe(2);
} catch (e) {
expect(e).toBeUndefined();
}
},
);
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
@@ -437,161 +466,6 @@ describe("when dealing with versioning", () => {
});
});
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
});
describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult;
let table: Table;

View File

@@ -31,7 +31,7 @@ import {
Schema,
Struct,
Utf8,
type Vector,
Vector,
makeBuilder,
makeData,
type makeTable,
@@ -182,6 +182,7 @@ export class MakeArrowTableOptions {
vector: new VectorColumnOptions(),
};
embeddings?: EmbeddingFunction<unknown>;
embeddingFunction?: EmbeddingFunctionConfig;
/**
* If true then string columns will be encoded with dictionary encoding
@@ -306,7 +307,11 @@ export function makeArrowTable(
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings);
opt.schema = validateSchemaEmbeddings(
opt.schema,
data,
options?.embeddingFunction,
);
}
const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns
@@ -545,7 +550,6 @@ async function applyEmbeddingsFromMetadata(
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
@@ -835,7 +839,7 @@ export function createEmptyTable(schema: Schema): ArrowTable {
function validateSchemaEmbeddings(
schema: Schema,
data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunction<unknown> | undefined,
embeddings: EmbeddingFunctionConfig | undefined,
) {
const fields = [];
const missingEmbeddingFields = [];

View File

@@ -100,33 +100,55 @@ export abstract class EmbeddingFunction<
* @see {@link lancedb.LanceSchema}
*/
vectorField(
options?: Partial<FieldOptions>,
optionsOrDatatype?: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType;
const dims = this.ndims() ?? options?.dims;
if (!options?.datatype) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
let dtype: DataType | undefined;
let vectorType: DataType;
let dims: number | undefined = this.ndims();
// `func.vectorField(new Float32())`
if (isDataType(optionsOrDatatype)) {
dtype = optionsOrDatatype;
} else {
if (isFixedSizeList(options.datatype)) {
dtype = options.datatype;
} else if (isFloat(options.datatype)) {
// `func.vectorField({
// datatype: new Float32(),
// dims: 10
// })`
dims = dims ?? optionsOrDatatype?.dims;
dtype = optionsOrDatatype?.datatype;
}
if (dtype !== undefined) {
// `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))`
// or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})`
if (isFixedSizeList(dtype)) {
vectorType = dtype;
// `func.vectorField(new Float32())`
// or `func.vectorField({datatype: new Float32()})`
} else if (isFloat(dtype)) {
// No `ndims` impl and no `{dims: n}` provided;
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = newVectorType(dims, options.datatype);
vectorType = newVectorType(dims, dtype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
} else {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = new FixedSizeList(
dims,
new Field("item", new Float32(), true),
);
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
return [dtype, metadata];
return [vectorType, metadata];
}
/** The number of dimensions of the embeddings */

View File

@@ -168,10 +168,10 @@ export class QueryBase<
}
/** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> {
// biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[]
async toArray(): Promise<any[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}
}

View File

@@ -135,6 +135,7 @@ export class Table {
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
await this.inner.add(buffer, mode);
}