[python] Use pydantic for embedding function persistence (#467)

1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Chang She
2023-09-05 21:30:45 -07:00
committed by GitHub
parent 52fa7f5577
commit 9a9a73a65d
13 changed files with 815 additions and 192 deletions

View File

@@ -144,7 +144,7 @@ def test_ingest_iterator(tmp_path):
tbl_len = len(tbl)
tbl.add(make_batches())
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 2
assert len(tbl.list_versions()) == 3
db.drop_database()
run_tests(arrow_schema)

View File

@@ -12,10 +12,12 @@
# limitations under the License.
import sys
import lance
import numpy as np
import pyarrow as pa
from lancedb.embeddings import with_embeddings
from lancedb.conftest import MockEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
def mock_embed_func(input_data):
@@ -40,3 +42,37 @@ def test_with_embeddings():
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()
# let's create a table
table = pa.table(
{
"text": pa.array(["hello world", "goodbye world"]),
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
metadata = registry.get_table_metadata([func])
table = table.replace_schema_metadata(metadata)
# Write it to disk
lance.write_dataset(table, tmp_path / "test.lance")
# Load this back
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
functions = registry.parse_functions(ds.schema.metadata)
func = functions["vector"]
actual = func("hello world")
# We create an instance
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
# And we make sure we can call it
expected = expected_func("hello world")
assert np.allclose(actual, expected)

View File

@@ -21,7 +21,7 @@ import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.query import LanceQueryBuilder, Query
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -72,7 +72,7 @@ def test_cast(table):
str_field: str
float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel)
assert len(results) == 1
r0 = results[0]
@@ -84,13 +84,15 @@ def test_cast(table):
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
)
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
)
tm.assert_frame_equal(df_default, df_l2)
df_cosine = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])

View File

@@ -22,6 +22,7 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.table import LanceTable
@@ -178,16 +179,16 @@ def test_versioning(db):
],
)
assert len(table.list_versions()) == 1
assert table.version == 1
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 2
assert table.version == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 3
table.checkout(1)
assert table.version == 1
table.checkout(2)
assert table.version == 2
assert len(table) == 2
@@ -278,21 +279,21 @@ def test_restore(db):
data=[{"vector": [1.1, 0.9], "type": "vector"}],
)
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(1)
assert len(table.list_versions()) == 3
table.restore(2)
assert len(table.list_versions()) == 4
assert len(table) == 1
expected = table.to_arrow()
table.checkout(1)
table.checkout(2)
table.restore()
assert len(table.list_versions()) == 4
assert len(table.list_versions()) == 5
assert table.to_arrow() == expected
table.restore(4) # latest version should be no-op
assert len(table.list_versions()) == 4
table.restore(5) # latest version should be no-op
assert len(table.list_versions()) == 5
with pytest.raises(ValueError):
table.restore(5)
table.restore(6)
with pytest.raises(ValueError):
table.restore(0)
@@ -306,7 +307,7 @@ def test_merge(db, tmp_path):
)
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
table.merge(other_table, left_on="id")
assert len(table.list_versions()) == 2
assert len(table.list_versions()) == 3
expected = pa.table(
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
schema=table.schema,
@@ -325,10 +326,10 @@ def test_delete(db):
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
table.delete("id=0")
assert len(table.list_versions()) == 2
assert table.version == 2
table.delete("id=0")
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
@@ -340,11 +341,103 @@ def test_update(db):
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table.list_versions()) == 4
assert table.version == 4
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func(texts)})
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: vector(10)
vector2: vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_empty_query(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
val = df.id.iloc[0]
assert val == 1