mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 19:10:39 +00:00
feat(python): support to_pydantic in async (#2438)
This request improves support for `pydantic` integration by adding `to_pydantic` method to asynchronous queries and handling models that use `alias` in field definitions. Fixes #2436 and closes #2437 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for converting asynchronous query results to Pydantic models. - **Bug Fixes** - Simplified conversion of query results to Pydantic models for improved reliability. - Improved handling of field aliases and computed fields when mapping query results to Pydantic models. - **Tests** - Added tests to verify correct mapping of aliased and computed fields in both synchronous and asynchronous scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -412,3 +412,50 @@ def test_multi_vector_in_lance_model():
|
||||
|
||||
t = TestModel(id=1)
|
||||
assert t.vectors == [[0.0] * 16]
|
||||
|
||||
|
||||
def test_aliases_in_lance_model(mem_db):
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 6.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
tbl = mem_db.create_table("items", data=data)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
name: str = Field(alias="item")
|
||||
price: float
|
||||
distance: float = Field(alias="_distance")
|
||||
|
||||
model = (
|
||||
tbl.search([5.9, 6.5])
|
||||
.distance_type("cosine")
|
||||
.limit(1)
|
||||
.to_pydantic(TestModel)[0]
|
||||
)
|
||||
assert hasattr(model, "name")
|
||||
assert hasattr(model, "distance")
|
||||
assert model.distance < 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aliases_in_lance_model_async(mem_db_async):
|
||||
data = [
|
||||
{"vector": [8.3, 2.5], "item": "foo", "price": 12.0},
|
||||
{"vector": [7.7, 3.9], "item": "bar", "price": 11.2},
|
||||
]
|
||||
tbl = await mem_db_async.create_table("items", data=data)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
name: str = Field(alias="item")
|
||||
price: float
|
||||
distance: float = Field(alias="_distance")
|
||||
|
||||
model = (
|
||||
await tbl.vector_search([7.7, 3.9])
|
||||
.distance_type("cosine")
|
||||
.limit(1)
|
||||
.to_pydantic(TestModel)
|
||||
)[0]
|
||||
assert hasattr(model, "name")
|
||||
assert hasattr(model, "distance")
|
||||
assert model.distance < 0.01
|
||||
|
||||
Reference in New Issue
Block a user