feat: a utility for creating "permutation views" (#2552)

I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
This commit is contained in:
Weston Pace
2025-10-09 18:07:31 -07:00
committed by GitHub
parent 3dcec724b7
commit 5a19cf15a6
38 changed files with 3786 additions and 58 deletions

View File

@@ -296,3 +296,34 @@ class AlterColumnsResult:
class DropColumnsResult:
version: int
class AsyncPermutationBuilder:
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
def split_random(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_hash(
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
) -> "AsyncPermutationBuilder": ...
def split_sequential(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
def shuffle(
self, seed: Optional[int], clump_size: Optional[int]
) -> "AsyncPermutationBuilder": ...
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
async def execute(self) -> Table: ...
def async_permutation_builder(
table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
@@ -40,7 +41,6 @@ import deprecation
if TYPE_CHECKING:
import pyarrow as pa
from .pydantic import LanceModel
from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
from .common import DATA, URI
@@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection):
read_consistency_interval: Optional[timedelta] = None,
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
_inner: Optional[LanceDbConnection] = None,
):
if _inner is not None:
self._conn = _inner
return
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
self.storage_options = storage_options
self.session = session
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
@@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection):
session,
)
# TODO: It would be nice if we didn't store self.storage_options but it is
# currently used by the LanceTable.to_lance method. This doesn't _really_
# work because some paths like LanceDBConnection.from_inner will lose the
# storage_options. Also, this class really shouldn't be holding any state
# beyond _conn.
self.storage_options = storage_options
self._conn = AsyncConnection(LOOP.run(do_connect()))
@property
def read_consistency_interval(self) -> Optional[timedelta]:
return LOOP.run(self._conn.get_read_consistency_interval())
@property
def session(self) -> Optional[Session]:
return self._conn.session
@property
def uri(self) -> str:
return self._conn.uri
@classmethod
def from_inner(cls, inner: LanceDbConnection):
return cls(None, _inner=inner)
def __repr__(self) -> str:
val = f"{self.__class__.__name__}(uri={self._uri!r}"
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@property
def _inner(self) -> LanceDbConnection:
return self._conn._inner
@override
def list_namespaces(
self,
@@ -856,6 +882,13 @@ class AsyncConnection(object):
def uri(self) -> str:
return self._inner.uri
async def get_read_consistency_interval(self) -> Optional[timedelta]:
interval_secs = await self._inner.get_read_consistency_interval()
if interval_secs is not None:
return timedelta(seconds=interval_secs)
else:
return None
async def list_namespaces(
self,
namespace: List[str] = [],

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from ._lancedb import async_permutation_builder
from .table import LanceTable
from .background_loop import LOOP
from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable, dest_table_name: str):
self._async = async_permutation_builder(table, dest_table_name)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
return self
def split_random(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
return self
def split_hash(
self,
columns: list[str],
split_weights: list[int],
*,
discard_weight: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
return self
def split_sequential(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
return self
def split_calculated(self, calculation: str) -> "PermutationBuilder":
self._async.split_calculated(calculation)
return self
def shuffle(
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
) -> "PermutationBuilder":
self._async.shuffle(seed=seed, clump_size=clump_size)
return self
def filter(self, filter: str) -> "PermutationBuilder":
self._async.filter(filter)
return self
def execute(self) -> LanceTable:
async def do_execute():
inner_tbl = await self._async.execute()
return LanceTable.from_inner(inner_tbl)
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
return PermutationBuilder(table, dest_table_name)

View File

@@ -74,6 +74,7 @@ from .index import lang_mapping
if TYPE_CHECKING:
from .db import LanceDBConnection
from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
MergeResult,
UpdateResult,
)
from .db import LanceDBConnection
from .index import IndexConfig
import pandas
import PIL
@@ -1707,22 +1707,38 @@ class LanceTable(Table):
namespace: List[str] = [],
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
_async: AsyncTable = None,
):
self._conn = connection
self._namespace = namespace
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
if _async is not None:
self._table = _async
else:
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
)
)
@property
def name(self) -> str:
return self._table.name
@classmethod
def from_inner(cls, tbl: LanceDBTable):
from .db import LanceDBConnection
async_tbl = AsyncTable(tbl)
conn = LanceDBConnection.from_inner(tbl.database())
return cls(
conn,
async_tbl.name,
_async=async_tbl,
)
@classmethod
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
tbl = cls(db, name, namespace=namespace, **kwargs)
@@ -2756,6 +2772,10 @@ class LanceTable(Table):
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
@property
def _inner(self) -> LanceDBTable:
return self._table._inner
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,

View File

@@ -0,0 +1,496 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
from lancedb.permutation import permutation_builder
def test_split_random_ratios(mem_db):
"""Test random splitting with ratios."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(ratios=[0.3, 0.7])
.execute()
)
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
# Check that split_id column exists and has correct values
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Check approximate split sizes (allowing for rounding)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
def test_split_random_counts(mem_db):
"""Test random splitting with absolute counts."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(counts=[20, 30])
.execute()
)
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert split_ids.count(0) == 20
assert split_ids.count(1) == 30
def test_split_random_fixed(mem_db):
"""Test random splitting with fixed number of splits."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
)
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1, 2, 3}
for split_id in range(4):
assert split_ids.count(split_id) == 25
def test_split_random_with_seed(mem_db):
"""Test that seeded random splits are reproducible."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_split_hash(mem_db):
"""Test hash-based splitting."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
"value": range(100),
}
),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
# Should have all 100 rows (no discard)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Verify that each split has roughly 50 rows (allowing for hash variance)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl, "test_permutation2")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
data2 = perm2.search(None).to_arrow().to_pydict()
assert data["split_id"] == data2["split_id"] # Should be identical
def test_split_hash_with_discard(mem_db):
"""Test hash-based splitting with discard weight."""
tbl = mem_db.create_table(
"test_table",
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
# Should have fewer than 100 rows due to discard
row_count = permutation_tbl.count_rows()
assert row_count < 100
assert row_count > 0 # But not empty
def test_split_sequential(mem_db):
"""Test sequential splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[30, 40])
.execute()
)
assert permutation_tbl.count_rows() == 70
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Sequential should maintain order
assert row_ids == sorted(row_ids)
# First 30 should be split 0, next 40 should be split 1
assert split_ids[:30] == [0] * 30
assert split_ids[30:] == [1] * 40
def test_split_calculated(mem_db):
"""Test calculated splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Verify the calculation: each row's split_id should equal row_id % 3
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
assert split_id == row_id % 3
def test_split_error_cases(mem_db):
"""Test error handling for invalid split parameters."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error1").split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error2").split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error3").split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error4").split_sequential(
ratios=[0.5, 0.5], fixed=2
).execute()
def test_shuffle_no_seed(mem_db):
"""Test shuffling without a seed."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# Row IDs should not be in sequential order due to shuffling
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
# in order
assert row_ids != list(range(100))
def test_shuffle_with_seed(mem_db):
"""Test that shuffling with a seed is reproducible."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_shuffle_with_clump_size(mem_db):
"""Test shuffling with clump size."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
for i in range(10):
start = row_ids[i * 10]
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
def test_shuffle_different_seeds(mem_db):
"""Test that different seeds produce different shuffle orders."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two permutations with different shuffle seeds
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(fixed=2)
.shuffle(seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(fixed=2)
.shuffle(seed=123)
.execute()
)
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
# Row order should be different
assert data1["row_id"] != data2["row_id"]
def test_shuffle_combined_with_splits(mem_db):
"""Test shuffling combined with different split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl, "perm_random")
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
)
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl, "perm_hash")
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
)
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl, "perm_sequential")
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
)
# Verify all permutations work and have expected properties
assert perm_random.count_rows() == 100
assert perm_hash.count_rows() == 100
assert perm_sequential.count_rows() == 75
# Verify shuffle affected the order
data_random = perm_random.search(None).to_arrow().to_pydict()
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
assert data_random["row_id"] != list(range(100))
assert data_sequential["row_id"] != list(range(75))
def test_no_shuffle_maintains_order(mem_db):
"""Test that not calling shuffle maintains the original order."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# With sequential splits and no shuffle, should maintain order
assert row_ids == list(range(50))
def test_filter_basic(mem_db):
"""Test basic filtering functionality."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
)
# Filter to only include rows where id < 50
permutation_tbl = (
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# All row_ids should be less than 50
assert all(row_id < 50 for row_id in row_ids)
def test_filter_with_splits(mem_db):
"""Test filtering combined with split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
)
# Should have fewer than 100 rows due to filtering
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):
"""Test filtering combined with shuffling."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C", "D"] * 25)[:100],
"value": range(100),
}
),
)
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
)
row_count = permutation_tbl.count_rows()
assert row_count == 50 # Should have 50 rows (A and C categories)
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
assert row_ids != sorted(row_ids)
def test_filter_empty_result(mem_db):
"""Test filtering that results in empty set."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("value > 100") # No values > 100 in our data
.execute()
)
assert permutation_tbl.count_rows() == 0

View File

@@ -4,7 +4,10 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -23,7 +26,7 @@ impl Connection {
Self { inner: Some(inner) }
}
fn get_inner(&self) -> PyResult<&LanceConnection> {
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -63,6 +66,18 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = ())]
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
Ok(match inner.read_consistency().await.infer_error()? {
ReadConsistency::Manual => None,
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
ReadConsistency::Strong => Some(0.0_f64),
})
})
}
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -22,6 +23,7 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DeleteResult>()?;
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

177
python/src/permutation.rs Normal file
View File

@@ -0,0 +1,177 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
};
use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(
table: Bound<'_, PyAny>,
dest_table_name: String,
) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
dest_table_name,
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
dest_table_name: String,
}
#[pyclass(name = "AsyncPermutationBuilder")]
pub struct PyAsyncPermutationBuilder {
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
}
impl PyAsyncPermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> PyResult<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
}
pub fn shuffle(
slf: PyRefMut<'_, Self>,
seed: Option<u64>,
clump_size: Option<u64>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_filter(filter))
}
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let mut state = slf.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
future_into_py(slf.py(), async move {
let table = builder.build(&dest_table_name).await.infer_error()?;
Ok(Table::new(table))
})
}
}

View File

@@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -249,7 +250,7 @@ impl Table {
}
impl Table {
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -272,6 +273,13 @@ impl Table {
self.inner.take();
}
pub fn database(&self) -> PyResult<Connection> {
let inner = self.inner_ref()?.clone();
let inner_connection =
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
Ok(Connection::new(inner_connection))
}
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {