From 01c338e51c0b8d87e8ff0728f0ee0957eebbe091 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 3 Jun 2026 22:59:25 +0800 Subject: [PATCH] fix(python): route blob query pandas through scanner --- python/python/lancedb/query.py | 183 ++++++++++++++++++++++++++---- python/python/tests/test_query.py | 90 ++++++++++++++- 2 files changed, 247 insertions(+), 26 deletions(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 7fa018892..847ba54db 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -91,14 +91,14 @@ def _schema_has_blob_field(schema: pa.Schema) -> bool: def _blob_mode_requires_native_pandas(blob_mode: BlobMode, schema: pa.Schema) -> bool: - return blob_mode in ("lazy", "bytes") and _schema_has_blob_field(schema) + return blob_mode in _BLOB_MODE_TO_HANDLING and _schema_has_blob_field(schema) def _unsupported_blob_pandas_error(reason: str) -> RuntimeError: return RuntimeError( - "blob_mode='lazy' and blob_mode='bytes' require Lance native pandas " - f"conversion for queries that return blob columns, but {reason}. " - "Use blob_mode='descriptions' or remove blob columns from the projection." + "blob columns require Lance native scanner conversion for query " + f"to_pandas(), but {reason}. Use a plain scan query or remove blob " + "columns from the projection." ) @@ -149,19 +149,48 @@ def _projection_to_scanner_kwargs( return {"columns": projection} -def _scanner_kwargs_for_query(query: Query, blob_mode: BlobMode) -> Dict[str, Any]: +def _scanner_kwargs_for_query( + query: Query, blob_mode: BlobMode, dataset: Optional[Any] = None +) -> Dict[str, Any]: + fragments = _scanner_fragments_for_query(query, dataset) kwargs = { **_projection_to_scanner_kwargs(query.columns), "filter": _filter_to_sql(query.filter), "limit": query.limit, "offset": query.offset, "with_row_id": query.with_row_id, + "with_row_address": query.with_row_address, "fast_search": query.fast_search, "blob_handling": _BLOB_MODE_TO_HANDLING[blob_mode], + "fragments": fragments, } return {key: value for key, value in kwargs.items() if value is not None} +def _scanner_fragments_for_query(query: Query, dataset: Optional[Any]) -> Optional[Any]: + if query.fragments is not None and query.fragment_ids is not None: + raise ValueError("fragments and fragment_ids cannot both be set") + if query.fragments is not None: + return query.fragments + if query.fragment_ids is None: + return None + if dataset is None: + raise ValueError("fragment_ids require a Lance dataset") + + requested = set(query.fragment_ids) + fragments = [ + fragment + for fragment in dataset.get_fragments() + if fragment.fragment_id in requested + ] + found = {fragment.fragment_id for fragment in fragments} + missing = requested - found + if missing: + missing_ids = ", ".join(str(fragment_id) for fragment_id in sorted(missing)) + raise ValueError(f"fragment_ids not found in dataset: {missing_ids}") + return fragments + + def _ensure_lazy_blob_frame( df: "pd.DataFrame", schema: pa.Schema, blob_mode: BlobMode ) -> "pd.DataFrame": @@ -179,6 +208,16 @@ def _ensure_lazy_blob_frame( return df +def _scanner_to_table(scanner: Any) -> pa.Table: + if hasattr(scanner, "to_pyarrow"): + reader = scanner.to_pyarrow() + return reader.read_all() + if hasattr(scanner, "to_table"): + return scanner.to_table() + reader = scanner.to_reader() + return reader.read_all() + + def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataFrame": schema = getattr(scanner, "projected_schema", None) if schema is None: @@ -199,14 +238,7 @@ def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataF return _ensure_lazy_blob_frame(df, schema, blob_mode) return df - if hasattr(scanner, "to_pyarrow"): - reader = scanner.to_pyarrow() - tbl = reader.read_all() - elif hasattr(scanner, "to_table"): - tbl = scanner.to_table() - else: - reader = scanner.to_reader() - tbl = reader.read_all() + tbl = _scanner_to_table(scanner) if blob_mode == "lazy" and _schema_has_blob_field(tbl.schema): raise _unsupported_blob_pandas_error( "the Lance scanner does not expose to_pandas" @@ -648,6 +680,13 @@ class Query(pydantic.BaseModel): # if true, include the row id in the results with_row_id: Optional[bool] = None + # if true, include the row address in the results + with_row_address: Optional[bool] = None + + # Lance fragments or fragment ids to scan on scanner-backed plain queries + fragments: Optional[Any] = None + fragment_ids: Optional[List[int]] = None + # offset to start fetching results from offset: Optional[int] = None @@ -840,6 +879,9 @@ class LanceQueryBuilder(ABC): self._where = None self._postfilter = None self._with_row_id = None + self._with_row_address = None + self._fragments = None + self._fragment_ids = None self._vector = None self._text = None self._ef = None @@ -901,9 +943,11 @@ class LanceQueryBuilder(ABC): schema = output_schema() if _blob_mode_requires_native_pandas(blob_mode, schema): native_error = None - if flatten is None and timeout is None: + if (flatten is None or blob_mode == "descriptions") and timeout is None: try: - df = self._plain_scan_to_pandas(blob_mode, **kwargs) + df = self._plain_scan_to_pandas( + blob_mode, flatten=flatten, **kwargs + ) if df is not None: return df except Exception as err: @@ -1125,6 +1169,32 @@ class LanceQueryBuilder(ABC): self._with_row_id = with_row_id return self + def with_row_address(self, with_row_address: bool = True) -> Self: + """Set whether to return row addresses. + + Parameters + ---------- + with_row_address: bool, default True + If True, return the _rowaddr column in the results. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + self._with_row_address = with_row_address + return self + + def with_fragments(self, fragments: Any) -> Self: + """Set the Lance fragments to scan for plain scanner-backed queries.""" + self._fragments = fragments + return self + + def fragment_ids(self, fragment_ids: List[int]) -> Self: + """Set the Lance fragment ids to scan for plain scanner-backed queries.""" + self._fragment_ids = fragment_ids + return self + def explain_plan(self, verbose: Optional[bool] = False) -> str: """Return the execution plan for this query. @@ -1267,6 +1337,7 @@ class LanceQueryBuilder(ABC): def _plain_scan_to_pandas( self, blob_mode: BlobMode, + flatten: Optional[Union[int, bool]] = None, **kwargs, ) -> Optional["pd.DataFrame"]: query = self.to_query_object() @@ -1274,7 +1345,12 @@ class LanceQueryBuilder(ABC): return None dataset = self._table.to_lance() - scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode)) + scanner = dataset.scanner( + **_scanner_kwargs_for_query(query, blob_mode, dataset) + ) + if flatten is not None: + tbl = flatten_columns(_scanner_to_table(scanner), flatten) + return tbl.to_pandas(**kwargs) return _scanner_to_pandas(scanner, blob_mode, **kwargs) @abstractmethod @@ -1548,6 +1624,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): refine_factor=self._refine_factor, vector_column=self._vector_column, with_row_id=self._with_row_id, + with_row_address=self._with_row_address, + fragments=self._fragments, + fragment_ids=self._fragment_ids, offset=self._offset, fast_search=self._fast_search, ef=self._ef, @@ -1750,6 +1829,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): limit=self._limit, postfilter=self._postfilter, with_row_id=self._with_row_id, + with_row_address=self._with_row_address, + fragments=self._fragments, + fragment_ids=self._fragment_ids, full_text_query=FullTextSearchQuery( query=self._query, columns=self._fts_columns ), @@ -1820,6 +1902,9 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): filter=self._where, limit=self._limit, with_row_id=self._with_row_id, + with_row_address=self._with_row_address, + fragments=self._fragments, + fragment_ids=self._fragment_ids, offset=self._offset, order_by=self._order_by, ) @@ -2411,6 +2496,9 @@ class AsyncQueryBase(object): """ self._inner = inner self._table = table + self._with_row_address = None + self._fragments = None + self._fragment_ids = None def to_query_object(self) -> Query: """ @@ -2419,7 +2507,11 @@ class AsyncQueryBase(object): This is currently experimental but can be useful as the query object is pure python and more easily serializable. """ - return Query.from_inner(self._inner.to_query_request()) + query = Query.from_inner(self._inner.to_query_request()) + query.with_row_address = self._with_row_address + query.fragments = self._fragments + query.fragment_ids = self._fragment_ids + return query def select(self, columns: Union[List[str], dict[str, str]]) -> Self: """ @@ -2476,6 +2568,27 @@ class AsyncQueryBase(object): self._inner.with_row_id() return self + def with_row_address(self, with_row_address: bool = True) -> Self: + """ + Include the _rowaddr column in scanner-backed plain query results. + """ + self._with_row_address = with_row_address + return self + + def with_fragments(self, fragments: Any) -> Self: + """ + Restrict scanner-backed plain query results to the given Lance fragments. + """ + self._fragments = fragments + return self + + def fragment_ids(self, fragment_ids: List[int]) -> Self: + """ + Restrict scanner-backed plain query results to the given Lance fragment ids. + """ + self._fragment_ids = fragment_ids + return self + async def to_batches( self, *, @@ -2601,9 +2714,11 @@ class AsyncQueryBase(object): schema = await self.output_schema() if _blob_mode_requires_native_pandas(blob_mode, schema): native_error = None - if flatten is None and timeout is None: + if (flatten is None or blob_mode == "descriptions") and timeout is None: try: - df = await self._plain_scan_to_pandas(blob_mode, **kwargs) + df = await self._plain_scan_to_pandas( + blob_mode, flatten=flatten, **kwargs + ) if df is not None: return df except Exception as err: @@ -2625,6 +2740,7 @@ class AsyncQueryBase(object): async def _plain_scan_to_pandas( self, blob_mode: BlobMode, + flatten: Optional[Union[int, bool]] = None, **kwargs, ) -> Optional["pd.DataFrame"]: if self._table is None: @@ -2635,7 +2751,12 @@ class AsyncQueryBase(object): return None dataset = await self._table._to_lance() - scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode)) + scanner = dataset.scanner( + **_scanner_kwargs_for_query(query, blob_mode, dataset) + ) + if flatten is not None: + tbl = flatten_columns(_scanner_to_table(scanner), flatten) + return tbl.to_pandas(**kwargs) return _scanner_to_pandas(scanner, blob_mode, **kwargs) async def to_polars( @@ -3522,6 +3643,7 @@ class AsyncTakeQuery(AsyncQueryBase): async def _plain_scan_to_pandas( self, blob_mode: BlobMode, + flatten: Optional[Union[int, bool]] = None, **kwargs, ) -> Optional["pd.DataFrame"]: return None @@ -3576,6 +3698,27 @@ class BaseQueryBuilder(object): self._inner.with_row_id() return self + def with_row_address(self, with_row_address: bool = True) -> Self: + """ + Include the _rowaddr column in scanner-backed plain query results. + """ + self._inner.with_row_address(with_row_address) + return self + + def with_fragments(self, fragments: Any) -> Self: + """ + Restrict scanner-backed plain query results to the given Lance fragments. + """ + self._inner.with_fragments(fragments) + return self + + def fragment_ids(self, fragment_ids: List[int]) -> Self: + """ + Restrict scanner-backed plain query results to the given Lance fragment ids. + """ + self._inner.fragment_ids(fragment_ids) + return self + def output_schema(self) -> pa.Schema: """ Return the output schema for the query diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index aa9468120..8f977bf91 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -255,8 +255,9 @@ def test_plain_scan_query_to_pandas_blob_projection(tmp_db): assert df["double_id"].tolist() == [6, 8] +@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"]) def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow( - tmp_db, monkeypatch + tmp_db, monkeypatch, blob_mode ): pytest.importorskip("lance") table = tmp_db.create_table( @@ -269,10 +270,69 @@ def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow( monkeypatch.setattr(query, "to_arrow", fail_to_arrow) - df = query.to_pandas(blob_mode="bytes") + df = query.to_pandas(blob_mode=blob_mode) assert df["id"].tolist() == [1] - assert df["blob"].tolist() == [b"one"] + if blob_mode == "bytes": + assert df["blob"].tolist() == [b"one"] + else: + first = df["blob"].iloc[0] + assert first != b"one" + assert not hasattr(first, "readall") + + +def test_plain_scan_query_to_pandas_blob_descriptions_flatten_uses_scanner( + tmp_db, monkeypatch +): + pytest.importorskip("lance") + table = tmp_db.create_table( + "test_query_to_pandas_blob_desc_flatten", _blob_query_data() + ) + query = table.search().where("id = 1").select(["id", "blob"]) + + def fail_to_arrow(*args, **kwargs): + raise AssertionError("to_arrow should not be called before scanner pandas") + + monkeypatch.setattr(query, "to_arrow", fail_to_arrow) + + df = query.to_pandas(blob_mode="descriptions", flatten=True) + + assert df["id"].tolist() == [1] + assert any(column == "blob" or column.startswith("blob.") for column in df.columns) + + +def test_plain_scan_query_to_pandas_scanner_state(tmp_db): + pytest.importorskip("lance") + data = _blob_query_data() + table = tmp_db.create_table("test_query_to_pandas_scanner_state", data.slice(0, 2)) + table.add(data.slice(2, 2)) + + fragments = table.to_lance().get_fragments() + assert len(fragments) == 2 + + query = ( + table.search() + .select(["id", "blob"]) + .with_row_address() + .fragment_ids([fragments[1].fragment_id]) + ) + query_obj = query.to_query_object() + assert query_obj.with_row_address is True + assert query_obj.fragment_ids == [fragments[1].fragment_id] + + df = query.to_pandas(blob_mode="descriptions") + + assert df["id"].tolist() == [3, 4] + assert "_rowaddr" in df.columns + assert {rowaddr >> 32 for rowaddr in df["_rowaddr"]} == {fragments[1].fragment_id} + + df_by_fragment = ( + table.search() + .select(["id", "blob"]) + .with_fragments([fragments[0]]) + .to_pandas(blob_mode="descriptions") + ) + assert df_by_fragment["id"].tolist() == [1, 2] @pytest.mark.asyncio @@ -312,8 +372,9 @@ async def test_async_plain_scan_query_to_pandas_blob_projection(tmp_db_async): @pytest.mark.asyncio +@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"]) async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow( - tmp_db_async, monkeypatch + tmp_db_async, monkeypatch, blob_mode ): pytest.importorskip("lance") table = await tmp_db_async.create_table( @@ -326,10 +387,15 @@ async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow monkeypatch.setattr(query, "to_arrow", fail_to_arrow) - df = await query.to_pandas(blob_mode="bytes") + df = await query.to_pandas(blob_mode=blob_mode) assert df["id"].tolist() == [1] - assert df["blob"].tolist() == [b"one"] + if blob_mode == "bytes": + assert df["blob"].tolist() == [b"one"] + else: + first = df["blob"].iloc[0] + assert first != b"one" + assert not hasattr(first, "readall") def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db): @@ -342,6 +408,18 @@ def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db): ) +def test_vector_query_to_pandas_blob_descriptions_requires_plain_scan(tmp_db): + pytest.importorskip("lance") + table = tmp_db.create_table( + "test_vector_query_blob_descriptions", _blob_query_data() + ) + + with pytest.raises(RuntimeError, match="plain scan query"): + table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas( + blob_mode="descriptions" + ) + + def test_order_by_plain_query(mem_db): table = mem_db.create_table( "test_order_by",