feat(nodejs): add better typehints for registry (#1408)

previously the `registry` would return `undefined | EmbeddingFunction`
even for built in functions such as "openai"

now it'll return the correct type for `getRegistry().get("openai")

as well as pass in the correct options type to `create`

### before
```ts
const options: {model: 'not-a-real-model'}
// this'd compile just fine, but result in runtime error
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create(options)
// this'd also compile fine
const openai: EmbeddingFunction | undefined = getRegistry().get("openai").create({MODEL: ''})
```
### after
```ts
const options: {model: 'not-a-real-model'}

const openai: OpenAIEmbeddingFunction = getRegistry().get("openai").create(options)
// Type '"not-a-real-model"' is not assignable to type '"text-embedding-ada-002" | "text-embedding-3-large" | "text-embedding-3-small" | undefined'


```
This commit is contained in:
Cory Grinstead
2024-07-01 12:49:42 -05:00
committed by GitHub
parent e780b2f51c
commit 5c3a88b6b2
4 changed files with 57 additions and 28 deletions

View File

@@ -63,6 +63,7 @@ describe("Registry", () => {
return data.map(() => [1, 2, 3]);
}
}
const func = getRegistry()
.get<MockEmbeddingFunction>("mock-embedding")!
.create();

View File

@@ -35,6 +35,11 @@ export interface FunctionOptions {
[key: string]: any;
}
export interface EmbeddingFunctionConstructor<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: T["TOptions"]): T;
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
@@ -43,6 +48,12 @@ export abstract class EmbeddingFunction<
T = any,
M extends FunctionOptions = FunctionOptions,
> {
/**
* @ignore
* This is only used for associating the options type with the class for type checking
*/
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
readonly TOptions!: M;
/**
* Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema

View File

@@ -13,24 +13,29 @@
// limitations under the License.
import type OpenAI from "openai";
import { type EmbeddingCreateParams } from "openai/resources";
import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
export type OpenAIOptions = {
apiKey?: string;
model?: string;
apiKey: string;
model: EmbeddingCreateParams["model"];
};
@register("openai")
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
string,
OpenAIOptions
Partial<OpenAIOptions>
> {
#openai: OpenAI;
#modelName: string;
#modelName: OpenAIOptions["model"];
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
constructor(
options: Partial<OpenAIOptions> = {
model: "text-embedding-ada-002",
},
) {
super();
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) {
@@ -73,7 +78,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
case "text-embedding-3-small":
return 1536;
default:
return null as never;
throw new Error(`Unknown model: ${this.#modelName}`);
}
}

View File

@@ -12,21 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import type { EmbeddingFunction } from "./embedding_function";
import {
type EmbeddingFunction,
type EmbeddingFunctionConstructor,
} from "./embedding_function";
import "reflect-metadata";
export interface EmbeddingFunctionOptions {
[key: string]: unknown;
}
export interface EmbeddingFunctionFactory<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: EmbeddingFunctionOptions): T;
}
import { OpenAIEmbeddingFunction } from "./openai";
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: EmbeddingFunctionOptions): T;
create(options?: T["TOptions"]): T;
}
/**
@@ -36,7 +30,7 @@ interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
#functions = new Map<string, EmbeddingFunctionConstructor>();
/**
* Register an embedding function
@@ -44,7 +38,9 @@ export class EmbeddingFunctionRegistry {
* @param func The function to register
* @throws Error if the function is already registered
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
register<
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -69,18 +65,34 @@ export class EmbeddingFunctionRegistry {
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
name: string,
): EmbeddingFunctionCreate<T> | undefined {
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
name: Name extends "openai" ? "openai" : string,
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
// ex:
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
//
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
// ```ts
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
// ```
): Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined {
type Output = Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined;
const factory = this.#functions.get(name);
if (!factory) {
return undefined;
return undefined as Output;
}
return {
create: function (options: EmbeddingFunctionOptions) {
return new factory(options) as unknown as T;
create: function (options?: T["TOptions"]) {
return new factory(options);
},
};
} as Output;
}
/**
@@ -104,7 +116,7 @@ export class EmbeddingFunctionRegistry {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunctionOptions;
model: EmbeddingFunction["TOptions"];
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)