mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
multi-modal embedding-function (#484)
This commit is contained in:
@@ -22,8 +22,9 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -356,20 +357,23 @@ def test_create_with_embedding_function(db):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
func = MockTextEmbeddingFunction()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
embedding_functions=[conf],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
@@ -377,17 +381,13 @@ def test_create_with_embedding_function(db):
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user