diff --git a/python/python/lancedb/integrations/torch.py b/python/python/lancedb/integrations/torch.py new file mode 100644 index 000000000..11aac3626 --- /dev/null +++ b/python/python/lancedb/integrations/torch.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +""" +PyTorch integration for LanceDB. + +Exposes ``LanceTorchDataset`` (map-style) and ``LanceIterableTorchDataset`` +(iterable-style) wrappers that adapt a LanceDB table or permutation to the +PyTorch ``torch.utils.data`` API, while transparently handling the bits +that make a hand-rolled subclass tricky: + +* The underlying Lance reader holds Rust state that is not picklable, but + ``DataLoader(num_workers > 0)`` needs to fork the dataset to its workers. + These classes strip the reader on pickle and re-open it in the worker on + first read. +* Constructing a permutation from a table involves several steps + (``permutation_builder``/``Permutation.from_tables``/``select_columns`` + /``with_format``/...). The wrapper takes those as constructor arguments + and applies them once the dataset is opened in the worker. + +Example +------- +>>> import lancedb, torch # doctest: +SKIP +>>> from lancedb.integrations.torch import LanceTorchDataset +>>> db = lancedb.connect(uri) # doctest: +SKIP +>>> tbl = db.open_table("images_224") # doctest: +SKIP +>>> ds = LanceTorchDataset( # doctest: +SKIP +... tbl, columns=["image_bytes", "label"], format="torch" +... ) +>>> loader = torch.utils.data.DataLoader( # doctest: +SKIP +... ds, batch_size=64, num_workers=4, shuffle=True, +... ) +""" + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch.utils.data as _torch_data + +from ..permutation import Permutation +from ..table import LanceTable + + +def _capture_table_state(table: LanceTable) -> Dict[str, Any]: + """Pull just enough state out of a LanceTable so we can re-open the same + table in a forked worker process where the Rust handle isn't valid.""" + conn = table._conn + connect_kwargs: Dict[str, Any] = {} + storage_options = getattr(conn, "storage_options", None) + if storage_options is not None: + connect_kwargs["storage_options"] = storage_options + return { + "uri": conn.uri, + "table_name": table.name, + "connect_kwargs": connect_kwargs, + } + + +def _open_permutation(state: Dict[str, Any]) -> Permutation: + """Reconstruct a Permutation from a captured state dict.""" + import lancedb + + db = lancedb.connect(state["uri"], **state["connect_kwargs"]) + base = db.open_table(state["table_name"]) + + perm_table_name = state.get("perm_table_name") + if perm_table_name is not None: + perm_tbl = db.open_table(perm_table_name) + perm = Permutation.from_tables(base, perm_tbl, state.get("split")) + else: + perm = Permutation.identity(base) + + columns = state.get("columns") + fmt = state.get("format") + transform = state.get("transform") + batch_size = state.get("batch_size") + + if columns is not None: + perm = perm.select_columns(columns) + if fmt is not None: + perm = perm.with_format(fmt) + if transform is not None: + perm = perm.with_transform(transform) + if batch_size is not None: + perm = perm.with_batch_size(batch_size) + return perm + + +class LanceTorchDataset(_torch_data.Dataset): + """ + A PyTorch map-style ``Dataset`` backed by a LanceDB table or permutation. + + Pass the same ``LanceTable`` you already opened (and, optionally, a + permutation table / split / column selection / output format) and use + the result anywhere a ``torch.utils.data.Dataset`` is expected. + + The wrapper: + + * Stores the URI / table name / storage options needed to re-open the + table, not the Rust reader handle. Pickling keeps only the rebuild + recipe, so ``DataLoader(num_workers > 0)`` works out of the box. + * Implements both ``__getitem__`` and PyTorch's ``__getitems__`` dunder + so the underlying batched ``Permutation.fetch`` is used when the + DataLoader fetches a batch of indices. + + Parameters + ---------- + table : LanceTable, optional + The base table to read from. Either ``table`` or both ``uri`` and + ``table_name`` must be provided. + uri : str, optional + Database URI to reconnect to. Required if ``table`` is not given. + table_name : str, optional + Name of the base table within ``uri``. + connect_kwargs : dict, optional + Extra keyword arguments forwarded to ``lancedb.connect`` when + re-opening the database in a worker. + permutation_table : LanceTable, optional + A pre-built permutation table (see ``permutation_builder``) used to + define the row ordering. If omitted, the identity permutation is + used (rows in physical order). + split : str or int, optional + Split selector when ``permutation_table`` defines splits. + columns : list[str], optional + Subset of columns to read. + format : str, optional + Output format, forwarded to ``Permutation.with_format`` (e.g. + ``"torch"`` for HuggingFace-style ``dict[str, Tensor]`` batches). + transform : Callable, optional + Custom batch transform, forwarded to ``Permutation.with_transform``. + Must be picklable to work with ``num_workers > 0``. + batch_size : int, optional + Forwarded to ``Permutation.with_batch_size`` for direct iteration. + DataLoader controls its own batching, so this only matters if the + dataset is iterated directly. + """ + + def __init__( + self, + table: Optional[LanceTable] = None, + *, + uri: Optional[str] = None, + table_name: Optional[str] = None, + connect_kwargs: Optional[Dict[str, Any]] = None, + permutation_table: Optional[LanceTable] = None, + split: Optional[Union[str, int]] = None, + columns: Optional[List[str]] = None, + format: Optional[str] = None, + transform: Optional[Callable] = None, + batch_size: Optional[int] = None, + ): + if table is None and (uri is None or table_name is None): + raise ValueError( + "Provide either `table` or both `uri` and `table_name`." + ) + + if table is not None: + state = _capture_table_state(table) + if connect_kwargs is not None: + state["connect_kwargs"] = connect_kwargs + else: + state = { + "uri": uri, + "table_name": table_name, + "connect_kwargs": connect_kwargs or {}, + } + + state["perm_table_name"] = ( + permutation_table.name if permutation_table is not None else None + ) + state["split"] = split + state["columns"] = columns + state["format"] = format + state["transform"] = transform + state["batch_size"] = batch_size + + self._state: Dict[str, Any] = state + self._perm: Optional[Permutation] = None + + def __getstate__(self) -> Dict[str, Any]: + # Strip the Rust-backed reader so the dataset is picklable. Workers + # rebuild it on first read via _ensure_open(). + d = self.__dict__.copy() + d["_perm"] = None + return d + + def __setstate__(self, d: Dict[str, Any]) -> None: + self.__dict__.update(d) + + def _ensure_open(self) -> None: + if self._perm is None: + self._perm = _open_permutation(self._state) + + def __len__(self) -> int: + self._ensure_open() + return len(self._perm) + + def __getitem__(self, index: int) -> Any: + self._ensure_open() + return self._perm[index] + + def __getitems__(self, indices: List[int]) -> Any: + self._ensure_open() + return self._perm.fetch(indices) + + +class LanceIterableTorchDataset(_torch_data.IterableDataset): + """ + PyTorch iterable-style ``IterableDataset`` over a LanceDB permutation. + + Yields batches in the order defined by the underlying ``Permutation``. + With ``num_workers > 1`` each worker iterates the permutation + independently — for sharded iteration use the map-style + ``LanceTorchDataset`` together with a sampler. + + Constructor arguments mirror ``LanceTorchDataset``. + """ + + def __init__(self, *args, **kwargs): + self._inner = LanceTorchDataset(*args, **kwargs) + + def __getstate__(self) -> Dict[str, Any]: + return {"_inner": self._inner.__getstate__()} + + def __setstate__(self, d: Dict[str, Any]) -> None: + self._inner = LanceTorchDataset.__new__(LanceTorchDataset) + self._inner.__setstate__(d["_inner"]) + + def __iter__(self): + self._inner._ensure_open() + return iter(self._inner._perm) diff --git a/python/python/tests/test_torch_dataset.py b/python/python/tests/test_torch_dataset.py new file mode 100644 index 000000000..be53a61cb --- /dev/null +++ b/python/python/tests/test_torch_dataset.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import pickle + +import pyarrow as pa +import pytest + +from lancedb import connect +from lancedb.permutation import permutation_builder + +torch = pytest.importorskip("torch") +from lancedb.integrations.torch import ( # noqa: E402 + LanceIterableTorchDataset, + LanceTorchDataset, +) + + +@pytest.fixture +def db_path(tmp_path): + """LanceTorchDataset needs a real, on-disk DB so workers can re-open it.""" + return tmp_path + + +def _make_table(db_path, name="imgs", n=20): + db = connect(db_path) + return db.create_table( + name, + pa.table({"x": [float(i) for i in range(n)], "y": list(range(n))}), + ) + + +def test_basic_len_and_getitem(db_path): + tbl = _make_table(db_path) + ds = LanceTorchDataset(tbl) + assert len(ds) == 20 + row = ds[0] + # Default ("python") format = list of dicts; __getitem__ wraps a single index. + assert isinstance(row, list) + assert row[0] == {"x": 0.0, "y": 0} + + +def test_getitems_uses_fetch(db_path): + tbl = _make_table(db_path) + ds = LanceTorchDataset(tbl) + rows = ds.__getitems__([0, 2, 4]) + assert rows == [ + {"x": 0.0, "y": 0}, + {"x": 2.0, "y": 2}, + {"x": 4.0, "y": 4}, + ] + + +def test_dataloader_default_collate(db_path): + tbl = _make_table(db_path, n=40) + ds = LanceTorchDataset(tbl) + loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=False) + batch = next(iter(loader)) + # default collate stacks list-of-dicts into dict-of-tensors + assert isinstance(batch, dict) + assert batch["x"].size() == (8,) + assert batch["y"].size() == (8,) + + +def test_picklable(db_path): + tbl = _make_table(db_path) + ds = LanceTorchDataset(tbl, columns=["x"]) + + # Force open then ensure pickle drops the Rust handle. + _ = len(ds) + blob = pickle.dumps(ds) + restored: LanceTorchDataset = pickle.loads(blob) + # Rust state should not survive pickling. + assert restored._perm is None + # …but the dataset must work after re-opening transparently. + assert len(restored) == 20 + assert restored[0] == [{"x": 0.0}] + + +def test_dataloader_with_workers(db_path): + tbl = _make_table(db_path, n=32) + ds = LanceTorchDataset(tbl) + loader = torch.utils.data.DataLoader( + ds, batch_size=4, num_workers=2, shuffle=False + ) + batches = list(loader) + seen = [] + for b in batches: + seen.extend(b["x"].tolist()) + assert sorted(seen) == [float(i) for i in range(32)] + + +def test_with_permutation_table(db_path): + tbl = _make_table(db_path, n=30) + db = connect(db_path) + perm_tbl = ( + permutation_builder(tbl) + .split_random(ratios=[0.5, 0.5], seed=1, split_names=["train", "test"]) + .persist(db, "imgs_perm") + .execute() + ) + ds = LanceTorchDataset(tbl, permutation_table=perm_tbl, split="train") + # Should pickle/restore the permutation table reference too. + blob = pickle.dumps(ds) + restored = pickle.loads(blob) + assert len(restored) == 15 + + +def test_format_passthrough_dataloader(db_path): + """Custom `format` is forwarded to the underlying Permutation.""" + tbl = _make_table(db_path, n=20) + ds = LanceTorchDataset(tbl, format="arrow") + # Arrow batches don't go through default_collate, so use a no-op collate. + loader = torch.utils.data.DataLoader( + ds, batch_size=5, shuffle=False, collate_fn=lambda x: x + ) + batch = next(iter(loader)) + assert isinstance(batch, pa.RecordBatch) + assert batch.num_rows == 5 + + +def test_iterable_dataset(db_path): + tbl = _make_table(db_path, n=20) + ds = LanceIterableTorchDataset(tbl, batch_size=5) + batches = list(ds) + # default batch size + skip_last_batch=True yields full-size batches only + assert len(batches) == 4 + assert all(len(b) == 5 for b in batches) + + +def test_uri_table_name_constructor(db_path): + _make_table(db_path) + ds = LanceTorchDataset(uri=str(db_path), table_name="imgs") + assert len(ds) == 20 + assert ds[0] == [{"x": 0.0, "y": 0}] + + +def test_constructor_validates_args(): + with pytest.raises(ValueError, match="table"): + LanceTorchDataset()