Improve pydantic integration (#384)

This commit is contained in:
Chang She
2023-07-31 12:16:44 -04:00
committed by GitHub
parent 2d25c263e9
commit cada35d5b7
9 changed files with 108 additions and 7 deletions

View File

@@ -17,11 +17,13 @@ import numpy as np
import pandas as pd
import pyarrow as pa
from .pydantic import LanceModel
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
URI = Union[str, Path]
# TODO support generator
DATA = Union[List[dict], dict, pd.DataFrame]
DATA = Union[List[dict], List[LanceModel], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector"

View File

@@ -279,7 +279,7 @@ class LanceDBConnection(DBConnection):
def create_table(
self,
name: str,
data: Optional[Union[List[dict], dict, pd.DataFrame]] = None,
data: Optional[DATA] = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",

View File

@@ -249,3 +249,36 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
"""
fields = _pydantic_model_to_fields(model)
return pa.schema(fields)
class LanceModel(pydantic.BaseModel):
"""
A Pydantic Model base class that can be converted to a LanceDB Table.
Examples
--------
>>> import lancedb
>>> from lancedb.pydantic import LanceModel, vector
>>>
>>> class TestModel(LanceModel):
... name: str
... vector: vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table.add([
... TestModel(name="test", vector=[1.0, 2.0])
... ])
>>> table.search([0., 0.]).limit(1).to_pydantic(TestModel)
[TestModel(name='test', vector=FixedSizeList(dim=2))]
"""
@classmethod
def to_arrow_schema(cls):
return pydantic_to_schema(cls)
@classmethod
def field_names(cls) -> List[str]:
if PYDANTIC_VERSION.major < 2:
return list(cls.__fields__.keys())
return list(cls.model_fields.keys())

View File

@@ -13,17 +13,18 @@
from __future__ import annotations
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Type, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel
import pydantic
from .common import VECTOR_COLUMN_NAME
from .pydantic import LanceModel
class Query(BaseModel):
class Query(pydantic.BaseModel):
"""A Query"""
vector_column: str = VECTOR_COLUMN_NAME
@@ -230,6 +231,23 @@ class LanceQueryBuilder:
)
return self._table._execute_query(query)
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
class LanceFtsQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pd.Table:

View File

@@ -27,12 +27,17 @@ from lance import LanceDataset
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
from .util import fs_from_uri
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
data = [dict(d) for d in data]
data = pa.Table.from_pylist(data)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value

View File

@@ -20,7 +20,7 @@ import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import PYDANTIC_VERSION, pydantic_to_schema, vector
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
@pytest.mark.skipif(
@@ -163,3 +163,13 @@ def test_fixed_size_list_validation():
TestModel(vec=range(7))
TestModel(vec=range(8))
def test_lance_model():
class TestModel(LanceModel):
vec: vector(16)
li: List[int]
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vec", "li"]

View File

@@ -20,6 +20,7 @@ import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.query import LanceQueryBuilder, Query
from lancedb.table import LanceTable
@@ -64,6 +65,24 @@ def table(tmp_path) -> MockTable:
return MockTable(tmp_path)
def test_cast(table):
class TestModel(LanceModel):
vector: vector(2)
id: int
str_field: str
float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel)
assert len(results) == 1
r0 = results[0]
assert isinstance(r0, TestModel)
assert r0.id == 1
assert r0.vector == [1, 2]
assert r0.str_field == "a"
assert r0.float_field == 1.0
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
assert df["id"].values[0] == 1

View File

@@ -13,15 +13,16 @@
import functools
from pathlib import Path
from typing import List
from unittest.mock import PropertyMock, patch
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lance.vector import vec_to_table
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.table import LanceTable
@@ -135,6 +136,17 @@ def test_add(db):
_add(table, schema)
def test_add_pydantic_model(db):
class TestModel(LanceModel):
vector: vector(16)
li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
table = LanceTable.create(db, "test", data=[data])
assert len(table) == 1
assert table.schema == TestModel.to_arrow_schema()
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2