From 4bccb43e569232886d112fc9dfe74db364ccc454 Mon Sep 17 00:00:00 2001 From: Justin Miller Date: Thu, 21 May 2026 12:11:13 -0700 Subject: [PATCH] fix(python): route sync BaseQueryBuilder.to_batches through async path (#3425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes #3424. `LanceTakeQueryBuilder.to_batches()` raised `AttributeError: 'AsyncTakeQuery' object has no attribute 'execute'`. The inherited `BaseQueryBuilder.to_batches` called `self._inner.execute(...)`, but `self._inner` is an `AsyncQueryBase` (Python wrapper) — only its native inner exposes `execute`. Every other sync builder overrides `to_batches`, so the bug only surfaced on take-query builders, which inherit the base unchanged. `take_offsets(...).to_batches()` is broken for the same reason. Route the sync wrapper through the async `to_batches` on the background event loop, so the native `execute` is invoked from inside an awaiting context (matching how the async path works correctly). ## Repro ```python import lancedb, pyarrow as pa, tempfile db = lancedb.connect(tempfile.mkdtemp()) tbl = db.create_table("t", data=pa.table({"a": list(range(100))})) tbl.take_row_ids([0, 1, 2]).to_arrow() # works tbl.search().to_batches() # works list(tbl.take_row_ids([0, 1, 2]).to_batches()) # AttributeError (before) ``` ## Test plan - [x] New regression test `test_take_queries_to_batches` covers `take_offsets(...).to_batches()`, `take_row_ids(...).to_batches()`, and the `select(...)` projection — all fail on `main` with the patch reverted, all pass with the fix. - [x] `test_take_queries`, `test_query_builder_batches`, and `test_query_schema` still pass. - [x] `ruff format --check` and `ruff check` clean on changed files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- python/python/lancedb/query.py | 29 ++++++++++++++++------------- python/python/tests/test_query.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 04b13add1..ee3e7c7cd 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -3,12 +3,14 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from enum import Enum from datetime import timedelta +from enum import Enum from typing import ( TYPE_CHECKING, + Any, Dict, List, Literal, @@ -17,41 +19,40 @@ from typing import ( Type, TypeVar, Union, - Any, ) -import asyncio import deprecation import numpy as np import pyarrow as pa import pyarrow.compute as pc import pydantic +from typing_extensions import Annotated -from lancedb.pydantic import PYDANTIC_VERSION +from lancedb._lancedb import fts_query_to_json from lancedb.background_loop import LOOP +from lancedb.pydantic import PYDANTIC_VERSION from . import __version__ from .arrow import AsyncRecordBatchReader from .dependencies import pandas as pd +from .expr import Expr from .rerankers.base import Reranker from .rerankers.rrf import RRFReranker from .rerankers.util import check_reranker_result from .util import flatten_columns -from .expr import Expr -from lancedb._lancedb import fts_query_to_json -from typing_extensions import Annotated if TYPE_CHECKING: import sys + import PIL import polars as pl - from ._lancedb import Query as LanceQuery from ._lancedb import FTSQuery as LanceFTSQuery from ._lancedb import HybridQuery as LanceHybridQuery - from ._lancedb import VectorQuery as LanceVectorQuery - from ._lancedb import TakeQuery as LanceTakeQuery from ._lancedb import PyQueryRequest + from ._lancedb import Query as LanceQuery + from ._lancedb import TakeQuery as LanceTakeQuery + from ._lancedb import VectorQuery as LanceVectorQuery from .common import VEC from .pydantic import LanceModel from .table import Table @@ -3348,16 +3349,18 @@ class BaseQueryBuilder(object): If not specified, no timeout is applied. If the query does not complete within the specified time, an error will be raised. """ - async_iter = LOOP.run(self._inner.execute(max_batch_length, timeout)) + async_reader = LOOP.run( + self._inner.to_batches(max_batch_length=max_batch_length, timeout=timeout) + ) def iter_sync(): try: while True: - yield LOOP.run(async_iter.__anext__()) + yield LOOP.run(async_reader.__anext__()) except StopAsyncIteration: return - return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) + return pa.RecordBatchReader.from_batches(async_reader.schema, iter_sync()) def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table: """ diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index febb7e784..891ed808f 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -1512,6 +1512,37 @@ def test_take_queries(tmp_path): ] +def test_take_queries_to_batches(tmp_path): + # Regression test for the sync take-query path: `to_batches` previously + # raised ``AttributeError: 'AsyncTakeQuery' object has no attribute + # 'execute'`` because the inherited ``BaseQueryBuilder.to_batches`` called + # ``execute`` on the async wrapper instead of the native query. + db = lancedb.connect(tmp_path) + data = pa.table({"idx": list(range(100)), "label": [str(i) for i in range(100)]}) + table = db.create_table("test", data) + + # Take by offset → to_batches + rs = list(table.take_offsets([5, 2, 17]).to_batches()) + assert all(isinstance(b, pa.RecordBatch) for b in rs) + assert sum(b.num_rows for b in rs) == 3 + assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17] + + # Take by row id → to_batches + rs = list(table.take_row_ids([5, 2, 17]).to_batches()) + assert all(isinstance(b, pa.RecordBatch) for b in rs) + assert sum(b.num_rows for b in rs) == 3 + assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17] + + # Take with select projection → to_batches preserves the projection + rs = list(table.take_row_ids([5, 2, 17]).select(["label"]).to_batches()) + assert all(b.schema.names == ["label"] for b in rs) + assert sorted(v for b in rs for v in b.column("label").to_pylist()) == [ + "17", + "2", + "5", + ] + + def test_getitems(tmp_path): db = lancedb.connect(tmp_path) data = pa.table(