multi-modal embedding-function (#484)

This commit is contained in:
Chang She
2023-09-16 21:23:51 -04:00
committed by GitHub
parent 9585f550b3
commit 31dad71c94
13 changed files with 645 additions and 143 deletions

View File

@@ -22,8 +22,9 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable
@@ -356,20 +357,23 @@ def test_create_with_embedding_function(db):
text: str
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
func = MockTextEmbeddingFunction()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func(texts)})
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
embedding_functions=[conf],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func(query_str)[0]
query_vector = func.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
@@ -377,17 +381,13 @@ def test_create_with_embedding_function(db):
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims) = emb.VectorField()
table = LanceTable.create(db, "my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = func(query_str)[0]
query_vector = emb.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()