feat: js embedding registry (#1308)

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Cory Grinstead
2024-05-29 13:12:19 -05:00
committed by GitHub
parent 3bb7c546d7
commit dbea3a7544
17 changed files with 8516 additions and 7988 deletions

View File

@@ -31,6 +31,7 @@ import {
Schema,
Struct,
type Table,
Type,
Utf8,
tableFromIPC,
} from "apache-arrow";
@@ -51,7 +52,12 @@ import {
makeArrowTable,
makeEmptyTable,
} from "../lancedb/arrow";
import { type EmbeddingFunction } from "../lancedb/embedding/embedding_function";
import {
EmbeddingFunction,
FieldOptions,
FunctionOptions,
} from "../lancedb/embedding/embedding_function";
import { EmbeddingFunctionConfig } from "../lancedb/embedding/registry";
// biome-ignore lint/suspicious/noExplicitAny: skip
function sampleRecords(): Array<Record<string, any>> {
@@ -280,23 +286,46 @@ describe("The function makeArrowTable", function () {
});
});
class DummyEmbedding implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
public readonly embeddingDimension = 2;
public readonly embeddingDataType = new Float16();
class DummyEmbedding extends EmbeddingFunction<string> {
toJSON(): Partial<FunctionOptions> {
return {};
}
async embed(data: string[]): Promise<number[][]> {
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
ndims(): number {
return 2;
}
embeddingDataType() {
return new Float16();
}
}
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
class DummyEmbeddingWithNoDimension extends EmbeddingFunction<string> {
toJSON(): Partial<FunctionOptions> {
return {};
}
async embed(data: string[]): Promise<number[][]> {
embeddingDataType(): Float {
return new Float16();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
}
const dummyEmbeddingConfig: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbedding(),
};
const dummyEmbeddingConfigWithNoDimension: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbeddingWithNoDimension(),
};
describe("convertToTable", function () {
it("will infer data types correctly", async function () {
@@ -331,7 +360,7 @@ describe("convertToTable", function () {
it("will apply embeddings", async function () {
const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding());
const table = await convertToTable(records, dummyEmbeddingConfig);
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
@@ -340,7 +369,7 @@ describe("convertToTable", function () {
it("will fail if missing the embedding source column", async function () {
await expect(
convertToTable([{ id: 1 }], new DummyEmbedding()),
convertToTable([{ id: 1 }], dummyEmbeddingConfig),
).rejects.toThrow("'string' was not present");
});
@@ -351,7 +380,7 @@ describe("convertToTable", function () {
const table = makeEmptyTable(schema);
// If the embedding specifies the dimension we are fine
await fromTableToBuffer(table, new DummyEmbedding());
await fromTableToBuffer(table, dummyEmbeddingConfig);
// We can also supply a schema and should be ok
const schemaWithEmbedding = new Schema([
@@ -364,13 +393,13 @@ describe("convertToTable", function () {
]);
await fromTableToBuffer(
table,
new DummyEmbeddingWithNoDimension(),
dummyEmbeddingConfigWithNoDimension,
schemaWithEmbedding,
);
// Otherwise we will get an error
await expect(
fromTableToBuffer(table, new DummyEmbeddingWithNoDimension()),
fromTableToBuffer(table, dummyEmbeddingConfigWithNoDimension),
).rejects.toThrow("does not specify `embeddingDimension`");
});
@@ -383,7 +412,7 @@ describe("convertToTable", function () {
false,
),
]);
const table = await convertToTable([], new DummyEmbedding(), { schema });
const table = await convertToTable([], dummyEmbeddingConfig, { schema });
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
@@ -393,16 +422,17 @@ describe("convertToTable", function () {
it("will complain if embeddings present but schema missing embedding column", async function () {
const schema = new Schema([new Field("string", new Utf8(), false)]);
await expect(
convertToTable([], new DummyEmbedding(), { schema }),
convertToTable([], dummyEmbeddingConfig, { schema }),
).rejects.toThrow("column vector was missing");
});
it("will provide a nice error if run twice", async function () {
const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding());
const table = await convertToTable(records, dummyEmbeddingConfig);
// fromTableToBuffer will try and apply the embeddings again
await expect(
fromTableToBuffer(table, new DummyEmbedding()),
fromTableToBuffer(table, dummyEmbeddingConfig),
).rejects.toThrow("already existed");
});
});

View File

@@ -13,7 +13,6 @@
// limitations under the License.
import * as tmp from "tmp";
import { Connection, connect } from "../lancedb";
describe("when connecting", () => {

View File

@@ -0,0 +1,166 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32, Int32, Utf8, Vector } from "apache-arrow";
import * as tmp from "tmp";
import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("LanceSchema", () => {
test("should preserve input order", async () => {
const schema = LanceSchema({
id: new Int32(),
text: new Utf8(),
vector: new Float32(),
});
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
});
});
describe("Registry", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should register a new item to the registry", async () => {
@register("mock-embedding")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = getRegistry()
.get<MockEmbeddingFunction>("mock-embedding")!
.create();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{ schema },
);
const expected = [
[1, 2, 3],
[1, 2, 3],
];
const actual = await table.query().toArrow();
const vectors = actual
.getChild("vector")
?.toArray()
.map((x: unknown) => {
if (x instanceof Vector) {
return [...x];
} else {
return x;
}
});
expect(vectors).toEqual(expected);
});
test("should error if registering with the same name", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
register("mock-embedding")(MockEmbeddingFunction);
expect(() => register("mock-embedding")(MockEmbeddingFunction)).toThrow(
'Embedding function with alias "mock-embedding" already exists',
);
});
test("schema should contain correct metadata", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const expectedMetadata = new Map<string, string>([
[
"embedding_functions",
JSON.stringify([
{
sourceColumn: "text",
vectorColumn: "vector",
name: "MockEmbeddingFunction",
model: { someText: "hello" },
},
]),
],
]);
expect(schema.metadata).toEqual(expectedMetadata);
});
});

View File

@@ -19,14 +19,18 @@ import * as tmp from "tmp";
import {
Field,
FixedSizeList,
Float,
Float32,
Float64,
Int32,
Int64,
Schema,
Utf8,
} from "apache-arrow";
import { Table, connect } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices";
describe("Given a table", () => {
@@ -420,6 +424,161 @@ describe("when dealing with versioning", () => {
});
});
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
});
describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult;
let table: Table;

View File

@@ -48,7 +48,7 @@
"noUnsafeFinally": "error",
"noUnsafeOptionalChaining": "error",
"noUnusedLabels": "error",
"noUnusedVariables": "error",
"noUnusedVariables": "warn",
"useIsNan": "error",
"useValidForDirection": "error",
"useYield": "error"
@@ -101,7 +101,13 @@
},
"overrides": [
{
"include": ["**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"],
"include": [
"**/*.ts",
"**/*.tsx",
"**/*.mts",
"**/*.cts",
"__test__/*.test.ts"
],
"linter": {
"rules": {
"correctness": {

View File

@@ -34,6 +34,7 @@ import {
vectorFromArray,
} from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeSchema } from "./sanitize";
/** Data type accepted by NodeJS SDK */
@@ -198,6 +199,7 @@ export class MakeArrowTableOptions {
export function makeArrowTable(
data: Array<Record<string, unknown>>,
options?: Partial<MakeArrowTableOptions>,
metadata?: Map<string, string>,
): ArrowTable {
if (
data.length === 0 &&
@@ -290,20 +292,41 @@ export function makeArrowTable(
// `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns);
const batchesFixed = firstTable.batches.map(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
(batch) => new RecordBatch(opt.schema!, batch.data),
);
return new ArrowTable(opt.schema, batchesFixed);
} else {
return new ArrowTable(columns);
let schema: Schema;
if (metadata !== undefined) {
let schemaMetadata = opt.schema.metadata;
if (schemaMetadata.size === 0) {
schemaMetadata = metadata;
} else {
for (const [key, entry] of schemaMetadata.entries()) {
schemaMetadata.set(key, entry);
}
}
schema = new Schema(opt.schema.fields, schemaMetadata);
} else {
schema = opt.schema;
}
return new ArrowTable(schema, batchesFixed);
}
const tbl = new ArrowTable(columns);
if (metadata !== undefined) {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
(<any>tbl.schema).metadata = metadata;
}
return tbl;
}
/**
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable(schema: Schema): ArrowTable {
return makeArrowTable([], { schema });
export function makeEmptyTable(
schema: Schema,
metadata?: Map<string, string>,
): ArrowTable {
return makeArrowTable([], { schema }, metadata);
}
/**
@@ -375,13 +398,75 @@ function makeVector(
}
}
/** Helper function to apply embeddings from metadata to an input table */
async function applyEmbeddingsFromMetadata(
table: ArrowTable,
schema: Schema,
): Promise<ArrowTable> {
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const columns = Object.fromEntries(
table.schema.fields.map((field) => [
field.name,
table.getChild(field.name)!,
]),
);
for (const functionEntry of functions.values()) {
const sourceColumn = columns[functionEntry.sourceColumn];
const destColumn = functionEntry.vectorColumn ?? "vector";
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`,
);
}
if (columns[destColumn] !== undefined) {
throw new Error(
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
);
}
if (table.batches.length > 1) {
throw new Error(
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
);
}
const values = sourceColumn.toArray();
const vectors =
await functionEntry.function.computeSourceEmbeddings(values);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
);
}
let destType: DataType;
const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
if (dtype instanceof FixedSizeList) {
destType = dtype;
} else {
throw new Error(
"Expected FixedSizeList as datatype for vector field, instead got: " +
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
const newTable = new ArrowTable(columns);
return alignTable(newTable, schema);
}
/** Helper function to apply embeddings to an input table */
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<ArrowTable> {
if (embeddings == null) {
if (schema?.metadata.has("embedding_functions")) {
return applyEmbeddingsFromMetadata(table, schema!);
} else if (embeddings == null || embeddings === undefined) {
return table;
}
@@ -399,8 +484,9 @@ async function applyEmbeddings<T>(
const newColumns = Object.fromEntries(colEntries);
const sourceColumn = newColumns[embeddings.sourceColumn];
const destColumn = embeddings.destColumn ?? "vector";
const innerDestType = embeddings.embeddingDataType ?? new Float32();
const destColumn = embeddings.vectorColumn ?? "vector";
const innerDestType =
embeddings.function.embeddingDataType() ?? new Float32();
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
@@ -414,11 +500,9 @@ async function applyEmbeddings<T>(
// if we call convertToTable with 0 records and a schema that includes the embedding
return table;
}
if (embeddings.embeddingDimension !== undefined) {
const destType = newVectorType(
embeddings.embeddingDimension,
innerDestType,
);
const dimensions = embeddings.function.ndims();
if (dimensions !== undefined) {
const destType = newVectorType(dimensions, innerDestType);
newColumns[destColumn] = makeVector([], destType);
} else if (schema != null) {
const destField = schema.fields.find((f) => f.name === destColumn);
@@ -446,7 +530,9 @@ async function applyEmbeddings<T>(
);
}
const values = sourceColumn.toArray();
const vectors = await embeddings.embed(values as T[]);
const vectors = await embeddings.function.computeSourceEmbeddings(
values as T[],
);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
@@ -486,9 +572,9 @@ async function applyEmbeddings<T>(
* embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns.
*/
export async function convertToTable<T>(
export async function convertToTable(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
makeTableOptions?: Partial<MakeArrowTableOptions>,
): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions);
@@ -496,7 +582,7 @@ export async function convertToTable<T>(
}
/** Creates the Arrow Type for a Vector column with dimension `dim` */
function newVectorType<T extends Float>(
export function newVectorType<T extends Float>(
dim: number,
innerType: T,
): FixedSizeList<T> {
@@ -513,9 +599,9 @@ function newVectorType<T extends Float>(
*
* `schema` is required if data is empty
*/
export async function fromRecordsToBuffer<T>(
export async function fromRecordsToBuffer(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -533,9 +619,9 @@ export async function fromRecordsToBuffer<T>(
*
* `schema` is required if data is empty
*/
export async function fromRecordsToStreamBuffer<T>(
export async function fromRecordsToStreamBuffer(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -554,9 +640,9 @@ export async function fromRecordsToStreamBuffer<T>(
*
* `schema` is required if the table is empty
*/
export async function fromTableToBuffer<T>(
export async function fromTableToBuffer(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -575,9 +661,9 @@ export async function fromTableToBuffer<T>(
*
* `schema` is required if the table is empty
*/
export async function fromDataToBuffer<T>(
export async function fromDataToBuffer(
data: Data,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
@@ -586,8 +672,8 @@ export async function fromDataToBuffer<T>(
if (data instanceof ArrowTable) {
return fromTableToBuffer(data, embeddings, schema);
} else {
const table = await convertToTable(data);
return fromTableToBuffer(table, embeddings, schema);
const table = await convertToTable(data, embeddings, { schema });
return fromTableToBuffer(table);
}
}
@@ -599,9 +685,9 @@ export async function fromDataToBuffer<T>(
*
* `schema` is required if the table is empty
*/
export async function fromTableToStreamBuffer<T>(
export async function fromTableToStreamBuffer(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
embeddings?: EmbeddingFunctionConfig,
schema?: Schema,
): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
@@ -667,7 +753,20 @@ function validateSchemaEmbeddings(
for (const field of schema.fields) {
if (field.type instanceof FixedSizeList) {
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
missingEmbeddingFields.push(field);
if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(
schema.metadata.get("embedding_functions")!,
);
if (
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
undefined
) {
missingEmbeddingFields.push(field);
}
} else {
missingEmbeddingFields.push(field);
}
} else {
fields.push(field);
}

View File

@@ -14,6 +14,7 @@
import { Table as ArrowTable, Schema } from "apache-arrow";
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
import { Table } from "./table";
@@ -65,6 +66,8 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
schema?: Schema;
embeddingFunction?: EmbeddingFunctionConfig;
}
export interface OpenTableOptions {
@@ -174,6 +177,7 @@ export class Connection {
cleanseStorageOptions(options?.storageOptions),
options?.indexCacheSize,
);
return new Table(innerTable);
}
@@ -199,15 +203,21 @@ export class Connection {
if (data instanceof ArrowTable) {
table = data;
} else {
table = makeArrowTable(data);
table = makeArrowTable(data, options);
}
const buf = await fromTableToBuffer(table);
const buf = await fromTableToBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
const innerTable = await this.inner.createTable(
name,
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
);
return new Table(innerTable);
}
@@ -227,8 +237,14 @@ export class Connection {
if (mode === "create" && existOk) {
mode = "exist_ok";
}
let metadata: Map<string, string> | undefined = undefined;
if (options?.embeddingFunction !== undefined) {
const embeddingFunction = options.embeddingFunction;
const registry = getRegistry();
metadata = registry.getTableMetadata([embeddingFunction]);
}
const table = makeEmptyTable(schema);
const table = makeEmptyTable(schema, metadata);
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createEmptyTable(
name,

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers.
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,67 +12,141 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { type Float } from "apache-arrow";
import { DataType, Field, FixedSizeList, Float, Float32 } from "apache-arrow";
import "reflect-metadata";
import { newVectorType } from "../arrow";
/**
* Options for a given embedding function
*/
export interface FunctionOptions {
// biome-ignore lint/suspicious/noExplicitAny: options can be anything
[key: string]: any;
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
export abstract class EmbeddingFunction<
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
T = any,
M extends FunctionOptions = FunctionOptions,
> {
/**
* The name of the column that will be used as input for the Embedding Function.
* Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema
* It's important that any object returned by this method contains all the necessary
* information to recreate the embedding function
*
* It should return the same object that was passed to the constructor
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
*
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* constructor(options: {model: string, timeout: number}) {
* super();
* this.model = options.model;
* this.timeout = options.timeout;
* }
* toJSON() {
* return {
* model: this.model,
* timeout: this.timeout,
* };
* }
* ```
*/
sourceColumn: string;
abstract toJSON(): Partial<M>;
/**
* The data type of the embedding
* sourceField is used in combination with `LanceSchema` to provide a declarative data model
*
* The embedding function should return `number`. This will be converted into
* an Arrow float array. By default this will be Float32 but this property can
* be used to control the conversion.
* @param optionsOrDatatype - The options for the field or the datatype
*
* @see {@link lancedb.LanceSchema}
*/
embeddingDataType?: Float;
sourceField(
optionsOrDatatype: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
const datatype =
optionsOrDatatype instanceof DataType
? optionsOrDatatype
: optionsOrDatatype?.datatype;
if (!datatype) {
throw new Error("Datatype is required");
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("source_column_for", this);
return [datatype, metadata];
}
/**
* The dimension of the embedding
* vectorField is used in combination with `LanceSchema` to provide a declarative data model
*
* This is optional, normally this can be determined by looking at the results of
* `embed`. If this is not specified, and there is an attempt to apply the embedding
* to an empty table, then that process will fail.
* @param options - The options for the field
*
* @see {@link lancedb.LanceSchema}
*/
embeddingDimension?: number;
vectorField(
options?: Partial<FieldOptions>,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType;
const dims = this.ndims() ?? options?.dims;
if (!options?.datatype) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
} else {
if (options.datatype instanceof FixedSizeList) {
dtype = options.datatype;
} else if (options.datatype instanceof Float) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = newVectorType(dims, options.datatype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
/**
* The name of the column that will contain the embedding
*
* By default this is "vector"
*/
destColumn?: string;
return [dtype, metadata];
}
/**
* Should the source column be excluded from the resulting table
*
* By default the source column is included. Set this to true and
* only the embedding will be stored.
*/
excludeSource?: boolean;
/** The number of dimensions of the embeddings */
ndims(): number | undefined {
return undefined;
}
/** The datatype of the embeddings */
abstract embeddingDataType(): Float;
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => Promise<number[][]>;
abstract computeSourceEmbeddings(
data: T[],
): Promise<number[][] | Float32Array[] | Float64Array[]>;
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(
data: T,
): Promise<number[] | Float32Array | Float64Array> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);
}
}
/** Test if the input seems to be an embedding function */
export function isEmbeddingFunction<T>(
value: unknown,
): value is EmbeddingFunction<T> {
if (typeof value !== "object" || value === null) {
return false;
}
if (!("sourceColumn" in value) || !("embed" in value)) {
return false;
}
return (
typeof value.sourceColumn === "string" && typeof value.embed === "function"
);
export interface FieldOptions<T extends DataType = DataType> {
datatype: T;
dims?: number;
}

View File

@@ -1,2 +1,105 @@
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function";
export { OpenAIEmbeddingFunction } from "./openai";
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { DataType, Field, Schema } from "apache-arrow";
import { EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
export { EmbeddingFunction } from "./embedding_function";
export * from "./openai";
/**
* Create a schema with embedding functions.
*
* @param fields
* @returns Schema
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* // ...
* }
* const func = new MyEmbeddingFunction();
* const schema = LanceSchema({
* id: new Int32(),
* text: func.sourceField(new Utf8()),
* vector: func.vectorField(),
* // optional: specify the datatype and/or dimensions
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}),
* });
*
* const table = await db.createTable("my_table", data, { schema });
* ```
*/
export function LanceSchema(
fields: Record<string, [DataType, Map<string, EmbeddingFunction>] | DataType>,
): Schema {
const arrowFields: Field[] = [];
const embeddingFunctions = new Map<
EmbeddingFunction,
Partial<EmbeddingFunctionConfig>
>();
Object.entries(fields).forEach(([key, value]) => {
if (value instanceof DataType) {
arrowFields.push(new Field(key, value, true));
} else {
const [dtype, metadata] = value;
arrowFields.push(new Field(key, dtype, true));
parseEmbeddingFunctions(embeddingFunctions, key, metadata);
}
});
const registry = getRegistry();
const metadata = registry.getTableMetadata(
Array.from(embeddingFunctions.values()) as EmbeddingFunctionConfig[],
);
const schema = new Schema(arrowFields, metadata);
return schema;
}
function parseEmbeddingFunctions(
embeddingFunctions: Map<EmbeddingFunction, Partial<EmbeddingFunctionConfig>>,
key: string,
metadata: Map<string, EmbeddingFunction>,
): void {
if (metadata.has("source_column_for")) {
const embedFunction = metadata.get("source_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
sourceColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
sourceColumn: key,
function: embedFunction,
});
}
} else if (metadata.has("vector_column_for")) {
const embedFunction = metadata.get("vector_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
vectorColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
vectorColumn: key,
function: embedFunction,
});
}
}
}

View File

@@ -12,18 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32 } from "apache-arrow";
import type OpenAI from "openai";
import { type EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
private readonly _openai: OpenAI;
private readonly _modelName: string;
export type OpenAIOptions = {
apiKey?: string;
model?: string;
};
@register("openai")
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
string,
OpenAIOptions
> {
#openai: OpenAI;
#modelName: string;
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
super();
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) {
throw new Error("OpenAI API key is required");
}
const modelName = options?.model ?? "text-embedding-ada-002";
constructor(
sourceColumn: string,
openAIKey: string,
modelName: string = "text-embedding-ada-002",
) {
/**
* @type {import("openai").default}
*/
@@ -36,18 +50,40 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
throw new Error("please install openai@^4.24.1 using npm install openai");
}
this.sourceColumn = sourceColumn;
const configuration = {
apiKey: openAIKey,
};
this._openai = new Openai(configuration);
this._modelName = modelName;
this.#openai = new Openai(configuration);
this.#modelName = modelName;
}
async embed(data: string[]): Promise<number[][]> {
const response = await this._openai.embeddings.create({
model: this._modelName,
toJSON() {
return {
model: this.#modelName,
};
}
ndims(): number {
switch (this.#modelName) {
case "text-embedding-ada-002":
return 1536;
case "text-embedding-3-large":
return 3072;
case "text-embedding-3-small":
return 1536;
default:
return null as never;
}
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
const response = await this.#openai.embeddings.create({
model: this.#modelName,
input: data,
});
@@ -58,5 +94,15 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
return embeddings;
}
sourceColumn: string;
async computeQueryEmbeddings(data: string): Promise<number[]> {
if (typeof data !== "string") {
throw new Error("Data must be a string");
}
const response = await this.#openai.embeddings.create({
model: this.#modelName,
input: data,
});
return response.data[0].embedding;
}
}

View File

@@ -0,0 +1,172 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import type { EmbeddingFunction } from "./embedding_function";
import "reflect-metadata";
export interface EmbeddingFunctionOptions {
[key: string]: unknown;
}
export interface EmbeddingFunctionFactory<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: EmbeddingFunctionOptions): T;
}
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: EmbeddingFunctionOptions): T;
}
/**
* This is a singleton class used to register embedding functions
* and fetch them by name. It also handles serializing and deserializing.
* You can implement your own embedding function by subclassing EmbeddingFunction
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
/**
* Register an embedding function
* @param name The name of the function
* @param func The function to register
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
): (ctor: T) => any {
const self = this;
return function (ctor: T) {
if (!alias) {
alias = ctor.name;
}
if (self.#functions.has(alias)) {
throw new Error(
`Embedding function with alias "${alias}" already exists`,
);
}
self.#functions.set(alias, ctor);
Reflect.defineMetadata("lancedb::embedding::name", alias, ctor);
return ctor;
};
}
/**
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
name: string,
): EmbeddingFunctionCreate<T> | undefined {
const factory = this.#functions.get(name);
if (!factory) {
return undefined;
}
return {
create: function (options: EmbeddingFunctionOptions) {
return new factory(options) as unknown as T;
},
};
}
/**
* reset the registry to the initial state
*/
reset(this: EmbeddingFunctionRegistry) {
this.#functions.clear();
}
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,
): Map<string, EmbeddingFunctionConfig> {
if (!metadata.has("embedding_functions")) {
return new Map();
} else {
type FunctionConfig = {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunctionOptions;
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)
);
return new Map(
functions.map((f) => {
const fn = this.get(f.name);
if (!fn) {
throw new Error(`Function "${f.name}" not found in registry`);
}
return [
f.name,
{
sourceColumn: f.sourceColumn,
vectorColumn: f.vectorColumn,
function: this.get(f.name)!.create(f.model),
},
];
}),
);
}
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
functionToMetadata(conf: EmbeddingFunctionConfig): Record<string, any> {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const metadata: Record<string, any> = {};
const name = Reflect.getMetadata(
"lancedb::embedding::name",
conf.function.constructor,
);
metadata["sourceColumn"] = conf.sourceColumn;
metadata["vectorColumn"] = conf.vectorColumn ?? "vector";
metadata["name"] = name ?? conf.function.constructor.name;
metadata["model"] = conf.function.toJSON();
return metadata;
}
getTableMetadata(functions: EmbeddingFunctionConfig[]): Map<string, string> {
const metadata = new Map<string, string>();
const jsonData = functions.map((conf) => this.functionToMetadata(conf));
metadata.set("embedding_functions", JSON.stringify(jsonData));
return metadata;
}
}
const _REGISTRY = new EmbeddingFunctionRegistry();
export function register(name?: string) {
return _REGISTRY.register(name);
}
/**
* Utility function to get the global instance of the registry
* @returns `EmbeddingFunctionRegistry` The global instance of the registry
* @example
* ```ts
* const registry = getRegistry();
* const openai = registry.get("openai").create();
*/
export function getRegistry(): EmbeddingFunctionRegistry {
return _REGISTRY;
}
export interface EmbeddingFunctionConfig {
sourceColumn: string;
vectorColumn?: string;
function: EmbeddingFunction;
}

View File

@@ -170,6 +170,7 @@ export class QueryBase<
/** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}

View File

@@ -14,6 +14,7 @@
import { Schema, tableFromIPC } from "apache-arrow";
import { Data, fromDataToBuffer } from "./arrow";
import { getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices";
import {
AddColumnsSql,
@@ -122,8 +123,14 @@ export class Table {
*/
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(data);
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
);
await this.inner.add(buffer, mode);
}

15383
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -62,6 +62,7 @@
},
"dependencies": {
"apache-arrow": "^15.0.0",
"openai": "^4.29.2"
"openai": "^4.29.2",
"reflect-metadata": "^0.2.2"
}
}

View File

@@ -7,7 +7,9 @@
"outDir": "./dist",
"strict": true,
"allowJs": true,
"resolveJsonModule": true
"resolveJsonModule": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true
},
"exclude": ["./dist/*"],
"typedocOptions": {