Compare commits

..

2 Commits

Author SHA1 Message Date
Ayush Chaurasia
ff08a996fc feat(python): add LanceTorchDataset / LanceIterableTorchDataset wrappers
Provides first-class PyTorch `Dataset`/`IterableDataset` wrappers around a
LanceDB table or permutation. The wrapper:

* Captures only the URI / table name / connect kwargs needed to re-open
  the table — no Rust handles in pickle output. Works out of the box with
  `DataLoader(num_workers > 0)`, which would otherwise crash a
  hand-rolled subclass.
* Implements both `__getitem__` and PyTorch's `__getitems__` dunder so
  the underlying batched `Permutation.fetch` is used when DataLoader
  fetches a batch of indices.
* Forwards column selection / format / transform / batch_size to the
  underlying Permutation, so users do not have to hand-roll the
  `_ensure_open` boilerplate from the issue.

Builds on the public `Permutation.fetch` API (#3243).

Closes lancedb/lancedb#3242
2026-04-29 22:21:00 +05:30
Ayush Chaurasia
049a689a1c feat(python): add public Permutation.fetch(indices) API
Adds a public method that mirrors __getitems__ for batch index access,
so users do not have to call a dunder directly when implementing custom
torch datasets.

Closes lancedb/lancedb#3243
2026-04-29 22:13:42 +05:30
6 changed files with 415 additions and 67 deletions

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""
PyTorch integration for LanceDB.
Exposes ``LanceTorchDataset`` (map-style) and ``LanceIterableTorchDataset``
(iterable-style) wrappers that adapt a LanceDB table or permutation to the
PyTorch ``torch.utils.data`` API, while transparently handling the bits
that make a hand-rolled subclass tricky:
* The underlying Lance reader holds Rust state that is not picklable, but
``DataLoader(num_workers > 0)`` needs to fork the dataset to its workers.
These classes strip the reader on pickle and re-open it in the worker on
first read.
* Constructing a permutation from a table involves several steps
(``permutation_builder``/``Permutation.from_tables``/``select_columns``
/``with_format``/...). The wrapper takes those as constructor arguments
and applies them once the dataset is opened in the worker.
Example
-------
>>> import lancedb, torch # doctest: +SKIP
>>> from lancedb.integrations.torch import LanceTorchDataset
>>> db = lancedb.connect(uri) # doctest: +SKIP
>>> tbl = db.open_table("images_224") # doctest: +SKIP
>>> ds = LanceTorchDataset( # doctest: +SKIP
... tbl, columns=["image_bytes", "label"], format="torch"
... )
>>> loader = torch.utils.data.DataLoader( # doctest: +SKIP
... ds, batch_size=64, num_workers=4, shuffle=True,
... )
"""
from typing import Any, Callable, Dict, List, Optional, Union
import torch.utils.data as _torch_data
from ..permutation import Permutation
from ..table import LanceTable
def _capture_table_state(table: LanceTable) -> Dict[str, Any]:
"""Pull just enough state out of a LanceTable so we can re-open the same
table in a forked worker process where the Rust handle isn't valid."""
conn = table._conn
connect_kwargs: Dict[str, Any] = {}
storage_options = getattr(conn, "storage_options", None)
if storage_options is not None:
connect_kwargs["storage_options"] = storage_options
return {
"uri": conn.uri,
"table_name": table.name,
"connect_kwargs": connect_kwargs,
}
def _open_permutation(state: Dict[str, Any]) -> Permutation:
"""Reconstruct a Permutation from a captured state dict."""
import lancedb
db = lancedb.connect(state["uri"], **state["connect_kwargs"])
base = db.open_table(state["table_name"])
perm_table_name = state.get("perm_table_name")
if perm_table_name is not None:
perm_tbl = db.open_table(perm_table_name)
perm = Permutation.from_tables(base, perm_tbl, state.get("split"))
else:
perm = Permutation.identity(base)
columns = state.get("columns")
fmt = state.get("format")
transform = state.get("transform")
batch_size = state.get("batch_size")
if columns is not None:
perm = perm.select_columns(columns)
if fmt is not None:
perm = perm.with_format(fmt)
if transform is not None:
perm = perm.with_transform(transform)
if batch_size is not None:
perm = perm.with_batch_size(batch_size)
return perm
class LanceTorchDataset(_torch_data.Dataset):
"""
A PyTorch map-style ``Dataset`` backed by a LanceDB table or permutation.
Pass the same ``LanceTable`` you already opened (and, optionally, a
permutation table / split / column selection / output format) and use
the result anywhere a ``torch.utils.data.Dataset`` is expected.
The wrapper:
* Stores the URI / table name / storage options needed to re-open the
table, not the Rust reader handle. Pickling keeps only the rebuild
recipe, so ``DataLoader(num_workers > 0)`` works out of the box.
* Implements both ``__getitem__`` and PyTorch's ``__getitems__`` dunder
so the underlying batched ``Permutation.fetch`` is used when the
DataLoader fetches a batch of indices.
Parameters
----------
table : LanceTable, optional
The base table to read from. Either ``table`` or both ``uri`` and
``table_name`` must be provided.
uri : str, optional
Database URI to reconnect to. Required if ``table`` is not given.
table_name : str, optional
Name of the base table within ``uri``.
connect_kwargs : dict, optional
Extra keyword arguments forwarded to ``lancedb.connect`` when
re-opening the database in a worker.
permutation_table : LanceTable, optional
A pre-built permutation table (see ``permutation_builder``) used to
define the row ordering. If omitted, the identity permutation is
used (rows in physical order).
split : str or int, optional
Split selector when ``permutation_table`` defines splits.
columns : list[str], optional
Subset of columns to read.
format : str, optional
Output format, forwarded to ``Permutation.with_format`` (e.g.
``"torch"`` for HuggingFace-style ``dict[str, Tensor]`` batches).
transform : Callable, optional
Custom batch transform, forwarded to ``Permutation.with_transform``.
Must be picklable to work with ``num_workers > 0``.
batch_size : int, optional
Forwarded to ``Permutation.with_batch_size`` for direct iteration.
DataLoader controls its own batching, so this only matters if the
dataset is iterated directly.
"""
def __init__(
self,
table: Optional[LanceTable] = None,
*,
uri: Optional[str] = None,
table_name: Optional[str] = None,
connect_kwargs: Optional[Dict[str, Any]] = None,
permutation_table: Optional[LanceTable] = None,
split: Optional[Union[str, int]] = None,
columns: Optional[List[str]] = None,
format: Optional[str] = None,
transform: Optional[Callable] = None,
batch_size: Optional[int] = None,
):
if table is None and (uri is None or table_name is None):
raise ValueError(
"Provide either `table` or both `uri` and `table_name`."
)
if table is not None:
state = _capture_table_state(table)
if connect_kwargs is not None:
state["connect_kwargs"] = connect_kwargs
else:
state = {
"uri": uri,
"table_name": table_name,
"connect_kwargs": connect_kwargs or {},
}
state["perm_table_name"] = (
permutation_table.name if permutation_table is not None else None
)
state["split"] = split
state["columns"] = columns
state["format"] = format
state["transform"] = transform
state["batch_size"] = batch_size
self._state: Dict[str, Any] = state
self._perm: Optional[Permutation] = None
def __getstate__(self) -> Dict[str, Any]:
# Strip the Rust-backed reader so the dataset is picklable. Workers
# rebuild it on first read via _ensure_open().
d = self.__dict__.copy()
d["_perm"] = None
return d
def __setstate__(self, d: Dict[str, Any]) -> None:
self.__dict__.update(d)
def _ensure_open(self) -> None:
if self._perm is None:
self._perm = _open_permutation(self._state)
def __len__(self) -> int:
self._ensure_open()
return len(self._perm)
def __getitem__(self, index: int) -> Any:
self._ensure_open()
return self._perm[index]
def __getitems__(self, indices: List[int]) -> Any:
self._ensure_open()
return self._perm.fetch(indices)
class LanceIterableTorchDataset(_torch_data.IterableDataset):
"""
PyTorch iterable-style ``IterableDataset`` over a LanceDB permutation.
Yields batches in the order defined by the underlying ``Permutation``.
With ``num_workers > 1`` each worker iterates the permutation
independently — for sharded iteration use the map-style
``LanceTorchDataset`` together with a sampler.
Constructor arguments mirror ``LanceTorchDataset``.
"""
def __init__(self, *args, **kwargs):
self._inner = LanceTorchDataset(*args, **kwargs)
def __getstate__(self) -> Dict[str, Any]:
return {"_inner": self._inner.__getstate__()}
def __setstate__(self, d: Dict[str, Any]) -> None:
self._inner = LanceTorchDataset.__new__(LanceTorchDataset)
self._inner.__setstate__(d["_inner"])
def __iter__(self):
self._inner._ensure_open()
return iter(self._inner._perm)

View File

@@ -9,7 +9,7 @@ import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .util import batch_to_tensor, batch_to_tensor_dict, batch_to_tensor_rows
from .util import batch_to_tensor, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@@ -697,7 +697,6 @@ class Permutation:
"pandas",
"arrow",
"torch",
"torch_row",
"torch_col",
"polars",
],
@@ -713,17 +712,8 @@ class Permutation:
- "python_col" - the batch will be a dict of lists (one entry per column)
- "pandas" - the batch will be a pandas DataFrame
- "arrow" - the batch will be a pyarrow RecordBatch
- "torch" - a list of per-row dicts whose values are torch tensors. When
used with ``torch.utils.data.DataLoader`` (default collate), each
batch yielded by the loader is ``dict[str, Tensor]`` — one tensor per
column, with column names preserved. This matches HuggingFace
``dataset.set_format("torch")`` semantics.
- "torch_row" - a list of 1-D torch tensors, one per row. Each tensor
stacks all column values into a single row vector and column names
are not preserved. (This was the previous "torch" behavior.)
- "torch_col" - a 2-D torch tensor of shape ``(n_cols, n_rows)``. Column
names are not preserved. Requires ``collate_fn=lambda x: x`` if used
with ``DataLoader``.
- "torch" - the batch will be a list of tensors, one per row
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
- "polars" - the batch will be a polars DataFrame
Conversion may or may not involve a data copy. Lance uses Arrow internally
@@ -751,8 +741,6 @@ class Permutation:
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_dict)
elif format == "torch_row":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)
@@ -791,6 +779,25 @@ class Permutation:
batch = LOOP.run(do_getitems())
return self.transform_fn(batch)
def fetch(self, indices: list[int]) -> Any:
"""
Fetch rows from the permutation by offset.
This is the public batch-access API. It returns the rows for the given
offsets in the same shape as configured by
[with_format](#with_format) / [with_transform](#with_transform).
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("memory:///")
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(10)])
>>> perm = Permutation.identity(tbl)
>>> perm.fetch([0, 5, 9])
[{'x': 0}, {'x': 5}, {'x': 9}]
"""
return self.__getitems__(indices)
@deprecated(details="Use with_skip instead")
def skip(self, skip: int) -> "Permutation":
"""

View File

@@ -448,29 +448,3 @@ def batch_to_tensor_rows(batch: pa.RecordBatch):
stacked = torch.tensor(numpy.column_stack(columns))
rows = list(stacked.unbind(dim=0))
return rows
def batch_to_tensor_dict(batch: pa.RecordBatch):
"""
Convert a PyArrow RecordBatch into a list of per-row dicts whose values
are PyTorch tensors.
Each column is converted to a tensor in one shot (zero-copy via DLPack
when supported), then sliced per row. The result is shaped to work with
PyTorch's default DataLoader collate, which stacks the per-row dicts
into a single ``dict[str, Tensor]`` per batch — matching the
HuggingFace ``dataset.set_format("torch")`` convention.
Fails if torch is not installed.
Fails if a column's data type is not supported by PyTorch.
"""
torch = attempt_import_or_raise("torch", "torch")
columns: dict[str, "torch.Tensor"] = {}
for i, name in enumerate(batch.schema.names):
col = batch.column(i)
try:
columns[name] = torch.from_dlpack(col)
except Exception:
columns[name] = torch.tensor(col.to_numpy(zero_copy_only=False))
n = batch.num_rows
return [{name: t[i] for name, t in columns.items()} for i in range(n)]

View File

@@ -950,27 +950,14 @@ def test_transform_fn(mem_db):
try:
import torch
# "torch" format: list of per-row dicts of tensors (HF-compatible).
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_result, list)
assert len(torch_result) == 10
assert isinstance(torch_result[0], dict)
assert set(torch_result[0].keys()) == {"id", "value"}
assert isinstance(torch_result[0]["id"], torch.Tensor)
assert torch_result[0]["id"].shape == ()
assert torch_result[0]["id"].dtype == torch.int64
# "torch_row" format: list of 1-D row tensors (previous "torch" behavior).
torch_row_result = list(
permutation.with_format("torch_row").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_row_result, list)
assert len(torch_row_result) == 10
assert isinstance(torch_row_result[0], torch.Tensor)
assert torch_row_result[0].shape == (2,)
assert torch_row_result[0].dtype == torch.int64
assert isinstance(torch_result[0], torch.Tensor)
assert torch_result[0].shape == (2,)
assert torch_result[0].dtype == torch.int64
except ImportError:
# Skip check if torch is not installed
pass
@@ -1108,3 +1095,23 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
"""Test __getitems__ with an out-of-range offset raises an error."""
with pytest.raises(Exception):
some_permutation.__getitems__([999999])
def test_fetch_matches_getitems(some_permutation: Permutation):
"""Public fetch() should be equivalent to __getitems__."""
indices = [0, 1, 2, 10, 100]
assert some_permutation.fetch(indices) == some_permutation.__getitems__(indices)
def test_fetch_respects_format(some_permutation: Permutation):
"""fetch() applies the configured format/transform."""
arrow_perm = some_permutation.with_format("arrow")
result = arrow_perm.fetch([0, 1, 2])
assert isinstance(result, pa.RecordBatch)
assert result.num_rows == 3
def test_fetch_invalid_offset(some_permutation: Permutation):
"""fetch() with an out-of-range offset raises an error."""
with pytest.raises(Exception):
some_permutation.fetch([999999])

View File

@@ -27,18 +27,8 @@ def test_permutation_dataloader(mem_db):
for batch in dataloader:
assert batch["a"].size(0) == 10
# New "torch" format: per-row dicts of tensors, default collate yields
# dict[str, Tensor] (HuggingFace style).
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert isinstance(batch, dict)
assert "a" in batch
assert batch["a"].size() == (10,)
# Previous "torch" semantics is preserved under the "torch_row" name.
permutation = permutation.with_format("torch_row")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pickle
import pyarrow as pa
import pytest
from lancedb import connect
from lancedb.permutation import permutation_builder
torch = pytest.importorskip("torch")
from lancedb.integrations.torch import ( # noqa: E402
LanceIterableTorchDataset,
LanceTorchDataset,
)
@pytest.fixture
def db_path(tmp_path):
"""LanceTorchDataset needs a real, on-disk DB so workers can re-open it."""
return tmp_path
def _make_table(db_path, name="imgs", n=20):
db = connect(db_path)
return db.create_table(
name,
pa.table({"x": [float(i) for i in range(n)], "y": list(range(n))}),
)
def test_basic_len_and_getitem(db_path):
tbl = _make_table(db_path)
ds = LanceTorchDataset(tbl)
assert len(ds) == 20
row = ds[0]
# Default ("python") format = list of dicts; __getitem__ wraps a single index.
assert isinstance(row, list)
assert row[0] == {"x": 0.0, "y": 0}
def test_getitems_uses_fetch(db_path):
tbl = _make_table(db_path)
ds = LanceTorchDataset(tbl)
rows = ds.__getitems__([0, 2, 4])
assert rows == [
{"x": 0.0, "y": 0},
{"x": 2.0, "y": 2},
{"x": 4.0, "y": 4},
]
def test_dataloader_default_collate(db_path):
tbl = _make_table(db_path, n=40)
ds = LanceTorchDataset(tbl)
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=False)
batch = next(iter(loader))
# default collate stacks list-of-dicts into dict-of-tensors
assert isinstance(batch, dict)
assert batch["x"].size() == (8,)
assert batch["y"].size() == (8,)
def test_picklable(db_path):
tbl = _make_table(db_path)
ds = LanceTorchDataset(tbl, columns=["x"])
# Force open then ensure pickle drops the Rust handle.
_ = len(ds)
blob = pickle.dumps(ds)
restored: LanceTorchDataset = pickle.loads(blob)
# Rust state should not survive pickling.
assert restored._perm is None
# …but the dataset must work after re-opening transparently.
assert len(restored) == 20
assert restored[0] == [{"x": 0.0}]
def test_dataloader_with_workers(db_path):
tbl = _make_table(db_path, n=32)
ds = LanceTorchDataset(tbl)
loader = torch.utils.data.DataLoader(
ds, batch_size=4, num_workers=2, shuffle=False
)
batches = list(loader)
seen = []
for b in batches:
seen.extend(b["x"].tolist())
assert sorted(seen) == [float(i) for i in range(32)]
def test_with_permutation_table(db_path):
tbl = _make_table(db_path, n=30)
db = connect(db_path)
perm_tbl = (
permutation_builder(tbl)
.split_random(ratios=[0.5, 0.5], seed=1, split_names=["train", "test"])
.persist(db, "imgs_perm")
.execute()
)
ds = LanceTorchDataset(tbl, permutation_table=perm_tbl, split="train")
# Should pickle/restore the permutation table reference too.
blob = pickle.dumps(ds)
restored = pickle.loads(blob)
assert len(restored) == 15
def test_format_passthrough_dataloader(db_path):
"""Custom `format` is forwarded to the underlying Permutation."""
tbl = _make_table(db_path, n=20)
ds = LanceTorchDataset(tbl, format="arrow")
# Arrow batches don't go through default_collate, so use a no-op collate.
loader = torch.utils.data.DataLoader(
ds, batch_size=5, shuffle=False, collate_fn=lambda x: x
)
batch = next(iter(loader))
assert isinstance(batch, pa.RecordBatch)
assert batch.num_rows == 5
def test_iterable_dataset(db_path):
tbl = _make_table(db_path, n=20)
ds = LanceIterableTorchDataset(tbl, batch_size=5)
batches = list(ds)
# default batch size + skip_last_batch=True yields full-size batches only
assert len(batches) == 4
assert all(len(b) == 5 for b in batches)
def test_uri_table_name_constructor(db_path):
_make_table(db_path)
ds = LanceTorchDataset(uri=str(db_path), table_name="imgs")
assert len(ds) == 20
assert ds[0] == [{"x": 0.0, "y": 0}]
def test_constructor_validates_args():
with pytest.raises(ValueError, match="table"):
LanceTorchDataset()