multi-modal embedding-function (#484)

This commit is contained in:
Chang She
2023-09-16 21:23:51 -04:00
committed by GitHub
parent 9585f550b3
commit 31dad71c94
13 changed files with 645 additions and 143 deletions

View File

@@ -16,8 +16,12 @@ import lance
import numpy as np
import pyarrow as pa
from lancedb.conftest import MockEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
def mock_embed_func(input_data):
@@ -54,8 +58,12 @@ def test_embedding_function(tmp_path):
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
metadata = registry.get_table_metadata([func])
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
table = table.replace_schema_metadata(metadata)
# Write it to disk
@@ -65,14 +73,13 @@ def test_embedding_function(tmp_path):
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
functions = registry.parse_functions(ds.schema.metadata)
configs = registry.parse_functions(ds.schema.metadata)
func = functions["vector"]
actual = func("hello world")
conf = configs["vector"]
func = conf.function
actual = func.compute_query_embeddings("hello world")
# We create an instance
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
# And we make sure we can call it
expected = expected_func("hello world")
expected = func.compute_query_embeddings("hello world")
assert np.allclose(actual, expected)