Files
lancedb/python/python/tests/test_permutation.py

944 lines
31 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import math
import pytest
from lancedb import DBConnection, Table, connect
from lancedb.permutation import Permutation, Permutations, permutation_builder
def test_permutation_persistence(tmp_path):
db = connect(tmp_path)
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
permutation_tbl = (
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
)
assert permutation_tbl.count_rows() == 100
re_open = db.open_table("test_permutation")
assert re_open.count_rows() == 100
assert permutation_tbl.to_arrow() == re_open.to_arrow()
def test_split_random_ratios(mem_db):
"""Test random splitting with ratios."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
# Check that split_id column exists and has correct values
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Check approximate split sizes (allowing for rounding)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
def test_split_random_counts(mem_db):
"""Test random splitting with absolute counts."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert split_ids.count(0) == 20
assert split_ids.count(1) == 30
def test_split_random_fixed(mem_db):
"""Test random splitting with fixed number of splits."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1, 2, 3}
for split_id in range(4):
assert split_ids.count(split_id) == 25
def test_split_random_with_seed(mem_db):
"""Test that seeded random splits are reproducible."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_split_hash(mem_db):
"""Test hash-based splitting."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
"value": range(100),
}
),
)
permutation_tbl = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
# Should have all 100 rows (no discard)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Verify that each split has roughly 50 rows (allowing for hash variance)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
data2 = perm2.search(None).to_arrow().to_pydict()
assert data["split_id"] == data2["split_id"] # Should be identical
def test_split_hash_with_discard(mem_db):
"""Test hash-based splitting with discard weight."""
tbl = mem_db.create_table(
"test_table",
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
)
permutation_tbl = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
# Should have fewer than 100 rows due to discard
row_count = permutation_tbl.count_rows()
assert row_count < 100
assert row_count > 0 # But not empty
def test_split_sequential(mem_db):
"""Test sequential splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
)
assert permutation_tbl.count_rows() == 70
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Sequential should maintain order
assert row_ids == sorted(row_ids)
# First 30 should be split 0, next 40 should be split 1
assert split_ids[:30] == [0] * 30
assert split_ids[30:] == [1] * 40
def test_split_calculated(mem_db):
"""Test calculated splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
permutation_tbl = (
permutation_builder(tbl)
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Verify the calculation: each row's split_id should equal row_id % 3
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
assert split_id == row_id % 3
def test_split_error_cases(mem_db):
"""Test error handling for invalid split parameters."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
def test_shuffle_no_seed(mem_db):
"""Test shuffling without a seed."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl).shuffle().execute()
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# Row IDs should not be in sequential order due to shuffling
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
# in order
assert row_ids != list(range(100))
def test_shuffle_with_seed(mem_db):
"""Test that shuffling with a seed is reproducible."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_shuffle_with_clump_size(mem_db):
"""Test shuffling with clump size."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl)
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
for i in range(10):
start = row_ids[i * 10]
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
def test_shuffle_different_seeds(mem_db):
"""Test that different seeds produce different shuffle orders."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two permutations with different shuffle seeds
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
# Row order should be different
assert data1["row_id"] != data2["row_id"]
def test_shuffle_combined_with_splits(mem_db):
"""Test shuffling combined with different split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl)
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
)
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
)
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl)
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
)
# Verify all permutations work and have expected properties
assert perm_random.count_rows() == 100
assert perm_hash.count_rows() == 100
assert perm_sequential.count_rows() == 75
# Verify shuffle affected the order
data_random = perm_random.search(None).to_arrow().to_pydict()
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
assert data_random["row_id"] != list(range(100))
assert data_sequential["row_id"] != list(range(75))
def test_no_shuffle_maintains_order(mem_db):
"""Test that not calling shuffle maintains the original order."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl)
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# With sequential splits and no shuffle, should maintain order
assert row_ids == list(range(50))
def test_filter_basic(mem_db):
"""Test basic filtering functionality."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
)
# Filter to only include rows where id < 50
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# All row_ids should be less than 50
assert all(row_id < 50 for row_id in row_ids)
def test_filter_with_splits(mem_db):
"""Test filtering combined with split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl)
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
)
# Should have fewer than 100 rows due to filtering
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):
"""Test filtering combined with shuffling."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C", "D"] * 25)[:100],
"value": range(100),
}
),
)
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl)
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
)
row_count = permutation_tbl.count_rows()
assert row_count == 50 # Should have 50 rows (A and C categories)
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
assert row_ids != sorted(row_ids)
def test_filter_empty_result(mem_db):
"""Test filtering that results in empty set."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl)
.filter("value > 100") # No values > 100 in our data
.execute()
)
assert permutation_tbl.count_rows() == 0
@pytest.fixture
def mem_db() -> DBConnection:
return connect("memory:///")
@pytest.fixture
def some_table(mem_db: DBConnection) -> Table:
data = pa.table(
{
"id": range(1000),
"value": range(1000),
}
)
return mem_db.create_table("some_table", data)
def test_no_split_names(some_table: Table):
perm_tbl = (
permutation_builder(some_table).split_sequential(counts=[500, 500]).execute()
)
permutations = Permutations(some_table, perm_tbl)
assert permutations.split_names == []
assert permutations.split_dict == {}
assert permutations[0].num_rows == 500
assert permutations[1].num_rows == 500
@pytest.fixture
def some_perm_table(some_table: Table) -> Table:
return (
permutation_builder(some_table)
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
.shuffle(seed=42)
.execute()
)
def test_nonexistent_split(some_table: Table, some_perm_table: Table):
# Reference by name and name does not exist
with pytest.raises(ValueError, match="split `nonexistent` is not defined"):
Permutation.from_tables(some_table, some_perm_table, "nonexistent")
# Reference by ordinal and there are no rows
with pytest.raises(ValueError, match="No rows found"):
Permutation.from_tables(some_table, some_perm_table, 5)
def test_permutations(some_table: Table, some_perm_table: Table):
permutations = Permutations(some_table, some_perm_table)
assert permutations.split_names == ["train", "test"]
assert permutations.split_dict == {"train": 0, "test": 1}
assert permutations["train"].num_rows == 950
assert permutations[0].num_rows == 950
assert permutations["test"].num_rows == 50
assert permutations[1].num_rows == 50
with pytest.raises(ValueError, match="No split named `nonexistent` found"):
permutations["nonexistent"]
with pytest.raises(ValueError, match="No rows found"):
permutations[5]
@pytest.fixture
def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
return Permutation.from_tables(some_table, some_perm_table)
def test_num_rows(some_permutation: Permutation):
assert some_permutation.num_rows == 950
def test_num_columns(some_permutation: Permutation):
assert some_permutation.num_columns == 2
def test_column_names(some_permutation: Permutation):
assert some_permutation.column_names == ["id", "value"]
def test_shape(some_permutation: Permutation):
assert some_permutation.shape == (950, 2)
def test_schema(some_permutation: Permutation):
assert some_permutation.schema == pa.schema(
[("id", pa.int64()), ("value", pa.int64())]
)
def test_limit_offset(some_permutation: Permutation):
assert some_permutation.with_take(100).num_rows == 100
assert some_permutation.with_skip(100).num_rows == 850
assert some_permutation.with_take(100).with_skip(100).num_rows == 100
with pytest.raises(Exception):
some_permutation.with_take(1000000).num_rows
with pytest.raises(Exception):
some_permutation.with_skip(1000000).num_rows
with pytest.raises(Exception):
some_permutation.with_take(500).with_skip(500).num_rows
with pytest.raises(Exception):
some_permutation.with_skip(500).with_take(500).num_rows
def test_remove_columns(some_permutation: Permutation):
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
[("id", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot remove all columns
with pytest.raises(ValueError, match="Cannot remove all columns"):
some_permutation.remove_columns(["id", "value"])
def test_rename_column(some_permutation: Permutation):
assert some_permutation.rename_column("value", "new_value").schema == pa.schema(
[("id", pa.int64()), ("new_value", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot rename to an existing column
with pytest.raises(
ValueError,
match="a column with that name already exists",
):
some_permutation.rename_column("value", "id")
# Cannot rename a non-existent column
with pytest.raises(
ValueError,
match="does not exist",
):
some_permutation.rename_column("non_existent", "new_value")
def test_rename_columns(some_permutation: Permutation):
assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema(
[("id", pa.int64()), ("new_value", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot rename to an existing column
with pytest.raises(ValueError, match="a column with that name already exists"):
some_permutation.rename_columns({"value": "id"})
def test_select_columns(some_permutation: Permutation):
assert some_permutation.select_columns(["id"]).schema == pa.schema(
[("id", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot select a non-existent column
with pytest.raises(ValueError, match="does not exist"):
some_permutation.select_columns(["non_existent"])
# Empty selection is not allowed
with pytest.raises(ValueError, match="select at least one column"):
some_permutation.select_columns([])
def test_iter_basic(some_permutation: Permutation):
"""Test basic iteration with custom batch size."""
batch_size = 100
batches = list(some_permutation.iter(batch_size, skip_last_batch=False))
# Check that we got the expected number of batches
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
assert len(batches) == expected_batches
# Check that all batches are dicts (default python format)
assert all(isinstance(batch, dict) for batch in batches)
# Check that batches have the correct structure
for batch in batches:
assert "id" in batch
assert "value" in batch
assert isinstance(batch["id"], list)
assert isinstance(batch["value"], list)
# Check that all batches except the last have the correct size
for batch in batches[:-1]:
assert len(batch["id"]) == batch_size
assert len(batch["value"]) == batch_size
# Last batch might be smaller
assert len(batches[-1]["id"]) <= batch_size
def test_iter_skip_last_batch(some_permutation: Permutation):
"""Test iteration with skip_last_batch=True."""
batch_size = 300
batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True))
batches_without_skip = list(
some_permutation.iter(batch_size, skip_last_batch=False)
)
# With skip_last_batch=True, we should have fewer batches if the last one is partial
num_full_batches = 950 // batch_size
assert len(batches_with_skip) == num_full_batches
# Without skip_last_batch, we should have one more batch if there's a remainder
if 950 % batch_size != 0:
assert len(batches_without_skip) == num_full_batches + 1
# Last batch should be smaller
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
# All batches with skip_last_batch should be full size
for batch in batches_with_skip:
assert len(batch["id"]) == batch_size
def test_iter_different_batch_sizes(some_permutation: Permutation):
"""Test iteration with different batch sizes."""
# Test with small batch size
small_batches = list(some_permutation.iter(100, skip_last_batch=False))
assert len(small_batches) == 10 # ceiling(950 / 100)
# Test with large batch size
large_batches = list(some_permutation.iter(400, skip_last_batch=False))
assert len(large_batches) == 3 # ceiling(950 / 400)
# Test with batch size equal to total rows
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
assert len(single_batch) == 1
assert len(single_batch[0]["id"]) == 950
# Test with batch size larger than total rows
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
assert len(oversized_batch) == 1
assert len(oversized_batch[0]["id"]) == 950
def test_dunder_iter(some_permutation: Permutation):
"""Test the __iter__ method."""
# __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True
batches = list(some_permutation)
# With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches
assert len(batches) == 9 # ceiling(950 / 100)
# All batches should be full size
for batch in batches:
assert len(batch["id"]) == 100
assert len(batch["value"]) == 100
some_permutation = some_permutation.with_batch_size(400)
batches = list(some_permutation)
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
for batch in batches:
assert len(batch["id"]) == 400
assert len(batch["value"]) == 400
def test_iter_with_different_formats(some_permutation: Permutation):
"""Test iteration with different output formats."""
batch_size = 100
# Test with arrow format
arrow_perm = some_permutation.with_format("arrow")
arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False))
assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches)
# Test with python format (default)
python_perm = some_permutation.with_format("python")
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
assert all(isinstance(batch, dict) for batch in python_batches)
# Test with pandas format
pandas_perm = some_permutation.with_format("pandas")
pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False))
# Import pandas to check the type
import pandas as pd
assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches)
def test_iter_with_column_selection(some_permutation: Permutation):
"""Test iteration after column selection."""
# Select only the id column
id_only = some_permutation.select_columns(["id"])
batches = list(id_only.iter(100, skip_last_batch=False))
# Check that batches only contain the id column
for batch in batches:
assert "id" in batch
assert "value" not in batch
def test_iter_with_column_rename(some_permutation: Permutation):
"""Test iteration after renaming columns."""
renamed = some_permutation.rename_column("value", "data")
batches = list(renamed.iter(100, skip_last_batch=False))
# Check that batches have the renamed column
for batch in batches:
assert "id" in batch
assert "data" in batch
assert "value" not in batch
def test_iter_with_limit_offset(some_permutation: Permutation):
"""Test iteration with limit and offset."""
# Test with offset
offset_perm = some_permutation.with_skip(100)
offset_batches = list(offset_perm.iter(100, skip_last_batch=False))
# Should have 850 rows (950 - 100)
expected_batches = math.ceil(850 / 100)
assert len(offset_batches) == expected_batches
# Test with limit
limit_perm = some_permutation.with_take(500)
limit_batches = list(limit_perm.iter(100, skip_last_batch=False))
# Should have 5 batches (500 / 100)
assert len(limit_batches) == 5
no_skip = some_permutation.iter(101, skip_last_batch=False)
row_100 = next(no_skip)["id"][100]
# Test with both limit and offset
limited_perm = some_permutation.with_skip(100).with_take(300)
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
# Should have 3 batches (300 / 100)
assert len(limited_batches) == 3
assert limited_batches[0]["id"][0] == row_100
def test_iter_empty_permutation(mem_db):
"""Test iteration over an empty permutation."""
# Create a table and filter it to be empty
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation_tbl = permutation_builder(tbl).filter("value > 100").execute()
with pytest.raises(ValueError, match="No rows found"):
Permutation.from_tables(tbl, permutation_tbl)
def test_iter_single_row(mem_db):
"""Test iteration over a permutation with a single row."""
tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]}))
permutation_tbl = permutation_builder(tbl).execute()
perm = Permutation.from_tables(tbl, permutation_tbl)
# With skip_last_batch=False, should get one batch
batches = list(perm.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 1
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
batches_skip = list(perm.iter(10, skip_last_batch=True))
assert len(batches_skip) == 0
def test_identity_permutation(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
assert permutation.num_rows == 10
assert permutation.num_columns == 2
batches = list(permutation.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 10
assert len(batches[0]["value"]) == 10
permutation = permutation.remove_columns(["value"])
assert permutation.num_columns == 1
assert permutation.schema == pa.schema([("id", pa.int64())])
assert permutation.column_names == ["id"]
assert permutation.shape == (10, 1)
def test_transform_fn(mem_db):
import numpy as np
import pandas as pd
import polars as pl
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[
0
]
assert np_result.shape == (10, 2)
assert np_result.dtype == np.int64
assert isinstance(np_result, np.ndarray)
pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[
0
]
assert pd_result.shape == (10, 2)
assert pd_result.dtypes.tolist() == [np.int64, np.int64]
assert isinstance(pd_result, pd.DataFrame)
pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[
0
]
assert pl_result.shape == (10, 2)
assert pl_result.dtypes == [pl.Int64, pl.Int64]
assert isinstance(pl_result, pl.DataFrame)
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
0
]
assert len(py_result) == 2
assert len(py_result["id"]) == 10
assert len(py_result["value"]) == 10
assert isinstance(py_result, dict)
try:
import torch
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert torch_result.shape == (2, 10)
assert torch_result.dtype == torch.int64
assert isinstance(torch_result, torch.Tensor)
except ImportError:
# Skip check if torch is not installed
pass
arrow_result = list(
permutation.with_format("arrow").iter(10, skip_last_batch=False)
)[0]
assert arrow_result.shape == (10, 2)
assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())])
assert isinstance(arrow_result, pa.RecordBatch)
def test_custom_transform(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
def transform(batch: pa.RecordBatch) -> pa.RecordBatch:
return batch.select(["id"])
transformed = permutation.with_transform(transform)
batches = list(transformed.iter(10, skip_last_batch=False))
assert len(batches) == 1
batch = batches[0]
assert batch == pa.record_batch([range(10)], ["id"])