Files
lancedb/python/python/tests/docs/test_embeddings_optional.py
Will Jones 7ac5f74c80 feat!: add variable store to embeddings registry (#2112)
BREAKING CHANGE: embedding function implementations in Node need to now
call `resolveVariables()` in their constructors and should **not**
implement `toJSON()`.

This tries to address the handling of secrets. In Node, they are
currently lost. In Python, they are currently leaked into the table
schema metadata.

This PR introduces an in-memory variable store on the function registry.
It also allows embedding function definitions to label certain config
values as "sensitive", and the preprocessing logic will raise an error
if users try to pass in hard-coded values.

Closes #2110
Closes #521

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2025-02-24 15:52:19 -08:00

77 lines
2.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import lancedb
# --8<-- [start:imports]
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
# --8<-- [end:imports]
import pytest
@pytest.mark.slow
def test_embeddings_openai():
# --8<-- [start:openai_embeddings]
db = lancedb.connect("/tmp/db")
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words, mode="overwrite")
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:openai_embeddings]
@pytest.mark.slow
@pytest.mark.asyncio
async def test_embeddings_openai_async():
uri = "memory://"
# --8<-- [start:async_openai_embeddings]
db = await lancedb.connect_async(uri)
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = await db.create_table("words", schema=Words, mode="overwrite")
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:async_openai_embeddings]
def test_embeddings_secret():
# --8<-- [start:register_secret]
registry = get_registry()
registry.set_var("api_key", "sk-...")
func = registry.get("openai").create(api_key="$var:api_key")
# --8<-- [end:register_secret]
try:
import torch
except ImportError:
pytest.skip("torch not installed")
# --8<-- [start:register_device]
import torch
registry = get_registry()
if torch.cuda.is_available():
registry.set_var("device", "cuda")
func = registry.get("huggingface").create(device="$var:device:cpu")
# --8<-- [end:register_device]
assert func.device == "cuda" if torch.cuda.is_available() else "cpu"