diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 99bcce89..e9467b87 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -182,6 +182,7 @@ nav: - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - User-defined embedding functions: embeddings/custom_embedding_function.md + - Variables and secrets: embeddings/variables_and_secrets.md - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - 🔌 Integrations: @@ -315,6 +316,7 @@ nav: - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - User-defined embedding functions: embeddings/custom_embedding_function.md + - Variables and secrets: embeddings/variables_and_secrets.md - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - Integrations: diff --git a/docs/src/embeddings/custom_embedding_function.md b/docs/src/embeddings/custom_embedding_function.md index 619c5437..655c6904 100644 --- a/docs/src/embeddings/custom_embedding_function.md +++ b/docs/src/embeddings/custom_embedding_function.md @@ -55,6 +55,14 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings. +!!! danger "Use sensitive keys to prevent leaking secrets" + To prevent leaking secrets, such as API keys, you should add any sensitive + parameters of an embedding function to the output of the + [sensitive_keys()][lancedb.embeddings.base.EmbeddingFunction.sensitive_keys] / + [getSensitiveKeys()](../../js/namespaces/embedding/classes/EmbeddingFunction/#getsensitivekeys) + method. This prevents users from accidentally instantiating the embedding + function with hard-coded secrets. + Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs. === "Python" diff --git a/docs/src/embeddings/variables_and_secrets.md b/docs/src/embeddings/variables_and_secrets.md new file mode 100644 index 00000000..72388b24 --- /dev/null +++ b/docs/src/embeddings/variables_and_secrets.md @@ -0,0 +1,53 @@ +# Variable and Secrets + +Most embedding configuration options are saved in the table's metadata. However, +this isn't always appropriate. For example, API keys should never be stored in the +metadata. Additionally, other configuration options might be best set at runtime, +such as the `device` configuration that controls whether to use GPU or CPU for +inference. If you hardcoded this to GPU, you wouldn't be able to run the code on +a server without one. + +To handle these cases, you can set variables on the embedding registry and +reference them in the embedding configuration. These variables will be available +during the runtime of your program, but not saved in the table's metadata. When +the table is loaded from a different process, the variables must be set again. + +To set a variable, use the `set_var()` / `setVar()` method on the embedding registry. +To reference a variable, use the syntax `$env:VARIABLE_NAME`. If there is a default +value, you can use the syntax `$env:VARIABLE_NAME:DEFAULT_VALUE`. + +## Using variables to set secrets + +Sensitive configuration, such as API keys, must either be set as environment +variables or using variables on the embedding registry. If you pass in a hardcoded +value, LanceDB will raise an error. Instead, if you want to set an API key via +configuration, use a variable: + +=== "Python" + + ```python + --8<-- "python/python/tests/docs/test_embeddings_optional.py:register_secret" + ``` + +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/embedding.test.ts:register_secret" + ``` + +## Using variables to set the device parameter + +Many embedding functions that run locally have a `device` parameter that controls +whether to use GPU or CPU for inference. Because not all computers have a GPU, +it's helpful to be able to set the `device` parameter at runtime, rather than +have it hard coded in the embedding configuration. To make it work even if the +variable isn't set, you could provide a default value of `cpu` in the embedding +configuration. + +Some embedding libraries even have a method to detect which devices are available, +which could be used to dynamically set the device at runtime. For example, in Python +you can check if a CUDA GPU is available using `torch.cuda.is_available()`. + +```python +--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_device" +``` diff --git a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md index 24c54915..66d6ee16 100644 --- a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md @@ -8,6 +8,23 @@ An embedding function that automatically creates vector representation for a given column. +It's important subclasses pass the **original** options to the super constructor +and then pass those options to `resolveVariables` to resolve any variables before +using them. + +## Example + +```ts +class MyEmbeddingFunction extends EmbeddingFunction { + constructor(options: {model: string, timeout: number}) { + super(optionsRaw); + const options = this.resolveVariables(optionsRaw); + this.model = options.model; + this.timeout = options.timeout; + } +} +``` + ## Extended by - [`TextEmbeddingFunction`](TextEmbeddingFunction.md) @@ -82,12 +99,33 @@ The datatype of the embeddings *** +### getSensitiveKeys() + +```ts +protected getSensitiveKeys(): string[] +``` + +Provide a list of keys in the function options that should be treated as +sensitive. If users pass raw values for these keys, they will be rejected. + +#### Returns + +`string`[] + +*** + ### init()? ```ts optional init(): Promise ``` +Optionally load any resources needed for the embedding function. + +This method is called after the embedding function has been initialized +but before any embeddings are computed. It is useful for loading local models +or other resources that are needed for the embedding function to work. + #### Returns `Promise`<`void`> @@ -108,6 +146,24 @@ The number of dimensions of the embeddings *** +### resolveVariables() + +```ts +protected resolveVariables(config): Partial +``` + +Apply variables to the config. + +#### Parameters + +* **config**: `Partial`<`M`> + +#### Returns + +`Partial`<`M`> + +*** + ### sourceField() ```ts @@ -134,37 +190,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d ### toJSON() ```ts -abstract toJSON(): Partial +toJSON(): Record ``` -Convert the embedding function to a JSON object -It is used to serialize the embedding function to the schema -It's important that any object returned by this method contains all the necessary -information to recreate the embedding function - -It should return the same object that was passed to the constructor -If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly +Get the original arguments to the constructor, to serialize them so they +can be used to recreate the embedding function later. #### Returns -`Partial`<`M`> - -#### Example - -```ts -class MyEmbeddingFunction extends EmbeddingFunction { - constructor(options: {model: string, timeout: number}) { - super(); - this.model = options.model; - this.timeout = options.timeout; - } - toJSON() { - return { - model: this.model, - timeout: this.timeout, - }; -} -``` +`Record`<`string`, `any`> *** diff --git a/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md b/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md index e01078ab..2186d696 100644 --- a/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md +++ b/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md @@ -80,6 +80,28 @@ getTableMetadata(functions): Map *** +### getVar() + +```ts +getVar(name): undefined | string +``` + +Get a variable. + +#### Parameters + +* **name**: `string` + +#### Returns + +`undefined` \| `string` + +#### See + +[setVar](EmbeddingFunctionRegistry.md#setvar) + +*** + ### length() ```ts @@ -145,3 +167,31 @@ reset the registry to the initial state #### Returns `void` + +*** + +### setVar() + +```ts +setVar(name, value): void +``` + +Set a variable. These can be accessed in the embedding function +configuration using the syntax `$var:variable_name`. If they are not +set, an error will be thrown letting you know which key is unset. If you +want to supply a default value, you can add an additional part in the +configuration like so: `$var:variable_name:default_value`. Default values +can be used for runtime configurations that are not sensitive, such as +whether to use a GPU for inference. + +The name must not contain colons. The default value can contain colons. + +#### Parameters + +* **name**: `string` + +* **value**: `string` + +#### Returns + +`void` diff --git a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md index 2cc13bef..8aee4f44 100644 --- a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md @@ -114,12 +114,37 @@ abstract generateEmbeddings(texts, ...args): Promise ``` +Optionally load any resources needed for the embedding function. + +This method is called after the embedding function has been initialized +but before any embeddings are computed. It is useful for loading local models +or other resources that are needed for the embedding function to work. + #### Returns `Promise`<`void`> @@ -148,6 +173,28 @@ The number of dimensions of the embeddings *** +### resolveVariables() + +```ts +protected resolveVariables(config): Partial +``` + +Apply variables to the config. + +#### Parameters + +* **config**: `Partial`<`M`> + +#### Returns + +`Partial`<`M`> + +#### Inherited from + +[`EmbeddingFunction`](EmbeddingFunction.md).[`resolveVariables`](EmbeddingFunction.md#resolvevariables) + +*** + ### sourceField() ```ts @@ -173,37 +220,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d ### toJSON() ```ts -abstract toJSON(): Partial +toJSON(): Record ``` -Convert the embedding function to a JSON object -It is used to serialize the embedding function to the schema -It's important that any object returned by this method contains all the necessary -information to recreate the embedding function - -It should return the same object that was passed to the constructor -If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly +Get the original arguments to the constructor, to serialize them so they +can be used to recreate the embedding function later. #### Returns -`Partial`<`M`> - -#### Example - -```ts -class MyEmbeddingFunction extends EmbeddingFunction { - constructor(options: {model: string, timeout: number}) { - super(); - this.model = options.model; - this.timeout = options.timeout; - } - toJSON() { - return { - model: this.model, - timeout: this.timeout, - }; -} -``` +`Record`<`string`, `any`> #### Inherited from diff --git a/nodejs/__test__/embedding.test.ts b/nodejs/__test__/embedding.test.ts index 03802dab..30d8b1fc 100644 --- a/nodejs/__test__/embedding.test.ts +++ b/nodejs/__test__/embedding.test.ts @@ -17,6 +17,8 @@ import { import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; +const testOpenAIInteg = process.env.OPENAI_API_KEY == null ? test.skip : test; + describe("embedding functions", () => { let tmpDir: tmp.DirResult; beforeEach(() => { @@ -29,9 +31,6 @@ describe("embedding functions", () => { it("should be able to create a table with an embedding function", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -75,9 +74,6 @@ describe("embedding functions", () => { it("should be able to append and upsert using embedding function", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -143,9 +139,6 @@ describe("embedding functions", () => { it("should be able to create an empty table with an embedding function", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -194,9 +187,6 @@ describe("embedding functions", () => { it("should error when appending to a table with an unregistered embedding function", async () => { @register("mock") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -241,13 +231,35 @@ describe("embedding functions", () => { `Function "mock" not found in registry`, ); }); + + testOpenAIInteg("propagates variables through all methods", async () => { + delete process.env.OPENAI_API_KEY; + const registry = getRegistry(); + registry.setVar("openai_api_key", "sk-..."); + const func = registry.get("openai")?.create({ + model: "text-embedding-ada-002", + apiKey: "$var:openai_api_key", + }) as EmbeddingFunction; + + const db = await connect("memory://"); + const wordsSchema = LanceSchema({ + text: func.sourceField(new Utf8()), + vector: func.vectorField(), + }); + const tbl = await db.createEmptyTable("words", wordsSchema, { + mode: "overwrite", + }); + await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]); + + const query = "greetings"; + const actual = (await tbl.search(query).limit(1).toArray())[0]; + expect(actual).toHaveProperty("text"); + }); + test.each([new Float16(), new Float32(), new Float64()])( "should be able to provide manual embeddings with multiple float datatype", async (floatType) => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -292,10 +304,6 @@ describe("embedding functions", () => { async (floatType) => { @register("test1") class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction { - toJSON(): object { - return {}; - } - embeddingDataType(): Float { return floatType; } @@ -310,9 +318,6 @@ describe("embedding functions", () => { } @register("test") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } diff --git a/nodejs/__test__/registry.test.ts b/nodejs/__test__/registry.test.ts index f00121da..a5cf73e7 100644 --- a/nodejs/__test__/registry.test.ts +++ b/nodejs/__test__/registry.test.ts @@ -11,7 +11,11 @@ import * as arrow18 from "apache-arrow-18"; import * as tmp from "tmp"; import { connect } from "../lancedb"; -import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; +import { + EmbeddingFunction, + FunctionOptions, + LanceSchema, +} from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => { @@ -39,11 +43,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { it("should register a new item to the registry", async () => { @register("mock-embedding") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } constructor() { super(); } @@ -89,11 +88,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { }); test("should error if registering with the same name", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } constructor() { super(); } @@ -114,13 +108,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { }); test("schema should contain correct metadata", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } - constructor() { + constructor(args: FunctionOptions = {}) { super(); + this.resolveVariables(args); } ndims() { return 3; @@ -132,7 +122,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { return data.map(() => [1, 2, 3]); } } - const func = new MockEmbeddingFunction(); + const func = new MockEmbeddingFunction({ someText: "hello" }); const schema = LanceSchema({ id: new arrow.Int32(), @@ -155,3 +145,79 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { expect(schema.metadata).toEqual(expectedMetadata); }); }); + +describe("Registry.setVar", () => { + const registry = getRegistry(); + + beforeEach(() => { + @register("mock-embedding") + // biome-ignore lint/correctness/noUnusedVariables : + class MockEmbeddingFunction extends EmbeddingFunction { + constructor(optionsRaw: FunctionOptions = {}) { + super(); + const options = this.resolveVariables(optionsRaw); + + expect(optionsRaw["someKey"].startsWith("$var:someName")).toBe(true); + expect(options["someKey"]).toBe("someValue"); + + if (options["secretKey"]) { + expect(optionsRaw["secretKey"]).toBe("$var:secretKey"); + expect(options["secretKey"]).toBe("mySecret"); + } + } + async computeSourceEmbeddings(data: string[]) { + return data.map(() => [1, 2, 3]); + } + embeddingDataType() { + return new arrow18.Float32() as apiArrow.Float; + } + protected getSensitiveKeys() { + return ["secretKey"]; + } + } + }); + afterEach(() => { + registry.reset(); + }); + + it("Should error if the variable is not set", () => { + console.log(registry.get("mock-embedding")); + expect(() => + registry.get("mock-embedding")!.create({ someKey: "$var:someName" }), + ).toThrow('Variable "someName" not found'); + }); + + it("should use default values if not set", () => { + registry + .get("mock-embedding")! + .create({ someKey: "$var:someName:someValue" }); + }); + + it("should set a variable that the embedding function understand", () => { + registry.setVar("someName", "someValue"); + registry.get("mock-embedding")!.create({ someKey: "$var:someName" }); + }); + + it("should reject secrets that aren't passed as variables", () => { + registry.setVar("someName", "someValue"); + expect(() => + registry + .get("mock-embedding")! + .create({ secretKey: "someValue", someKey: "$var:someName" }), + ).toThrow( + 'The key "secretKey" is sensitive and cannot be set directly. Please use the $var: syntax to set it.', + ); + }); + + it("should not serialize secrets", () => { + registry.setVar("someName", "someValue"); + registry.setVar("secretKey", "mySecret"); + const func = registry + .get("mock-embedding")! + .create({ secretKey: "$var:secretKey", someKey: "$var:someName" }); + expect(func.toJSON()).toEqual({ + secretKey: "$var:secretKey", + someKey: "$var:someName", + }); + }); +}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 50740309..8d495d64 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -1038,9 +1038,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( test("can search using a string", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 1; } diff --git a/nodejs/examples/embedding.test.ts b/nodejs/examples/embedding.test.ts index 2dc79c69..35af9e59 100644 --- a/nodejs/examples/embedding.test.ts +++ b/nodejs/examples/embedding.test.ts @@ -43,12 +43,17 @@ test("custom embedding function", async () => { @register("my_embedding") class MyEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; + constructor(optionsRaw = {}) { + super(); + const options = this.resolveVariables(optionsRaw); + // Initialize using options } ndims() { return 3; } + protected getSensitiveKeys(): string[] { + return []; + } embeddingDataType(): Float { return new Float32(); } @@ -94,3 +99,14 @@ test("custom embedding function", async () => { expect(await table2.countRows()).toBe(2); }); }); + +test("embedding function api_key", async () => { + // --8<-- [start:register_secret] + const registry = getRegistry(); + registry.setVar("api_key", "sk-..."); + + const func = registry.get("openai")!.create({ + apiKey: "$var:api_key", + }); + // --8<-- [end:register_secret] +}); diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 4d00eb29..45aa83f5 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -15,6 +15,7 @@ import { newVectorType, } from "../arrow"; import { sanitizeType } from "../sanitize"; +import { getRegistry } from "./registry"; /** * Options for a given embedding function @@ -32,6 +33,22 @@ export interface EmbeddingFunctionConstructor< /** * An embedding function that automatically creates vector representation for a given column. + * + * It's important subclasses pass the **original** options to the super constructor + * and then pass those options to `resolveVariables` to resolve any variables before + * using them. + * + * @example + * ```ts + * class MyEmbeddingFunction extends EmbeddingFunction { + * constructor(options: {model: string, timeout: number}) { + * super(optionsRaw); + * const options = this.resolveVariables(optionsRaw); + * this.model = options.model; + * this.timeout = options.timeout; + * } + * } + * ``` */ export abstract class EmbeddingFunction< // biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do @@ -44,33 +61,74 @@ export abstract class EmbeddingFunction< */ // biome-ignore lint/style/useNamingConvention: we want to keep the name as it is readonly TOptions!: M; - /** - * Convert the embedding function to a JSON object - * It is used to serialize the embedding function to the schema - * It's important that any object returned by this method contains all the necessary - * information to recreate the embedding function - * - * It should return the same object that was passed to the constructor - * If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly - * - * @example - * ```ts - * class MyEmbeddingFunction extends EmbeddingFunction { - * constructor(options: {model: string, timeout: number}) { - * super(); - * this.model = options.model; - * this.timeout = options.timeout; - * } - * toJSON() { - * return { - * model: this.model, - * timeout: this.timeout, - * }; - * } - * ``` - */ - abstract toJSON(): Partial; + #config: Partial; + + /** + * Get the original arguments to the constructor, to serialize them so they + * can be used to recreate the embedding function later. + */ + // biome-ignore lint/suspicious/noExplicitAny : + toJSON(): Record { + return JSON.parse(JSON.stringify(this.#config)); + } + + constructor() { + this.#config = {}; + } + + /** + * Provide a list of keys in the function options that should be treated as + * sensitive. If users pass raw values for these keys, they will be rejected. + */ + protected getSensitiveKeys(): string[] { + return []; + } + + /** + * Apply variables to the config. + */ + protected resolveVariables(config: Partial): Partial { + this.#config = config; + const registry = getRegistry(); + const newConfig = { ...config }; + for (const [key_, value] of Object.entries(newConfig)) { + if ( + this.getSensitiveKeys().includes(key_) && + !value.startsWith("$var:") + ) { + throw new Error( + `The key "${key_}" is sensitive and cannot be set directly. Please use the $var: syntax to set it.`, + ); + } + // Makes TS happy (https://stackoverflow.com/a/78391854) + const key = key_ as keyof M; + if (typeof value === "string" && value.startsWith("$var:")) { + const [name, defaultValue] = value.slice(5).split(":", 2); + const variableValue = registry.getVar(name); + if (!variableValue) { + if (defaultValue) { + // biome-ignore lint/suspicious/noExplicitAny: + newConfig[key] = defaultValue as any; + } else { + throw new Error(`Variable "${name}" not found`); + } + } else { + // biome-ignore lint/suspicious/noExplicitAny: + newConfig[key] = variableValue as any; + } + } + } + return newConfig; + } + + /** + * Optionally load any resources needed for the embedding function. + * + * This method is called after the embedding function has been initialized + * but before any embeddings are computed. It is useful for loading local models + * or other resources that are needed for the embedding function to work. + */ async init?(): Promise; /** diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index 1fe86668..5771cfeb 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -21,11 +21,13 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction< #modelName: OpenAIOptions["model"]; constructor( - options: Partial = { + optionsRaw: Partial = { model: "text-embedding-ada-002", }, ) { super(); + const options = this.resolveVariables(optionsRaw); + const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; if (!openAIKey) { throw new Error("OpenAI API key is required"); @@ -52,10 +54,8 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction< this.#modelName = modelName; } - toJSON() { - return { - model: this.#modelName, - }; + protected getSensitiveKeys(): string[] { + return ["apiKey"]; } ndims(): number { diff --git a/nodejs/lancedb/embedding/registry.ts b/nodejs/lancedb/embedding/registry.ts index 2f33ac91..2eae90ed 100644 --- a/nodejs/lancedb/embedding/registry.ts +++ b/nodejs/lancedb/embedding/registry.ts @@ -23,6 +23,7 @@ export interface EmbeddingFunctionCreate { */ export class EmbeddingFunctionRegistry { #functions = new Map(); + #variables = new Map(); /** * Get the number of registered functions @@ -82,10 +83,7 @@ export class EmbeddingFunctionRegistry { }; } else { // biome-ignore lint/suspicious/noExplicitAny: - create = function (options?: any) { - const instance = new factory(options); - return instance; - }; + create = (options?: any) => new factory(options); } return { @@ -164,6 +162,37 @@ export class EmbeddingFunctionRegistry { return metadata; } + + /** + * Set a variable. These can be accessed in the embedding function + * configuration using the syntax `$var:variable_name`. If they are not + * set, an error will be thrown letting you know which key is unset. If you + * want to supply a default value, you can add an additional part in the + * configuration like so: `$var:variable_name:default_value`. Default values + * can be used for runtime configurations that are not sensitive, such as + * whether to use a GPU for inference. + * + * The name must not contain colons. The default value can contain colons. + * + * @param name + * @param value + */ + setVar(name: string, value: string): void { + if (name.includes(":")) { + throw new Error("Variable names cannot contain colons"); + } + this.#variables.set(name, value); + } + + /** + * Get a variable. + * @param name + * @returns + * @see {@link setVar} + */ + getVar(name: string): string | undefined { + return this.#variables.get(name); + } } const _REGISTRY = new EmbeddingFunctionRegistry(); diff --git a/nodejs/lancedb/embedding/transformers.ts b/nodejs/lancedb/embedding/transformers.ts index 6b3e4f3a..16157528 100644 --- a/nodejs/lancedb/embedding/transformers.ts +++ b/nodejs/lancedb/embedding/transformers.ts @@ -44,11 +44,12 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction< #ndims?: number; constructor( - options: Partial = { + optionsRaw: Partial = { model: "Xenova/all-MiniLM-L6-v2", }, ) { super(); + const options = this.resolveVariables(optionsRaw); const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2"; this.#tokenizerOptions = { @@ -59,22 +60,6 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction< this.#ndims = options.ndims; this.#modelName = modelName; } - toJSON() { - // biome-ignore lint/suspicious/noExplicitAny: - const obj: Record = { - model: this.#modelName, - }; - if (this.#ndims) { - obj["ndims"] = this.#ndims; - } - if (this.#tokenizerOptions) { - obj["tokenizerOptions"] = this.#tokenizerOptions; - } - if (this.#tokenizer) { - obj["tokenizer"] = this.#tokenizer.name; - } - return obj; - } async init() { let transformers; diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index cd68aa70..dc78e27a 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -2,8 +2,10 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors from abc import ABC, abstractmethod +import copy from typing import List, Union +from lancedb.util import add_note import numpy as np import pyarrow as pa from pydantic import BaseModel, Field, PrivateAttr @@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC): 7 # Setting 0 disables retires. Maybe this should not be enabled by default, ) _ndims: int = PrivateAttr() + _original_args: dict = PrivateAttr() @classmethod def create(cls, **kwargs): """ Create an instance of the embedding function """ - return cls(**kwargs) + resolved_kwargs = cls.__resolveVariables(kwargs) + instance = cls(**resolved_kwargs) + instance._original_args = kwargs + return instance + + @classmethod + def __resolveVariables(cls, args: dict) -> dict: + """ + Resolve variables in the args + """ + from .registry import EmbeddingFunctionRegistry + + new_args = copy.deepcopy(args) + + registry = EmbeddingFunctionRegistry.get_instance() + sensitive_keys = cls.sensitive_keys() + for k, v in new_args.items(): + if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys: + exc = ValueError( + f"Sensitive key '{k}' cannot be set to a hardcoded value" + ) + add_note(exc, "Help: Use $var: to set sensitive keys to variables") + raise exc + + if isinstance(v, str) and v.startswith("$var:"): + parts = v[5:].split(":", maxsplit=1) + if len(parts) == 1: + try: + new_args[k] = registry.get_var(parts[0]) + except KeyError: + exc = ValueError( + "Variable '{}' not found in registry".format(parts[0]) + ) + add_note( + exc, + "Help: Variables are reset in new Python sessions. " + "Use `registry.set_var` to set variables.", + ) + raise exc + else: + name, default = parts + try: + new_args[k] = registry.get_var(name) + except KeyError: + new_args[k] = default + return new_args + + @staticmethod + def sensitive_keys() -> List[str]: + """ + Return a list of keys that are sensitive and should not be allowed + to be set to hardcoded values in the config. For example, API keys. + """ + return [] @abstractmethod def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]: @@ -103,17 +159,11 @@ class EmbeddingFunction(BaseModel, ABC): return texts def safe_model_dump(self): - from ..pydantic import PYDANTIC_VERSION - - if PYDANTIC_VERSION.major < 2: - return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} - return self.model_dump( - exclude={ - field_name - for field_name in self.model_fields - if field_name.startswith("_") - } - ) + if not hasattr(self, "_original_args"): + raise ValueError( + "EmbeddingFunction was not created with EmbeddingFunction.create()" + ) + return self._original_args @abstractmethod def ndims(self) -> int: diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py index 03ffbd63..3f3816eb 100644 --- a/python/python/lancedb/embeddings/jinaai.py +++ b/python/python/lancedb/embeddings/jinaai.py @@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction): # TODO: fix hardcoding return 768 + @staticmethod + def sensitive_keys() -> List[str]: + return ["api_key"] + def sanitize_input( self, inputs: Union[TEXT, IMAGES] ) -> Union[List[Any], np.ndarray]: diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index a061d00a..9b18e45c 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction): def ndims(self): return self._ndims + @staticmethod + def sensitive_keys(): + return ["api_key"] + @staticmethod def model_names(): return [ diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index 78353e47..91424253 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry: def __init__(self): self._functions = {} + self._variables = {} def register(self, alias: str = None): """ @@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry: metadata = json.dumps(json_data, indent=2).encode("utf-8") return {"embedding_functions": metadata} + def set_var(self, name: str, value: str) -> None: + """ + Set a variable. These can be accessed in embedding configuration using + the syntax `$var:variable_name`. If they are not set, an error will be + thrown letting you know which variable is missing. If you want to supply + a default value, you can add an additional part in the configuration + like so: `$var:variable_name:default_value`. Default values can be + used for runtime configurations that are not sensitive, such as + whether to use a GPU for inference. + + The name must not contain a colon. Default values can contain colons. + """ + if ":" in name: + raise ValueError("Variable names cannot contain colons") + self._variables[name] = value + + def get_var(self, name: str) -> str: + """ + Get a variable. + """ + return self._variables[name] + # Global instance __REGISTRY__ = EmbeddingFunctionRegistry() diff --git a/python/python/lancedb/embeddings/watsonx.py b/python/python/lancedb/embeddings/watsonx.py index 039713c0..29d7d64a 100644 --- a/python/python/lancedb/embeddings/watsonx.py +++ b/python/python/lancedb/embeddings/watsonx.py @@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction): url: Optional[str] = None params: Optional[Dict] = None + @staticmethod + def sensitive_keys(): + return ["api_key"] + @staticmethod def model_names(): return [ diff --git a/python/python/tests/docs/test_embeddings_optional.py b/python/python/tests/docs/test_embeddings_optional.py index 5197a88a..4e28855c 100644 --- a/python/python/tests/docs/test_embeddings_optional.py +++ b/python/python/tests/docs/test_embeddings_optional.py @@ -49,3 +49,28 @@ async def test_embeddings_openai_async(): actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0] print(actual.text) # --8<-- [end:async_openai_embeddings] + + +def test_embeddings_secret(): + # --8<-- [start:register_secret] + registry = get_registry() + registry.set_var("api_key", "sk-...") + + func = registry.get("openai").create(api_key="$var:api_key") + # --8<-- [end:register_secret] + + try: + import torch + except ImportError: + pytest.skip("torch not installed") + + # --8<-- [start:register_device] + import torch + + registry = get_registry() + if torch.cuda.is_available(): + registry.set_var("device", "cuda") + + func = registry.get("huggingface").create(device="$var:device:cpu") + # --8<-- [end:register_device] + assert func.device == "cuda" if torch.cuda.is_available() else "cpu" diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index fe794103..168189dc 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -from typing import List, Union +import os +from typing import List, Optional, Union from unittest.mock import MagicMock, patch import lance @@ -56,7 +57,7 @@ def test_embedding_function(tmp_path): conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) table = table.replace_schema_metadata(metadata) @@ -80,6 +81,57 @@ def test_embedding_function(tmp_path): assert np.allclose(actual, expected) +def test_embedding_function_variables(): + @register("variable-testing") + class VariableTestingFunction(TextEmbeddingFunction): + key1: str + secret_key: Optional[str] = None + + @staticmethod + def sensitive_keys(): + return ["secret_key"] + + def ndims(): + pass + + def generate_embeddings(self, _texts): + pass + + registry = EmbeddingFunctionRegistry.get_instance() + + # Should error if variable is not set + with pytest.raises(ValueError, match="Variable 'test' not found"): + registry.get("variable-testing").create( + key1="$var:test", + ) + + # Should use default values if not set + func = registry.get("variable-testing").create(key1="$var:test:some_value") + assert func.key1 == "some_value" + + # Should set a variable that the embedding function understands + registry.set_var("test", "some_value") + func = registry.get("variable-testing").create(key1="$var:test") + assert func.key1 == "some_value" + + # Should reject secrets that aren't passed in as variables + with pytest.raises( + ValueError, + match="Sensitive key 'secret_key' cannot be set to a hardcoded value", + ): + registry.get("variable-testing").create( + key1="whatever", secret_key="some_value" + ) + + # Should not serialize secrets. + registry.set_var("secret", "secret_value") + func = registry.get("variable-testing").create( + key1="whatever", secret_key="$var:secret" + ) + assert func.secret_key == "secret_value" + assert func.safe_model_dump()["secret_key"] == "$var:secret" + + def test_embedding_with_bad_results(tmp_path): @register("null-embedding") class NullEmbeddingFunction(TextEmbeddingFunction): @@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path): ) -> list[Union[np.array, None]]: # Return None, which is bad if field is non-nullable a = [ - np.full(self.ndims(), np.nan) - if i % 2 == 0 - else np.random.randn(self.ndims()) + ( + np.full(self.ndims(), np.nan) + if i % 2 == 0 + else np.random.randn(self.ndims()) + ) for i in range(len(texts)) ] return a @@ -359,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type): # Note: Some embedding types might require specific parameters try: - model = registry.get(embedding_type).create() + model = registry.get(embedding_type).create({"max_retries": 1}) except Exception as e: pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}") @@ -392,3 +446,33 @@ def test_retry(mock_sleep): result = test_function() assert mock_sleep.call_count == 9 assert result == "result" + + +@pytest.mark.skipif( + os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set" +) +def test_openai_propagates_api_key(monkeypatch): + # Make sure that if we set it as a variable, the API key is propagated + api_key = os.environ["OPENAI_API_KEY"] + monkeypatch.delenv("OPENAI_API_KEY") + + uri = "memory://" + registry = get_registry() + registry.set_var("open_api_key", api_key) + func = registry.get("openai").create( + name="text-embedding-ada-002", + max_retries=0, + api_key="$var:open_api_key", + ) + + class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + db = lancedb.connect(uri) + table = db.create_table("words", schema=Words, mode="overwrite") + table.add([{"text": "hello world"}, {"text": "goodbye world"}]) + + query = "greetings" + actual = table.search(query).limit(1).to_pydantic(Words)[0] + assert len(actual.text) > 0 diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index fce74ac9..21d697a3 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts") def get_test_table(tmp_path, use_tantivy): db = lancedb.connect(tmp_path) # Create a LanceDB table schema with a vector and a text column - emb = EmbeddingFunctionRegistry.get_instance().get("test")() - meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")() + emb = EmbeddingFunctionRegistry.get_instance().get("test").create() + meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create() class MyTable(LanceModel): text: str = emb.SourceField() @@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy): @pytest.mark.skipif( - os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" + os.environ.get("OPENAI_API_KEY") is None + or os.environ.get("OPENAI_BASE_URL") is not None, + reason="OPENAI_API_KEY not set", ) @pytest.mark.parametrize("use_tantivy", [True, False]) def test_openai_reranker(tmp_path, use_tantivy): diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 6e810b97..b3117e83 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection): text: str vector: Vector(10) - func = MockTextEmbeddingFunction() + func = MockTextEmbeddingFunction.create() texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)}) @@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection): def test_add_with_embedding_function(mem_db: DBConnection): - emb = EmbeddingFunctionRegistry.get_instance().get("test")() + emb = EmbeddingFunctionRegistry.get_instance().get("test").create() class MyTable(LanceModel): text: str = emb.SourceField() @@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection): def setup_hybrid_search_table(db: DBConnection, embedding_func): # Create a LanceDB table schema with a vector and a text column - emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)() + emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create() class MyTable(LanceModel): text: str = emb.SourceField() diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 84c2f560..4a00f948 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -127,7 +127,7 @@ def test_append_vector_columns(): conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) @@ -434,7 +434,7 @@ def test_sanitize_data( conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) else: