diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 66d515ec..0367e9dc 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -786,10 +786,7 @@ class LanceQueryBuilder(ABC): ------- List[LanceModel] """ - return [ - model(**{k: v for k, v in row.items() if k in model.field_names()}) - for row in self.to_arrow(timeout=timeout).to_pylist() - ] + return [model(**row) for row in self.to_arrow(timeout=timeout).to_pylist()] def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame": """ @@ -2400,6 +2397,28 @@ class AsyncQueryBase(object): return pl.from_arrow(await self.to_arrow(timeout=timeout)) + async def to_pydantic( + self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None + ) -> List[LanceModel]: + """ + Convert results to a list of pydantic models. + + Parameters + ---------- + model : Type[LanceModel] + The pydantic model to use. + timeout : timedelta, optional + The maximum time to wait for the query to complete. + If None, wait indefinitely. + + Returns + ------- + list[LanceModel] + """ + return [ + model(**row) for row in (await self.to_arrow(timeout=timeout)).to_pylist() + ] + async def explain_plan(self, verbose: Optional[bool] = False): """Return the execution plan for this query. diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 514871cc..41dbb6ec 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -412,3 +412,50 @@ def test_multi_vector_in_lance_model(): t = TestModel(id=1) assert t.vectors == [[0.0] * 16] + + +def test_aliases_in_lance_model(mem_db): + data = [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 6.5], "item": "bar", "price": 20.0}, + ] + tbl = mem_db.create_table("items", data=data) + + class TestModel(LanceModel): + name: str = Field(alias="item") + price: float + distance: float = Field(alias="_distance") + + model = ( + tbl.search([5.9, 6.5]) + .distance_type("cosine") + .limit(1) + .to_pydantic(TestModel)[0] + ) + assert hasattr(model, "name") + assert hasattr(model, "distance") + assert model.distance < 0.01 + + +@pytest.mark.asyncio +async def test_aliases_in_lance_model_async(mem_db_async): + data = [ + {"vector": [8.3, 2.5], "item": "foo", "price": 12.0}, + {"vector": [7.7, 3.9], "item": "bar", "price": 11.2}, + ] + tbl = await mem_db_async.create_table("items", data=data) + + class TestModel(LanceModel): + name: str = Field(alias="item") + price: float + distance: float = Field(alias="_distance") + + model = ( + await tbl.vector_search([7.7, 3.9]) + .distance_type("cosine") + .limit(1) + .to_pydantic(TestModel) + )[0] + assert hasattr(model, "name") + assert hasattr(model, "distance") + assert model.distance < 0.01