mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-15 08:12:58 +00:00
docs: custom embedding function for ts (#1479)
This commit is contained in:
64
nodejs/examples/custom_embedding_function.ts
Normal file
64
nodejs/examples/custom_embedding_function.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
// --8<-- [start:imports]
|
||||
import * as lancedb from "@lancedb/lancedb";
|
||||
import {
|
||||
LanceSchema,
|
||||
TextEmbeddingFunction,
|
||||
getRegistry,
|
||||
register,
|
||||
} from "@lancedb/lancedb/embedding";
|
||||
import { pipeline } from "@xenova/transformers";
|
||||
// --8<-- [end:imports]
|
||||
|
||||
// --8<-- [start:embedding_impl]
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
|
||||
name = "Xenova/all-miniLM-L6-v2";
|
||||
#ndims!: number;
|
||||
extractor: any;
|
||||
|
||||
async init() {
|
||||
this.extractor = await pipeline("feature-extraction", this.name);
|
||||
this.#ndims = await this.generateEmbeddings(["hello"]).then(
|
||||
(e) => e[0].length,
|
||||
);
|
||||
}
|
||||
|
||||
ndims() {
|
||||
return this.#ndims;
|
||||
}
|
||||
|
||||
toJSON() {
|
||||
return {
|
||||
name: this.name,
|
||||
};
|
||||
}
|
||||
async generateEmbeddings(texts: string[]) {
|
||||
const output = await this.extractor(texts, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
return output.tolist();
|
||||
}
|
||||
}
|
||||
// -8<-- [end:embedding_impl]
|
||||
|
||||
// --8<-- [start:call_custom_function]
|
||||
const registry = getRegistry();
|
||||
|
||||
const sentenceTransformer = await registry
|
||||
.get<SentenceTransformersEmbeddings>("sentence-transformers")!
|
||||
.create();
|
||||
|
||||
const schema = LanceSchema({
|
||||
vector: sentenceTransformer.vectorField(),
|
||||
text: sentenceTransformer.sourceField(),
|
||||
});
|
||||
|
||||
const db = await lancedb.connect("/tmp/db");
|
||||
const table = await db.createEmptyTable("table", schema, { mode: "overwrite" });
|
||||
|
||||
await table.add([{ text: "hello" }, { text: "world" }]);
|
||||
|
||||
const results = await table.search("greeting").limit(1).toArray();
|
||||
console.log(results[0].text);
|
||||
// -8<-- [end:call_custom_function]
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
Float32,
|
||||
FloatLike,
|
||||
type IntoVector,
|
||||
Utf8,
|
||||
isDataType,
|
||||
isFixedSizeList,
|
||||
isFloat,
|
||||
@@ -192,6 +193,38 @@ export abstract class EmbeddingFunction<
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* an abstract class for implementing embedding functions that take text as input
|
||||
*/
|
||||
export abstract class TextEmbeddingFunction<
|
||||
M extends FunctionOptions = FunctionOptions,
|
||||
> extends EmbeddingFunction<string, M> {
|
||||
//** Generate the embeddings for the given texts */
|
||||
abstract generateEmbeddings(
|
||||
texts: string[],
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
||||
...args: any[]
|
||||
): Promise<number[][] | Float32Array[] | Float64Array[]>;
|
||||
|
||||
async computeQueryEmbeddings(data: string): Promise<Awaited<IntoVector>> {
|
||||
return this.generateEmbeddings([data]).then((data) => data[0]);
|
||||
}
|
||||
|
||||
embeddingDataType(): FloatLike {
|
||||
return new Float32();
|
||||
}
|
||||
|
||||
override sourceField(): [DataTypeLike, Map<string, EmbeddingFunction>] {
|
||||
return super.sourceField(new Utf8());
|
||||
}
|
||||
|
||||
computeSourceEmbeddings(
|
||||
data: string[],
|
||||
): Promise<number[][] | Float32Array[] | Float64Array[]> {
|
||||
return this.generateEmbeddings(data);
|
||||
}
|
||||
}
|
||||
|
||||
export interface FieldOptions<T extends DataType = DataType> {
|
||||
datatype: T;
|
||||
dims?: number;
|
||||
|
||||
@@ -18,7 +18,7 @@ import { sanitizeType } from "../sanitize";
|
||||
import { EmbeddingFunction } from "./embedding_function";
|
||||
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
|
||||
|
||||
export { EmbeddingFunction } from "./embedding_function";
|
||||
export { EmbeddingFunction, TextEmbeddingFunction } from "./embedding_function";
|
||||
|
||||
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
|
||||
export * from "./openai";
|
||||
|
||||
Reference in New Issue
Block a user