feat: support nested pydantic schema (#707)

This commit is contained in:
Chang She
2023-12-14 18:20:45 -08:00
committed by Weston Pace
parent e52f691420
commit 374a6f7e78
4 changed files with 49 additions and 15 deletions

View File

@@ -91,11 +91,7 @@ jobs:
pip install "pydantic<2"
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock black isort
- name: Black
run: black --check --diff --no-color --quiet .
- name: isort
run: isort --check --diff --quiet .
pip install pytest pytest-mock
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest

View File

@@ -348,3 +348,20 @@ def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
if PYDANTIC_VERSION.major < 2:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.dict()
else:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()

View File

@@ -28,7 +28,7 @@ from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql
@@ -52,8 +52,10 @@ def _sanitize_data(
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
data = [dict(d) for d in data]
data = pa.Table.from_pylist(data)
data = [model_to_dict(d) for d in data]
data = pa.Table.from_pylist(data, schema=schema)
else:
data = pa.Table.from_pylist(data)
elif isinstance(data, dict):
data = vec_to_table(data)
elif pd is not None and isinstance(data, pd.DataFrame):

View File

@@ -21,6 +21,7 @@ import lance
import numpy as np
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel
import pytest
from lancedb.conftest import MockTextEmbeddingFunction
@@ -141,14 +142,32 @@ def test_add(db):
def test_add_pydantic_model(db):
class TestModel(LanceModel):
vector: Vector(16)
li: List[int]
# https://github.com/lancedb/lancedb/issues/562
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
table = LanceTable.create(db, "test", data=[data])
assert len(table) == 1
assert table.schema == TestModel.to_arrow_schema()
class Document(BaseModel):
content: str
source: str
class LanceSchema(LanceModel):
id: str
vector: Vector(2)
li: List[int]
payload: Document
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
assert tbl.schema == LanceSchema.to_arrow_schema()
# add works
expected = LanceSchema(
id="id",
vector=[0.0, 0.0],
li=[1, 2, 3],
payload=Document(content="foo", source="bar"),
)
tbl.add([expected])
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
assert result == expected
def _add(table, schema):