diff --git a/nodejs/__test__/registry.test.ts b/nodejs/__test__/registry.test.ts index e87a38e6..1c4e398b 100644 --- a/nodejs/__test__/registry.test.ts +++ b/nodejs/__test__/registry.test.ts @@ -63,6 +63,7 @@ describe("Registry", () => { return data.map(() => [1, 2, 3]); } } + const func = getRegistry() .get("mock-embedding")! .create(); diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index e2e098a3..1f98b8c9 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -35,6 +35,11 @@ export interface FunctionOptions { [key: string]: any; } +export interface EmbeddingFunctionConstructor< + T extends EmbeddingFunction = EmbeddingFunction, +> { + new (modelOptions?: T["TOptions"]): T; +} /** * An embedding function that automatically creates vector representation for a given column. */ @@ -43,6 +48,12 @@ export abstract class EmbeddingFunction< T = any, M extends FunctionOptions = FunctionOptions, > { + /** + * @ignore + * This is only used for associating the options type with the class for type checking + */ + // biome-ignore lint/style/useNamingConvention: we want to keep the name as it is + readonly TOptions!: M; /** * Convert the embedding function to a JSON object * It is used to serialize the embedding function to the schema diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index e055b175..f5144d00 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -13,24 +13,29 @@ // limitations under the License. import type OpenAI from "openai"; +import { type EmbeddingCreateParams } from "openai/resources"; import { Float, Float32 } from "../arrow"; import { EmbeddingFunction } from "./embedding_function"; import { register } from "./registry"; export type OpenAIOptions = { - apiKey?: string; - model?: string; + apiKey: string; + model: EmbeddingCreateParams["model"]; }; @register("openai") export class OpenAIEmbeddingFunction extends EmbeddingFunction< string, - OpenAIOptions + Partial > { #openai: OpenAI; - #modelName: string; + #modelName: OpenAIOptions["model"]; - constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) { + constructor( + options: Partial = { + model: "text-embedding-ada-002", + }, + ) { super(); const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; if (!openAIKey) { @@ -73,7 +78,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction< case "text-embedding-3-small": return 1536; default: - return null as never; + throw new Error(`Unknown model: ${this.#modelName}`); } } diff --git a/nodejs/lancedb/embedding/registry.ts b/nodejs/lancedb/embedding/registry.ts index 47e52917..7d77df50 100644 --- a/nodejs/lancedb/embedding/registry.ts +++ b/nodejs/lancedb/embedding/registry.ts @@ -12,21 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -import type { EmbeddingFunction } from "./embedding_function"; +import { + type EmbeddingFunction, + type EmbeddingFunctionConstructor, +} from "./embedding_function"; import "reflect-metadata"; - -export interface EmbeddingFunctionOptions { - [key: string]: unknown; -} - -export interface EmbeddingFunctionFactory< - T extends EmbeddingFunction = EmbeddingFunction, -> { - new (modelOptions?: EmbeddingFunctionOptions): T; -} +import { OpenAIEmbeddingFunction } from "./openai"; interface EmbeddingFunctionCreate { - create(options?: EmbeddingFunctionOptions): T; + create(options?: T["TOptions"]): T; } /** @@ -36,7 +30,7 @@ interface EmbeddingFunctionCreate { * or TextEmbeddingFunction and registering it with the registry */ export class EmbeddingFunctionRegistry { - #functions: Map = new Map(); + #functions = new Map(); /** * Register an embedding function @@ -44,7 +38,9 @@ export class EmbeddingFunctionRegistry { * @param func The function to register * @throws Error if the function is already registered */ - register( + register< + T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor, + >( this: EmbeddingFunctionRegistry, alias?: string, // biome-ignore lint/suspicious/noExplicitAny: @@ -69,18 +65,34 @@ export class EmbeddingFunctionRegistry { * Fetch an embedding function by name * @param name The name of the function */ - get = EmbeddingFunction>( - name: string, - ): EmbeddingFunctionCreate | undefined { + get, Name extends string = "">( + name: Name extends "openai" ? "openai" : string, + //This makes it so that you can use string constants as "types", or use an explicitly supplied type + // ex: + // `registry.get("openai") -> EmbeddingFunctionCreate` + // `registry.get("my_func") -> EmbeddingFunctionCreate | undefined` + // + // the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined + // ```ts + // const openai: OpenAIEmbeddingFunction = registry.get("openai").create() + // ``` + ): Name extends "openai" + ? EmbeddingFunctionCreate + : EmbeddingFunctionCreate | undefined { + type Output = Name extends "openai" + ? EmbeddingFunctionCreate + : EmbeddingFunctionCreate | undefined; + const factory = this.#functions.get(name); if (!factory) { - return undefined; + return undefined as Output; } + return { - create: function (options: EmbeddingFunctionOptions) { - return new factory(options) as unknown as T; + create: function (options?: T["TOptions"]) { + return new factory(options); }, - }; + } as Output; } /** @@ -104,7 +116,7 @@ export class EmbeddingFunctionRegistry { name: string; sourceColumn: string; vectorColumn: string; - model: EmbeddingFunctionOptions; + model: EmbeddingFunction["TOptions"]; }; const functions = ( JSON.parse(metadata.get("embedding_functions")!)