mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
This request improves support for `pydantic` integration by adding `to_pydantic` method to asynchronous queries and handling models that use `alias` in field definitions. Fixes #2436 and closes #2437 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for converting asynchronous query results to Pydantic models. - **Bug Fixes** - Simplified conversion of query results to Pydantic models for improved reliability. - Improved handling of field aliases and computed fields when mapping query results to Pydantic models. - **Tests** - Added tests to verify correct mapping of aliased and computed fields in both synchronous and asynchronous scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
462 lines
13 KiB
Python
462 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import json
|
|
import sys
|
|
from datetime import date, datetime
|
|
from typing import List, Optional, Tuple
|
|
|
|
import pyarrow as pa
|
|
import pydantic
|
|
import pytest
|
|
from lancedb.pydantic import (
|
|
PYDANTIC_VERSION,
|
|
LanceModel,
|
|
Vector,
|
|
pydantic_to_schema,
|
|
MultiVector,
|
|
)
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 9),
|
|
reason="using native type alias requires python3.9 or higher",
|
|
)
|
|
def test_pydantic_to_arrow():
|
|
class StructModel(pydantic.BaseModel):
|
|
a: str
|
|
b: Optional[float]
|
|
|
|
class TestModel(pydantic.BaseModel):
|
|
id: int
|
|
s: str
|
|
vec: list[float]
|
|
li: list[int]
|
|
lili: list[list[float]]
|
|
litu: list[tuple[float, float]]
|
|
opt: Optional[str] = None
|
|
st: StructModel
|
|
dt: date
|
|
dtt: datetime
|
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
|
# d: dict
|
|
|
|
# TODO: test we can actually convert the model into data.
|
|
# m = TestModel(
|
|
# id=1,
|
|
# s="hello",
|
|
# vec=[1.0, 2.0, 3.0],
|
|
# li=[2, 3, 4],
|
|
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
|
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
|
# st=StructModel(a="a", b=1.0),
|
|
# dt=date.today(),
|
|
# dtt=datetime.now(),
|
|
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
|
# )
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("id", pa.int64(), False),
|
|
pa.field("s", pa.utf8(), False),
|
|
pa.field("vec", pa.list_(pa.float64()), False),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("opt", pa.utf8(), True),
|
|
pa.field(
|
|
"st",
|
|
pa.struct(
|
|
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("dt", pa.date32(), False),
|
|
pa.field("dtt", pa.timestamp("us"), False),
|
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 10),
|
|
reason="using | type syntax requires python3.10 or higher",
|
|
)
|
|
def test_optional_types_py310():
|
|
class TestModel(pydantic.BaseModel):
|
|
a: str | None
|
|
b: None | str
|
|
c: Optional[str]
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("a", pa.utf8(), True),
|
|
pa.field("b", pa.utf8(), True),
|
|
pa.field("c", pa.utf8(), True),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info > (3, 8),
|
|
reason="using native type alias requires python3.9 or higher",
|
|
)
|
|
def test_pydantic_to_arrow_py38():
|
|
class StructModel(pydantic.BaseModel):
|
|
a: str
|
|
b: Optional[float]
|
|
|
|
class TestModel(pydantic.BaseModel):
|
|
id: int
|
|
s: str
|
|
vec: List[float]
|
|
li: List[int]
|
|
lili: List[List[float]]
|
|
litu: List[Tuple[float, float]]
|
|
opt: Optional[str] = None
|
|
st: StructModel
|
|
dt: date
|
|
dtt: datetime
|
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
|
# d: dict
|
|
|
|
# TODO: test we can actually convert the model to Arrow data.
|
|
# m = TestModel(
|
|
# id=1,
|
|
# s="hello",
|
|
# vec=[1.0, 2.0, 3.0],
|
|
# li=[2, 3, 4],
|
|
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
|
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
|
# st=StructModel(a="a", b=1.0),
|
|
# dt=date.today(),
|
|
# dtt=datetime.now(),
|
|
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
|
# )
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("id", pa.int64(), False),
|
|
pa.field("s", pa.utf8(), False),
|
|
pa.field("vec", pa.list_(pa.float64()), False),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("opt", pa.utf8(), True),
|
|
pa.field(
|
|
"st",
|
|
pa.struct(
|
|
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("dt", pa.date32(), False),
|
|
pa.field("dtt", pa.timestamp("us"), False),
|
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
def test_nullable_vector():
|
|
class NullableModel(pydantic.BaseModel):
|
|
vec: Vector(16, nullable=False)
|
|
|
|
schema = pydantic_to_schema(NullableModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)])
|
|
|
|
class DefaultModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
|
|
schema = pydantic_to_schema(DefaultModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
|
|
|
|
class NotNullableModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
|
|
schema = pydantic_to_schema(NotNullableModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
|
|
|
|
|
|
def test_fixed_size_list_field():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
li: List[int]
|
|
|
|
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
|
if PYDANTIC_VERSION.major >= 2:
|
|
assert json.loads(data.model_dump_json()) == {
|
|
"vec": list(range(16)),
|
|
"li": [1, 2, 3],
|
|
}
|
|
else:
|
|
assert data.dict() == {
|
|
"vec": list(range(16)),
|
|
"li": [1, 2, 3],
|
|
}
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == pa.schema(
|
|
[
|
|
pa.field("vec", pa.list_(pa.float32(), 16)),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
]
|
|
)
|
|
|
|
if PYDANTIC_VERSION.major >= 2:
|
|
json_schema = TestModel.model_json_schema()
|
|
else:
|
|
json_schema = TestModel.schema()
|
|
|
|
assert json_schema == {
|
|
"properties": {
|
|
"vec": {
|
|
"items": {"type": "number"},
|
|
"maxItems": 16,
|
|
"minItems": 16,
|
|
"title": "Vec",
|
|
"type": "array",
|
|
},
|
|
"li": {"items": {"type": "integer"}, "title": "Li", "type": "array"},
|
|
},
|
|
"required": ["vec", "li"],
|
|
"title": "TestModel",
|
|
"type": "object",
|
|
}
|
|
|
|
|
|
def test_fixed_size_list_validation():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: Vector(8)
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=range(9))
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=range(7))
|
|
|
|
TestModel(vec=range(8))
|
|
|
|
|
|
def test_lance_model():
|
|
class TestModel(LanceModel):
|
|
vector: Vector(16) = Field(default=[0.0] * 16)
|
|
li: List[int] = Field(default=[1, 2, 3])
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == TestModel.to_arrow_schema()
|
|
assert TestModel.field_names() == ["vector", "li"]
|
|
|
|
t = TestModel()
|
|
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
|
|
|
|
|
def test_optional_nested_model():
|
|
class WAMedia(BaseModel):
|
|
url: str
|
|
mimetype: str
|
|
filename: Optional[str]
|
|
error: Optional[str]
|
|
data: bytes
|
|
|
|
class WALocation(BaseModel):
|
|
description: Optional[str]
|
|
latitude: str
|
|
longitude: str
|
|
|
|
class ReplyToMessage(BaseModel):
|
|
id: str
|
|
participant: str
|
|
body: str
|
|
|
|
class Message(BaseModel):
|
|
id: str
|
|
timestamp: int
|
|
from_: str
|
|
fromMe: bool
|
|
to: str
|
|
body: str
|
|
hasMedia: Optional[bool]
|
|
media: WAMedia
|
|
mediaUrl: Optional[str]
|
|
ack: Optional[int]
|
|
ackName: Optional[str]
|
|
author: Optional[str]
|
|
location: Optional[WALocation]
|
|
vCards: Optional[List[str]]
|
|
replyTo: Optional[ReplyToMessage]
|
|
|
|
class AnyEvent(LanceModel):
|
|
id: str
|
|
session: str
|
|
metadata: Optional[str] = None
|
|
engine: str
|
|
event: str
|
|
|
|
class MessageEvent(AnyEvent):
|
|
payload: Message
|
|
|
|
schema = pydantic_to_schema(MessageEvent)
|
|
|
|
payload = schema.field("payload")
|
|
assert payload.type == pa.struct(
|
|
[
|
|
pa.field("id", pa.utf8(), False),
|
|
pa.field("timestamp", pa.int64(), False),
|
|
pa.field("from_", pa.utf8(), False),
|
|
pa.field("fromMe", pa.bool_(), False),
|
|
pa.field("to", pa.utf8(), False),
|
|
pa.field("body", pa.utf8(), False),
|
|
pa.field("hasMedia", pa.bool_(), True),
|
|
pa.field(
|
|
"media",
|
|
pa.struct(
|
|
[
|
|
pa.field("url", pa.utf8(), False),
|
|
pa.field("mimetype", pa.utf8(), False),
|
|
pa.field("filename", pa.utf8(), True),
|
|
pa.field("error", pa.utf8(), True),
|
|
pa.field("data", pa.binary(), False),
|
|
]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("mediaUrl", pa.utf8(), True),
|
|
pa.field("ack", pa.int64(), True),
|
|
pa.field("ackName", pa.utf8(), True),
|
|
pa.field("author", pa.utf8(), True),
|
|
pa.field(
|
|
"location",
|
|
pa.struct(
|
|
[
|
|
pa.field("description", pa.utf8(), True),
|
|
pa.field("latitude", pa.utf8(), False),
|
|
pa.field("longitude", pa.utf8(), False),
|
|
]
|
|
),
|
|
True, # Optional
|
|
),
|
|
pa.field("vCards", pa.list_(pa.utf8()), True),
|
|
pa.field(
|
|
"replyTo",
|
|
pa.struct(
|
|
[
|
|
pa.field("id", pa.utf8(), False),
|
|
pa.field("participant", pa.utf8(), False),
|
|
pa.field("body", pa.utf8(), False),
|
|
]
|
|
),
|
|
True,
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
def test_multi_vector():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: MultiVector(8)
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
|
|
)
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=[[1.0] * 7])
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=[[1.0] * 9])
|
|
|
|
TestModel(vec=[[1.0] * 8])
|
|
TestModel(vec=[[1.0] * 8, [2.0] * 8])
|
|
|
|
TestModel(vec=[])
|
|
|
|
|
|
def test_multi_vector_nullable():
|
|
class NullableModel(pydantic.BaseModel):
|
|
vec: MultiVector(16, nullable=False)
|
|
|
|
schema = pydantic_to_schema(NullableModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
|
|
)
|
|
|
|
class DefaultModel(pydantic.BaseModel):
|
|
vec: MultiVector(16)
|
|
|
|
schema = pydantic_to_schema(DefaultModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
|
|
)
|
|
|
|
|
|
def test_multi_vector_in_lance_model():
|
|
class TestModel(LanceModel):
|
|
id: int
|
|
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == TestModel.to_arrow_schema()
|
|
assert TestModel.field_names() == ["id", "vectors"]
|
|
|
|
t = TestModel(id=1)
|
|
assert t.vectors == [[0.0] * 16]
|
|
|
|
|
|
def test_aliases_in_lance_model(mem_db):
|
|
data = [
|
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
|
{"vector": [5.9, 6.5], "item": "bar", "price": 20.0},
|
|
]
|
|
tbl = mem_db.create_table("items", data=data)
|
|
|
|
class TestModel(LanceModel):
|
|
name: str = Field(alias="item")
|
|
price: float
|
|
distance: float = Field(alias="_distance")
|
|
|
|
model = (
|
|
tbl.search([5.9, 6.5])
|
|
.distance_type("cosine")
|
|
.limit(1)
|
|
.to_pydantic(TestModel)[0]
|
|
)
|
|
assert hasattr(model, "name")
|
|
assert hasattr(model, "distance")
|
|
assert model.distance < 0.01
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aliases_in_lance_model_async(mem_db_async):
|
|
data = [
|
|
{"vector": [8.3, 2.5], "item": "foo", "price": 12.0},
|
|
{"vector": [7.7, 3.9], "item": "bar", "price": 11.2},
|
|
]
|
|
tbl = await mem_db_async.create_table("items", data=data)
|
|
|
|
class TestModel(LanceModel):
|
|
name: str = Field(alias="item")
|
|
price: float
|
|
distance: float = Field(alias="_distance")
|
|
|
|
model = (
|
|
await tbl.vector_search([7.7, 3.9])
|
|
.distance_type("cosine")
|
|
.limit(1)
|
|
.to_pydantic(TestModel)
|
|
)[0]
|
|
assert hasattr(model, "name")
|
|
assert hasattr(model, "distance")
|
|
assert model.distance < 0.01
|