# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors import pyarrow as pa import math import pytest from lancedb import DBConnection, Table, connect from lancedb.permutation import Permutation, Permutations, permutation_builder def test_permutation_persistence(tmp_path): db = connect(tmp_path) tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)})) permutation_tbl = ( permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute() ) assert permutation_tbl.count_rows() == 100 re_open = db.open_table("test_permutation") assert re_open.count_rows() == 100 assert permutation_tbl.to_arrow() == re_open.to_arrow() def test_split_random_ratios(mem_db): """Test random splitting with ratios.""" tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute() # Check that the table was created and has data assert permutation_tbl.count_rows() == 100 # Check that split_id column exists and has correct values data = permutation_tbl.search(None).to_arrow().to_pydict() split_ids = data["split_id"] assert set(split_ids) == {0, 1} # Check approximate split sizes (allowing for rounding) split_0_count = split_ids.count(0) split_1_count = split_ids.count(1) assert 25 <= split_0_count <= 35 # ~30% ± tolerance assert 65 <= split_1_count <= 75 # ~70% ± tolerance def test_split_random_counts(mem_db): """Test random splitting with absolute counts.""" tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute() # Check that we have exactly the requested counts assert permutation_tbl.count_rows() == 50 data = permutation_tbl.search(None).to_arrow().to_pydict() split_ids = data["split_id"] assert split_ids.count(0) == 20 assert split_ids.count(1) == 30 def test_split_random_fixed(mem_db): """Test random splitting with fixed number of splits.""" tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute() # Check that we have 4 splits with 25 rows each assert permutation_tbl.count_rows() == 100 data = permutation_tbl.search(None).to_arrow().to_pydict() split_ids = data["split_id"] assert set(split_ids) == {0, 1, 2, 3} for split_id in range(4): assert split_ids.count(split_id) == 25 def test_split_random_with_seed(mem_db): """Test that seeded random splits are reproducible.""" tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)})) # Create two identical permutations with same seed perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute() perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute() # Results should be identical data1 = perm1.search(None).to_arrow().to_pydict() data2 = perm2.search(None).to_arrow().to_pydict() assert data1["row_id"] == data2["row_id"] assert data1["split_id"] == data2["split_id"] def test_split_hash(mem_db): """Test hash-based splitting.""" tbl = mem_db.create_table( "test_table", pa.table( { "id": range(100), "category": (["A", "B", "C"] * 34)[:100], # Repeating pattern "value": range(100), } ), ) permutation_tbl = ( permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .execute() ) # Should have all 100 rows (no discard) assert permutation_tbl.count_rows() == 100 data = permutation_tbl.search(None).to_arrow().to_pydict() split_ids = data["split_id"] assert set(split_ids) == {0, 1} # Verify that each split has roughly 50 rows (allowing for hash variance) split_0_count = split_ids.count(0) split_1_count = split_ids.count(1) assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution # Hash splits should be deterministic - same category should go to same split # Let's verify by creating another permutation and checking consistency perm2 = ( permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .execute() ) data2 = perm2.search(None).to_arrow().to_pydict() assert data["split_id"] == data2["split_id"] # Should be identical def test_split_hash_with_discard(mem_db): """Test hash-based splitting with discard weight.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}), ) permutation_tbl = ( permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50% .execute() ) # Should have fewer than 100 rows due to discard row_count = permutation_tbl.count_rows() assert row_count < 100 assert row_count > 0 # But not empty def test_split_sequential(mem_db): """Test sequential splitting.""" tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) permutation_tbl = ( permutation_builder(tbl).split_sequential(counts=[30, 40]).execute() ) assert permutation_tbl.count_rows() == 70 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] split_ids = data["split_id"] # Sequential should maintain order assert row_ids == sorted(row_ids) # First 30 should be split 0, next 40 should be split 1 assert split_ids[:30] == [0] * 30 assert split_ids[30:] == [1] * 40 def test_split_calculated(mem_db): """Test calculated splitting.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(100), "value": range(100)}) ) permutation_tbl = ( permutation_builder(tbl) .split_calculated("id % 3") # Split based on id modulo 3 .execute() ) assert permutation_tbl.count_rows() == 100 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] split_ids = data["split_id"] # Verify the calculation: each row's split_id should equal row_id % 3 for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)): assert split_id == row_id % 3 def test_split_error_cases(mem_db): """Test error handling for invalid split parameters.""" tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)})) # Test split_random with no parameters with pytest.raises( ValueError, match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", ): permutation_builder(tbl).split_random().execute() # Test split_random with multiple parameters with pytest.raises( ValueError, match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", ): permutation_builder(tbl).split_random( ratios=[0.5, 0.5], counts=[5, 5] ).execute() # Test split_sequential with no parameters with pytest.raises( ValueError, match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", ): permutation_builder(tbl).split_sequential().execute() # Test split_sequential with multiple parameters with pytest.raises( ValueError, match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", ): permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute() def test_shuffle_no_seed(mem_db): """Test shuffling without a seed.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(100), "value": range(100)}) ) # Create a permutation with shuffling (no seed) permutation_tbl = permutation_builder(tbl).shuffle().execute() assert permutation_tbl.count_rows() == 100 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] # Row IDs should not be in sequential order due to shuffling # This is probabilistic but with 100 rows, it's extremely unlikely they'd stay # in order assert row_ids != list(range(100)) def test_shuffle_with_seed(mem_db): """Test that shuffling with a seed is reproducible.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(50), "value": range(50)}) ) # Create two identical permutations with same shuffle seed perm1 = permutation_builder(tbl).shuffle(seed=42).execute() perm2 = permutation_builder(tbl).shuffle(seed=42).execute() # Results should be identical due to same seed data1 = perm1.search(None).to_arrow().to_pydict() data2 = perm2.search(None).to_arrow().to_pydict() assert data1["row_id"] == data2["row_id"] assert data1["split_id"] == data2["split_id"] def test_shuffle_with_clump_size(mem_db): """Test shuffling with clump size.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(100), "value": range(100)}) ) # Create a permutation with shuffling using clumps permutation_tbl = ( permutation_builder(tbl) .shuffle(clump_size=10) # 10-row clumps .execute() ) assert permutation_tbl.count_rows() == 100 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] for i in range(10): start = row_ids[i * 10] assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10)) def test_shuffle_different_seeds(mem_db): """Test that different seeds produce different shuffle orders.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(50), "value": range(50)}) ) # Create two permutations with different shuffle seeds perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute() perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute() # Results should be different due to different seeds data1 = perm1.search(None).to_arrow().to_pydict() data2 = perm2.search(None).to_arrow().to_pydict() # Row order should be different assert data1["row_id"] != data2["row_id"] def test_shuffle_combined_with_splits(mem_db): """Test shuffling combined with different split strategies.""" tbl = mem_db.create_table( "test_table", pa.table( { "id": range(100), "category": (["A", "B", "C"] * 34)[:100], "value": range(100), } ), ) # Test shuffle with random splits perm_random = ( permutation_builder(tbl) .split_random(ratios=[0.6, 0.4], seed=42) .shuffle(seed=123, clump_size=None) .execute() ) # Test shuffle with hash splits perm_hash = ( permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .shuffle(seed=456, clump_size=5) .execute() ) # Test shuffle with sequential splits perm_sequential = ( permutation_builder(tbl) .split_sequential(counts=[40, 35]) .shuffle(seed=789, clump_size=None) .execute() ) # Verify all permutations work and have expected properties assert perm_random.count_rows() == 100 assert perm_hash.count_rows() == 100 assert perm_sequential.count_rows() == 75 # Verify shuffle affected the order data_random = perm_random.search(None).to_arrow().to_pydict() data_sequential = perm_sequential.search(None).to_arrow().to_pydict() assert data_random["row_id"] != list(range(100)) assert data_sequential["row_id"] != list(range(75)) def test_no_shuffle_maintains_order(mem_db): """Test that not calling shuffle maintains the original order.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(50), "value": range(50)}) ) # Create permutation without shuffle (should maintain some order) permutation_tbl = ( permutation_builder(tbl) .split_sequential(counts=[25, 25]) # Sequential maintains order .execute() ) assert permutation_tbl.count_rows() == 50 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] # With sequential splits and no shuffle, should maintain order assert row_ids == list(range(50)) def test_filter_basic(mem_db): """Test basic filtering functionality.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(100), "value": range(100, 200)}) ) # Filter to only include rows where id < 50 permutation_tbl = permutation_builder(tbl).filter("id < 50").execute() assert permutation_tbl.count_rows() == 50 data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] # All row_ids should be less than 50 assert all(row_id < 50 for row_id in row_ids) def test_filter_with_splits(mem_db): """Test filtering combined with split strategies.""" tbl = mem_db.create_table( "test_table", pa.table( { "id": range(100), "category": (["A", "B", "C"] * 34)[:100], "value": range(100), } ), ) # Filter to only category A and B, then split permutation_tbl = ( permutation_builder(tbl) .filter("category IN ('A', 'B')") .split_random(ratios=[0.5, 0.5]) .execute() ) # Should have fewer than 100 rows due to filtering row_count = permutation_tbl.count_rows() assert row_count == 67 data = permutation_tbl.search(None).to_arrow().to_pydict() categories = data["category"] # All categories should be A or B assert all(cat in ["A", "B"] for cat in categories) def test_filter_with_shuffle(mem_db): """Test filtering combined with shuffling.""" tbl = mem_db.create_table( "test_table", pa.table( { "id": range(100), "category": (["A", "B", "C", "D"] * 25)[:100], "value": range(100), } ), ) # Filter and shuffle permutation_tbl = ( permutation_builder(tbl) .filter("category IN ('A', 'C')") .shuffle(seed=42) .execute() ) row_count = permutation_tbl.count_rows() assert row_count == 50 # Should have 50 rows (A and C categories) data = permutation_tbl.search(None).to_arrow().to_pydict() row_ids = data["row_id"] assert row_ids != sorted(row_ids) def test_filter_empty_result(mem_db): """Test filtering that results in empty set.""" tbl = mem_db.create_table( "test_table", pa.table({"id": range(10), "value": range(10)}) ) # Filter that matches nothing permutation_tbl = ( permutation_builder(tbl) .filter("value > 100") # No values > 100 in our data .execute() ) assert permutation_tbl.count_rows() == 0 @pytest.fixture def mem_db() -> DBConnection: return connect("memory:///") @pytest.fixture def some_table(mem_db: DBConnection) -> Table: data = pa.table( { "id": range(1000), "value": range(1000), } ) return mem_db.create_table("some_table", data) def test_no_split_names(some_table: Table): perm_tbl = ( permutation_builder(some_table).split_sequential(counts=[500, 500]).execute() ) permutations = Permutations(some_table, perm_tbl) assert permutations.split_names == [] assert permutations.split_dict == {} assert permutations[0].num_rows == 500 assert permutations[1].num_rows == 500 @pytest.fixture def some_perm_table(some_table: Table) -> Table: return ( permutation_builder(some_table) .split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"]) .shuffle(seed=42) .execute() ) def test_nonexistent_split(some_table: Table, some_perm_table: Table): # Reference by name and name does not exist with pytest.raises(ValueError, match="split `nonexistent` is not defined"): Permutation.from_tables(some_table, some_perm_table, "nonexistent") # Reference by ordinal and there are no rows with pytest.raises(ValueError, match="No rows found"): Permutation.from_tables(some_table, some_perm_table, 5) def test_permutations(some_table: Table, some_perm_table: Table): permutations = Permutations(some_table, some_perm_table) assert permutations.split_names == ["train", "test"] assert permutations.split_dict == {"train": 0, "test": 1} assert permutations["train"].num_rows == 950 assert permutations[0].num_rows == 950 assert permutations["test"].num_rows == 50 assert permutations[1].num_rows == 50 with pytest.raises(ValueError, match="No split named `nonexistent` found"): permutations["nonexistent"] with pytest.raises(ValueError, match="No rows found"): permutations[5] @pytest.fixture def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation: return Permutation.from_tables(some_table, some_perm_table) def test_num_rows(some_permutation: Permutation): assert some_permutation.num_rows == 950 def test_num_columns(some_permutation: Permutation): assert some_permutation.num_columns == 2 def test_column_names(some_permutation: Permutation): assert some_permutation.column_names == ["id", "value"] def test_shape(some_permutation: Permutation): assert some_permutation.shape == (950, 2) def test_schema(some_permutation: Permutation): assert some_permutation.schema == pa.schema( [("id", pa.int64()), ("value", pa.int64())] ) def test_limit_offset(some_permutation: Permutation): assert some_permutation.with_take(100).num_rows == 100 assert some_permutation.with_skip(100).num_rows == 850 assert some_permutation.with_take(100).with_skip(100).num_rows == 100 with pytest.raises(Exception): some_permutation.with_take(1000000).num_rows with pytest.raises(Exception): some_permutation.with_skip(1000000).num_rows with pytest.raises(Exception): some_permutation.with_take(500).with_skip(500).num_rows with pytest.raises(Exception): some_permutation.with_skip(500).with_take(500).num_rows def test_remove_columns(some_permutation: Permutation): assert some_permutation.remove_columns(["value"]).schema == pa.schema( [("id", pa.int64())] ) # Should not modify the original permutation assert some_permutation.schema.names == ["id", "value"] # Cannot remove all columns with pytest.raises(ValueError, match="Cannot remove all columns"): some_permutation.remove_columns(["id", "value"]) def test_rename_column(some_permutation: Permutation): assert some_permutation.rename_column("value", "new_value").schema == pa.schema( [("id", pa.int64()), ("new_value", pa.int64())] ) # Should not modify the original permutation assert some_permutation.schema.names == ["id", "value"] # Cannot rename to an existing column with pytest.raises( ValueError, match="a column with that name already exists", ): some_permutation.rename_column("value", "id") # Cannot rename a non-existent column with pytest.raises( ValueError, match="does not exist", ): some_permutation.rename_column("non_existent", "new_value") def test_rename_columns(some_permutation: Permutation): assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema( [("id", pa.int64()), ("new_value", pa.int64())] ) # Should not modify the original permutation assert some_permutation.schema.names == ["id", "value"] # Cannot rename to an existing column with pytest.raises(ValueError, match="a column with that name already exists"): some_permutation.rename_columns({"value": "id"}) def test_select_columns(some_permutation: Permutation): assert some_permutation.select_columns(["id"]).schema == pa.schema( [("id", pa.int64())] ) # Should not modify the original permutation assert some_permutation.schema.names == ["id", "value"] # Cannot select a non-existent column with pytest.raises(ValueError, match="does not exist"): some_permutation.select_columns(["non_existent"]) # Empty selection is not allowed with pytest.raises(ValueError, match="select at least one column"): some_permutation.select_columns([]) def test_iter_basic(some_permutation: Permutation): """Test basic iteration with custom batch size.""" batch_size = 100 batches = list(some_permutation.iter(batch_size, skip_last_batch=False)) # Check that we got the expected number of batches expected_batches = (950 + batch_size - 1) // batch_size # ceiling division assert len(batches) == expected_batches # Check that all batches are dicts (default python format) assert all(isinstance(batch, dict) for batch in batches) # Check that batches have the correct structure for batch in batches: assert "id" in batch assert "value" in batch assert isinstance(batch["id"], list) assert isinstance(batch["value"], list) # Check that all batches except the last have the correct size for batch in batches[:-1]: assert len(batch["id"]) == batch_size assert len(batch["value"]) == batch_size # Last batch might be smaller assert len(batches[-1]["id"]) <= batch_size def test_iter_skip_last_batch(some_permutation: Permutation): """Test iteration with skip_last_batch=True.""" batch_size = 300 batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True)) batches_without_skip = list( some_permutation.iter(batch_size, skip_last_batch=False) ) # With skip_last_batch=True, we should have fewer batches if the last one is partial num_full_batches = 950 // batch_size assert len(batches_with_skip) == num_full_batches # Without skip_last_batch, we should have one more batch if there's a remainder if 950 % batch_size != 0: assert len(batches_without_skip) == num_full_batches + 1 # Last batch should be smaller assert len(batches_without_skip[-1]["id"]) == 950 % batch_size # All batches with skip_last_batch should be full size for batch in batches_with_skip: assert len(batch["id"]) == batch_size def test_iter_different_batch_sizes(some_permutation: Permutation): """Test iteration with different batch sizes.""" # Test with small batch size small_batches = list(some_permutation.iter(100, skip_last_batch=False)) assert len(small_batches) == 10 # ceiling(950 / 100) # Test with large batch size large_batches = list(some_permutation.iter(400, skip_last_batch=False)) assert len(large_batches) == 3 # ceiling(950 / 400) # Test with batch size equal to total rows single_batch = list(some_permutation.iter(950, skip_last_batch=False)) assert len(single_batch) == 1 assert len(single_batch[0]["id"]) == 950 # Test with batch size larger than total rows oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False)) assert len(oversized_batch) == 1 assert len(oversized_batch[0]["id"]) == 950 def test_dunder_iter(some_permutation: Permutation): """Test the __iter__ method.""" # __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True batches = list(some_permutation) # With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches assert len(batches) == 9 # ceiling(950 / 100) # All batches should be full size for batch in batches: assert len(batch["id"]) == 100 assert len(batch["value"]) == 100 some_permutation = some_permutation.with_batch_size(400) batches = list(some_permutation) assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True for batch in batches: assert len(batch["id"]) == 400 assert len(batch["value"]) == 400 def test_iter_with_different_formats(some_permutation: Permutation): """Test iteration with different output formats.""" batch_size = 100 # Test with arrow format arrow_perm = some_permutation.with_format("arrow") arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False)) assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches) # Test with python format (default) python_perm = some_permutation.with_format("python") python_batches = list(python_perm.iter(batch_size, skip_last_batch=False)) assert all(isinstance(batch, dict) for batch in python_batches) # Test with pandas format pandas_perm = some_permutation.with_format("pandas") pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False)) # Import pandas to check the type import pandas as pd assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches) def test_iter_with_column_selection(some_permutation: Permutation): """Test iteration after column selection.""" # Select only the id column id_only = some_permutation.select_columns(["id"]) batches = list(id_only.iter(100, skip_last_batch=False)) # Check that batches only contain the id column for batch in batches: assert "id" in batch assert "value" not in batch def test_iter_with_column_rename(some_permutation: Permutation): """Test iteration after renaming columns.""" renamed = some_permutation.rename_column("value", "data") batches = list(renamed.iter(100, skip_last_batch=False)) # Check that batches have the renamed column for batch in batches: assert "id" in batch assert "data" in batch assert "value" not in batch def test_iter_with_limit_offset(some_permutation: Permutation): """Test iteration with limit and offset.""" # Test with offset offset_perm = some_permutation.with_skip(100) offset_batches = list(offset_perm.iter(100, skip_last_batch=False)) # Should have 850 rows (950 - 100) expected_batches = math.ceil(850 / 100) assert len(offset_batches) == expected_batches # Test with limit limit_perm = some_permutation.with_take(500) limit_batches = list(limit_perm.iter(100, skip_last_batch=False)) # Should have 5 batches (500 / 100) assert len(limit_batches) == 5 no_skip = some_permutation.iter(101, skip_last_batch=False) row_100 = next(no_skip)["id"][100] # Test with both limit and offset limited_perm = some_permutation.with_skip(100).with_take(300) limited_batches = list(limited_perm.iter(100, skip_last_batch=False)) # Should have 3 batches (300 / 100) assert len(limited_batches) == 3 assert limited_batches[0]["id"][0] == row_100 def test_iter_empty_permutation(mem_db): """Test iteration over an empty permutation.""" # Create a table and filter it to be empty tbl = mem_db.create_table( "test_table", pa.table({"id": range(10), "value": range(10)}) ) permutation_tbl = permutation_builder(tbl).filter("value > 100").execute() with pytest.raises(ValueError, match="No rows found"): Permutation.from_tables(tbl, permutation_tbl) def test_iter_single_row(mem_db): """Test iteration over a permutation with a single row.""" tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]})) permutation_tbl = permutation_builder(tbl).execute() perm = Permutation.from_tables(tbl, permutation_tbl) # With skip_last_batch=False, should get one batch batches = list(perm.iter(10, skip_last_batch=False)) assert len(batches) == 1 assert len(batches[0]["id"]) == 1 # With skip_last_batch=True, should skip the single row (since it's < batch_size) batches_skip = list(perm.iter(10, skip_last_batch=True)) assert len(batches_skip) == 0 def test_identity_permutation(mem_db): tbl = mem_db.create_table( "test_table", pa.table({"id": range(10), "value": range(10)}) ) permutation = Permutation.identity(tbl) assert permutation.num_rows == 10 assert permutation.num_columns == 2 batches = list(permutation.iter(10, skip_last_batch=False)) assert len(batches) == 1 assert len(batches[0]["id"]) == 10 assert len(batches[0]["value"]) == 10 permutation = permutation.remove_columns(["value"]) assert permutation.num_columns == 1 assert permutation.schema == pa.schema([("id", pa.int64())]) assert permutation.column_names == ["id"] assert permutation.shape == (10, 1) def test_transform_fn(mem_db): import numpy as np import pandas as pd import polars as pl tbl = mem_db.create_table( "test_table", pa.table({"id": range(10), "value": range(10)}) ) permutation = Permutation.identity(tbl) np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[ 0 ] assert np_result.shape == (10, 2) assert np_result.dtype == np.int64 assert isinstance(np_result, np.ndarray) pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[ 0 ] assert pd_result.shape == (10, 2) assert pd_result.dtypes.tolist() == [np.int64, np.int64] assert isinstance(pd_result, pd.DataFrame) pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[ 0 ] assert pl_result.shape == (10, 2) assert pl_result.dtypes == [pl.Int64, pl.Int64] assert isinstance(pl_result, pl.DataFrame) py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[ 0 ] assert len(py_result) == 2 assert len(py_result["id"]) == 10 assert len(py_result["value"]) == 10 assert isinstance(py_result, dict) try: import torch torch_result = list( permutation.with_format("torch").iter(10, skip_last_batch=False) )[0] assert torch_result.shape == (2, 10) assert torch_result.dtype == torch.int64 assert isinstance(torch_result, torch.Tensor) except ImportError: # Skip check if torch is not installed pass arrow_result = list( permutation.with_format("arrow").iter(10, skip_last_batch=False) )[0] assert arrow_result.shape == (10, 2) assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())]) assert isinstance(arrow_result, pa.RecordBatch) def test_custom_transform(mem_db): tbl = mem_db.create_table( "test_table", pa.table({"id": range(10), "value": range(10)}) ) permutation = Permutation.identity(tbl) def transform(batch: pa.RecordBatch) -> pa.RecordBatch: return batch.select(["id"]) transformed = permutation.with_transform(transform) batches = list(transformed.iter(10, skip_last_batch=False)) assert len(batches) == 1 batch = batches[0] assert batch == pa.record_batch([range(10)], ["id"])