mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-13 18:10:41 +00:00
fix(python): make Permutation picklable for PyTorch multiprocessing (#3335)
## Summary When pytorch is used with multiprocessing and the mp mode is spawn then the Permutation needs to be pickled. It could not be pickled because `Table` and `Connection` are not serializable. This PR adds pickle support to Permutation without adding general pickle support to `Table` or `Connection`. To add general support we probably need to start by adding serialization in the namespace client. In the meantime this PR enable pickling by adding special cases for: * In-memory tables (just serialize as Arrow IPC) * Native tables (serialize the URI) If a user is not using one of the above cases (e.g. using a remote connection) then they will need to provide a connection factory that can be pickled. ## Breaking change `PermutationBuilder.persist(...)` is removed from the Python bindings; the permutation table is now always in-memory. The underlying Rust `PermutationBuilder::persist` API is untouched and can be re-exposed later if needed. It probably won't make sense to do that until we have a way to serialize `Table` and `Connection`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -442,7 +442,7 @@ class AsyncPermutationBuilder:
|
||||
async def execute(self) -> Table: ...
|
||||
|
||||
def async_permutation_builder(
|
||||
table: Table, dest_table_name: str
|
||||
table: Table,
|
||||
) -> AsyncPermutationBuilder: ...
|
||||
def fts_query_to_json(query: Any) -> str: ...
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from deprecation import deprecated
|
||||
from lancedb import AsyncConnection, DBConnection
|
||||
import pyarrow as pa
|
||||
import copy
|
||||
import json
|
||||
|
||||
from deprecation import deprecated
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
@@ -36,10 +37,7 @@ class PermutationBuilder:
|
||||
be referenced by name in the future. If names are not provided then they can only
|
||||
be referenced by their ordinal index. There is no requirement to name every split.
|
||||
|
||||
By default, the permutation will be stored in memory and will be lost when the
|
||||
program exits. To persist the permutation (for very large datasets or to share
|
||||
the permutation across multiple workers) use the [persist](#persist) method to
|
||||
create a permanent table.
|
||||
The permutation is stored in memory and will be lost when the program exits.
|
||||
"""
|
||||
|
||||
def __init__(self, table: LanceTable):
|
||||
@@ -51,15 +49,6 @@ class PermutationBuilder:
|
||||
"""
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def persist(
|
||||
self, database: Union[DBConnection, AsyncConnection], table_name: str
|
||||
) -> "PermutationBuilder":
|
||||
"""
|
||||
Persist the permutation to the given database.
|
||||
"""
|
||||
self._async.persist(database, table_name)
|
||||
return self
|
||||
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
@@ -380,20 +369,44 @@ class Permutation:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: PermutationReader,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable],
|
||||
split: int,
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
connection_factory: Optional[Callable[[str], LanceTable]] = None,
|
||||
_reader: Optional[PermutationReader] = None,
|
||||
):
|
||||
"""
|
||||
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||
"""
|
||||
assert reader is not None, "reader is required"
|
||||
assert base_table is not None, "base_table is required"
|
||||
assert selection is not None, "selection is required"
|
||||
self.reader = reader
|
||||
self.base_table = base_table
|
||||
self.permutation_table = permutation_table
|
||||
self.split = split
|
||||
self.selection = selection
|
||||
self.transform_fn = transform_fn
|
||||
self.batch_size = batch_size
|
||||
self.offset = offset
|
||||
self.limit = limit
|
||||
self.connection_factory = connection_factory
|
||||
if _reader is None:
|
||||
_reader = LOOP.run(self._build_reader())
|
||||
self.reader: PermutationReader = _reader
|
||||
|
||||
async def _build_reader(self) -> PermutationReader:
|
||||
reader = await PermutationReader.from_tables(
|
||||
self.base_table, self.permutation_table, self.split
|
||||
)
|
||||
if self.offset is not None:
|
||||
reader = await reader.with_offset(self.offset)
|
||||
if self.limit is not None:
|
||||
reader = await reader.with_limit(self.limit)
|
||||
return reader
|
||||
|
||||
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
@@ -402,21 +415,97 @@ class Permutation:
|
||||
Does not validation of the selection and it replaces it entirely. This is not
|
||||
intended for public use.
|
||||
"""
|
||||
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||
|
||||
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given reader
|
||||
|
||||
This is an internal method and should not be used directly.
|
||||
"""
|
||||
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||
new = copy.copy(self)
|
||||
new.selection = selection
|
||||
return new
|
||||
|
||||
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given batch size
|
||||
"""
|
||||
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||
new = copy.copy(self)
|
||||
new.batch_size = batch_size
|
||||
return new
|
||||
|
||||
def with_connection_factory(
|
||||
self, connection_factory: Callable[[str], LanceTable]
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation that will use ``connection_factory`` to reopen
|
||||
the base table when this permutation is unpickled in a worker process.
|
||||
|
||||
The factory is a callable that takes a single argument — the base table
|
||||
name — and returns a [LanceTable]. It must be picklable; the worker
|
||||
will pickle it via standard ``pickle`` and call it to recover the base
|
||||
table. Picklable callables in practice means top-level (module-level)
|
||||
functions, ``functools.partial`` of such functions, or instances of
|
||||
picklable classes implementing ``__call__``. Lambdas and closures over
|
||||
local variables don't pickle with the default protocol.
|
||||
|
||||
Setting a factory is necessary when the URI alone is not enough to
|
||||
re-open the connection — most importantly for LanceDB Cloud (``db://``)
|
||||
connections, where ``api_key`` and ``region`` aren't recoverable from
|
||||
the connection object after construction.
|
||||
|
||||
For local file or cloud-storage paths the factory is optional: if not
|
||||
set, ``__getstate__`` falls back to capturing
|
||||
``(uri, storage_options, namespace_path)`` and re-opening via
|
||||
``lancedb.connect(uri, storage_options=...)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Basic native (file-system path), parameterized via ``functools.partial``::
|
||||
|
||||
import functools, lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
def open_native_table(uri: str, table_name: str):
|
||||
return lancedb.connect(uri).open_table(table_name)
|
||||
|
||||
factory = functools.partial(open_native_table, "/data/lance_db")
|
||||
permutation = Permutation.identity(
|
||||
factory("training")
|
||||
).with_connection_factory(factory)
|
||||
|
||||
Native via :func:`lancedb.connect_namespace` (e.g. a directory- or
|
||||
REST-backed namespace client). The factory takes the
|
||||
implementation name and properties dict as partial-bound args so
|
||||
the worker can rebuild the same namespace connection::
|
||||
|
||||
def open_via_namespace(
|
||||
impl: str, properties: dict[str, str], table_name: str,
|
||||
):
|
||||
return lancedb.connect_namespace(impl, properties).open_table(
|
||||
table_name,
|
||||
)
|
||||
|
||||
factory = functools.partial(
|
||||
open_via_namespace,
|
||||
"dir",
|
||||
{"root": "/data/lance_db"},
|
||||
)
|
||||
|
||||
LanceDB Cloud, reading credentials from env vars at worker startup
|
||||
so secrets aren't pickled into the dataset::
|
||||
|
||||
import os, lancedb
|
||||
|
||||
def open_remote_table(table_name: str):
|
||||
db = lancedb.connect(
|
||||
"db://my-database",
|
||||
api_key=os.environ["LANCEDB_API_KEY"],
|
||||
region=os.environ.get("LANCEDB_REGION", "us-east-1"),
|
||||
)
|
||||
return db.open_table(table_name)
|
||||
|
||||
permutation = Permutation.identity(
|
||||
open_remote_table("training")
|
||||
).with_connection_factory(open_remote_table)
|
||||
"""
|
||||
assert connection_factory is not None, "connection_factory is required"
|
||||
new = copy.copy(self)
|
||||
new.connection_factory = connection_factory
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
@@ -489,11 +578,126 @@ class Permutation:
|
||||
schema = await reader.output_schema(None)
|
||||
initial_selection = {name: name for name in schema.names}
|
||||
return cls(
|
||||
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||
base_table,
|
||||
permutation_table,
|
||||
split,
|
||||
initial_selection,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
Transforms.arrow2python,
|
||||
_reader=reader,
|
||||
)
|
||||
|
||||
return LOOP.run(do_from_tables())
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Build a picklable state dict for this permutation.
|
||||
|
||||
The base table is captured either via a user-supplied
|
||||
``connection_factory`` (see [with_connection_factory]) or, as a
|
||||
fallback, by introspecting ``(uri, storage_options, namespace_path)``
|
||||
on the connection. The permutation table — always an in-memory
|
||||
LanceDB table — is captured as a pyarrow Table (which pickles via
|
||||
Arrow IPC natively). The reader is dropped from the wire format;
|
||||
``__setstate__`` rebuilds it from the restored tables.
|
||||
"""
|
||||
permutation_data: Optional[pa.Table] = None
|
||||
if self.permutation_table is not None:
|
||||
permutation_data = self.permutation_table.to_arrow()
|
||||
|
||||
common = {
|
||||
"base_table_name": self.base_table.name,
|
||||
"permutation_data": permutation_data,
|
||||
"split": self.split,
|
||||
"selection": self.selection,
|
||||
"batch_size": self.batch_size,
|
||||
"transform_fn": self.transform_fn,
|
||||
"offset": self.offset,
|
||||
"limit": self.limit,
|
||||
"connection_factory": self.connection_factory,
|
||||
}
|
||||
|
||||
if self.connection_factory is not None:
|
||||
# The factory carries enough state to recover the base table on
|
||||
# its own; we don't need to capture the URI / storage options /
|
||||
# namespace from the existing connection.
|
||||
return common
|
||||
|
||||
# URI-introspection fallback: only viable for native (OSS) connections
|
||||
# where (uri, storage_options) is enough to reopen. Remote / cloud
|
||||
# connections don't expose recoverable api_key / region — those users
|
||||
# must call with_connection_factory().
|
||||
try:
|
||||
base_uri = self.base_table._conn.uri
|
||||
storage_options = self.base_table._conn.storage_options
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"Cannot pickle this Permutation: the base table's connection "
|
||||
"does not expose a uri/storage_options, which usually means it "
|
||||
"is a remote (LanceDB Cloud) connection. Call "
|
||||
"Permutation.with_connection_factory(...) first to provide a "
|
||||
"picklable callable that re-opens the base table from a worker "
|
||||
"process."
|
||||
) from e
|
||||
|
||||
if base_uri.startswith("memory://"):
|
||||
# In-memory base tables don't exist in any worker process by
|
||||
# default, so dump the entire base table into the pickle. This
|
||||
# can be expensive for large datasets — users with large
|
||||
# in-memory base tables should either persist them or set a
|
||||
# connection_factory.
|
||||
return {
|
||||
**common,
|
||||
"base_table_data": self.base_table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
**common,
|
||||
"base_table_uri": base_uri,
|
||||
"base_table_namespace": self.base_table._namespace_path,
|
||||
"base_table_storage_options": storage_options,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
from . import connect
|
||||
|
||||
connection_factory = state["connection_factory"]
|
||||
if connection_factory is not None:
|
||||
base_table = connection_factory(state["base_table_name"])
|
||||
elif "base_table_data" in state:
|
||||
# In-memory base table inlined into the pickle; rebuild the same
|
||||
# way we rebuild the in-memory permutation table.
|
||||
mem_db = connect("memory://")
|
||||
base_table = mem_db.create_table(
|
||||
state["base_table_name"], state["base_table_data"]
|
||||
)
|
||||
else:
|
||||
base_db = connect(
|
||||
state["base_table_uri"],
|
||||
storage_options=state["base_table_storage_options"],
|
||||
)
|
||||
base_table = base_db.open_table(
|
||||
state["base_table_name"],
|
||||
namespace_path=state["base_table_namespace"] or None,
|
||||
)
|
||||
|
||||
permutation_table: Optional[LanceTable] = None
|
||||
if state["permutation_data"] is not None:
|
||||
mem_db = connect("memory://")
|
||||
permutation_table = mem_db.create_table(
|
||||
"permutation", state["permutation_data"]
|
||||
)
|
||||
|
||||
self.base_table = base_table
|
||||
self.permutation_table = permutation_table
|
||||
self.split = state["split"]
|
||||
self.selection = state["selection"]
|
||||
self.batch_size = state["batch_size"]
|
||||
self.transform_fn = state["transform_fn"]
|
||||
self.offset = state["offset"]
|
||||
self.limit = state["limit"]
|
||||
self.connection_factory = connection_factory
|
||||
self.reader = LOOP.run(self._build_reader())
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
async def do_output_schema():
|
||||
@@ -760,7 +964,9 @@ class Permutation:
|
||||
for expensive operations such as image decoding.
|
||||
"""
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
new = copy.copy(self)
|
||||
new.transform_fn = transform
|
||||
return new
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
@@ -795,12 +1001,10 @@ class Permutation:
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
"""
|
||||
|
||||
async def do_with_skip():
|
||||
reader = await self.reader.with_offset(skip)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_skip())
|
||||
new = copy.copy(self)
|
||||
new.offset = skip
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_take instead")
|
||||
def take(self, limit: int) -> "Permutation":
|
||||
@@ -818,12 +1022,10 @@ class Permutation:
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
"""
|
||||
|
||||
async def do_with_take():
|
||||
reader = await self.reader.with_limit(limit)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_take())
|
||||
new = copy.copy(self)
|
||||
new.limit = limit
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_repeat instead")
|
||||
def repeat(self, times: int) -> "Permutation":
|
||||
|
||||
@@ -9,21 +9,6 @@ from lancedb import DBConnection, Table, connect
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
|
||||
|
||||
def test_permutation_persistence(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
|
||||
)
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
re_open = db.open_table("test_permutation")
|
||||
assert re_open.count_rows() == 100
|
||||
|
||||
assert permutation_tbl.to_arrow() == re_open.to_arrow()
|
||||
|
||||
|
||||
def test_split_random_ratios(mem_db):
|
||||
"""Test random splitting with ratios."""
|
||||
tbl = mem_db.create_table(
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
from lancedb.permutation import Permutation
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def _open_native_table(uri: str, table_name: str):
|
||||
"""Top-level connection factory used by the explicit-factory pickle test.
|
||||
|
||||
Defined at module scope so that pickle can resolve it by name in the
|
||||
worker / unpickling process.
|
||||
"""
|
||||
return lancedb.connect(uri).open_table(table_name)
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -40,3 +53,96 @@ def test_permutation_dataloader(mem_db):
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
|
||||
|
||||
def test_permutation_is_picklable(tmp_db):
|
||||
"""A Permutation must be picklable so it can be used with PyTorch's
|
||||
DataLoader when num_workers > 0 (which uses multiprocessing and pickles
|
||||
the dataset to pass it to worker processes)."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
pickled = pickle.dumps(permutation)
|
||||
restored = pickle.loads(pickled)
|
||||
|
||||
assert len(restored) == 1000
|
||||
rows = restored.__getitems__([0, 1, 2])
|
||||
assert rows == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_memory_base_is_picklable(mem_db):
|
||||
"""An in-memory base table is inlined into the pickle as Arrow IPC bytes
|
||||
and rebuilt on the other side as an in-memory LanceTable, so the
|
||||
Permutation round-trips even though the original database can't be
|
||||
reopened across processes."""
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
assert restored.__getitems__([0, 10, 49]) == [{"a": 0}, {"a": 10}, {"a": 49}]
|
||||
|
||||
|
||||
def test_permutation_dataloader_multiprocessing(tmp_db):
|
||||
"""Using a Permutation with a PyTorch DataLoader that has num_workers > 0
|
||||
must work end-to-end. Each worker process gets a pickled copy of the
|
||||
dataset and reads batches from it."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
seen += batch["a"].size(0)
|
||||
assert seen == 1000
|
||||
|
||||
|
||||
def test_permutation_pickle_with_connection_factory(tmp_path):
|
||||
"""When the user provides a connection_factory, pickling should round-trip
|
||||
through that factory rather than introspecting the connection URI. Useful
|
||||
for remote / cloud connections where the URI alone isn't reopenable."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
|
||||
factory = functools.partial(_open_native_table, str(tmp_path))
|
||||
permutation = Permutation.identity(factory("test_table")).with_connection_factory(
|
||||
factory
|
||||
)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
# The factory survives pickling and is what powered base-table reopen.
|
||||
assert restored.connection_factory is not None
|
||||
assert restored.connection_factory.func is _open_native_table
|
||||
assert restored.__getitems__([0, 1, 2]) == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_builder_is_picklable(tmp_db):
|
||||
"""A Permutation built from a non-identity permutation table must round-trip
|
||||
through pickle while preserving the row order defined by the permutation."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(100)}))
|
||||
perm_tbl = (
|
||||
permutation_builder(table)
|
||||
.split_random(ratios=[0.8, 0.2], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
permutations = Permutations(table, perm_tbl)
|
||||
permutation = permutations["train"]
|
||||
|
||||
indices = list(range(len(permutation)))
|
||||
expected = permutation.__getitems__(indices)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == len(permutation)
|
||||
assert restored.__getitems__(indices) == expected
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
};
|
||||
use crate::{arrow::RecordBatchStream, error::PythonErrorExt, table::Table};
|
||||
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
@@ -80,24 +78,6 @@ impl PyAsyncPermutationBuilder {
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (database, table_name))]
|
||||
pub fn persist(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
database: Bound<'_, PyAny>,
|
||||
table_name: String,
|
||||
) -> PyResult<Self> {
|
||||
let conn = if database.hasattr("_conn")? {
|
||||
database
|
||||
.getattr("_conn")?
|
||||
.getattr("_inner")?
|
||||
.cast_into::<Connection>()?
|
||||
} else {
|
||||
database.getattr("_inner")?.cast_into::<Connection>()?
|
||||
};
|
||||
let database = conn.borrow().database()?;
|
||||
slf.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
|
||||
Reference in New Issue
Block a user