mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
feat(nodejs): add better typehints for registry (#1408)
previously the `registry` would return `undefined | EmbeddingFunction`
even for built in functions such as "openai"
now it'll return the correct type for `getRegistry().get("openai")
as well as pass in the correct options type to `create`
### before
```ts
const options: {model: 'not-a-real-model'}
// this'd compile just fine, but result in runtime error
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create(options)
// this'd also compile fine
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create({MODEL: ''})
```
### after
```ts
const options: {model: 'not-a-real-model'}
const openai: OpenAIEmbeddingFunction = getRegistry().get("openai").create(options)
// Type '"not-a-real-model"' is not assignable to type '"text-embedding-ada-002" | "text-embedding-3-large" | "text-embedding-3-small" | undefined'
```
This commit is contained in:
@@ -63,6 +63,7 @@ describe("Registry", () => {
|
||||
return data.map(() => [1, 2, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
const func = getRegistry()
|
||||
.get<MockEmbeddingFunction>("mock-embedding")!
|
||||
.create();
|
||||
|
||||
@@ -35,6 +35,11 @@ export interface FunctionOptions {
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
export interface EmbeddingFunctionConstructor<
|
||||
T extends EmbeddingFunction = EmbeddingFunction,
|
||||
> {
|
||||
new (modelOptions?: T["TOptions"]): T;
|
||||
}
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
@@ -43,6 +48,12 @@ export abstract class EmbeddingFunction<
|
||||
T = any,
|
||||
M extends FunctionOptions = FunctionOptions,
|
||||
> {
|
||||
/**
|
||||
* @ignore
|
||||
* This is only used for associating the options type with the class for type checking
|
||||
*/
|
||||
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
||||
readonly TOptions!: M;
|
||||
/**
|
||||
* Convert the embedding function to a JSON object
|
||||
* It is used to serialize the embedding function to the schema
|
||||
|
||||
@@ -13,24 +13,29 @@
|
||||
// limitations under the License.
|
||||
|
||||
import type OpenAI from "openai";
|
||||
import { type EmbeddingCreateParams } from "openai/resources";
|
||||
import { Float, Float32 } from "../arrow";
|
||||
import { EmbeddingFunction } from "./embedding_function";
|
||||
import { register } from "./registry";
|
||||
|
||||
export type OpenAIOptions = {
|
||||
apiKey?: string;
|
||||
model?: string;
|
||||
apiKey: string;
|
||||
model: EmbeddingCreateParams["model"];
|
||||
};
|
||||
|
||||
@register("openai")
|
||||
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
string,
|
||||
OpenAIOptions
|
||||
Partial<OpenAIOptions>
|
||||
> {
|
||||
#openai: OpenAI;
|
||||
#modelName: string;
|
||||
#modelName: OpenAIOptions["model"];
|
||||
|
||||
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
|
||||
constructor(
|
||||
options: Partial<OpenAIOptions> = {
|
||||
model: "text-embedding-ada-002",
|
||||
},
|
||||
) {
|
||||
super();
|
||||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||
if (!openAIKey) {
|
||||
@@ -73,7 +78,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
case "text-embedding-3-small":
|
||||
return 1536;
|
||||
default:
|
||||
return null as never;
|
||||
throw new Error(`Unknown model: ${this.#modelName}`);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,21 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import type { EmbeddingFunction } from "./embedding_function";
|
||||
import {
|
||||
type EmbeddingFunction,
|
||||
type EmbeddingFunctionConstructor,
|
||||
} from "./embedding_function";
|
||||
import "reflect-metadata";
|
||||
|
||||
export interface EmbeddingFunctionOptions {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface EmbeddingFunctionFactory<
|
||||
T extends EmbeddingFunction = EmbeddingFunction,
|
||||
> {
|
||||
new (modelOptions?: EmbeddingFunctionOptions): T;
|
||||
}
|
||||
import { OpenAIEmbeddingFunction } from "./openai";
|
||||
|
||||
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
create(options?: EmbeddingFunctionOptions): T;
|
||||
create(options?: T["TOptions"]): T;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,7 +30,7 @@ interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
* or TextEmbeddingFunction and registering it with the registry
|
||||
*/
|
||||
export class EmbeddingFunctionRegistry {
|
||||
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
|
||||
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
||||
|
||||
/**
|
||||
* Register an embedding function
|
||||
@@ -44,7 +38,9 @@ export class EmbeddingFunctionRegistry {
|
||||
* @param func The function to register
|
||||
* @throws Error if the function is already registered
|
||||
*/
|
||||
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
|
||||
register<
|
||||
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
|
||||
>(
|
||||
this: EmbeddingFunctionRegistry,
|
||||
alias?: string,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
@@ -69,18 +65,34 @@ export class EmbeddingFunctionRegistry {
|
||||
* Fetch an embedding function by name
|
||||
* @param name The name of the function
|
||||
*/
|
||||
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
|
||||
name: string,
|
||||
): EmbeddingFunctionCreate<T> | undefined {
|
||||
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
|
||||
name: Name extends "openai" ? "openai" : string,
|
||||
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
|
||||
// ex:
|
||||
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
|
||||
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
|
||||
//
|
||||
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
|
||||
// ```ts
|
||||
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
|
||||
// ```
|
||||
): Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined {
|
||||
type Output = Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined;
|
||||
|
||||
const factory = this.#functions.get(name);
|
||||
if (!factory) {
|
||||
return undefined;
|
||||
return undefined as Output;
|
||||
}
|
||||
|
||||
return {
|
||||
create: function (options: EmbeddingFunctionOptions) {
|
||||
return new factory(options) as unknown as T;
|
||||
create: function (options?: T["TOptions"]) {
|
||||
return new factory(options);
|
||||
},
|
||||
};
|
||||
} as Output;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,7 +116,7 @@ export class EmbeddingFunctionRegistry {
|
||||
name: string;
|
||||
sourceColumn: string;
|
||||
vectorColumn: string;
|
||||
model: EmbeddingFunctionOptions;
|
||||
model: EmbeddingFunction["TOptions"];
|
||||
};
|
||||
const functions = <FunctionConfig[]>(
|
||||
JSON.parse(metadata.get("embedding_functions")!)
|
||||
|
||||
Reference in New Issue
Block a user