mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-13 15:22:57 +00:00
Improve pydantic integration (#384)
This commit is contained in:
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, pydantic_to_schema, vector
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -163,3 +163,13 @@ def test_fixed_size_list_validation():
|
||||
TestModel(vec=range(7))
|
||||
|
||||
TestModel(vec=range(8))
|
||||
|
||||
|
||||
def test_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
vec: vector(16)
|
||||
li: List[int]
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["vec", "li"]
|
||||
|
||||
@@ -20,6 +20,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.query import LanceQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -64,6 +65,24 @@ def table(tmp_path) -> MockTable:
|
||||
return MockTable(tmp_path)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(2)
|
||||
id: int
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
results = q.to_pydantic(TestModel)
|
||||
assert len(results) == 1
|
||||
r0 = results[0]
|
||||
assert isinstance(r0, TestModel)
|
||||
assert r0.id == 1
|
||||
assert r0.vector == [1, 2]
|
||||
assert r0.str_field == "a"
|
||||
assert r0.float_field == 1.0
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
assert df["id"].values[0] == 1
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -135,6 +136,17 @@ def test_add(db):
|
||||
_add(table, schema)
|
||||
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
table = LanceTable.create(db, "test", data=[data])
|
||||
assert len(table) == 1
|
||||
assert table.schema == TestModel.to_arrow_schema()
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
Reference in New Issue
Block a user