mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat!: add variable store to embeddings registry (#2112)
BREAKING CHANGE: embedding function implementations in Node need to now call `resolveVariables()` in their constructors and should **not** implement `toJSON()`. This tries to address the handling of secrets. In Node, they are currently lost. In Python, they are currently leaked into the table schema metadata. This PR introduces an in-memory variable store on the function registry. It also allows embedding function definitions to label certain config values as "sensitive", and the preprocessing logic will raise an error if users try to pass in hard-coded values. Closes #2110 Closes #521 --------- Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
from typing import List, Union
|
||||
|
||||
from lancedb.util import add_note
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
_original_args: dict = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
resolved_kwargs = cls.__resolveVariables(kwargs)
|
||||
instance = cls(**resolved_kwargs)
|
||||
instance._original_args = kwargs
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def __resolveVariables(cls, args: dict) -> dict:
|
||||
"""
|
||||
Resolve variables in the args
|
||||
"""
|
||||
from .registry import EmbeddingFunctionRegistry
|
||||
|
||||
new_args = copy.deepcopy(args)
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
sensitive_keys = cls.sensitive_keys()
|
||||
for k, v in new_args.items():
|
||||
if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys:
|
||||
exc = ValueError(
|
||||
f"Sensitive key '{k}' cannot be set to a hardcoded value"
|
||||
)
|
||||
add_note(exc, "Help: Use $var: to set sensitive keys to variables")
|
||||
raise exc
|
||||
|
||||
if isinstance(v, str) and v.startswith("$var:"):
|
||||
parts = v[5:].split(":", maxsplit=1)
|
||||
if len(parts) == 1:
|
||||
try:
|
||||
new_args[k] = registry.get_var(parts[0])
|
||||
except KeyError:
|
||||
exc = ValueError(
|
||||
"Variable '{}' not found in registry".format(parts[0])
|
||||
)
|
||||
add_note(
|
||||
exc,
|
||||
"Help: Variables are reset in new Python sessions. "
|
||||
"Use `registry.set_var` to set variables.",
|
||||
)
|
||||
raise exc
|
||||
else:
|
||||
name, default = parts
|
||||
try:
|
||||
new_args[k] = registry.get_var(name)
|
||||
except KeyError:
|
||||
new_args[k] = default
|
||||
return new_args
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys() -> List[str]:
|
||||
"""
|
||||
Return a list of keys that are sensitive and should not be allowed
|
||||
to be set to hardcoded values in the config. For example, API keys.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
@@ -103,17 +159,11 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
return texts
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
return self.model_dump(
|
||||
exclude={
|
||||
field_name
|
||||
for field_name in self.model_fields
|
||||
if field_name.startswith("_")
|
||||
}
|
||||
)
|
||||
if not hasattr(self, "_original_args"):
|
||||
raise ValueError(
|
||||
"EmbeddingFunction was not created with EmbeddingFunction.create()"
|
||||
)
|
||||
return self._original_args
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self) -> int:
|
||||
|
||||
@@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys() -> List[str]:
|
||||
return ["api_key"]
|
||||
|
||||
def sanitize_input(
|
||||
self, inputs: Union[TEXT, IMAGES]
|
||||
) -> Union[List[Any], np.ndarray]:
|
||||
|
||||
@@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["api_key"]
|
||||
|
||||
@staticmethod
|
||||
def model_names():
|
||||
return [
|
||||
|
||||
@@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry:
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
self._variables = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
@@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry:
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
def set_var(self, name: str, value: str) -> None:
|
||||
"""
|
||||
Set a variable. These can be accessed in embedding configuration using
|
||||
the syntax `$var:variable_name`. If they are not set, an error will be
|
||||
thrown letting you know which variable is missing. If you want to supply
|
||||
a default value, you can add an additional part in the configuration
|
||||
like so: `$var:variable_name:default_value`. Default values can be
|
||||
used for runtime configurations that are not sensitive, such as
|
||||
whether to use a GPU for inference.
|
||||
|
||||
The name must not contain a colon. Default values can contain colons.
|
||||
"""
|
||||
if ":" in name:
|
||||
raise ValueError("Variable names cannot contain colons")
|
||||
self._variables[name] = value
|
||||
|
||||
def get_var(self, name: str) -> str:
|
||||
"""
|
||||
Get a variable.
|
||||
"""
|
||||
return self._variables[name]
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
@@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction):
|
||||
url: Optional[str] = None
|
||||
params: Optional[Dict] = None
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["api_key"]
|
||||
|
||||
@staticmethod
|
||||
def model_names():
|
||||
return [
|
||||
|
||||
@@ -49,3 +49,28 @@ async def test_embeddings_openai_async():
|
||||
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:async_openai_embeddings]
|
||||
|
||||
|
||||
def test_embeddings_secret():
|
||||
# --8<-- [start:register_secret]
|
||||
registry = get_registry()
|
||||
registry.set_var("api_key", "sk-...")
|
||||
|
||||
func = registry.get("openai").create(api_key="$var:api_key")
|
||||
# --8<-- [end:register_secret]
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pytest.skip("torch not installed")
|
||||
|
||||
# --8<-- [start:register_device]
|
||||
import torch
|
||||
|
||||
registry = get_registry()
|
||||
if torch.cuda.is_available():
|
||||
registry.set_var("device", "cuda")
|
||||
|
||||
func = registry.get("huggingface").create(device="$var:device:cpu")
|
||||
# --8<-- [end:register_device]
|
||||
assert func.device == "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import lance
|
||||
@@ -56,7 +57,7 @@ def test_embedding_function(tmp_path):
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
@@ -80,6 +81,57 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_function_variables():
|
||||
@register("variable-testing")
|
||||
class VariableTestingFunction(TextEmbeddingFunction):
|
||||
key1: str
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["secret_key"]
|
||||
|
||||
def ndims():
|
||||
pass
|
||||
|
||||
def generate_embeddings(self, _texts):
|
||||
pass
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# Should error if variable is not set
|
||||
with pytest.raises(ValueError, match="Variable 'test' not found"):
|
||||
registry.get("variable-testing").create(
|
||||
key1="$var:test",
|
||||
)
|
||||
|
||||
# Should use default values if not set
|
||||
func = registry.get("variable-testing").create(key1="$var:test:some_value")
|
||||
assert func.key1 == "some_value"
|
||||
|
||||
# Should set a variable that the embedding function understands
|
||||
registry.set_var("test", "some_value")
|
||||
func = registry.get("variable-testing").create(key1="$var:test")
|
||||
assert func.key1 == "some_value"
|
||||
|
||||
# Should reject secrets that aren't passed in as variables
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Sensitive key 'secret_key' cannot be set to a hardcoded value",
|
||||
):
|
||||
registry.get("variable-testing").create(
|
||||
key1="whatever", secret_key="some_value"
|
||||
)
|
||||
|
||||
# Should not serialize secrets.
|
||||
registry.set_var("secret", "secret_value")
|
||||
func = registry.get("variable-testing").create(
|
||||
key1="whatever", secret_key="$var:secret"
|
||||
)
|
||||
assert func.secret_key == "secret_value"
|
||||
assert func.safe_model_dump()["secret_key"] == "$var:secret"
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("null-embedding")
|
||||
class NullEmbeddingFunction(TextEmbeddingFunction):
|
||||
@@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
) -> list[Union[np.array, None]]:
|
||||
# Return None, which is bad if field is non-nullable
|
||||
a = [
|
||||
np.full(self.ndims(), np.nan)
|
||||
if i % 2 == 0
|
||||
else np.random.randn(self.ndims())
|
||||
(
|
||||
np.full(self.ndims(), np.nan)
|
||||
if i % 2 == 0
|
||||
else np.random.randn(self.ndims())
|
||||
)
|
||||
for i in range(len(texts))
|
||||
]
|
||||
return a
|
||||
@@ -359,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type):
|
||||
|
||||
# Note: Some embedding types might require specific parameters
|
||||
try:
|
||||
model = registry.get(embedding_type).create()
|
||||
model = registry.get(embedding_type).create({"max_retries": 1})
|
||||
except Exception as e:
|
||||
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
|
||||
|
||||
@@ -392,3 +446,33 @@ def test_retry(mock_sleep):
|
||||
result = test_function()
|
||||
assert mock_sleep.call_count == 9
|
||||
assert result == "result"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set"
|
||||
)
|
||||
def test_openai_propagates_api_key(monkeypatch):
|
||||
# Make sure that if we set it as a variable, the API key is propagated
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
uri = "memory://"
|
||||
registry = get_registry()
|
||||
registry.set_var("open_api_key", api_key)
|
||||
func = registry.get("openai").create(
|
||||
name="text-embedding-ada-002",
|
||||
max_retries=0,
|
||||
api_key="$var:open_api_key",
|
||||
)
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
db = lancedb.connect(uri)
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
assert len(actual.text) > 0
|
||||
|
||||
@@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts")
|
||||
def get_test_table(tmp_path, use_tantivy):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
os.environ.get("OPENAI_API_KEY") is None
|
||||
or os.environ.get("OPENAI_BASE_URL") is not None,
|
||||
reason="OPENAI_API_KEY not set",
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_openai_reranker(tmp_path, use_tantivy):
|
||||
|
||||
@@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockTextEmbeddingFunction()
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
@@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_add_with_embedding_function(mem_db: DBConnection):
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection):
|
||||
|
||||
def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
|
||||
@@ -127,7 +127,7 @@ def test_append_vector_columns():
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
|
||||
@@ -434,7 +434,7 @@ def test_sanitize_data(
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user