mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
feat: support nested pydantic schema (#707)
This commit is contained in:
@@ -21,6 +21,7 @@ import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
@@ -141,14 +142,32 @@ def test_add(db):
|
||||
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(16)
|
||||
li: List[int]
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
table = LanceTable.create(db, "test", data=[data])
|
||||
assert len(table) == 1
|
||||
assert table.schema == TestModel.to_arrow_schema()
|
||||
class Document(BaseModel):
|
||||
content: str
|
||||
source: str
|
||||
|
||||
class LanceSchema(LanceModel):
|
||||
id: str
|
||||
vector: Vector(2)
|
||||
li: List[int]
|
||||
payload: Document
|
||||
|
||||
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
|
||||
assert tbl.schema == LanceSchema.to_arrow_schema()
|
||||
|
||||
# add works
|
||||
expected = LanceSchema(
|
||||
id="id",
|
||||
vector=[0.0, 0.0],
|
||||
li=[1, 2, 3],
|
||||
payload=Document(content="foo", source="bar"),
|
||||
)
|
||||
tbl.add([expected])
|
||||
|
||||
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
|
||||
Reference in New Issue
Block a user