From bd0034a157ba6acd78eab72580509280a7eea29a Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:20:45 -0800 Subject: [PATCH] feat: support nested pydantic schema (#707) --- .github/workflows/python.yml | 6 +----- python/lancedb/pydantic.py | 17 +++++++++++++++++ python/lancedb/table.py | 8 +++++--- python/tests/test_table.py | 33 ++++++++++++++++++++++++++------- 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 744668c5..10892841 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -91,11 +91,7 @@ jobs: pip install "pydantic<2" pip install -e .[tests] pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 - pip install pytest pytest-mock black isort - - name: Black - run: black --check --diff --no-color --quiet . - - name: isort - run: isort --check --diff --quiet . + pip install pytest pytest-mock - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 4ca93903..caa69405 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -348,3 +348,20 @@ def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any: if PYDANTIC_VERSION.major >= 2: return (field_info.json_schema_extra or {}).get(key) return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key) + + +if PYDANTIC_VERSION.major < 2: + + def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: + """ + Convert a Pydantic model to a dictionary. + """ + return model.dict() + +else: + + def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: + """ + Convert a Pydantic model to a dictionary. + """ + return model.model_dump() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 60aac438..95e83a24 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -28,7 +28,7 @@ from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry -from .pydantic import LanceModel +from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query from .util import fs_from_uri, safe_import_pandas, value_to_sql from .utils.events import register_event @@ -53,8 +53,10 @@ def _sanitize_data( # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): schema = data[0].__class__.to_arrow_schema() - data = [dict(d) for d in data] - data = pa.Table.from_pylist(data) + data = [model_to_dict(d) for d in data] + data = pa.Table.from_pylist(data, schema=schema) + else: + data = pa.Table.from_pylist(data) elif isinstance(data, dict): data = vec_to_table(data) elif pd is not None and isinstance(data, pd.DataFrame): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 9b12d42b..cbe641cd 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -21,6 +21,7 @@ import lance import numpy as np import pandas as pd import pyarrow as pa +from pydantic import BaseModel import pytest from lancedb.conftest import MockTextEmbeddingFunction @@ -141,14 +142,32 @@ def test_add(db): def test_add_pydantic_model(db): - class TestModel(LanceModel): - vector: Vector(16) - li: List[int] + # https://github.com/lancedb/lancedb/issues/562 - data = TestModel(vector=list(range(16)), li=[1, 2, 3]) - table = LanceTable.create(db, "test", data=[data]) - assert len(table) == 1 - assert table.schema == TestModel.to_arrow_schema() + class Document(BaseModel): + content: str + source: str + + class LanceSchema(LanceModel): + id: str + vector: Vector(2) + li: List[int] + payload: Document + + tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite") + assert tbl.schema == LanceSchema.to_arrow_schema() + + # add works + expected = LanceSchema( + id="id", + vector=[0.0, 0.0], + li=[1, 2, 3], + payload=Document(content="foo", source="bar"), + ) + tbl.add([expected]) + + result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0] + assert result == expected def _add(table, schema):