mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat: a utility for creating "permutation views" (#2552)
I'm working on a lancedb version of pytorch data loading (and hopefully addressing https://github.com/lancedb/lance/issues/3727). However, rather than rely on pytorch for everything I'm moving some of the things that pytorch does into rust. This gives us more control over data loading (e.g. using shards or a hash-based split) and it allows permutations to be persistent. In particular I hope to be able to: * Create a persistent permutation * This permutation can handle splits, filtering, shuffling, and sharding * Create a rust data loader that can read a permutation (one or more splits), or a subset of a permutation (for DDP) * Create a python data loader that delegates to the rust data loader Eventually create integrations for other data loading libraries, including rust & node
This commit is contained in:
@@ -296,3 +296,34 @@ class AlterColumnsResult:
|
||||
|
||||
class DropColumnsResult:
|
||||
version: int
|
||||
|
||||
class AsyncPermutationBuilder:
|
||||
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[List[float]] = None,
|
||||
counts: Optional[List[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_hash(
|
||||
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_sequential(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[List[float]] = None,
|
||||
counts: Optional[List[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
|
||||
def shuffle(
|
||||
self, seed: Optional[int], clump_size: Optional[int]
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
|
||||
async def execute(self) -> Table: ...
|
||||
|
||||
def async_permutation_builder(
|
||||
table: Table, dest_table_name: str
|
||||
) -> AsyncPermutationBuilder: ...
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
@@ -40,7 +41,6 @@ import deprecation
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow as pa
|
||||
from .pydantic import LanceModel
|
||||
from datetime import timedelta
|
||||
|
||||
from ._lancedb import Connection as LanceDbConnection
|
||||
from .common import DATA, URI
|
||||
@@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection):
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
_inner: Optional[LanceDbConnection] = None,
|
||||
):
|
||||
if _inner is not None:
|
||||
self._conn = _inner
|
||||
return
|
||||
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
@@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection):
|
||||
uri = Path(uri)
|
||||
uri = uri.expanduser().absolute()
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
self.storage_options = storage_options
|
||||
self.session = session
|
||||
|
||||
if read_consistency_interval is not None:
|
||||
read_consistency_interval_secs = read_consistency_interval.total_seconds()
|
||||
@@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection):
|
||||
session,
|
||||
)
|
||||
|
||||
# TODO: It would be nice if we didn't store self.storage_options but it is
|
||||
# currently used by the LanceTable.to_lance method. This doesn't _really_
|
||||
# work because some paths like LanceDBConnection.from_inner will lose the
|
||||
# storage_options. Also, this class really shouldn't be holding any state
|
||||
# beyond _conn.
|
||||
self.storage_options = storage_options
|
||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||
|
||||
@property
|
||||
def read_consistency_interval(self) -> Optional[timedelta]:
|
||||
return LOOP.run(self._conn.get_read_consistency_interval())
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[Session]:
|
||||
return self._conn.session
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._conn.uri
|
||||
|
||||
@classmethod
|
||||
def from_inner(cls, inner: LanceDbConnection):
|
||||
return cls(None, _inner=inner)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}(uri={self._uri!r}"
|
||||
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
@@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection):
|
||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||
return await conn.table_names(start_after=start_after, limit=limit)
|
||||
|
||||
@property
|
||||
def _inner(self) -> LanceDbConnection:
|
||||
return self._conn._inner
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
@@ -856,6 +882,13 @@ class AsyncConnection(object):
|
||||
def uri(self) -> str:
|
||||
return self._inner.uri
|
||||
|
||||
async def get_read_consistency_interval(self) -> Optional[timedelta]:
|
||||
interval_secs = await self._inner.get_read_consistency_interval()
|
||||
if interval_secs is not None:
|
||||
return timedelta(seconds=interval_secs)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
|
||||
72
python/python/lancedb/permutation.py
Normal file
72
python/python/lancedb/permutation.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from ._lancedb import async_permutation_builder
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
def __init__(self, table: LanceTable, dest_table_name: str):
|
||||
self._async = async_permutation_builder(table, dest_table_name)
|
||||
|
||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
||||
self._async.select(projections)
|
||||
return self
|
||||
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[list[float]] = None,
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
|
||||
return self
|
||||
|
||||
def split_hash(
|
||||
self,
|
||||
columns: list[str],
|
||||
split_weights: list[int],
|
||||
*,
|
||||
discard_weight: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
|
||||
return self
|
||||
|
||||
def split_sequential(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[list[float]] = None,
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
|
||||
return self
|
||||
|
||||
def split_calculated(self, calculation: str) -> "PermutationBuilder":
|
||||
self._async.split_calculated(calculation)
|
||||
return self
|
||||
|
||||
def shuffle(
|
||||
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
||||
) -> "PermutationBuilder":
|
||||
self._async.shuffle(seed=seed, clump_size=clump_size)
|
||||
return self
|
||||
|
||||
def filter(self, filter: str) -> "PermutationBuilder":
|
||||
self._async.filter(filter)
|
||||
return self
|
||||
|
||||
def execute(self) -> LanceTable:
|
||||
async def do_execute():
|
||||
inner_tbl = await self._async.execute()
|
||||
return LanceTable.from_inner(inner_tbl)
|
||||
|
||||
return LOOP.run(do_execute())
|
||||
|
||||
|
||||
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
|
||||
return PermutationBuilder(table, dest_table_name)
|
||||
@@ -74,6 +74,7 @@ from .index import lang_mapping
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import LanceDBConnection
|
||||
from ._lancedb import (
|
||||
Table as LanceDBTable,
|
||||
OptimizeStats,
|
||||
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
|
||||
MergeResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
import pandas
|
||||
import PIL
|
||||
@@ -1707,22 +1707,38 @@ class LanceTable(Table):
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
_async: AsyncTable = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self._namespace = namespace
|
||||
self._table = LOOP.run(
|
||||
connection._conn.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
self._table = LOOP.run(
|
||||
connection._conn.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._table.name
|
||||
|
||||
@classmethod
|
||||
def from_inner(cls, tbl: LanceDBTable):
|
||||
from .db import LanceDBConnection
|
||||
|
||||
async_tbl = AsyncTable(tbl)
|
||||
conn = LanceDBConnection.from_inner(tbl.database())
|
||||
return cls(
|
||||
conn,
|
||||
async_tbl.name,
|
||||
_async=async_tbl,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
|
||||
tbl = cls(db, name, namespace=namespace, **kwargs)
|
||||
@@ -2756,6 +2772,10 @@ class LanceTable(Table):
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
|
||||
@property
|
||||
def _inner(self) -> LanceDBTable:
|
||||
return self._table._inner
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.21.0",
|
||||
current_version=__version__,
|
||||
|
||||
496
python/python/tests/test_permutation.py
Normal file
496
python/python/tests/test_permutation.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.permutation import permutation_builder
|
||||
|
||||
|
||||
def test_split_random_ratios(mem_db):
|
||||
"""Test random splitting with ratios."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(ratios=[0.3, 0.7])
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Check that the table was created and has data
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
# Check that split_id column exists and has correct values
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1}
|
||||
|
||||
# Check approximate split sizes (allowing for rounding)
|
||||
split_0_count = split_ids.count(0)
|
||||
split_1_count = split_ids.count(1)
|
||||
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
|
||||
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
|
||||
|
||||
|
||||
def test_split_random_counts(mem_db):
|
||||
"""Test random splitting with absolute counts."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(counts=[20, 30])
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Check that we have exactly the requested counts
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert split_ids.count(0) == 20
|
||||
assert split_ids.count(1) == 30
|
||||
|
||||
|
||||
def test_split_random_fixed(mem_db):
|
||||
"""Test random splitting with fixed number of splits."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
|
||||
)
|
||||
|
||||
# Check that we have 4 splits with 25 rows each
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1, 2, 3}
|
||||
|
||||
for split_id in range(4):
|
||||
assert split_ids.count(split_id) == 25
|
||||
|
||||
|
||||
def test_split_random_with_seed(mem_db):
|
||||
"""Test that seeded random splits are reproducible."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||
|
||||
# Create two identical permutations with same seed
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Results should be identical
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data1["row_id"] == data2["row_id"]
|
||||
assert data1["split_id"] == data2["split_id"]
|
||||
|
||||
|
||||
def test_split_hash(mem_db):
|
||||
"""Test hash-based splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have all 100 rows (no discard)
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1}
|
||||
|
||||
# Verify that each split has roughly 50 rows (allowing for hash variance)
|
||||
split_0_count = split_ids.count(0)
|
||||
split_1_count = split_ids.count(1)
|
||||
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||
|
||||
# Hash splits should be deterministic - same category should go to same split
|
||||
# Let's verify by creating another permutation and checking consistency
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "test_permutation2")
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
assert data["split_id"] == data2["split_id"] # Should be identical
|
||||
|
||||
|
||||
def test_split_hash_with_discard(mem_db):
|
||||
"""Test hash-based splitting with discard weight."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have fewer than 100 rows due to discard
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count < 100
|
||||
assert row_count > 0 # But not empty
|
||||
|
||||
|
||||
def test_split_sequential(mem_db):
|
||||
"""Test sequential splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_sequential(counts=[30, 40])
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 70
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
split_ids = data["split_id"]
|
||||
|
||||
# Sequential should maintain order
|
||||
assert row_ids == sorted(row_ids)
|
||||
|
||||
# First 30 should be split 0, next 40 should be split 1
|
||||
assert split_ids[:30] == [0] * 30
|
||||
assert split_ids[30:] == [1] * 40
|
||||
|
||||
|
||||
def test_split_calculated(mem_db):
|
||||
"""Test calculated splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_calculated("id % 3") # Split based on id modulo 3
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
split_ids = data["split_id"]
|
||||
|
||||
# Verify the calculation: each row's split_id should equal row_id % 3
|
||||
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
|
||||
assert split_id == row_id % 3
|
||||
|
||||
|
||||
def test_split_error_cases(mem_db):
|
||||
"""Test error handling for invalid split parameters."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
||||
|
||||
# Test split_random with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error1").split_random().execute()
|
||||
|
||||
# Test split_random with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error2").split_random(
|
||||
ratios=[0.5, 0.5], counts=[5, 5]
|
||||
).execute()
|
||||
|
||||
# Test split_sequential with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error3").split_sequential().execute()
|
||||
|
||||
# Test split_sequential with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error4").split_sequential(
|
||||
ratios=[0.5, 0.5], fixed=2
|
||||
).execute()
|
||||
|
||||
|
||||
def test_shuffle_no_seed(mem_db):
|
||||
"""Test shuffling without a seed."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling (no seed)
|
||||
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# Row IDs should not be in sequential order due to shuffling
|
||||
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
|
||||
# in order
|
||||
assert row_ids != list(range(100))
|
||||
|
||||
|
||||
def test_shuffle_with_seed(mem_db):
|
||||
"""Test that shuffling with a seed is reproducible."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create two identical permutations with same shuffle seed
|
||||
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
|
||||
|
||||
# Results should be identical due to same seed
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data1["row_id"] == data2["row_id"]
|
||||
assert data1["split_id"] == data2["split_id"]
|
||||
|
||||
|
||||
def test_shuffle_with_clump_size(mem_db):
|
||||
"""Test shuffling with clump size."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling using clumps
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.shuffle(clump_size=10) # 10-row clumps
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
for i in range(10):
|
||||
start = row_ids[i * 10]
|
||||
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
|
||||
|
||||
|
||||
def test_shuffle_different_seeds(mem_db):
|
||||
"""Test that different seeds produce different shuffle orders."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create two permutations with different shuffle seeds
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=123)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Results should be different due to different seeds
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
# Row order should be different
|
||||
assert data1["row_id"] != data2["row_id"]
|
||||
|
||||
|
||||
def test_shuffle_combined_with_splits(mem_db):
|
||||
"""Test shuffling combined with different split strategies."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Test shuffle with random splits
|
||||
perm_random = (
|
||||
permutation_builder(tbl, "perm_random")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.shuffle(seed=123, clump_size=None)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Test shuffle with hash splits
|
||||
perm_hash = (
|
||||
permutation_builder(tbl, "perm_hash")
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.shuffle(seed=456, clump_size=5)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Test shuffle with sequential splits
|
||||
perm_sequential = (
|
||||
permutation_builder(tbl, "perm_sequential")
|
||||
.split_sequential(counts=[40, 35])
|
||||
.shuffle(seed=789, clump_size=None)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Verify all permutations work and have expected properties
|
||||
assert perm_random.count_rows() == 100
|
||||
assert perm_hash.count_rows() == 100
|
||||
assert perm_sequential.count_rows() == 75
|
||||
|
||||
# Verify shuffle affected the order
|
||||
data_random = perm_random.search(None).to_arrow().to_pydict()
|
||||
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data_random["row_id"] != list(range(100))
|
||||
assert data_sequential["row_id"] != list(range(75))
|
||||
|
||||
|
||||
def test_no_shuffle_maintains_order(mem_db):
|
||||
"""Test that not calling shuffle maintains the original order."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create permutation without shuffle (should maintain some order)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# With sequential splits and no shuffle, should maintain order
|
||||
assert row_ids == list(range(50))
|
||||
|
||||
|
||||
def test_filter_basic(mem_db):
|
||||
"""Test basic filtering functionality."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
|
||||
)
|
||||
|
||||
# Filter to only include rows where id < 50
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# All row_ids should be less than 50
|
||||
assert all(row_id < 50 for row_id in row_ids)
|
||||
|
||||
|
||||
def test_filter_with_splits(mem_db):
|
||||
"""Test filtering combined with split strategies."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Filter to only category A and B, then split
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.filter("category IN ('A', 'B')")
|
||||
.split_random(ratios=[0.5, 0.5])
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have fewer than 100 rows due to filtering
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count == 67
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
categories = data["category"]
|
||||
|
||||
# All categories should be A or B
|
||||
assert all(cat in ["A", "B"] for cat in categories)
|
||||
|
||||
|
||||
def test_filter_with_shuffle(mem_db):
|
||||
"""Test filtering combined with shuffling."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C", "D"] * 25)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Filter and shuffle
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.filter("category IN ('A', 'C')")
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count == 50 # Should have 50 rows (A and C categories)
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
assert row_ids != sorted(row_ids)
|
||||
|
||||
|
||||
def test_filter_empty_result(mem_db):
|
||||
"""Test filtering that results in empty set."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
|
||||
# Filter that matches nothing
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.filter("value > 100") # No values > 100 in our data
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 0
|
||||
@@ -4,7 +4,10 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
database::{CreateTableMode, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -23,7 +26,7 @@ impl Connection {
|
||||
Self { inner: Some(inner) }
|
||||
}
|
||||
|
||||
fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
|
||||
@@ -63,6 +66,18 @@ impl Connection {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
Ok(match inner.read_consistency().await.infer_error()? {
|
||||
ReadConsistency::Manual => None,
|
||||
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
|
||||
ReadConsistency::Strong => Some(0.0_f64),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::PyAsyncPermutationBuilder;
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -22,6 +23,7 @@ pub mod connection;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<DeleteResult>()?;
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
177
python/src/permutation.rs
Normal file
177
python/src/permutation.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
PyResult,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(
|
||||
table: Bound<'_, PyAny>,
|
||||
dest_table_name: String,
|
||||
) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
dest_table_name: String,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
pub struct PyAsyncPermutationBuilder {
|
||||
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
||||
}
|
||||
|
||||
impl PyAsyncPermutationBuilder {
|
||||
fn modify(
|
||||
&self,
|
||||
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||
) -> PyResult<Self> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
state.builder = Some(func(builder));
|
||||
Ok(Self {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
seed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
||||
pub fn split_hash(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
||||
pub fn split_sequential(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
}
|
||||
|
||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
||||
}
|
||||
|
||||
pub fn shuffle(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||
})
|
||||
}
|
||||
|
||||
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_filter(filter))
|
||||
}
|
||||
|
||||
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let mut state = slf.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build(&dest_table_name).await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::{Query, TakeQuery},
|
||||
@@ -249,7 +250,7 @@ impl Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
|
||||
@@ -272,6 +273,13 @@ impl Table {
|
||||
self.inner.take();
|
||||
}
|
||||
|
||||
pub fn database(&self) -> PyResult<Connection> {
|
||||
let inner = self.inner_ref()?.clone();
|
||||
let inner_connection =
|
||||
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
|
||||
Ok(Connection::new(inner_connection))
|
||||
}
|
||||
|
||||
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
Reference in New Issue
Block a user