feat!: add variable store to embeddings registry (#2112)

BREAKING CHANGE: embedding function implementations in Node need to now
call `resolveVariables()` in their constructors and should **not**
implement `toJSON()`.

This tries to address the handling of secrets. In Node, they are
currently lost. In Python, they are currently leaked into the table
schema metadata.

This PR introduces an in-memory variable store on the function registry.
It also allows embedding function definitions to label certain config
values as "sensitive", and the preprocessing logic will raise an error
if users try to pass in hard-coded values.

Closes #2110
Closes #521

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2025-02-24 15:52:19 -08:00
committed by GitHub
parent ecdee4d2b1
commit 7ac5f74c80
24 changed files with 699 additions and 175 deletions

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
import os
from typing import List, Optional, Union
from unittest.mock import MagicMock, patch
import lance
@@ -56,7 +57,7 @@ def test_embedding_function(tmp_path):
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
function=MockTextEmbeddingFunction.create(),
)
metadata = registry.get_table_metadata([conf])
table = table.replace_schema_metadata(metadata)
@@ -80,6 +81,57 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected)
def test_embedding_function_variables():
@register("variable-testing")
class VariableTestingFunction(TextEmbeddingFunction):
key1: str
secret_key: Optional[str] = None
@staticmethod
def sensitive_keys():
return ["secret_key"]
def ndims():
pass
def generate_embeddings(self, _texts):
pass
registry = EmbeddingFunctionRegistry.get_instance()
# Should error if variable is not set
with pytest.raises(ValueError, match="Variable 'test' not found"):
registry.get("variable-testing").create(
key1="$var:test",
)
# Should use default values if not set
func = registry.get("variable-testing").create(key1="$var:test:some_value")
assert func.key1 == "some_value"
# Should set a variable that the embedding function understands
registry.set_var("test", "some_value")
func = registry.get("variable-testing").create(key1="$var:test")
assert func.key1 == "some_value"
# Should reject secrets that aren't passed in as variables
with pytest.raises(
ValueError,
match="Sensitive key 'secret_key' cannot be set to a hardcoded value",
):
registry.get("variable-testing").create(
key1="whatever", secret_key="some_value"
)
# Should not serialize secrets.
registry.set_var("secret", "secret_value")
func = registry.get("variable-testing").create(
key1="whatever", secret_key="$var:secret"
)
assert func.secret_key == "secret_value"
assert func.safe_model_dump()["secret_key"] == "$var:secret"
def test_embedding_with_bad_results(tmp_path):
@register("null-embedding")
class NullEmbeddingFunction(TextEmbeddingFunction):
@@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path):
) -> list[Union[np.array, None]]:
# Return None, which is bad if field is non-nullable
a = [
np.full(self.ndims(), np.nan)
if i % 2 == 0
else np.random.randn(self.ndims())
(
np.full(self.ndims(), np.nan)
if i % 2 == 0
else np.random.randn(self.ndims())
)
for i in range(len(texts))
]
return a
@@ -359,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type):
# Note: Some embedding types might require specific parameters
try:
model = registry.get(embedding_type).create()
model = registry.get(embedding_type).create({"max_retries": 1})
except Exception as e:
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
@@ -392,3 +446,33 @@ def test_retry(mock_sleep):
result = test_function()
assert mock_sleep.call_count == 9
assert result == "result"
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set"
)
def test_openai_propagates_api_key(monkeypatch):
# Make sure that if we set it as a variable, the API key is propagated
api_key = os.environ["OPENAI_API_KEY"]
monkeypatch.delenv("OPENAI_API_KEY")
uri = "memory://"
registry = get_registry()
registry.set_var("open_api_key", api_key)
func = registry.get("openai").create(
name="text-embedding-ada-002",
max_retries=0,
api_key="$var:open_api_key",
)
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
db = lancedb.connect(uri)
table = db.create_table("words", schema=Words, mode="overwrite")
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0