docs: custom embedding function for ts (#1479)

This commit is contained in:
Cory Grinstead
2024-07-30 18:19:55 -05:00
committed by GitHub
parent 277b753fd8
commit a062a92f6b
4 changed files with 295 additions and 170 deletions

View File

@@ -0,0 +1,64 @@
// --8<-- [start:imports]
import * as lancedb from "@lancedb/lancedb";
import {
LanceSchema,
TextEmbeddingFunction,
getRegistry,
register,
} from "@lancedb/lancedb/embedding";
import { pipeline } from "@xenova/transformers";
// --8<-- [end:imports]
// --8<-- [start:embedding_impl]
@register("sentence-transformers")
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
name = "Xenova/all-miniLM-L6-v2";
#ndims!: number;
extractor: any;
async init() {
this.extractor = await pipeline("feature-extraction", this.name);
this.#ndims = await this.generateEmbeddings(["hello"]).then(
(e) => e[0].length,
);
}
ndims() {
return this.#ndims;
}
toJSON() {
return {
name: this.name,
};
}
async generateEmbeddings(texts: string[]) {
const output = await this.extractor(texts, {
pooling: "mean",
normalize: true,
});
return output.tolist();
}
}
// -8<-- [end:embedding_impl]
// --8<-- [start:call_custom_function]
const registry = getRegistry();
const sentenceTransformer = await registry
.get<SentenceTransformersEmbeddings>("sentence-transformers")!
.create();
const schema = LanceSchema({
vector: sentenceTransformer.vectorField(),
text: sentenceTransformer.sourceField(),
});
const db = await lancedb.connect("/tmp/db");
const table = await db.createEmptyTable("table", schema, { mode: "overwrite" });
await table.add([{ text: "hello" }, { text: "world" }]);
const results = await table.search("greeting").limit(1).toArray();
console.log(results[0].text);
// -8<-- [end:call_custom_function]

View File

@@ -21,6 +21,7 @@ import {
Float32,
FloatLike,
type IntoVector,
Utf8,
isDataType,
isFixedSizeList,
isFloat,
@@ -192,6 +193,38 @@ export abstract class EmbeddingFunction<
}
}
/**
* an abstract class for implementing embedding functions that take text as input
*/
export abstract class TextEmbeddingFunction<
M extends FunctionOptions = FunctionOptions,
> extends EmbeddingFunction<string, M> {
//** Generate the embeddings for the given texts */
abstract generateEmbeddings(
texts: string[],
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
...args: any[]
): Promise<number[][] | Float32Array[] | Float64Array[]>;
async computeQueryEmbeddings(data: string): Promise<Awaited<IntoVector>> {
return this.generateEmbeddings([data]).then((data) => data[0]);
}
embeddingDataType(): FloatLike {
return new Float32();
}
override sourceField(): [DataTypeLike, Map<string, EmbeddingFunction>] {
return super.sourceField(new Utf8());
}
computeSourceEmbeddings(
data: string[],
): Promise<number[][] | Float32Array[] | Float64Array[]> {
return this.generateEmbeddings(data);
}
}
export interface FieldOptions<T extends DataType = DataType> {
datatype: T;
dims?: number;

View File

@@ -18,7 +18,7 @@ import { sanitizeType } from "../sanitize";
import { EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
export { EmbeddingFunction } from "./embedding_function";
export { EmbeddingFunction, TextEmbeddingFunction } from "./embedding_function";
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
export * from "./openai";