mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
feat(python): support to_pydantic in async (#2438)
This request improves support for `pydantic` integration by adding `to_pydantic` method to asynchronous queries and handling models that use `alias` in field definitions. Fixes #2436 and closes #2437 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for converting asynchronous query results to Pydantic models. - **Bug Fixes** - Simplified conversion of query results to Pydantic models for improved reliability. - Improved handling of field aliases and computed fields when mapping query results to Pydantic models. - **Tests** - Added tests to verify correct mapping of aliased and computed fields in both synchronous and asynchronous scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -786,10 +786,7 @@ class LanceQueryBuilder(ABC):
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow(timeout=timeout).to_pylist()
|
||||
]
|
||||
return [model(**row) for row in self.to_arrow(timeout=timeout).to_pylist()]
|
||||
|
||||
def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame":
|
||||
"""
|
||||
@@ -2400,6 +2397,28 @@ class AsyncQueryBase(object):
|
||||
|
||||
return pl.from_arrow(await self.to_arrow(timeout=timeout))
|
||||
|
||||
async def to_pydantic(
|
||||
self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None
|
||||
) -> List[LanceModel]:
|
||||
"""
|
||||
Convert results to a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
timeout : timedelta, optional
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**row) for row in (await self.to_arrow(timeout=timeout)).to_pylist()
|
||||
]
|
||||
|
||||
async def explain_plan(self, verbose: Optional[bool] = False):
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
|
||||
@@ -412,3 +412,50 @@ def test_multi_vector_in_lance_model():
|
||||
|
||||
t = TestModel(id=1)
|
||||
assert t.vectors == [[0.0] * 16]
|
||||
|
||||
|
||||
def test_aliases_in_lance_model(mem_db):
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 6.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
tbl = mem_db.create_table("items", data=data)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
name: str = Field(alias="item")
|
||||
price: float
|
||||
distance: float = Field(alias="_distance")
|
||||
|
||||
model = (
|
||||
tbl.search([5.9, 6.5])
|
||||
.distance_type("cosine")
|
||||
.limit(1)
|
||||
.to_pydantic(TestModel)[0]
|
||||
)
|
||||
assert hasattr(model, "name")
|
||||
assert hasattr(model, "distance")
|
||||
assert model.distance < 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aliases_in_lance_model_async(mem_db_async):
|
||||
data = [
|
||||
{"vector": [8.3, 2.5], "item": "foo", "price": 12.0},
|
||||
{"vector": [7.7, 3.9], "item": "bar", "price": 11.2},
|
||||
]
|
||||
tbl = await mem_db_async.create_table("items", data=data)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
name: str = Field(alias="item")
|
||||
price: float
|
||||
distance: float = Field(alias="_distance")
|
||||
|
||||
model = (
|
||||
await tbl.vector_search([7.7, 3.9])
|
||||
.distance_type("cosine")
|
||||
.limit(1)
|
||||
.to_pydantic(TestModel)
|
||||
)[0]
|
||||
assert hasattr(model, "name")
|
||||
assert hasattr(model, "distance")
|
||||
assert model.distance < 0.01
|
||||
|
||||
Reference in New Issue
Block a user