mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
[python] Use pydantic for embedding function persistence (#467)
1. Support persistent embedding function so users can just search using query string 2. Add fixed size list conversion for multiple vector columns 3. Add support for empty query (just apply select/where/limit). 4. Refactor and simplify some of the data prep code --------- Co-authored-by: Chang She <chang@lancedb.com> Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -144,7 +144,7 @@ def test_ingest_iterator(tmp_path):
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.embeddings import with_embeddings
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -40,3 +42,37 @@ def test_with_embeddings():
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
|
||||
|
||||
def test_embedding_function(tmp_path):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# let's create a table
|
||||
table = pa.table(
|
||||
{
|
||||
"text": pa.array(["hello world", "goodbye world"]),
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
metadata = registry.get_table_metadata([func])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
lance.write_dataset(table, tmp_path / "test.lance")
|
||||
|
||||
# Load this back
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
functions = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
func = functions["vector"]
|
||||
actual = func("hello world")
|
||||
|
||||
# We create an instance
|
||||
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
# And we make sure we can call it
|
||||
expected = expected_func("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.query import LanceQueryBuilder, Query
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def test_cast(table):
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
results = q.to_pydantic(TestModel)
|
||||
assert len(results) == 1
|
||||
r0 = results[0]
|
||||
@@ -84,13 +84,15 @@ def test_cast(table):
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
)
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
|
||||
@@ -22,6 +22,7 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.table import LanceTable
|
||||
@@ -178,16 +179,16 @@ def test_versioning(db):
|
||||
],
|
||||
)
|
||||
|
||||
assert len(table.list_versions()) == 1
|
||||
assert table.version == 1
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 3
|
||||
|
||||
table.checkout(1)
|
||||
assert table.version == 1
|
||||
table.checkout(2)
|
||||
assert table.version == 2
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
@@ -278,21 +279,21 @@ def test_restore(db):
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(1)
|
||||
assert len(table.list_versions()) == 3
|
||||
table.restore(2)
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table) == 1
|
||||
|
||||
expected = table.to_arrow()
|
||||
table.checkout(1)
|
||||
table.checkout(2)
|
||||
table.restore()
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table.list_versions()) == 5
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
table.restore(4) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 4
|
||||
table.restore(5) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 5
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(5)
|
||||
table.restore(6)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(0)
|
||||
@@ -306,7 +307,7 @@ def test_merge(db, tmp_path):
|
||||
)
|
||||
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
|
||||
table.merge(other_table, left_on="id")
|
||||
assert len(table.list_versions()) == 2
|
||||
assert len(table.list_versions()) == 3
|
||||
expected = pa.table(
|
||||
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
|
||||
schema=table.schema,
|
||||
@@ -325,10 +326,10 @@ def test_delete(db):
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
|
||||
@@ -340,11 +341,103 @@ def test_update(db):
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table.list_versions()) == 4
|
||||
assert table.version == 4
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: vector(10)
|
||||
vector2: vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_empty_query(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
|
||||
val = df.id.iloc[0]
|
||||
assert val == 1
|
||||
|
||||
Reference in New Issue
Block a user