diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 2298a9473..b33f89e40 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -442,7 +442,7 @@ class AsyncPermutationBuilder: async def execute(self) -> Table: ... def async_permutation_builder( - table: Table, dest_table_name: str + table: Table, ) -> AsyncPermutationBuilder: ... def fts_query_to_json(query: Any) -> str: ... diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 724a0fd25..91532f0a7 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -from deprecation import deprecated -from lancedb import AsyncConnection, DBConnection -import pyarrow as pa +import copy import json +from deprecation import deprecated +import pyarrow as pa + from ._lancedb import async_permutation_builder, PermutationReader from .table import LanceTable from .background_loop import LOOP @@ -36,10 +37,7 @@ class PermutationBuilder: be referenced by name in the future. If names are not provided then they can only be referenced by their ordinal index. There is no requirement to name every split. - By default, the permutation will be stored in memory and will be lost when the - program exits. To persist the permutation (for very large datasets or to share - the permutation across multiple workers) use the [persist](#persist) method to - create a permanent table. + The permutation is stored in memory and will be lost when the program exits. """ def __init__(self, table: LanceTable): @@ -51,15 +49,6 @@ class PermutationBuilder: """ self._async = async_permutation_builder(table) - def persist( - self, database: Union[DBConnection, AsyncConnection], table_name: str - ) -> "PermutationBuilder": - """ - Persist the permutation to the given database. - """ - self._async.persist(database, table_name) - return self - def split_random( self, *, @@ -380,20 +369,44 @@ class Permutation: def __init__( self, - reader: PermutationReader, + base_table: LanceTable, + permutation_table: Optional[LanceTable], + split: int, selection: dict[str, str], batch_size: int, transform_fn: Callable[pa.RecordBatch, Any], + offset: Optional[int] = None, + limit: Optional[int] = None, + connection_factory: Optional[Callable[[str], LanceTable]] = None, + _reader: Optional[PermutationReader] = None, ): """ Internal constructor. Use [from_tables](#from_tables) instead. """ - assert reader is not None, "reader is required" + assert base_table is not None, "base_table is required" assert selection is not None, "selection is required" - self.reader = reader + self.base_table = base_table + self.permutation_table = permutation_table + self.split = split self.selection = selection self.transform_fn = transform_fn self.batch_size = batch_size + self.offset = offset + self.limit = limit + self.connection_factory = connection_factory + if _reader is None: + _reader = LOOP.run(self._build_reader()) + self.reader: PermutationReader = _reader + + async def _build_reader(self) -> PermutationReader: + reader = await PermutationReader.from_tables( + self.base_table, self.permutation_table, self.split + ) + if self.offset is not None: + reader = await reader.with_offset(self.offset) + if self.limit is not None: + reader = await reader.with_limit(self.limit) + return reader def _with_selection(self, selection: dict[str, str]) -> "Permutation": """ @@ -402,21 +415,97 @@ class Permutation: Does not validation of the selection and it replaces it entirely. This is not intended for public use. """ - return Permutation(self.reader, selection, self.batch_size, self.transform_fn) - - def _with_reader(self, reader: PermutationReader) -> "Permutation": - """ - Creates a new permutation with the given reader - - This is an internal method and should not be used directly. - """ - return Permutation(reader, self.selection, self.batch_size, self.transform_fn) + new = copy.copy(self) + new.selection = selection + return new def with_batch_size(self, batch_size: int) -> "Permutation": """ Creates a new permutation with the given batch size """ - return Permutation(self.reader, self.selection, batch_size, self.transform_fn) + new = copy.copy(self) + new.batch_size = batch_size + return new + + def with_connection_factory( + self, connection_factory: Callable[[str], LanceTable] + ) -> "Permutation": + """ + Creates a new permutation that will use ``connection_factory`` to reopen + the base table when this permutation is unpickled in a worker process. + + The factory is a callable that takes a single argument — the base table + name — and returns a [LanceTable]. It must be picklable; the worker + will pickle it via standard ``pickle`` and call it to recover the base + table. Picklable callables in practice means top-level (module-level) + functions, ``functools.partial`` of such functions, or instances of + picklable classes implementing ``__call__``. Lambdas and closures over + local variables don't pickle with the default protocol. + + Setting a factory is necessary when the URI alone is not enough to + re-open the connection — most importantly for LanceDB Cloud (``db://``) + connections, where ``api_key`` and ``region`` aren't recoverable from + the connection object after construction. + + For local file or cloud-storage paths the factory is optional: if not + set, ``__getstate__`` falls back to capturing + ``(uri, storage_options, namespace_path)`` and re-opening via + ``lancedb.connect(uri, storage_options=...)``. + + Examples + -------- + Basic native (file-system path), parameterized via ``functools.partial``:: + + import functools, lancedb + from lancedb.permutation import Permutation + + def open_native_table(uri: str, table_name: str): + return lancedb.connect(uri).open_table(table_name) + + factory = functools.partial(open_native_table, "/data/lance_db") + permutation = Permutation.identity( + factory("training") + ).with_connection_factory(factory) + + Native via :func:`lancedb.connect_namespace` (e.g. a directory- or + REST-backed namespace client). The factory takes the + implementation name and properties dict as partial-bound args so + the worker can rebuild the same namespace connection:: + + def open_via_namespace( + impl: str, properties: dict[str, str], table_name: str, + ): + return lancedb.connect_namespace(impl, properties).open_table( + table_name, + ) + + factory = functools.partial( + open_via_namespace, + "dir", + {"root": "/data/lance_db"}, + ) + + LanceDB Cloud, reading credentials from env vars at worker startup + so secrets aren't pickled into the dataset:: + + import os, lancedb + + def open_remote_table(table_name: str): + db = lancedb.connect( + "db://my-database", + api_key=os.environ["LANCEDB_API_KEY"], + region=os.environ.get("LANCEDB_REGION", "us-east-1"), + ) + return db.open_table(table_name) + + permutation = Permutation.identity( + open_remote_table("training") + ).with_connection_factory(open_remote_table) + """ + assert connection_factory is not None, "connection_factory is required" + new = copy.copy(self) + new.connection_factory = connection_factory + return new @classmethod def identity(cls, table: LanceTable) -> "Permutation": @@ -489,11 +578,126 @@ class Permutation: schema = await reader.output_schema(None) initial_selection = {name: name for name in schema.names} return cls( - reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python + base_table, + permutation_table, + split, + initial_selection, + DEFAULT_BATCH_SIZE, + Transforms.arrow2python, + _reader=reader, ) return LOOP.run(do_from_tables()) + def __getstate__(self) -> dict[str, Any]: + """Build a picklable state dict for this permutation. + + The base table is captured either via a user-supplied + ``connection_factory`` (see [with_connection_factory]) or, as a + fallback, by introspecting ``(uri, storage_options, namespace_path)`` + on the connection. The permutation table — always an in-memory + LanceDB table — is captured as a pyarrow Table (which pickles via + Arrow IPC natively). The reader is dropped from the wire format; + ``__setstate__`` rebuilds it from the restored tables. + """ + permutation_data: Optional[pa.Table] = None + if self.permutation_table is not None: + permutation_data = self.permutation_table.to_arrow() + + common = { + "base_table_name": self.base_table.name, + "permutation_data": permutation_data, + "split": self.split, + "selection": self.selection, + "batch_size": self.batch_size, + "transform_fn": self.transform_fn, + "offset": self.offset, + "limit": self.limit, + "connection_factory": self.connection_factory, + } + + if self.connection_factory is not None: + # The factory carries enough state to recover the base table on + # its own; we don't need to capture the URI / storage options / + # namespace from the existing connection. + return common + + # URI-introspection fallback: only viable for native (OSS) connections + # where (uri, storage_options) is enough to reopen. Remote / cloud + # connections don't expose recoverable api_key / region — those users + # must call with_connection_factory(). + try: + base_uri = self.base_table._conn.uri + storage_options = self.base_table._conn.storage_options + except AttributeError as e: + raise ValueError( + "Cannot pickle this Permutation: the base table's connection " + "does not expose a uri/storage_options, which usually means it " + "is a remote (LanceDB Cloud) connection. Call " + "Permutation.with_connection_factory(...) first to provide a " + "picklable callable that re-opens the base table from a worker " + "process." + ) from e + + if base_uri.startswith("memory://"): + # In-memory base tables don't exist in any worker process by + # default, so dump the entire base table into the pickle. This + # can be expensive for large datasets — users with large + # in-memory base tables should either persist them or set a + # connection_factory. + return { + **common, + "base_table_data": self.base_table.to_arrow(), + } + + return { + **common, + "base_table_uri": base_uri, + "base_table_namespace": self.base_table._namespace_path, + "base_table_storage_options": storage_options, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + from . import connect + + connection_factory = state["connection_factory"] + if connection_factory is not None: + base_table = connection_factory(state["base_table_name"]) + elif "base_table_data" in state: + # In-memory base table inlined into the pickle; rebuild the same + # way we rebuild the in-memory permutation table. + mem_db = connect("memory://") + base_table = mem_db.create_table( + state["base_table_name"], state["base_table_data"] + ) + else: + base_db = connect( + state["base_table_uri"], + storage_options=state["base_table_storage_options"], + ) + base_table = base_db.open_table( + state["base_table_name"], + namespace_path=state["base_table_namespace"] or None, + ) + + permutation_table: Optional[LanceTable] = None + if state["permutation_data"] is not None: + mem_db = connect("memory://") + permutation_table = mem_db.create_table( + "permutation", state["permutation_data"] + ) + + self.base_table = base_table + self.permutation_table = permutation_table + self.split = state["split"] + self.selection = state["selection"] + self.batch_size = state["batch_size"] + self.transform_fn = state["transform_fn"] + self.offset = state["offset"] + self.limit = state["limit"] + self.connection_factory = connection_factory + self.reader = LOOP.run(self._build_reader()) + @property def schema(self) -> pa.Schema: async def do_output_schema(): @@ -760,7 +964,9 @@ class Permutation: for expensive operations such as image decoding. """ assert transform is not None, "transform is required" - return Permutation(self.reader, self.selection, self.batch_size, transform) + new = copy.copy(self) + new.transform_fn = transform + return new def __getitem__(self, index: int) -> Any: """ @@ -795,12 +1001,10 @@ class Permutation: """ Skip the first `skip` rows of the permutation """ - - async def do_with_skip(): - reader = await self.reader.with_offset(skip) - return self._with_reader(reader) - - return LOOP.run(do_with_skip()) + new = copy.copy(self) + new.offset = skip + new.reader = LOOP.run(new._build_reader()) + return new @deprecated(details="Use with_take instead") def take(self, limit: int) -> "Permutation": @@ -818,12 +1022,10 @@ class Permutation: """ Limit the permutation to `limit` rows (following any `skip`) """ - - async def do_with_take(): - reader = await self.reader.with_limit(limit) - return self._with_reader(reader) - - return LOOP.run(do_with_take()) + new = copy.copy(self) + new.limit = limit + new.reader = LOOP.run(new._build_reader()) + return new @deprecated(details="Use with_repeat instead") def repeat(self, times: int) -> "Permutation": diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index bb92ba0ba..96d77f9d1 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -9,21 +9,6 @@ from lancedb import DBConnection, Table, connect from lancedb.permutation import Permutation, Permutations, permutation_builder -def test_permutation_persistence(tmp_path): - db = connect(tmp_path) - tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)})) - - permutation_tbl = ( - permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute() - ) - assert permutation_tbl.count_rows() == 100 - - re_open = db.open_table("test_permutation") - assert re_open.count_rows() == 100 - - assert permutation_tbl.to_arrow() == re_open.to_arrow() - - def test_split_random_ratios(mem_db): """Test random splitting with ratios.""" tbl = mem_db.create_table( diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index ef1c5e73b..0ca1de3e8 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -1,14 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +import functools +import pickle + +import lancedb import pyarrow as pa import pytest from lancedb.util import tbl_to_tensor -from lancedb.permutation import Permutation +from lancedb.permutation import Permutation, Permutations, permutation_builder torch = pytest.importorskip("torch") +def _open_native_table(uri: str, table_name: str): + """Top-level connection factory used by the explicit-factory pickle test. + + Defined at module scope so that pickle can resolve it by name in the + worker / unpickling process. + """ + return lancedb.connect(uri).open_table(table_name) + + def test_table_dataloader(mem_db): table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) dataloader = torch.utils.data.DataLoader( @@ -40,3 +53,96 @@ def test_permutation_dataloader(mem_db): for batch in dataloader: assert batch.size(0) == 1 assert batch.size(1) == 10 + + +def test_permutation_is_picklable(tmp_db): + """A Permutation must be picklable so it can be used with PyTorch's + DataLoader when num_workers > 0 (which uses multiprocessing and pickles + the dataset to pass it to worker processes).""" + table = tmp_db.create_table("test_table", pa.table({"a": range(1000)})) + permutation = Permutation.identity(table) + + pickled = pickle.dumps(permutation) + restored = pickle.loads(pickled) + + assert len(restored) == 1000 + rows = restored.__getitems__([0, 1, 2]) + assert rows == [{"a": 0}, {"a": 1}, {"a": 2}] + + +def test_permutation_with_memory_base_is_picklable(mem_db): + """An in-memory base table is inlined into the pickle as Arrow IPC bytes + and rebuilt on the other side as an in-memory LanceTable, so the + Permutation round-trips even though the original database can't be + reopened across processes.""" + table = mem_db.create_table("test_table", pa.table({"a": range(50)})) + permutation = Permutation.identity(table) + + restored = pickle.loads(pickle.dumps(permutation)) + + assert len(restored) == 50 + assert restored.__getitems__([0, 10, 49]) == [{"a": 0}, {"a": 10}, {"a": 49}] + + +def test_permutation_dataloader_multiprocessing(tmp_db): + """Using a Permutation with a PyTorch DataLoader that has num_workers > 0 + must work end-to-end. Each worker process gets a pickled copy of the + dataset and reads batches from it.""" + table = tmp_db.create_table("test_table", pa.table({"a": range(1000)})) + permutation = Permutation.identity(table) + + dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=10, + shuffle=True, + num_workers=2, + multiprocessing_context="spawn", + ) + seen = 0 + for batch in dataloader: + assert batch["a"].size(0) == 10 + seen += batch["a"].size(0) + assert seen == 1000 + + +def test_permutation_pickle_with_connection_factory(tmp_path): + """When the user provides a connection_factory, pickling should round-trip + through that factory rather than introspecting the connection URI. Useful + for remote / cloud connections where the URI alone isn't reopenable.""" + db = lancedb.connect(tmp_path) + db.create_table("test_table", pa.table({"a": range(50)})) + + factory = functools.partial(_open_native_table, str(tmp_path)) + permutation = Permutation.identity(factory("test_table")).with_connection_factory( + factory + ) + + restored = pickle.loads(pickle.dumps(permutation)) + + assert len(restored) == 50 + # The factory survives pickling and is what powered base-table reopen. + assert restored.connection_factory is not None + assert restored.connection_factory.func is _open_native_table + assert restored.__getitems__([0, 1, 2]) == [{"a": 0}, {"a": 1}, {"a": 2}] + + +def test_permutation_with_builder_is_picklable(tmp_db): + """A Permutation built from a non-identity permutation table must round-trip + through pickle while preserving the row order defined by the permutation.""" + table = tmp_db.create_table("test_table", pa.table({"a": range(100)})) + perm_tbl = ( + permutation_builder(table) + .split_random(ratios=[0.8, 0.2], seed=42, split_names=["train", "test"]) + .shuffle(seed=42) + .execute() + ) + permutations = Permutations(table, perm_tbl) + permutation = permutations["train"] + + indices = list(range(len(permutation))) + expected = permutation.__getitems__(indices) + + restored = pickle.loads(pickle.dumps(permutation)) + + assert len(restored) == len(permutation) + assert restored.__getitems__(indices) == expected diff --git a/python/src/permutation.rs b/python/src/permutation.rs index ac20a2cc9..114825938 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -3,9 +3,7 @@ use std::sync::{Arc, Mutex}; -use crate::{ - arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table, -}; +use crate::{arrow::RecordBatchStream, error::PythonErrorExt, table::Table}; use arrow::pyarrow::{PyArrowType, ToPyArrow}; use lancedb::{ dataloader::permutation::{ @@ -80,24 +78,6 @@ impl PyAsyncPermutationBuilder { #[pymethods] impl PyAsyncPermutationBuilder { - #[pyo3(signature = (database, table_name))] - pub fn persist( - slf: PyRefMut<'_, Self>, - database: Bound<'_, PyAny>, - table_name: String, - ) -> PyResult { - let conn = if database.hasattr("_conn")? { - database - .getattr("_conn")? - .getattr("_inner")? - .cast_into::()? - } else { - database.getattr("_inner")?.cast_into::()? - }; - let database = conn.borrow().database()?; - slf.modify(|builder| builder.persist(database, table_name)) - } - #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))] pub fn split_random( slf: PyRefMut<'_, Self>,