mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 22:02:58 +00:00
docs: custom embedding function for ts (#1479)
This commit is contained in:
64
nodejs/examples/custom_embedding_function.ts
Normal file
64
nodejs/examples/custom_embedding_function.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
// --8<-- [start:imports]
|
||||
import * as lancedb from "@lancedb/lancedb";
|
||||
import {
|
||||
LanceSchema,
|
||||
TextEmbeddingFunction,
|
||||
getRegistry,
|
||||
register,
|
||||
} from "@lancedb/lancedb/embedding";
|
||||
import { pipeline } from "@xenova/transformers";
|
||||
// --8<-- [end:imports]
|
||||
|
||||
// --8<-- [start:embedding_impl]
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
|
||||
name = "Xenova/all-miniLM-L6-v2";
|
||||
#ndims!: number;
|
||||
extractor: any;
|
||||
|
||||
async init() {
|
||||
this.extractor = await pipeline("feature-extraction", this.name);
|
||||
this.#ndims = await this.generateEmbeddings(["hello"]).then(
|
||||
(e) => e[0].length,
|
||||
);
|
||||
}
|
||||
|
||||
ndims() {
|
||||
return this.#ndims;
|
||||
}
|
||||
|
||||
toJSON() {
|
||||
return {
|
||||
name: this.name,
|
||||
};
|
||||
}
|
||||
async generateEmbeddings(texts: string[]) {
|
||||
const output = await this.extractor(texts, {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
});
|
||||
return output.tolist();
|
||||
}
|
||||
}
|
||||
// -8<-- [end:embedding_impl]
|
||||
|
||||
// --8<-- [start:call_custom_function]
|
||||
const registry = getRegistry();
|
||||
|
||||
const sentenceTransformer = await registry
|
||||
.get<SentenceTransformersEmbeddings>("sentence-transformers")!
|
||||
.create();
|
||||
|
||||
const schema = LanceSchema({
|
||||
vector: sentenceTransformer.vectorField(),
|
||||
text: sentenceTransformer.sourceField(),
|
||||
});
|
||||
|
||||
const db = await lancedb.connect("/tmp/db");
|
||||
const table = await db.createEmptyTable("table", schema, { mode: "overwrite" });
|
||||
|
||||
await table.add([{ text: "hello" }, { text: "world" }]);
|
||||
|
||||
const results = await table.search("greeting").limit(1).toArray();
|
||||
console.log(results[0].text);
|
||||
// -8<-- [end:call_custom_function]
|
||||
Reference in New Issue
Block a user