[Python]Embeddings API refactor (#580)

Sets things up for this -> https://github.com/lancedb/lancedb/issues/579
- Just separates out the registry/ingestion code from the function
implementation code
- adds a `get_registry` util
- package name "open-clip" -> "open-clip-torch"
This commit is contained in:
Ayush Chaurasia
2023-10-18 11:02:19 +05:30
committed by GitHub
parent 7372656369
commit 0293bbe142
12 changed files with 636 additions and 602 deletions

View File

@@ -19,7 +19,7 @@ import pytest
import requests
import lancedb
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
# These are integration tests for embedding functions.
@@ -31,7 +31,7 @@ from lancedb.pydantic import LanceModel, Vector
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
registry = get_registry()
func = registry.get(alias).create()
class Words(LanceModel):
@@ -69,7 +69,7 @@ def test_openclip(tmp_path):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
registry = get_registry()
func = registry.get("open-clip").create()
class Images(LanceModel):
@@ -131,11 +131,7 @@ def test_openclip(tmp_path):
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
EmbeddingFunctionRegistry.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
)
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()