Files
lancedb/nodejs/lancedb/embedding/registry.ts
Cory Grinstead 5c3a88b6b2 feat(nodejs): add better typehints for registry (#1408)
previously the `registry` would return `undefined | EmbeddingFunction`
even for built in functions such as "openai"

now it'll return the correct type for `getRegistry().get("openai")

as well as pass in the correct options type to `create`

### before
```ts
const options: {model: 'not-a-real-model'}
// this'd compile just fine, but result in runtime error
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create(options)
// this'd also compile fine
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create({MODEL: ''})
```
### after
```ts
const options: {model: 'not-a-real-model'}

const openai: OpenAIEmbeddingFunction = getRegistry().get("openai").create(options)
// Type '"not-a-real-model"' is not assignable to type '"text-embedding-ada-002" | "text-embedding-3-large" | "text-embedding-3-small" | undefined'


```
2024-07-01 12:49:42 -05:00

189 lines
5.9 KiB
TypeScript

// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
type EmbeddingFunction,
type EmbeddingFunctionConstructor,
} from "./embedding_function";
import "reflect-metadata";
import { OpenAIEmbeddingFunction } from "./openai";
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: T["TOptions"]): T;
}
/**
* This is a singleton class used to register embedding functions
* and fetch them by name. It also handles serializing and deserializing.
* You can implement your own embedding function by subclassing EmbeddingFunction
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions = new Map<string, EmbeddingFunctionConstructor>();
/**
* Register an embedding function
* @param name The name of the function
* @param func The function to register
* @throws Error if the function is already registered
*/
register<
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
): (ctor: T) => any {
const self = this;
return function (ctor: T) {
if (!alias) {
alias = ctor.name;
}
if (self.#functions.has(alias)) {
throw new Error(
`Embedding function with alias "${alias}" already exists`,
);
}
self.#functions.set(alias, ctor);
Reflect.defineMetadata("lancedb::embedding::name", alias, ctor);
return ctor;
};
}
/**
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
name: Name extends "openai" ? "openai" : string,
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
// ex:
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
//
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
// ```ts
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
// ```
): Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined {
type Output = Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined;
const factory = this.#functions.get(name);
if (!factory) {
return undefined as Output;
}
return {
create: function (options?: T["TOptions"]) {
return new factory(options);
},
} as Output;
}
/**
* reset the registry to the initial state
*/
reset(this: EmbeddingFunctionRegistry) {
this.#functions.clear();
}
/**
* @ignore
*/
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,
): Map<string, EmbeddingFunctionConfig> {
if (!metadata.has("embedding_functions")) {
return new Map();
} else {
type FunctionConfig = {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunction["TOptions"];
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)
);
return new Map(
functions.map((f) => {
const fn = this.get(f.name);
if (!fn) {
throw new Error(`Function "${f.name}" not found in registry`);
}
return [
f.name,
{
sourceColumn: f.sourceColumn,
vectorColumn: f.vectorColumn,
function: this.get(f.name)!.create(f.model),
},
];
}),
);
}
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
functionToMetadata(conf: EmbeddingFunctionConfig): Record<string, any> {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const metadata: Record<string, any> = {};
const name = Reflect.getMetadata(
"lancedb::embedding::name",
conf.function.constructor,
);
metadata["sourceColumn"] = conf.sourceColumn;
metadata["vectorColumn"] = conf.vectorColumn ?? "vector";
metadata["name"] = name ?? conf.function.constructor.name;
metadata["model"] = conf.function.toJSON();
return metadata;
}
getTableMetadata(functions: EmbeddingFunctionConfig[]): Map<string, string> {
const metadata = new Map<string, string>();
const jsonData = functions.map((conf) => this.functionToMetadata(conf));
metadata.set("embedding_functions", JSON.stringify(jsonData));
return metadata;
}
}
const _REGISTRY = new EmbeddingFunctionRegistry();
export function register(name?: string) {
return _REGISTRY.register(name);
}
/**
* Utility function to get the global instance of the registry
* @returns `EmbeddingFunctionRegistry` The global instance of the registry
* @example
* ```ts
* const registry = getRegistry();
* const openai = registry.get("openai").create();
*/
export function getRegistry(): EmbeddingFunctionRegistry {
return _REGISTRY;
}
export interface EmbeddingFunctionConfig {
sourceColumn: string;
vectorColumn?: string;
function: EmbeddingFunction;
}