mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-15 08:12:58 +00:00
feat: js embedding registry (#1308)
--------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,67 +12,141 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { type Float } from "apache-arrow";
|
||||
import { DataType, Field, FixedSizeList, Float, Float32 } from "apache-arrow";
|
||||
import "reflect-metadata";
|
||||
import { newVectorType } from "../arrow";
|
||||
|
||||
/**
|
||||
* Options for a given embedding function
|
||||
*/
|
||||
export interface FunctionOptions {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: options can be anything
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
export interface EmbeddingFunction<T> {
|
||||
export abstract class EmbeddingFunction<
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
||||
T = any,
|
||||
M extends FunctionOptions = FunctionOptions,
|
||||
> {
|
||||
/**
|
||||
* The name of the column that will be used as input for the Embedding Function.
|
||||
* Convert the embedding function to a JSON object
|
||||
* It is used to serialize the embedding function to the schema
|
||||
* It's important that any object returned by this method contains all the necessary
|
||||
* information to recreate the embedding function
|
||||
*
|
||||
* It should return the same object that was passed to the constructor
|
||||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
* constructor(options: {model: string, timeout: number}) {
|
||||
* super();
|
||||
* this.model = options.model;
|
||||
* this.timeout = options.timeout;
|
||||
* }
|
||||
* toJSON() {
|
||||
* return {
|
||||
* model: this.model,
|
||||
* timeout: this.timeout,
|
||||
* };
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
sourceColumn: string;
|
||||
abstract toJSON(): Partial<M>;
|
||||
|
||||
/**
|
||||
* The data type of the embedding
|
||||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model
|
||||
*
|
||||
* The embedding function should return `number`. This will be converted into
|
||||
* an Arrow float array. By default this will be Float32 but this property can
|
||||
* be used to control the conversion.
|
||||
* @param optionsOrDatatype - The options for the field or the datatype
|
||||
*
|
||||
* @see {@link lancedb.LanceSchema}
|
||||
*/
|
||||
embeddingDataType?: Float;
|
||||
sourceField(
|
||||
optionsOrDatatype: Partial<FieldOptions> | DataType,
|
||||
): [DataType, Map<string, EmbeddingFunction>] {
|
||||
const datatype =
|
||||
optionsOrDatatype instanceof DataType
|
||||
? optionsOrDatatype
|
||||
: optionsOrDatatype?.datatype;
|
||||
if (!datatype) {
|
||||
throw new Error("Datatype is required");
|
||||
}
|
||||
const metadata = new Map<string, EmbeddingFunction>();
|
||||
metadata.set("source_column_for", this);
|
||||
|
||||
return [datatype, metadata];
|
||||
}
|
||||
|
||||
/**
|
||||
* The dimension of the embedding
|
||||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model
|
||||
*
|
||||
* This is optional, normally this can be determined by looking at the results of
|
||||
* `embed`. If this is not specified, and there is an attempt to apply the embedding
|
||||
* to an empty table, then that process will fail.
|
||||
* @param options - The options for the field
|
||||
*
|
||||
* @see {@link lancedb.LanceSchema}
|
||||
*/
|
||||
embeddingDimension?: number;
|
||||
vectorField(
|
||||
options?: Partial<FieldOptions>,
|
||||
): [DataType, Map<string, EmbeddingFunction>] {
|
||||
let dtype: DataType;
|
||||
const dims = this.ndims() ?? options?.dims;
|
||||
if (!options?.datatype) {
|
||||
if (dims === undefined) {
|
||||
throw new Error("ndims is required for vector field");
|
||||
}
|
||||
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
|
||||
} else {
|
||||
if (options.datatype instanceof FixedSizeList) {
|
||||
dtype = options.datatype;
|
||||
} else if (options.datatype instanceof Float) {
|
||||
if (dims === undefined) {
|
||||
throw new Error("ndims is required for vector field");
|
||||
}
|
||||
dtype = newVectorType(dims, options.datatype);
|
||||
} else {
|
||||
throw new Error(
|
||||
"Expected FixedSizeList or Float as datatype for vector field",
|
||||
);
|
||||
}
|
||||
}
|
||||
const metadata = new Map<string, EmbeddingFunction>();
|
||||
metadata.set("vector_column_for", this);
|
||||
|
||||
/**
|
||||
* The name of the column that will contain the embedding
|
||||
*
|
||||
* By default this is "vector"
|
||||
*/
|
||||
destColumn?: string;
|
||||
return [dtype, metadata];
|
||||
}
|
||||
|
||||
/**
|
||||
* Should the source column be excluded from the resulting table
|
||||
*
|
||||
* By default the source column is included. Set this to true and
|
||||
* only the embedding will be stored.
|
||||
*/
|
||||
excludeSource?: boolean;
|
||||
/** The number of dimensions of the embeddings */
|
||||
ndims(): number | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** The datatype of the embeddings */
|
||||
abstract embeddingDataType(): Float;
|
||||
|
||||
/**
|
||||
* Creates a vector representation for the given values.
|
||||
*/
|
||||
embed: (data: T[]) => Promise<number[][]>;
|
||||
abstract computeSourceEmbeddings(
|
||||
data: T[],
|
||||
): Promise<number[][] | Float32Array[] | Float64Array[]>;
|
||||
|
||||
/**
|
||||
Compute the embeddings for a single query
|
||||
*/
|
||||
async computeQueryEmbeddings(
|
||||
data: T,
|
||||
): Promise<number[] | Float32Array | Float64Array> {
|
||||
return this.computeSourceEmbeddings([data]).then(
|
||||
(embeddings) => embeddings[0],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Test if the input seems to be an embedding function */
|
||||
export function isEmbeddingFunction<T>(
|
||||
value: unknown,
|
||||
): value is EmbeddingFunction<T> {
|
||||
if (typeof value !== "object" || value === null) {
|
||||
return false;
|
||||
}
|
||||
if (!("sourceColumn" in value) || !("embed" in value)) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
typeof value.sourceColumn === "string" && typeof value.embed === "function"
|
||||
);
|
||||
export interface FieldOptions<T extends DataType = DataType> {
|
||||
datatype: T;
|
||||
dims?: number;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user