Files
lancedb/nodejs/lancedb/embedding/transformers.ts
Will Jones 0fd8a50bd7 ci(node): run examples in CI (#1796)
This is done as setup for a PR that will fix the OpenAI dependency
issue.

 * [x] FTS examples
 * [x] Setup mock openai
 * [x] Ran `npm audit fix`
 * [x] sentences embeddings test
 * [x] Double check formatting of docs examples
2024-11-13 11:10:56 -08:00

196 lines
6.0 KiB
TypeScript

// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
export type XenovaTransformerOptions = {
/** The wasm compatible model to use */
model: string;
/**
* The wasm compatible tokenizer to use
* If not provided, it will use the default tokenizer for the model
*/
tokenizer?: string;
/**
* The number of dimensions of the embeddings
*
* We will attempt to infer this from the model config if not provided.
* Since there isn't a standard way to get this information from the model,
* you may need to manually specify this if using a model that doesn't have a 'hidden_size' in the config.
* */
ndims?: number;
/** Options for the tokenizer */
tokenizerOptions?: {
textPair?: string | string[];
padding?: boolean | "max_length";
addSpecialTokens?: boolean;
truncation?: boolean;
maxLength?: number;
};
};
@register("huggingface")
export class TransformersEmbeddingFunction extends EmbeddingFunction<
string,
Partial<XenovaTransformerOptions>
> {
#model?: import("@huggingface/transformers").PreTrainedModel;
#tokenizer?: import("@huggingface/transformers").PreTrainedTokenizer;
#modelName: XenovaTransformerOptions["model"];
#initialized = false;
#tokenizerOptions: XenovaTransformerOptions["tokenizerOptions"];
#ndims?: number;
constructor(
options: Partial<XenovaTransformerOptions> = {
model: "Xenova/all-MiniLM-L6-v2",
},
) {
super();
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
this.#tokenizerOptions = {
padding: true,
...options.tokenizerOptions,
};
this.#ndims = options.ndims;
this.#modelName = modelName;
}
toJSON() {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const obj: Record<string, any> = {
model: this.#modelName,
};
if (this.#ndims) {
obj["ndims"] = this.#ndims;
}
if (this.#tokenizerOptions) {
obj["tokenizerOptions"] = this.#tokenizerOptions;
}
if (this.#tokenizer) {
obj["tokenizer"] = this.#tokenizer.name;
}
return obj;
}
async init() {
let transformers;
try {
// SAFETY:
// since typescript transpiles `import` to `require`, we need to do this in an unsafe way
// We can't use `require` because `@huggingface/transformers` is an ESM module
// and we can't use `import` directly because typescript will transpile it to `require`.
// and we want to remain compatible with both ESM and CJS modules
// so we use `eval` to bypass typescript for this specific import.
transformers = await eval('import("@huggingface/transformers")');
} catch (e) {
throw new Error(`error loading @huggingface/transformers\nReason: ${e}`);
}
try {
this.#model = await transformers.AutoModel.from_pretrained(
this.#modelName,
{ dtype: "fp32" },
);
} catch (e) {
throw new Error(
`error loading model ${this.#modelName}. Make sure you are using a wasm compatible model.\nReason: ${e}`,
);
}
try {
this.#tokenizer = await transformers.AutoTokenizer.from_pretrained(
this.#modelName,
);
} catch (e) {
throw new Error(
`error loading tokenizer for ${this.#modelName}. Make sure you are using a wasm compatible model:\nReason: ${e}`,
);
}
this.#initialized = true;
}
ndims(): number {
if (this.#ndims) {
return this.#ndims;
} else {
const config = this.#model!.config;
// biome-ignore lint/style/useNamingConvention: we don't control this name.
const ndims = (config as unknown as { hidden_size: number }).hidden_size;
if (!ndims) {
throw new Error(
"hidden_size not found in model config, you may need to manually specify the embedding dimensions. ",
);
}
return ndims;
}
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
// this should only happen if the user is trying to use the function directly.
// Anything going through the registry should already be initialized.
if (!this.#initialized) {
return Promise.reject(
new Error(
"something went wrong: embedding function not initialized. Please call init()",
),
);
}
const tokenizer = this.#tokenizer!;
const model = this.#model!;
const inputs = await tokenizer(data, this.#tokenizerOptions);
let tokens = await model.forward(inputs);
tokens = tokens[Object.keys(tokens)[0]];
const [nItems, nTokens] = tokens.dims;
tokens = tensorDiv(tokens.sum(1), nTokens);
// TODO: support other data types
const tokenData = tokens.data;
const stride = this.ndims();
const embeddings = [];
for (let i = 0; i < nItems; i++) {
const start = i * stride;
const end = start + stride;
const slice = tokenData.slice(start, end);
embeddings.push(Array.from(slice) as number[]); // TODO: Avoid copy here
}
return embeddings;
}
async computeQueryEmbeddings(data: string): Promise<number[]> {
return (await this.computeSourceEmbeddings([data]))[0];
}
}
const tensorDiv = (
src: import("@huggingface/transformers").Tensor,
divBy: number,
) => {
for (let i = 0; i < src.data.length; ++i) {
src.data[i] /= divBy;
}
return src;
};