# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors import json from datetime import date, datetime from enum import Enum from typing import List, Optional, Tuple import pyarrow as pa import pydantic import pytest from lancedb.pydantic import ( PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema, MultiVector, ) from pydantic import BaseModel from pydantic import Field def test_pydantic_to_arrow(): class StructModel(pydantic.BaseModel): a: str b: Optional[float] class TestModel(pydantic.BaseModel): id: int s: str vec: list[float] li: list[int] lili: list[list[float]] litu: list[tuple[float, float]] opt: Optional[str] = None st: StructModel dt: date dtt: datetime dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict # TODO: test we can actually convert the model into data. # m = TestModel( # id=1, # s="hello", # vec=[1.0, 2.0, 3.0], # li=[2, 3, 4], # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], # st=StructModel(a="a", b=1.0), # dt=date.today(), # dtt=datetime.now(), # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), # ) schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.int64(), False), pa.field("s", pa.utf8(), False), pa.field("vec", pa.list_(pa.float64()), False), pa.field("li", pa.list_(pa.int64()), False), pa.field("lili", pa.list_(pa.list_(pa.float64())), False), pa.field("litu", pa.list_(pa.list_(pa.float64())), False), pa.field("opt", pa.utf8(), True), pa.field( "st", pa.struct( [pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)] ), False, ), pa.field("dt", pa.date32(), False), pa.field("dtt", pa.timestamp("us"), False), pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False), ] ) assert schema == expect_schema def test_optional_types_py310(): class TestModel(pydantic.BaseModel): a: str | None b: None | str c: Optional[str] schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("a", pa.utf8(), True), pa.field("b", pa.utf8(), True), pa.field("c", pa.utf8(), True), ] ) assert schema == expect_schema def test_optional_structs(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str split: SplitInfo | None = None schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "split", pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ), True, ), ] ) assert schema == expect_schema def test_optional_struct_list_py310(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: list[SplitInfo] | None = None schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ) ), True, ), ] ) assert schema == expect_schema def test_nested_struct_list(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: list[SplitInfo] schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ) ), False, ), ] ) assert schema == expect_schema def test_nested_struct_list_optional(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: Optional[list[SplitInfo]] = None schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ) ), True, ), ] ) assert schema == expect_schema def test_nested_struct_list_optional_items(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: list[Optional[SplitInfo]] schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.field( "item", pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ), True, ) ), False, ), ] ) assert schema == expect_schema def test_nested_struct_list_optional_container_and_items(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: Optional[list[Optional[SplitInfo]]] = None schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.field( "item", pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ), True, ) ), True, ), ] ) assert schema == expect_schema def test_nested_struct_list_optional_items_pep604(): class SplitInfo(pydantic.BaseModel): start_frame: int end_frame: int class TestModel(pydantic.BaseModel): id: str splits: list[SplitInfo | None] schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.utf8(), False), pa.field( "splits", pa.list_( pa.field( "item", pa.struct( [ pa.field("start_frame", pa.int64(), False), pa.field("end_frame", pa.int64(), False), ] ), True, ) ), False, ), ] ) assert schema == expect_schema def test_pydantic_to_arrow_py38(): class StructModel(pydantic.BaseModel): a: str b: Optional[float] class TestModel(pydantic.BaseModel): id: int s: str vec: List[float] li: List[int] lili: List[List[float]] litu: List[Tuple[float, float]] opt: Optional[str] = None st: StructModel dt: date dtt: datetime dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict # TODO: test we can actually convert the model to Arrow data. # m = TestModel( # id=1, # s="hello", # vec=[1.0, 2.0, 3.0], # li=[2, 3, 4], # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], # st=StructModel(a="a", b=1.0), # dt=date.today(), # dtt=datetime.now(), # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), # ) schema = pydantic_to_schema(TestModel) expect_schema = pa.schema( [ pa.field("id", pa.int64(), False), pa.field("s", pa.utf8(), False), pa.field("vec", pa.list_(pa.float64()), False), pa.field("li", pa.list_(pa.int64()), False), pa.field("lili", pa.list_(pa.list_(pa.float64())), False), pa.field("litu", pa.list_(pa.list_(pa.float64())), False), pa.field("opt", pa.utf8(), True), pa.field( "st", pa.struct( [pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)] ), False, ), pa.field("dt", pa.date32(), False), pa.field("dtt", pa.timestamp("us"), False), pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False), ] ) assert schema == expect_schema def test_nullable_vector(): class NullableModel(pydantic.BaseModel): vec: Vector(16, nullable=False) schema = pydantic_to_schema(NullableModel) assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)]) class DefaultModel(pydantic.BaseModel): vec: Vector(16) schema = pydantic_to_schema(DefaultModel) assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)]) class NotNullableModel(pydantic.BaseModel): vec: Vector(16) schema = pydantic_to_schema(NotNullableModel) assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)]) def test_fixed_size_list_field(): class TestModel(pydantic.BaseModel): vec: Vector(16) li: List[int] data = TestModel(vec=list(range(16)), li=[1, 2, 3]) if PYDANTIC_VERSION.major >= 2: assert json.loads(data.model_dump_json()) == { "vec": list(range(16)), "li": [1, 2, 3], } else: assert data.dict() == { "vec": list(range(16)), "li": [1, 2, 3], } schema = pydantic_to_schema(TestModel) assert schema == pa.schema( [ pa.field("vec", pa.list_(pa.float32(), 16)), pa.field("li", pa.list_(pa.int64()), False), ] ) if PYDANTIC_VERSION.major >= 2: json_schema = TestModel.model_json_schema() else: json_schema = TestModel.schema() assert json_schema == { "properties": { "vec": { "items": {"type": "number"}, "maxItems": 16, "minItems": 16, "title": "Vec", "type": "array", }, "li": {"items": {"type": "integer"}, "title": "Li", "type": "array"}, }, "required": ["vec", "li"], "title": "TestModel", "type": "object", } def test_fixed_size_list_validation(): class TestModel(pydantic.BaseModel): vec: Vector(8) with pytest.raises(pydantic.ValidationError): TestModel(vec=range(9)) with pytest.raises(pydantic.ValidationError): TestModel(vec=range(7)) TestModel(vec=range(8)) def test_lance_model(): class TestModel(LanceModel): vector: Vector(16) = Field(default=[0.0] * 16) li: List[int] = Field(default=[1, 2, 3]) schema = pydantic_to_schema(TestModel) assert schema == TestModel.to_arrow_schema() assert TestModel.field_names() == ["vector", "li"] t = TestModel() assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) def test_optional_nested_model(): class WAMedia(BaseModel): url: str mimetype: str filename: Optional[str] error: Optional[str] data: bytes class WALocation(BaseModel): description: Optional[str] latitude: str longitude: str class ReplyToMessage(BaseModel): id: str participant: str body: str class Message(BaseModel): id: str timestamp: int from_: str fromMe: bool to: str body: str hasMedia: Optional[bool] media: WAMedia mediaUrl: Optional[str] ack: Optional[int] ackName: Optional[str] author: Optional[str] location: Optional[WALocation] vCards: Optional[List[str]] replyTo: Optional[ReplyToMessage] class AnyEvent(LanceModel): id: str session: str metadata: Optional[str] = None engine: str event: str class MessageEvent(AnyEvent): payload: Message schema = pydantic_to_schema(MessageEvent) payload = schema.field("payload") assert payload.type == pa.struct( [ pa.field("id", pa.utf8(), False), pa.field("timestamp", pa.int64(), False), pa.field("from_", pa.utf8(), False), pa.field("fromMe", pa.bool_(), False), pa.field("to", pa.utf8(), False), pa.field("body", pa.utf8(), False), pa.field("hasMedia", pa.bool_(), True), pa.field( "media", pa.struct( [ pa.field("url", pa.utf8(), False), pa.field("mimetype", pa.utf8(), False), pa.field("filename", pa.utf8(), True), pa.field("error", pa.utf8(), True), pa.field("data", pa.binary(), False), ] ), False, ), pa.field("mediaUrl", pa.utf8(), True), pa.field("ack", pa.int64(), True), pa.field("ackName", pa.utf8(), True), pa.field("author", pa.utf8(), True), pa.field( "location", pa.struct( [ pa.field("description", pa.utf8(), True), pa.field("latitude", pa.utf8(), False), pa.field("longitude", pa.utf8(), False), ] ), True, # Optional ), pa.field("vCards", pa.list_(pa.utf8()), True), pa.field( "replyTo", pa.struct( [ pa.field("id", pa.utf8(), False), pa.field("participant", pa.utf8(), False), pa.field("body", pa.utf8(), False), ] ), True, ), ] ) def test_multi_vector(): class TestModel(pydantic.BaseModel): vec: MultiVector(8) schema = pydantic_to_schema(TestModel) assert schema == pa.schema( [pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)] ) with pytest.raises(pydantic.ValidationError): TestModel(vec=[[1.0] * 7]) with pytest.raises(pydantic.ValidationError): TestModel(vec=[[1.0] * 9]) TestModel(vec=[[1.0] * 8]) TestModel(vec=[[1.0] * 8, [2.0] * 8]) TestModel(vec=[]) def test_multi_vector_nullable(): class NullableModel(pydantic.BaseModel): vec: MultiVector(16, nullable=False) schema = pydantic_to_schema(NullableModel) assert schema == pa.schema( [pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)] ) class DefaultModel(pydantic.BaseModel): vec: MultiVector(16) schema = pydantic_to_schema(DefaultModel) assert schema == pa.schema( [pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)] ) def test_multi_vector_in_lance_model(): class TestModel(LanceModel): id: int vectors: MultiVector(16) = Field(default=[[0.0] * 16]) schema = pydantic_to_schema(TestModel) assert schema == TestModel.to_arrow_schema() assert TestModel.field_names() == ["id", "vectors"] t = TestModel(id=1) assert t.vectors == [[0.0] * 16] def test_aliases_in_lance_model(mem_db): data = [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 6.5], "item": "bar", "price": 20.0}, ] tbl = mem_db.create_table("items", data=data) class TestModel(LanceModel): name: str = Field(alias="item") price: float distance: float = Field(alias="_distance") model = ( tbl.search([5.9, 6.5]) .distance_type("cosine") .limit(1) .to_pydantic(TestModel)[0] ) assert hasattr(model, "name") assert hasattr(model, "distance") assert model.distance < 0.01 @pytest.mark.asyncio async def test_aliases_in_lance_model_async(mem_db_async): data = [ {"vector": [8.3, 2.5], "item": "foo", "price": 12.0}, {"vector": [7.7, 3.9], "item": "bar", "price": 11.2}, ] tbl = await mem_db_async.create_table("items", data=data) class TestModel(LanceModel): name: str = Field(alias="item") price: float distance: float = Field(alias="_distance") model = ( await tbl.vector_search([7.7, 3.9]) .distance_type("cosine") .limit(1) .to_pydantic(TestModel) )[0] assert hasattr(model, "name") assert hasattr(model, "distance") assert model.distance < 0.01 def test_enum_types(): """Enum fields should map to the Arrow type of their value (issue #1846).""" class StrStatus(str, Enum): PENDING = "pending" RUNNING = "running" DONE = "done" class IntPriority(int, Enum): LOW = 1 MEDIUM = 2 HIGH = 3 class TestModel(pydantic.BaseModel): status: StrStatus priority: IntPriority opt_status: Optional[StrStatus] = None schema = pydantic_to_schema(TestModel) assert schema.field("status").type == pa.dictionary(pa.int32(), pa.utf8()) assert schema.field("priority").type == pa.int64() assert schema.field("opt_status").type == pa.dictionary(pa.int32(), pa.utf8()) assert schema.field("opt_status").nullable