From 4c9fc3044b6aa6311794069d80fdc08df73151d6 Mon Sep 17 00:00:00 2001 From: Le Duc Manh <118125824+manhld0206@users.noreply.github.com> Date: Sat, 13 Sep 2025 05:07:32 +0900 Subject: [PATCH] fix: use create to resolve variables (#2640) # What - Use `create` to resolve variables values # Reference Fixes #2181 --------- Co-authored-by: Will Jones --- python/python/lancedb/embeddings/registry.py | 2 +- python/python/tests/test_embeddings.py | 57 ++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index 10978486..4c91445a 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -122,7 +122,7 @@ class EmbeddingFunctionRegistry: obj["vector_column"]: EmbeddingFunctionConfig( vector_column=obj["vector_column"], source_column=obj["source_column"], - function=self.get(obj["name"])(**obj["model"]), + function=self.get(obj["name"]).create(**obj["model"]), ) for obj in raw_list } diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 3284f11f..2f01cf3b 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -114,6 +114,63 @@ def test_embedding_function_variables(): assert func.safe_model_dump()["secret_key"] == "$var:secret" +def test_parse_functions_with_variables(): + @register("variable-parsing-test") + class VariableParsingFunction(TextEmbeddingFunction): + api_key: str + base_url: Optional[str] = None + + @staticmethod + def sensitive_keys(): + return ["api_key"] + + def ndims(self): + return 10 + + def generate_embeddings(self, texts): + # Mock implementation that just returns random embeddings + # In real usage, this would use the api_key to call an API + return [np.random.rand(self.ndims()).tolist() for _ in texts] + + registry = EmbeddingFunctionRegistry.get_instance() + + registry.set_var("test_api_key", "sk-test-key-12345") + registry.set_var("test_base_url", "https://api.example.com") + + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="vector", + function=registry.get("variable-parsing-test").create( + api_key="$var:test_api_key", base_url="$var:test_base_url" + ), + ) + + metadata = registry.get_table_metadata([conf]) + + # Create a mock arrow table with the metadata + schema = pa.schema( + [pa.field("text", pa.string()), pa.field("vector", pa.list_(pa.float32(), 10))] + ) + table = pa.table({"text": [], "vector": []}, schema=schema) + table = table.replace_schema_metadata(metadata) + + ds = lance.write_dataset(table, "memory://") + + configs = registry.parse_functions(ds.schema.metadata) + + assert "vector" in configs + parsed_func = configs["vector"].function + + assert parsed_func.api_key == "sk-test-key-12345" + assert parsed_func.base_url == "https://api.example.com" + + embeddings = parsed_func.generate_embeddings(["test text"]) + assert len(embeddings) == 1 + assert len(embeddings[0]) == 10 + + assert parsed_func.safe_model_dump()["api_key"] == "$var:test_api_key" + + def test_embedding_with_bad_results(tmp_path): @register("null-embedding") class NullEmbeddingFunction(TextEmbeddingFunction):