[python] Use pydantic for embedding function persistence (#467)

1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Chang She
2023-09-05 21:30:45 -07:00
committed by GitHub
parent 52fa7f5577
commit 9a9a73a65d
13 changed files with 815 additions and 192 deletions

View File

@@ -21,7 +21,7 @@ import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.query import LanceQueryBuilder, Query
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -72,7 +72,7 @@ def test_cast(table):
str_field: str
float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel)
assert len(results) == 1
r0 = results[0]
@@ -84,13 +84,15 @@ def test_cast(table):
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
)
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
)
tm.assert_frame_equal(df_default, df_l2)
df_cosine = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceQueryBuilder(table, query, vector_column_name)
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])