mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
fix: use create to resolve variables (#2640)
# What - Use `create` to resolve variables values # Reference Fixes #2181 --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -122,7 +122,7 @@ class EmbeddingFunctionRegistry:
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
function=self.get(obj["name"]).create(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
@@ -114,6 +114,63 @@ def test_embedding_function_variables():
|
||||
assert func.safe_model_dump()["secret_key"] == "$var:secret"
|
||||
|
||||
|
||||
def test_parse_functions_with_variables():
|
||||
@register("variable-parsing-test")
|
||||
class VariableParsingFunction(TextEmbeddingFunction):
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["api_key"]
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
# Mock implementation that just returns random embeddings
|
||||
# In real usage, this would use the api_key to call an API
|
||||
return [np.random.rand(self.ndims()).tolist() for _ in texts]
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
registry.set_var("test_api_key", "sk-test-key-12345")
|
||||
registry.set_var("test_base_url", "https://api.example.com")
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=registry.get("variable-parsing-test").create(
|
||||
api_key="$var:test_api_key", base_url="$var:test_base_url"
|
||||
),
|
||||
)
|
||||
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
|
||||
# Create a mock arrow table with the metadata
|
||||
schema = pa.schema(
|
||||
[pa.field("text", pa.string()), pa.field("vector", pa.list_(pa.float32(), 10))]
|
||||
)
|
||||
table = pa.table({"text": [], "vector": []}, schema=schema)
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
ds = lance.write_dataset(table, "memory://")
|
||||
|
||||
configs = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
assert "vector" in configs
|
||||
parsed_func = configs["vector"].function
|
||||
|
||||
assert parsed_func.api_key == "sk-test-key-12345"
|
||||
assert parsed_func.base_url == "https://api.example.com"
|
||||
|
||||
embeddings = parsed_func.generate_embeddings(["test text"])
|
||||
assert len(embeddings) == 1
|
||||
assert len(embeddings[0]) == 10
|
||||
|
||||
assert parsed_func.safe_model_dump()["api_key"] == "$var:test_api_key"
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("null-embedding")
|
||||
class NullEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
Reference in New Issue
Block a user