diff --git a/docs/src/python/pydantic.md b/docs/src/python/pydantic.md index 92beaf76..7e44724d 100644 --- a/docs/src/python/pydantic.md +++ b/docs/src/python/pydantic.md @@ -1,6 +1,8 @@ # Pydantic [Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python. +LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting. + ## Schema diff --git a/python/lancedb/common.py b/python/lancedb/common.py index 47d0bc43..f50451a5 100644 --- a/python/lancedb/common.py +++ b/python/lancedb/common.py @@ -17,11 +17,13 @@ import numpy as np import pandas as pd import pyarrow as pa +from .pydantic import LanceModel + VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] URI = Union[str, Path] # TODO support generator -DATA = Union[List[dict], dict, pd.DataFrame] +DATA = Union[List[dict], List[LanceModel], dict, pd.DataFrame] VECTOR_COLUMN_NAME = "vector" diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 85564d3a..548a1c2d 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -279,7 +279,7 @@ class LanceDBConnection(DBConnection): def create_table( self, name: str, - data: Optional[Union[List[dict], dict, pd.DataFrame]] = None, + data: Optional[DATA] = None, schema: pa.Schema = None, mode: str = "create", on_bad_vectors: str = "error", diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index ed55e677..0b6e1fbd 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -249,3 +249,36 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema: """ fields = _pydantic_model_to_fields(model) return pa.schema(fields) + + +class LanceModel(pydantic.BaseModel): + """ + A Pydantic Model base class that can be converted to a LanceDB Table. + + Examples + -------- + >>> import lancedb + >>> from lancedb.pydantic import LanceModel, vector + >>> + >>> class TestModel(LanceModel): + ... name: str + ... vector: vector(2) + ... + >>> db = lancedb.connect("/tmp") + >>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) + >>> table.add([ + ... TestModel(name="test", vector=[1.0, 2.0]) + ... ]) + >>> table.search([0., 0.]).limit(1).to_pydantic(TestModel) + [TestModel(name='test', vector=FixedSizeList(dim=2))] + """ + + @classmethod + def to_arrow_schema(cls): + return pydantic_to_schema(cls) + + @classmethod + def field_names(cls) -> List[str]: + if PYDANTIC_VERSION.major < 2: + return list(cls.__fields__.keys()) + return list(cls.model_fields.keys()) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index a96f6682..d2fd5afd 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -13,17 +13,18 @@ from __future__ import annotations -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Type, Union import numpy as np import pandas as pd import pyarrow as pa -from pydantic import BaseModel +import pydantic from .common import VECTOR_COLUMN_NAME +from .pydantic import LanceModel -class Query(BaseModel): +class Query(pydantic.BaseModel): """A Query""" vector_column: str = VECTOR_COLUMN_NAME @@ -230,6 +231,23 @@ class LanceQueryBuilder: ) return self._table._execute_query(query) + def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]: + """Return the table as a list of pydantic models. + + Parameters + ---------- + model: Type[LanceModel] + The pydantic model to use. + + Returns + ------- + List[LanceModel] + """ + return [ + model(**{k: v for k, v in row.items() if k in model.field_names()}) + for row in self.to_arrow().to_pylist() + ] + class LanceFtsQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pd.Table: diff --git a/python/lancedb/table.py b/python/lancedb/table.py index bb6e2854..5be962b9 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -27,12 +27,17 @@ from lance import LanceDataset from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME +from .pydantic import LanceModel from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query from .util import fs_from_uri def _sanitize_data(data, schema, on_bad_vectors, fill_value): if isinstance(data, list): + # convert to list of dict if data is a bunch of LanceModels + if isinstance(data[0], LanceModel): + schema = data[0].__class__.to_arrow_schema() + data = [dict(d) for d in data] data = pa.Table.from_pylist(data) data = _sanitize_schema( data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index c38262db..2e33dc22 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -20,7 +20,7 @@ import pyarrow as pa import pydantic import pytest -from lancedb.pydantic import PYDANTIC_VERSION, pydantic_to_schema, vector +from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector @pytest.mark.skipif( @@ -163,3 +163,13 @@ def test_fixed_size_list_validation(): TestModel(vec=range(7)) TestModel(vec=range(8)) + + +def test_lance_model(): + class TestModel(LanceModel): + vec: vector(16) + li: List[int] + + schema = pydantic_to_schema(TestModel) + assert schema == TestModel.to_arrow_schema() + assert TestModel.field_names() == ["vec", "li"] diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 8e4678e1..21646111 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -20,6 +20,7 @@ import pyarrow as pa import pytest from lancedb.db import LanceDBConnection +from lancedb.pydantic import LanceModel, vector from lancedb.query import LanceQueryBuilder, Query from lancedb.table import LanceTable @@ -64,6 +65,24 @@ def table(tmp_path) -> MockTable: return MockTable(tmp_path) +def test_cast(table): + class TestModel(LanceModel): + vector: vector(2) + id: int + str_field: str + float_field: float + + q = LanceQueryBuilder(table, [0, 0], "vector").limit(1) + results = q.to_pydantic(TestModel) + assert len(results) == 1 + r0 = results[0] + assert isinstance(r0, TestModel) + assert r0.id == 1 + assert r0.vector == [1, 2] + assert r0.str_field == "a" + assert r0.float_field == 1.0 + + def test_query_builder(table): df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df() assert df["id"].values[0] == 1 diff --git a/python/tests/test_table.py b/python/tests/test_table.py index f7239590..8b892f75 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -13,15 +13,16 @@ import functools from pathlib import Path +from typing import List from unittest.mock import PropertyMock, patch import numpy as np import pandas as pd import pyarrow as pa import pytest -from lance.vector import vec_to_table from lancedb.db import LanceDBConnection +from lancedb.pydantic import LanceModel, vector from lancedb.table import LanceTable @@ -135,6 +136,17 @@ def test_add(db): _add(table, schema) +def test_add_pydantic_model(db): + class TestModel(LanceModel): + vector: vector(16) + li: List[int] + + data = TestModel(vector=list(range(16)), li=[1, 2, 3]) + table = LanceTable.create(db, "test", data=[data]) + assert len(table) == 1 + assert table.schema == TestModel.to_arrow_schema() + + def _add(table, schema): # table = LanceTable(db, "test") assert len(table) == 2