From 7ac5f74c80b571344f48555753a225c759aaa982 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 24 Feb 2025 15:52:19 -0800 Subject: [PATCH] feat!: add variable store to embeddings registry (#2112) BREAKING CHANGE: embedding function implementations in Node need to now call `resolveVariables()` in their constructors and should **not** implement `toJSON()`. This tries to address the handling of secrets. In Node, they are currently lost. In Python, they are currently leaked into the table schema metadata. This PR introduces an in-memory variable store on the function registry. It also allows embedding function definitions to label certain config values as "sensitive", and the preprocessing logic will raise an error if users try to pass in hard-coded values. Closes #2110 Closes #521 --------- Co-authored-by: Weston Pace --- docs/mkdocs.yml | 2 + .../embeddings/custom_embedding_function.md | 8 ++ docs/src/embeddings/variables_and_secrets.md | 53 +++++++++ .../embedding/classes/EmbeddingFunction.md | 86 +++++++++----- .../classes/EmbeddingFunctionRegistry.md | 50 ++++++++ .../classes/TextEmbeddingFunction.md | 77 +++++++----- nodejs/__test__/embedding.test.ts | 49 ++++---- nodejs/__test__/registry.test.ts | 102 +++++++++++++--- nodejs/__test__/table.test.ts | 3 - nodejs/examples/embedding.test.ts | 20 +++- .../lancedb/embedding/embedding_function.ts | 110 +++++++++++++----- nodejs/lancedb/embedding/openai.ts | 10 +- nodejs/lancedb/embedding/registry.ts | 37 +++++- nodejs/lancedb/embedding/transformers.ts | 19 +-- python/python/lancedb/embeddings/base.py | 74 ++++++++++-- python/python/lancedb/embeddings/jinaai.py | 4 + python/python/lancedb/embeddings/openai.py | 4 + python/python/lancedb/embeddings/registry.py | 23 ++++ python/python/lancedb/embeddings/watsonx.py | 4 + .../tests/docs/test_embeddings_optional.py | 25 ++++ python/python/tests/test_embeddings.py | 96 ++++++++++++++- python/python/tests/test_rerankers.py | 8 +- python/python/tests/test_table.py | 6 +- python/python/tests/test_util.py | 4 +- 24 files changed, 699 insertions(+), 175 deletions(-) create mode 100644 docs/src/embeddings/variables_and_secrets.md diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 99bcce89..e9467b87 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -182,6 +182,7 @@ nav: - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - User-defined embedding functions: embeddings/custom_embedding_function.md + - Variables and secrets: embeddings/variables_and_secrets.md - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - 🔌 Integrations: @@ -315,6 +316,7 @@ nav: - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - User-defined embedding functions: embeddings/custom_embedding_function.md + - Variables and secrets: embeddings/variables_and_secrets.md - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - Integrations: diff --git a/docs/src/embeddings/custom_embedding_function.md b/docs/src/embeddings/custom_embedding_function.md index 619c5437..655c6904 100644 --- a/docs/src/embeddings/custom_embedding_function.md +++ b/docs/src/embeddings/custom_embedding_function.md @@ -55,6 +55,14 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings. +!!! danger "Use sensitive keys to prevent leaking secrets" + To prevent leaking secrets, such as API keys, you should add any sensitive + parameters of an embedding function to the output of the + [sensitive_keys()][lancedb.embeddings.base.EmbeddingFunction.sensitive_keys] / + [getSensitiveKeys()](../../js/namespaces/embedding/classes/EmbeddingFunction/#getsensitivekeys) + method. This prevents users from accidentally instantiating the embedding + function with hard-coded secrets. + Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs. === "Python" diff --git a/docs/src/embeddings/variables_and_secrets.md b/docs/src/embeddings/variables_and_secrets.md new file mode 100644 index 00000000..72388b24 --- /dev/null +++ b/docs/src/embeddings/variables_and_secrets.md @@ -0,0 +1,53 @@ +# Variable and Secrets + +Most embedding configuration options are saved in the table's metadata. However, +this isn't always appropriate. For example, API keys should never be stored in the +metadata. Additionally, other configuration options might be best set at runtime, +such as the `device` configuration that controls whether to use GPU or CPU for +inference. If you hardcoded this to GPU, you wouldn't be able to run the code on +a server without one. + +To handle these cases, you can set variables on the embedding registry and +reference them in the embedding configuration. These variables will be available +during the runtime of your program, but not saved in the table's metadata. When +the table is loaded from a different process, the variables must be set again. + +To set a variable, use the `set_var()` / `setVar()` method on the embedding registry. +To reference a variable, use the syntax `$env:VARIABLE_NAME`. If there is a default +value, you can use the syntax `$env:VARIABLE_NAME:DEFAULT_VALUE`. + +## Using variables to set secrets + +Sensitive configuration, such as API keys, must either be set as environment +variables or using variables on the embedding registry. If you pass in a hardcoded +value, LanceDB will raise an error. Instead, if you want to set an API key via +configuration, use a variable: + +=== "Python" + + ```python + --8<-- "python/python/tests/docs/test_embeddings_optional.py:register_secret" + ``` + +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/embedding.test.ts:register_secret" + ``` + +## Using variables to set the device parameter + +Many embedding functions that run locally have a `device` parameter that controls +whether to use GPU or CPU for inference. Because not all computers have a GPU, +it's helpful to be able to set the `device` parameter at runtime, rather than +have it hard coded in the embedding configuration. To make it work even if the +variable isn't set, you could provide a default value of `cpu` in the embedding +configuration. + +Some embedding libraries even have a method to detect which devices are available, +which could be used to dynamically set the device at runtime. For example, in Python +you can check if a CUDA GPU is available using `torch.cuda.is_available()`. + +```python +--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_device" +``` diff --git a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md index 24c54915..66d6ee16 100644 --- a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md @@ -8,6 +8,23 @@ An embedding function that automatically creates vector representation for a given column. +It's important subclasses pass the **original** options to the super constructor +and then pass those options to `resolveVariables` to resolve any variables before +using them. + +## Example + +```ts +class MyEmbeddingFunction extends EmbeddingFunction { + constructor(options: {model: string, timeout: number}) { + super(optionsRaw); + const options = this.resolveVariables(optionsRaw); + this.model = options.model; + this.timeout = options.timeout; + } +} +``` + ## Extended by - [`TextEmbeddingFunction`](TextEmbeddingFunction.md) @@ -82,12 +99,33 @@ The datatype of the embeddings *** +### getSensitiveKeys() + +```ts +protected getSensitiveKeys(): string[] +``` + +Provide a list of keys in the function options that should be treated as +sensitive. If users pass raw values for these keys, they will be rejected. + +#### Returns + +`string`[] + +*** + ### init()? ```ts optional init(): Promise ``` +Optionally load any resources needed for the embedding function. + +This method is called after the embedding function has been initialized +but before any embeddings are computed. It is useful for loading local models +or other resources that are needed for the embedding function to work. + #### Returns `Promise`<`void`> @@ -108,6 +146,24 @@ The number of dimensions of the embeddings *** +### resolveVariables() + +```ts +protected resolveVariables(config): Partial +``` + +Apply variables to the config. + +#### Parameters + +* **config**: `Partial`<`M`> + +#### Returns + +`Partial`<`M`> + +*** + ### sourceField() ```ts @@ -134,37 +190,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d ### toJSON() ```ts -abstract toJSON(): Partial +toJSON(): Record ``` -Convert the embedding function to a JSON object -It is used to serialize the embedding function to the schema -It's important that any object returned by this method contains all the necessary -information to recreate the embedding function - -It should return the same object that was passed to the constructor -If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly +Get the original arguments to the constructor, to serialize them so they +can be used to recreate the embedding function later. #### Returns -`Partial`<`M`> - -#### Example - -```ts -class MyEmbeddingFunction extends EmbeddingFunction { - constructor(options: {model: string, timeout: number}) { - super(); - this.model = options.model; - this.timeout = options.timeout; - } - toJSON() { - return { - model: this.model, - timeout: this.timeout, - }; -} -``` +`Record`<`string`, `any`> *** diff --git a/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md b/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md index e01078ab..2186d696 100644 --- a/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md +++ b/docs/src/js/namespaces/embedding/classes/EmbeddingFunctionRegistry.md @@ -80,6 +80,28 @@ getTableMetadata(functions): Map *** +### getVar() + +```ts +getVar(name): undefined | string +``` + +Get a variable. + +#### Parameters + +* **name**: `string` + +#### Returns + +`undefined` \| `string` + +#### See + +[setVar](EmbeddingFunctionRegistry.md#setvar) + +*** + ### length() ```ts @@ -145,3 +167,31 @@ reset the registry to the initial state #### Returns `void` + +*** + +### setVar() + +```ts +setVar(name, value): void +``` + +Set a variable. These can be accessed in the embedding function +configuration using the syntax `$var:variable_name`. If they are not +set, an error will be thrown letting you know which key is unset. If you +want to supply a default value, you can add an additional part in the +configuration like so: `$var:variable_name:default_value`. Default values +can be used for runtime configurations that are not sensitive, such as +whether to use a GPU for inference. + +The name must not contain colons. The default value can contain colons. + +#### Parameters + +* **name**: `string` + +* **value**: `string` + +#### Returns + +`void` diff --git a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md index 2cc13bef..8aee4f44 100644 --- a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md @@ -114,12 +114,37 @@ abstract generateEmbeddings(texts, ...args): Promise ``` +Optionally load any resources needed for the embedding function. + +This method is called after the embedding function has been initialized +but before any embeddings are computed. It is useful for loading local models +or other resources that are needed for the embedding function to work. + #### Returns `Promise`<`void`> @@ -148,6 +173,28 @@ The number of dimensions of the embeddings *** +### resolveVariables() + +```ts +protected resolveVariables(config): Partial +``` + +Apply variables to the config. + +#### Parameters + +* **config**: `Partial`<`M`> + +#### Returns + +`Partial`<`M`> + +#### Inherited from + +[`EmbeddingFunction`](EmbeddingFunction.md).[`resolveVariables`](EmbeddingFunction.md#resolvevariables) + +*** + ### sourceField() ```ts @@ -173,37 +220,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d ### toJSON() ```ts -abstract toJSON(): Partial +toJSON(): Record ``` -Convert the embedding function to a JSON object -It is used to serialize the embedding function to the schema -It's important that any object returned by this method contains all the necessary -information to recreate the embedding function - -It should return the same object that was passed to the constructor -If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly +Get the original arguments to the constructor, to serialize them so they +can be used to recreate the embedding function later. #### Returns -`Partial`<`M`> - -#### Example - -```ts -class MyEmbeddingFunction extends EmbeddingFunction { - constructor(options: {model: string, timeout: number}) { - super(); - this.model = options.model; - this.timeout = options.timeout; - } - toJSON() { - return { - model: this.model, - timeout: this.timeout, - }; -} -``` +`Record`<`string`, `any`> #### Inherited from diff --git a/nodejs/__test__/embedding.test.ts b/nodejs/__test__/embedding.test.ts index 03802dab..30d8b1fc 100644 --- a/nodejs/__test__/embedding.test.ts +++ b/nodejs/__test__/embedding.test.ts @@ -17,6 +17,8 @@ import { import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; +const testOpenAIInteg = process.env.OPENAI_API_KEY == null ? test.skip : test; + describe("embedding functions", () => { let tmpDir: tmp.DirResult; beforeEach(() => { @@ -29,9 +31,6 @@ describe("embedding functions", () => { it("should be able to create a table with an embedding function", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -75,9 +74,6 @@ describe("embedding functions", () => { it("should be able to append and upsert using embedding function", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -143,9 +139,6 @@ describe("embedding functions", () => { it("should be able to create an empty table with an embedding function", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -194,9 +187,6 @@ describe("embedding functions", () => { it("should error when appending to a table with an unregistered embedding function", async () => { @register("mock") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -241,13 +231,35 @@ describe("embedding functions", () => { `Function "mock" not found in registry`, ); }); + + testOpenAIInteg("propagates variables through all methods", async () => { + delete process.env.OPENAI_API_KEY; + const registry = getRegistry(); + registry.setVar("openai_api_key", "sk-..."); + const func = registry.get("openai")?.create({ + model: "text-embedding-ada-002", + apiKey: "$var:openai_api_key", + }) as EmbeddingFunction; + + const db = await connect("memory://"); + const wordsSchema = LanceSchema({ + text: func.sourceField(new Utf8()), + vector: func.vectorField(), + }); + const tbl = await db.createEmptyTable("words", wordsSchema, { + mode: "overwrite", + }); + await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]); + + const query = "greetings"; + const actual = (await tbl.search(query).limit(1).toArray())[0]; + expect(actual).toHaveProperty("text"); + }); + test.each([new Float16(), new Float32(), new Float64()])( "should be able to provide manual embeddings with multiple float datatype", async (floatType) => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } @@ -292,10 +304,6 @@ describe("embedding functions", () => { async (floatType) => { @register("test1") class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction { - toJSON(): object { - return {}; - } - embeddingDataType(): Float { return floatType; } @@ -310,9 +318,6 @@ describe("embedding functions", () => { } @register("test") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 3; } diff --git a/nodejs/__test__/registry.test.ts b/nodejs/__test__/registry.test.ts index f00121da..a5cf73e7 100644 --- a/nodejs/__test__/registry.test.ts +++ b/nodejs/__test__/registry.test.ts @@ -11,7 +11,11 @@ import * as arrow18 from "apache-arrow-18"; import * as tmp from "tmp"; import { connect } from "../lancedb"; -import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; +import { + EmbeddingFunction, + FunctionOptions, + LanceSchema, +} from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => { @@ -39,11 +43,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { it("should register a new item to the registry", async () => { @register("mock-embedding") class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } constructor() { super(); } @@ -89,11 +88,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { }); test("should error if registering with the same name", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } constructor() { super(); } @@ -114,13 +108,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { }); test("schema should contain correct metadata", async () => { class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return { - someText: "hello", - }; - } - constructor() { + constructor(args: FunctionOptions = {}) { super(); + this.resolveVariables(args); } ndims() { return 3; @@ -132,7 +122,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { return data.map(() => [1, 2, 3]); } } - const func = new MockEmbeddingFunction(); + const func = new MockEmbeddingFunction({ someText: "hello" }); const schema = LanceSchema({ id: new arrow.Int32(), @@ -155,3 +145,79 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => { expect(schema.metadata).toEqual(expectedMetadata); }); }); + +describe("Registry.setVar", () => { + const registry = getRegistry(); + + beforeEach(() => { + @register("mock-embedding") + // biome-ignore lint/correctness/noUnusedVariables : + class MockEmbeddingFunction extends EmbeddingFunction { + constructor(optionsRaw: FunctionOptions = {}) { + super(); + const options = this.resolveVariables(optionsRaw); + + expect(optionsRaw["someKey"].startsWith("$var:someName")).toBe(true); + expect(options["someKey"]).toBe("someValue"); + + if (options["secretKey"]) { + expect(optionsRaw["secretKey"]).toBe("$var:secretKey"); + expect(options["secretKey"]).toBe("mySecret"); + } + } + async computeSourceEmbeddings(data: string[]) { + return data.map(() => [1, 2, 3]); + } + embeddingDataType() { + return new arrow18.Float32() as apiArrow.Float; + } + protected getSensitiveKeys() { + return ["secretKey"]; + } + } + }); + afterEach(() => { + registry.reset(); + }); + + it("Should error if the variable is not set", () => { + console.log(registry.get("mock-embedding")); + expect(() => + registry.get("mock-embedding")!.create({ someKey: "$var:someName" }), + ).toThrow('Variable "someName" not found'); + }); + + it("should use default values if not set", () => { + registry + .get("mock-embedding")! + .create({ someKey: "$var:someName:someValue" }); + }); + + it("should set a variable that the embedding function understand", () => { + registry.setVar("someName", "someValue"); + registry.get("mock-embedding")!.create({ someKey: "$var:someName" }); + }); + + it("should reject secrets that aren't passed as variables", () => { + registry.setVar("someName", "someValue"); + expect(() => + registry + .get("mock-embedding")! + .create({ secretKey: "someValue", someKey: "$var:someName" }), + ).toThrow( + 'The key "secretKey" is sensitive and cannot be set directly. Please use the $var: syntax to set it.', + ); + }); + + it("should not serialize secrets", () => { + registry.setVar("someName", "someValue"); + registry.setVar("secretKey", "mySecret"); + const func = registry + .get("mock-embedding")! + .create({ secretKey: "$var:secretKey", someKey: "$var:someName" }); + expect(func.toJSON()).toEqual({ + secretKey: "$var:secretKey", + someKey: "$var:someName", + }); + }); +}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 50740309..8d495d64 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -1038,9 +1038,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( test("can search using a string", async () => { @register() class MockEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; - } ndims() { return 1; } diff --git a/nodejs/examples/embedding.test.ts b/nodejs/examples/embedding.test.ts index 2dc79c69..35af9e59 100644 --- a/nodejs/examples/embedding.test.ts +++ b/nodejs/examples/embedding.test.ts @@ -43,12 +43,17 @@ test("custom embedding function", async () => { @register("my_embedding") class MyEmbeddingFunction extends EmbeddingFunction { - toJSON(): object { - return {}; + constructor(optionsRaw = {}) { + super(); + const options = this.resolveVariables(optionsRaw); + // Initialize using options } ndims() { return 3; } + protected getSensitiveKeys(): string[] { + return []; + } embeddingDataType(): Float { return new Float32(); } @@ -94,3 +99,14 @@ test("custom embedding function", async () => { expect(await table2.countRows()).toBe(2); }); }); + +test("embedding function api_key", async () => { + // --8<-- [start:register_secret] + const registry = getRegistry(); + registry.setVar("api_key", "sk-..."); + + const func = registry.get("openai")!.create({ + apiKey: "$var:api_key", + }); + // --8<-- [end:register_secret] +}); diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 4d00eb29..45aa83f5 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -15,6 +15,7 @@ import { newVectorType, } from "../arrow"; import { sanitizeType } from "../sanitize"; +import { getRegistry } from "./registry"; /** * Options for a given embedding function @@ -32,6 +33,22 @@ export interface EmbeddingFunctionConstructor< /** * An embedding function that automatically creates vector representation for a given column. + * + * It's important subclasses pass the **original** options to the super constructor + * and then pass those options to `resolveVariables` to resolve any variables before + * using them. + * + * @example + * ```ts + * class MyEmbeddingFunction extends EmbeddingFunction { + * constructor(options: {model: string, timeout: number}) { + * super(optionsRaw); + * const options = this.resolveVariables(optionsRaw); + * this.model = options.model; + * this.timeout = options.timeout; + * } + * } + * ``` */ export abstract class EmbeddingFunction< // biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do @@ -44,33 +61,74 @@ export abstract class EmbeddingFunction< */ // biome-ignore lint/style/useNamingConvention: we want to keep the name as it is readonly TOptions!: M; - /** - * Convert the embedding function to a JSON object - * It is used to serialize the embedding function to the schema - * It's important that any object returned by this method contains all the necessary - * information to recreate the embedding function - * - * It should return the same object that was passed to the constructor - * If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly - * - * @example - * ```ts - * class MyEmbeddingFunction extends EmbeddingFunction { - * constructor(options: {model: string, timeout: number}) { - * super(); - * this.model = options.model; - * this.timeout = options.timeout; - * } - * toJSON() { - * return { - * model: this.model, - * timeout: this.timeout, - * }; - * } - * ``` - */ - abstract toJSON(): Partial; + #config: Partial; + + /** + * Get the original arguments to the constructor, to serialize them so they + * can be used to recreate the embedding function later. + */ + // biome-ignore lint/suspicious/noExplicitAny : + toJSON(): Record { + return JSON.parse(JSON.stringify(this.#config)); + } + + constructor() { + this.#config = {}; + } + + /** + * Provide a list of keys in the function options that should be treated as + * sensitive. If users pass raw values for these keys, they will be rejected. + */ + protected getSensitiveKeys(): string[] { + return []; + } + + /** + * Apply variables to the config. + */ + protected resolveVariables(config: Partial): Partial { + this.#config = config; + const registry = getRegistry(); + const newConfig = { ...config }; + for (const [key_, value] of Object.entries(newConfig)) { + if ( + this.getSensitiveKeys().includes(key_) && + !value.startsWith("$var:") + ) { + throw new Error( + `The key "${key_}" is sensitive and cannot be set directly. Please use the $var: syntax to set it.`, + ); + } + // Makes TS happy (https://stackoverflow.com/a/78391854) + const key = key_ as keyof M; + if (typeof value === "string" && value.startsWith("$var:")) { + const [name, defaultValue] = value.slice(5).split(":", 2); + const variableValue = registry.getVar(name); + if (!variableValue) { + if (defaultValue) { + // biome-ignore lint/suspicious/noExplicitAny: + newConfig[key] = defaultValue as any; + } else { + throw new Error(`Variable "${name}" not found`); + } + } else { + // biome-ignore lint/suspicious/noExplicitAny: + newConfig[key] = variableValue as any; + } + } + } + return newConfig; + } + + /** + * Optionally load any resources needed for the embedding function. + * + * This method is called after the embedding function has been initialized + * but before any embeddings are computed. It is useful for loading local models + * or other resources that are needed for the embedding function to work. + */ async init?(): Promise; /** diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index 1fe86668..5771cfeb 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -21,11 +21,13 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction< #modelName: OpenAIOptions["model"]; constructor( - options: Partial = { + optionsRaw: Partial = { model: "text-embedding-ada-002", }, ) { super(); + const options = this.resolveVariables(optionsRaw); + const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; if (!openAIKey) { throw new Error("OpenAI API key is required"); @@ -52,10 +54,8 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction< this.#modelName = modelName; } - toJSON() { - return { - model: this.#modelName, - }; + protected getSensitiveKeys(): string[] { + return ["apiKey"]; } ndims(): number { diff --git a/nodejs/lancedb/embedding/registry.ts b/nodejs/lancedb/embedding/registry.ts index 2f33ac91..2eae90ed 100644 --- a/nodejs/lancedb/embedding/registry.ts +++ b/nodejs/lancedb/embedding/registry.ts @@ -23,6 +23,7 @@ export interface EmbeddingFunctionCreate { */ export class EmbeddingFunctionRegistry { #functions = new Map(); + #variables = new Map(); /** * Get the number of registered functions @@ -82,10 +83,7 @@ export class EmbeddingFunctionRegistry { }; } else { // biome-ignore lint/suspicious/noExplicitAny: - create = function (options?: any) { - const instance = new factory(options); - return instance; - }; + create = (options?: any) => new factory(options); } return { @@ -164,6 +162,37 @@ export class EmbeddingFunctionRegistry { return metadata; } + + /** + * Set a variable. These can be accessed in the embedding function + * configuration using the syntax `$var:variable_name`. If they are not + * set, an error will be thrown letting you know which key is unset. If you + * want to supply a default value, you can add an additional part in the + * configuration like so: `$var:variable_name:default_value`. Default values + * can be used for runtime configurations that are not sensitive, such as + * whether to use a GPU for inference. + * + * The name must not contain colons. The default value can contain colons. + * + * @param name + * @param value + */ + setVar(name: string, value: string): void { + if (name.includes(":")) { + throw new Error("Variable names cannot contain colons"); + } + this.#variables.set(name, value); + } + + /** + * Get a variable. + * @param name + * @returns + * @see {@link setVar} + */ + getVar(name: string): string | undefined { + return this.#variables.get(name); + } } const _REGISTRY = new EmbeddingFunctionRegistry(); diff --git a/nodejs/lancedb/embedding/transformers.ts b/nodejs/lancedb/embedding/transformers.ts index 6b3e4f3a..16157528 100644 --- a/nodejs/lancedb/embedding/transformers.ts +++ b/nodejs/lancedb/embedding/transformers.ts @@ -44,11 +44,12 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction< #ndims?: number; constructor( - options: Partial = { + optionsRaw: Partial = { model: "Xenova/all-MiniLM-L6-v2", }, ) { super(); + const options = this.resolveVariables(optionsRaw); const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2"; this.#tokenizerOptions = { @@ -59,22 +60,6 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction< this.#ndims = options.ndims; this.#modelName = modelName; } - toJSON() { - // biome-ignore lint/suspicious/noExplicitAny: - const obj: Record = { - model: this.#modelName, - }; - if (this.#ndims) { - obj["ndims"] = this.#ndims; - } - if (this.#tokenizerOptions) { - obj["tokenizerOptions"] = this.#tokenizerOptions; - } - if (this.#tokenizer) { - obj["tokenizer"] = this.#tokenizer.name; - } - return obj; - } async init() { let transformers; diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index cd68aa70..dc78e27a 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -2,8 +2,10 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors from abc import ABC, abstractmethod +import copy from typing import List, Union +from lancedb.util import add_note import numpy as np import pyarrow as pa from pydantic import BaseModel, Field, PrivateAttr @@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC): 7 # Setting 0 disables retires. Maybe this should not be enabled by default, ) _ndims: int = PrivateAttr() + _original_args: dict = PrivateAttr() @classmethod def create(cls, **kwargs): """ Create an instance of the embedding function """ - return cls(**kwargs) + resolved_kwargs = cls.__resolveVariables(kwargs) + instance = cls(**resolved_kwargs) + instance._original_args = kwargs + return instance + + @classmethod + def __resolveVariables(cls, args: dict) -> dict: + """ + Resolve variables in the args + """ + from .registry import EmbeddingFunctionRegistry + + new_args = copy.deepcopy(args) + + registry = EmbeddingFunctionRegistry.get_instance() + sensitive_keys = cls.sensitive_keys() + for k, v in new_args.items(): + if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys: + exc = ValueError( + f"Sensitive key '{k}' cannot be set to a hardcoded value" + ) + add_note(exc, "Help: Use $var: to set sensitive keys to variables") + raise exc + + if isinstance(v, str) and v.startswith("$var:"): + parts = v[5:].split(":", maxsplit=1) + if len(parts) == 1: + try: + new_args[k] = registry.get_var(parts[0]) + except KeyError: + exc = ValueError( + "Variable '{}' not found in registry".format(parts[0]) + ) + add_note( + exc, + "Help: Variables are reset in new Python sessions. " + "Use `registry.set_var` to set variables.", + ) + raise exc + else: + name, default = parts + try: + new_args[k] = registry.get_var(name) + except KeyError: + new_args[k] = default + return new_args + + @staticmethod + def sensitive_keys() -> List[str]: + """ + Return a list of keys that are sensitive and should not be allowed + to be set to hardcoded values in the config. For example, API keys. + """ + return [] @abstractmethod def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]: @@ -103,17 +159,11 @@ class EmbeddingFunction(BaseModel, ABC): return texts def safe_model_dump(self): - from ..pydantic import PYDANTIC_VERSION - - if PYDANTIC_VERSION.major < 2: - return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} - return self.model_dump( - exclude={ - field_name - for field_name in self.model_fields - if field_name.startswith("_") - } - ) + if not hasattr(self, "_original_args"): + raise ValueError( + "EmbeddingFunction was not created with EmbeddingFunction.create()" + ) + return self._original_args @abstractmethod def ndims(self) -> int: diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py index 03ffbd63..3f3816eb 100644 --- a/python/python/lancedb/embeddings/jinaai.py +++ b/python/python/lancedb/embeddings/jinaai.py @@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction): # TODO: fix hardcoding return 768 + @staticmethod + def sensitive_keys() -> List[str]: + return ["api_key"] + def sanitize_input( self, inputs: Union[TEXT, IMAGES] ) -> Union[List[Any], np.ndarray]: diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index a061d00a..9b18e45c 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction): def ndims(self): return self._ndims + @staticmethod + def sensitive_keys(): + return ["api_key"] + @staticmethod def model_names(): return [ diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index 78353e47..91424253 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry: def __init__(self): self._functions = {} + self._variables = {} def register(self, alias: str = None): """ @@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry: metadata = json.dumps(json_data, indent=2).encode("utf-8") return {"embedding_functions": metadata} + def set_var(self, name: str, value: str) -> None: + """ + Set a variable. These can be accessed in embedding configuration using + the syntax `$var:variable_name`. If they are not set, an error will be + thrown letting you know which variable is missing. If you want to supply + a default value, you can add an additional part in the configuration + like so: `$var:variable_name:default_value`. Default values can be + used for runtime configurations that are not sensitive, such as + whether to use a GPU for inference. + + The name must not contain a colon. Default values can contain colons. + """ + if ":" in name: + raise ValueError("Variable names cannot contain colons") + self._variables[name] = value + + def get_var(self, name: str) -> str: + """ + Get a variable. + """ + return self._variables[name] + # Global instance __REGISTRY__ = EmbeddingFunctionRegistry() diff --git a/python/python/lancedb/embeddings/watsonx.py b/python/python/lancedb/embeddings/watsonx.py index 039713c0..29d7d64a 100644 --- a/python/python/lancedb/embeddings/watsonx.py +++ b/python/python/lancedb/embeddings/watsonx.py @@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction): url: Optional[str] = None params: Optional[Dict] = None + @staticmethod + def sensitive_keys(): + return ["api_key"] + @staticmethod def model_names(): return [ diff --git a/python/python/tests/docs/test_embeddings_optional.py b/python/python/tests/docs/test_embeddings_optional.py index 5197a88a..4e28855c 100644 --- a/python/python/tests/docs/test_embeddings_optional.py +++ b/python/python/tests/docs/test_embeddings_optional.py @@ -49,3 +49,28 @@ async def test_embeddings_openai_async(): actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0] print(actual.text) # --8<-- [end:async_openai_embeddings] + + +def test_embeddings_secret(): + # --8<-- [start:register_secret] + registry = get_registry() + registry.set_var("api_key", "sk-...") + + func = registry.get("openai").create(api_key="$var:api_key") + # --8<-- [end:register_secret] + + try: + import torch + except ImportError: + pytest.skip("torch not installed") + + # --8<-- [start:register_device] + import torch + + registry = get_registry() + if torch.cuda.is_available(): + registry.set_var("device", "cuda") + + func = registry.get("huggingface").create(device="$var:device:cpu") + # --8<-- [end:register_device] + assert func.device == "cuda" if torch.cuda.is_available() else "cpu" diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index fe794103..168189dc 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -from typing import List, Union +import os +from typing import List, Optional, Union from unittest.mock import MagicMock, patch import lance @@ -56,7 +57,7 @@ def test_embedding_function(tmp_path): conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) table = table.replace_schema_metadata(metadata) @@ -80,6 +81,57 @@ def test_embedding_function(tmp_path): assert np.allclose(actual, expected) +def test_embedding_function_variables(): + @register("variable-testing") + class VariableTestingFunction(TextEmbeddingFunction): + key1: str + secret_key: Optional[str] = None + + @staticmethod + def sensitive_keys(): + return ["secret_key"] + + def ndims(): + pass + + def generate_embeddings(self, _texts): + pass + + registry = EmbeddingFunctionRegistry.get_instance() + + # Should error if variable is not set + with pytest.raises(ValueError, match="Variable 'test' not found"): + registry.get("variable-testing").create( + key1="$var:test", + ) + + # Should use default values if not set + func = registry.get("variable-testing").create(key1="$var:test:some_value") + assert func.key1 == "some_value" + + # Should set a variable that the embedding function understands + registry.set_var("test", "some_value") + func = registry.get("variable-testing").create(key1="$var:test") + assert func.key1 == "some_value" + + # Should reject secrets that aren't passed in as variables + with pytest.raises( + ValueError, + match="Sensitive key 'secret_key' cannot be set to a hardcoded value", + ): + registry.get("variable-testing").create( + key1="whatever", secret_key="some_value" + ) + + # Should not serialize secrets. + registry.set_var("secret", "secret_value") + func = registry.get("variable-testing").create( + key1="whatever", secret_key="$var:secret" + ) + assert func.secret_key == "secret_value" + assert func.safe_model_dump()["secret_key"] == "$var:secret" + + def test_embedding_with_bad_results(tmp_path): @register("null-embedding") class NullEmbeddingFunction(TextEmbeddingFunction): @@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path): ) -> list[Union[np.array, None]]: # Return None, which is bad if field is non-nullable a = [ - np.full(self.ndims(), np.nan) - if i % 2 == 0 - else np.random.randn(self.ndims()) + ( + np.full(self.ndims(), np.nan) + if i % 2 == 0 + else np.random.randn(self.ndims()) + ) for i in range(len(texts)) ] return a @@ -359,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type): # Note: Some embedding types might require specific parameters try: - model = registry.get(embedding_type).create() + model = registry.get(embedding_type).create({"max_retries": 1}) except Exception as e: pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}") @@ -392,3 +446,33 @@ def test_retry(mock_sleep): result = test_function() assert mock_sleep.call_count == 9 assert result == "result" + + +@pytest.mark.skipif( + os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set" +) +def test_openai_propagates_api_key(monkeypatch): + # Make sure that if we set it as a variable, the API key is propagated + api_key = os.environ["OPENAI_API_KEY"] + monkeypatch.delenv("OPENAI_API_KEY") + + uri = "memory://" + registry = get_registry() + registry.set_var("open_api_key", api_key) + func = registry.get("openai").create( + name="text-embedding-ada-002", + max_retries=0, + api_key="$var:open_api_key", + ) + + class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + db = lancedb.connect(uri) + table = db.create_table("words", schema=Words, mode="overwrite") + table.add([{"text": "hello world"}, {"text": "goodbye world"}]) + + query = "greetings" + actual = table.search(query).limit(1).to_pydantic(Words)[0] + assert len(actual.text) > 0 diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index fce74ac9..21d697a3 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts") def get_test_table(tmp_path, use_tantivy): db = lancedb.connect(tmp_path) # Create a LanceDB table schema with a vector and a text column - emb = EmbeddingFunctionRegistry.get_instance().get("test")() - meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")() + emb = EmbeddingFunctionRegistry.get_instance().get("test").create() + meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create() class MyTable(LanceModel): text: str = emb.SourceField() @@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy): @pytest.mark.skipif( - os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" + os.environ.get("OPENAI_API_KEY") is None + or os.environ.get("OPENAI_BASE_URL") is not None, + reason="OPENAI_API_KEY not set", ) @pytest.mark.parametrize("use_tantivy", [True, False]) def test_openai_reranker(tmp_path, use_tantivy): diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 6e810b97..b3117e83 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection): text: str vector: Vector(10) - func = MockTextEmbeddingFunction() + func = MockTextEmbeddingFunction.create() texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)}) @@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection): def test_add_with_embedding_function(mem_db: DBConnection): - emb = EmbeddingFunctionRegistry.get_instance().get("test")() + emb = EmbeddingFunctionRegistry.get_instance().get("test").create() class MyTable(LanceModel): text: str = emb.SourceField() @@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection): def setup_hybrid_search_table(db: DBConnection, embedding_func): # Create a LanceDB table schema with a vector and a text column - emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)() + emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create() class MyTable(LanceModel): text: str = emb.SourceField() diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 84c2f560..4a00f948 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -127,7 +127,7 @@ def test_append_vector_columns(): conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) @@ -434,7 +434,7 @@ def test_sanitize_data( conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", - function=MockTextEmbeddingFunction(), + function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) else: