diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index fa74b2733..e75a02a53 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -438,11 +438,15 @@ def test_filter_with_splits(mem_db): row_count = permutation_tbl.count_rows() assert row_count == 67 - data = permutation_tbl.search(None).to_arrow().to_pydict() + # Verify the permutation table only contains row_id and split_id + assert set(permutation_tbl.schema.names) == {"row_id", "split_id"} + + row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"] + data = tbl.take_row_ids(row_ids).to_arrow().to_pydict() categories = data["category"] # All categories should be A or B - assert all(cat in ["A", "B"] for cat in categories) + assert all(cat in ("A", "B") for cat in categories) def test_filter_with_shuffle(mem_db): diff --git a/rust/lancedb/src/dataloader/permutation/split.rs b/rust/lancedb/src/dataloader/permutation/split.rs index 2bcf1c1e0..12bc8f9b3 100644 --- a/rust/lancedb/src/dataloader/permutation/split.rs +++ b/rust/lancedb/src/dataloader/permutation/split.rs @@ -12,6 +12,8 @@ use datafusion_common::hash_utils::create_hashes; use futures::{StreamExt, TryStreamExt}; use lance_arrow::SchemaExt; +use lance_core::ROW_ID; + use crate::{ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, dataloader::{ @@ -360,11 +362,15 @@ impl Splitter { pub fn project(&self, query: Query) -> Query { match &self.strategy { - SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![( - SPLIT_ID_COLUMN.to_string(), - calculation.clone(), - )])), - SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())), + SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![ + (SPLIT_ID_COLUMN.to_string(), calculation.clone()), + (ROW_ID.to_string(), ROW_ID.to_string()), + ])), + SplitStrategy::Hash { columns, .. } => { + let mut cols = columns.clone(); + cols.push(ROW_ID.to_string()); + query.select(Select::Columns(cols)) + } _ => query, } }