Automatically convert pydantic model (#400)

Saves users from having to explicitly call
`LanceModel.to_arrow_schema()` when creating an empty table.
See new docs for full details.

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-08-06 14:50:03 -07:00
committed by GitHub
parent 8f7264f81d
commit a54d1e5618
4 changed files with 51 additions and 14 deletions

View File

@@ -22,6 +22,7 @@ import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
@@ -39,7 +40,7 @@ class DBConnection(ABC):
self,
name: str,
data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None,
schema: Optional[pa.Schema, LanceModel] = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
@@ -52,7 +53,7 @@ class DBConnection(ABC):
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to initialize the table. User must provide at least one of `data` or `schema`.
schema: pyarrow.Schema; optional
schema: pyarrow.Schema or LanceModel; optional
The schema of the table.
mode: str; default "create"
The mode to use when creating the table. Can be either "create" or "overwrite".
@@ -277,7 +278,7 @@ class LanceDBConnection(DBConnection):
self,
name: str,
data: Optional[DATA] = None,
schema: pa.Schema = None,
schema: Optional[pa.Schema, LanceModel] = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,

View File

@@ -13,6 +13,7 @@
from __future__ import annotations
import inspect
import os
from abc import ABC, abstractmethod
from functools import cached_property
@@ -506,7 +507,7 @@ class LanceTable(Table):
data: list-of-dict, dict, pd.DataFrame, default None
The data to insert into the table.
At least one of `data` or `schema` must be provided.
schema: dict, optional
schema: pa.Schema or LanceModel, optional
The schema of the table. If not provided, the schema is inferred from the data.
At least one of `data` or `schema` must be provided.
mode: str, default "create"
@@ -519,6 +520,8 @@ class LanceTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
schema = schema.to_arrow_schema()
if data is not None:
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value