diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 04b13add1..ee3e7c7cd 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -3,12 +3,14 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from enum import Enum from datetime import timedelta +from enum import Enum from typing import ( TYPE_CHECKING, + Any, Dict, List, Literal, @@ -17,41 +19,40 @@ from typing import ( Type, TypeVar, Union, - Any, ) -import asyncio import deprecation import numpy as np import pyarrow as pa import pyarrow.compute as pc import pydantic +from typing_extensions import Annotated -from lancedb.pydantic import PYDANTIC_VERSION +from lancedb._lancedb import fts_query_to_json from lancedb.background_loop import LOOP +from lancedb.pydantic import PYDANTIC_VERSION from . import __version__ from .arrow import AsyncRecordBatchReader from .dependencies import pandas as pd +from .expr import Expr from .rerankers.base import Reranker from .rerankers.rrf import RRFReranker from .rerankers.util import check_reranker_result from .util import flatten_columns -from .expr import Expr -from lancedb._lancedb import fts_query_to_json -from typing_extensions import Annotated if TYPE_CHECKING: import sys + import PIL import polars as pl - from ._lancedb import Query as LanceQuery from ._lancedb import FTSQuery as LanceFTSQuery from ._lancedb import HybridQuery as LanceHybridQuery - from ._lancedb import VectorQuery as LanceVectorQuery - from ._lancedb import TakeQuery as LanceTakeQuery from ._lancedb import PyQueryRequest + from ._lancedb import Query as LanceQuery + from ._lancedb import TakeQuery as LanceTakeQuery + from ._lancedb import VectorQuery as LanceVectorQuery from .common import VEC from .pydantic import LanceModel from .table import Table @@ -3348,16 +3349,18 @@ class BaseQueryBuilder(object): If not specified, no timeout is applied. If the query does not complete within the specified time, an error will be raised. """ - async_iter = LOOP.run(self._inner.execute(max_batch_length, timeout)) + async_reader = LOOP.run( + self._inner.to_batches(max_batch_length=max_batch_length, timeout=timeout) + ) def iter_sync(): try: while True: - yield LOOP.run(async_iter.__anext__()) + yield LOOP.run(async_reader.__anext__()) except StopAsyncIteration: return - return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) + return pa.RecordBatchReader.from_batches(async_reader.schema, iter_sync()) def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table: """ diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index febb7e784..891ed808f 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -1512,6 +1512,37 @@ def test_take_queries(tmp_path): ] +def test_take_queries_to_batches(tmp_path): + # Regression test for the sync take-query path: `to_batches` previously + # raised ``AttributeError: 'AsyncTakeQuery' object has no attribute + # 'execute'`` because the inherited ``BaseQueryBuilder.to_batches`` called + # ``execute`` on the async wrapper instead of the native query. + db = lancedb.connect(tmp_path) + data = pa.table({"idx": list(range(100)), "label": [str(i) for i in range(100)]}) + table = db.create_table("test", data) + + # Take by offset → to_batches + rs = list(table.take_offsets([5, 2, 17]).to_batches()) + assert all(isinstance(b, pa.RecordBatch) for b in rs) + assert sum(b.num_rows for b in rs) == 3 + assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17] + + # Take by row id → to_batches + rs = list(table.take_row_ids([5, 2, 17]).to_batches()) + assert all(isinstance(b, pa.RecordBatch) for b in rs) + assert sum(b.num_rows for b in rs) == 3 + assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17] + + # Take with select projection → to_batches preserves the projection + rs = list(table.take_row_ids([5, 2, 17]).select(["label"]).to_batches()) + assert all(b.schema.names == ["label"] for b in rs) + assert sorted(v for b in rs for v in b.column("label").to_pylist()) == [ + "17", + "2", + "5", + ] + + def test_getitems(tmp_path): db = lancedb.connect(tmp_path) data = pa.table(