mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-27 00:40:43 +00:00
feat(nodejs): huggingface compatible transformers (#1462)
This commit is contained in:
@@ -41,6 +41,7 @@ export interface EmbeddingFunctionConstructor<
|
||||
> {
|
||||
new (modelOptions?: T["TOptions"]): T;
|
||||
}
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
@@ -82,6 +83,8 @@ export abstract class EmbeddingFunction<
|
||||
*/
|
||||
abstract toJSON(): Partial<M>;
|
||||
|
||||
async init?(): Promise<void>;
|
||||
|
||||
/**
|
||||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model
|
||||
*
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { DataType, Field, Schema } from "../arrow";
|
||||
import { Field, Schema } from "../arrow";
|
||||
import { isDataType } from "../arrow";
|
||||
import { sanitizeType } from "../sanitize";
|
||||
import { EmbeddingFunction } from "./embedding_function";
|
||||
@@ -22,6 +22,7 @@ export { EmbeddingFunction } from "./embedding_function";
|
||||
|
||||
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
|
||||
export * from "./openai";
|
||||
export * from "./transformers";
|
||||
export * from "./registry";
|
||||
|
||||
/**
|
||||
|
||||
@@ -18,9 +18,14 @@ import {
|
||||
} from "./embedding_function";
|
||||
import "reflect-metadata";
|
||||
import { OpenAIEmbeddingFunction } from "./openai";
|
||||
import { TransformersEmbeddingFunction } from "./transformers";
|
||||
|
||||
type CreateReturnType<T> = T extends { init: () => Promise<void> }
|
||||
? Promise<T>
|
||||
: T;
|
||||
|
||||
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
create(options?: T["TOptions"]): T;
|
||||
create(options?: T["TOptions"]): CreateReturnType<T>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -61,38 +66,43 @@ export class EmbeddingFunctionRegistry {
|
||||
};
|
||||
}
|
||||
|
||||
get(name: "openai"): EmbeddingFunctionCreate<OpenAIEmbeddingFunction>;
|
||||
get(
|
||||
name: "huggingface",
|
||||
): EmbeddingFunctionCreate<TransformersEmbeddingFunction>;
|
||||
get<T extends EmbeddingFunction<unknown>>(
|
||||
name: string,
|
||||
): EmbeddingFunctionCreate<T> | undefined;
|
||||
/**
|
||||
* Fetch an embedding function by name
|
||||
* @param name The name of the function
|
||||
*/
|
||||
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
|
||||
name: Name extends "openai" ? "openai" : string,
|
||||
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
|
||||
// ex:
|
||||
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
|
||||
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
|
||||
//
|
||||
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
|
||||
// ```ts
|
||||
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
|
||||
// ```
|
||||
): Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined {
|
||||
type Output = Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined;
|
||||
|
||||
get(name: string) {
|
||||
const factory = this.#functions.get(name);
|
||||
if (!factory) {
|
||||
return undefined as Output;
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
return undefined as any;
|
||||
}
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
let create: any;
|
||||
if (factory.prototype.init) {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
create = async function (options?: any) {
|
||||
const instance = new factory(options);
|
||||
await instance.init!();
|
||||
return instance;
|
||||
};
|
||||
} else {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
create = function (options?: any) {
|
||||
const instance = new factory(options);
|
||||
return instance;
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
create: function (options?: T["TOptions"]) {
|
||||
return new factory(options);
|
||||
},
|
||||
} as Output;
|
||||
create,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -105,10 +115,10 @@ export class EmbeddingFunctionRegistry {
|
||||
/**
|
||||
* @ignore
|
||||
*/
|
||||
parseFunctions(
|
||||
async parseFunctions(
|
||||
this: EmbeddingFunctionRegistry,
|
||||
metadata: Map<string, string>,
|
||||
): Map<string, EmbeddingFunctionConfig> {
|
||||
): Promise<Map<string, EmbeddingFunctionConfig>> {
|
||||
if (!metadata.has("embedding_functions")) {
|
||||
return new Map();
|
||||
} else {
|
||||
@@ -118,25 +128,30 @@ export class EmbeddingFunctionRegistry {
|
||||
vectorColumn: string;
|
||||
model: EmbeddingFunction["TOptions"];
|
||||
};
|
||||
|
||||
const functions = <FunctionConfig[]>(
|
||||
JSON.parse(metadata.get("embedding_functions")!)
|
||||
);
|
||||
return new Map(
|
||||
functions.map((f) => {
|
||||
|
||||
const items: [string, EmbeddingFunctionConfig][] = await Promise.all(
|
||||
functions.map(async (f) => {
|
||||
const fn = this.get(f.name);
|
||||
if (!fn) {
|
||||
throw new Error(`Function "${f.name}" not found in registry`);
|
||||
}
|
||||
const func = await this.get(f.name)!.create(f.model);
|
||||
return [
|
||||
f.name,
|
||||
{
|
||||
sourceColumn: f.sourceColumn,
|
||||
vectorColumn: f.vectorColumn,
|
||||
function: this.get(f.name)!.create(f.model),
|
||||
function: func,
|
||||
},
|
||||
];
|
||||
}),
|
||||
);
|
||||
|
||||
return new Map(items);
|
||||
}
|
||||
}
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
|
||||
193
nodejs/lancedb/embedding/transformers.ts
Normal file
193
nodejs/lancedb/embedding/transformers.ts
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { Float, Float32 } from "../arrow";
|
||||
import { EmbeddingFunction } from "./embedding_function";
|
||||
import { register } from "./registry";
|
||||
|
||||
export type XenovaTransformerOptions = {
|
||||
/** The wasm compatible model to use */
|
||||
model: string;
|
||||
/**
|
||||
* The wasm compatible tokenizer to use
|
||||
* If not provided, it will use the default tokenizer for the model
|
||||
*/
|
||||
tokenizer?: string;
|
||||
/**
|
||||
* The number of dimensions of the embeddings
|
||||
*
|
||||
* We will attempt to infer this from the model config if not provided.
|
||||
* Since there isn't a standard way to get this information from the model,
|
||||
* you may need to manually specify this if using a model that doesn't have a 'hidden_size' in the config.
|
||||
* */
|
||||
ndims?: number;
|
||||
/** Options for the tokenizer */
|
||||
tokenizerOptions?: {
|
||||
textPair?: string | string[];
|
||||
padding?: boolean | "max_length";
|
||||
addSpecialTokens?: boolean;
|
||||
truncation?: boolean;
|
||||
maxLength?: number;
|
||||
};
|
||||
};
|
||||
|
||||
@register("huggingface")
|
||||
export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
||||
string,
|
||||
Partial<XenovaTransformerOptions>
|
||||
> {
|
||||
#model?: import("@xenova/transformers").PreTrainedModel;
|
||||
#tokenizer?: import("@xenova/transformers").PreTrainedTokenizer;
|
||||
#modelName: XenovaTransformerOptions["model"];
|
||||
#initialized = false;
|
||||
#tokenizerOptions: XenovaTransformerOptions["tokenizerOptions"];
|
||||
#ndims?: number;
|
||||
|
||||
constructor(
|
||||
options: Partial<XenovaTransformerOptions> = {
|
||||
model: "Xenova/all-MiniLM-L6-v2",
|
||||
},
|
||||
) {
|
||||
super();
|
||||
|
||||
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
|
||||
this.#tokenizerOptions = {
|
||||
padding: true,
|
||||
...options.tokenizerOptions,
|
||||
};
|
||||
|
||||
this.#ndims = options.ndims;
|
||||
this.#modelName = modelName;
|
||||
}
|
||||
toJSON() {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
const obj: Record<string, any> = {
|
||||
model: this.#modelName,
|
||||
};
|
||||
if (this.#ndims) {
|
||||
obj["ndims"] = this.#ndims;
|
||||
}
|
||||
if (this.#tokenizerOptions) {
|
||||
obj["tokenizerOptions"] = this.#tokenizerOptions;
|
||||
}
|
||||
if (this.#tokenizer) {
|
||||
obj["tokenizer"] = this.#tokenizer.name;
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
async init() {
|
||||
let transformers;
|
||||
try {
|
||||
// SAFETY:
|
||||
// since typescript transpiles `import` to `require`, we need to do this in an unsafe way
|
||||
// We can't use `require` because `@xenova/transformers` is an ESM module
|
||||
// and we can't use `import` directly because typescript will transpile it to `require`.
|
||||
// and we want to remain compatible with both ESM and CJS modules
|
||||
// so we use `eval` to bypass typescript for this specific import.
|
||||
transformers = await eval('import("@xenova/transformers")');
|
||||
} catch (e) {
|
||||
throw new Error(`error loading @xenova/transformers\nReason: ${e}`);
|
||||
}
|
||||
|
||||
try {
|
||||
this.#model = await transformers.AutoModel.from_pretrained(
|
||||
this.#modelName,
|
||||
);
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`error loading model ${this.#modelName}. Make sure you are using a wasm compatible model.\nReason: ${e}`,
|
||||
);
|
||||
}
|
||||
try {
|
||||
this.#tokenizer = await transformers.AutoTokenizer.from_pretrained(
|
||||
this.#modelName,
|
||||
);
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`error loading tokenizer for ${this.#modelName}. Make sure you are using a wasm compatible model:\nReason: ${e}`,
|
||||
);
|
||||
}
|
||||
this.#initialized = true;
|
||||
}
|
||||
|
||||
ndims(): number {
|
||||
if (this.#ndims) {
|
||||
return this.#ndims;
|
||||
} else {
|
||||
const config = this.#model!.config;
|
||||
|
||||
const ndims = config["hidden_size"];
|
||||
if (!ndims) {
|
||||
throw new Error(
|
||||
"hidden_size not found in model config, you may need to manually specify the embedding dimensions. ",
|
||||
);
|
||||
}
|
||||
return ndims;
|
||||
}
|
||||
}
|
||||
|
||||
embeddingDataType(): Float {
|
||||
return new Float32();
|
||||
}
|
||||
|
||||
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
|
||||
// this should only happen if the user is trying to use the function directly.
|
||||
// Anything going through the registry should already be initialized.
|
||||
if (!this.#initialized) {
|
||||
return Promise.reject(
|
||||
new Error(
|
||||
"something went wrong: embedding function not initialized. Please call init()",
|
||||
),
|
||||
);
|
||||
}
|
||||
const tokenizer = this.#tokenizer!;
|
||||
const model = this.#model!;
|
||||
|
||||
const inputs = await tokenizer(data, this.#tokenizerOptions);
|
||||
let tokens = await model.forward(inputs);
|
||||
tokens = tokens[Object.keys(tokens)[0]];
|
||||
|
||||
const [nItems, nTokens] = tokens.dims;
|
||||
|
||||
tokens = tensorDiv(tokens.sum(1), nTokens);
|
||||
|
||||
// TODO: support other data types
|
||||
const tokenData = tokens.data;
|
||||
const stride = this.ndims();
|
||||
|
||||
const embeddings = [];
|
||||
for (let i = 0; i < nItems; i++) {
|
||||
const start = i * stride;
|
||||
const end = start + stride;
|
||||
const slice = tokenData.slice(start, end);
|
||||
embeddings.push(Array.from(slice) as number[]); // TODO: Avoid copy here
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
async computeQueryEmbeddings(data: string): Promise<number[]> {
|
||||
return (await this.computeSourceEmbeddings([data]))[0];
|
||||
}
|
||||
}
|
||||
|
||||
const tensorDiv = (
|
||||
src: import("@xenova/transformers").Tensor,
|
||||
divBy: number,
|
||||
) => {
|
||||
for (let i = 0; i < src.data.length; ++i) {
|
||||
src.data[i] /= divBy;
|
||||
}
|
||||
return src;
|
||||
};
|
||||
Reference in New Issue
Block a user