From 4cfcd9532085453ab397218c28e56dc61f40a936 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 17 Oct 2025 05:00:23 -0700 Subject: [PATCH] feat: add a permutation reader that can read a permutation view (#2712) This adds a rust permutation builder. In the next PR I will have python bindings and integration with pytorch. --- Cargo.lock | 1 + Cargo.toml | 1 + docs/src/js/functions/permutationBuilder.md | 5 +- nodejs/__test__/permutation.test.ts | 39 +- nodejs/lancedb/permutation.ts | 7 +- nodejs/src/permutation.rs | 28 +- python/python/lancedb/permutation.py | 8 +- python/python/tests/test_permutation.py | 90 ++-- python/src/permutation.rs | 15 +- rust/lancedb/Cargo.toml | 1 + rust/lancedb/src/connection.rs | 2 +- rust/lancedb/src/dataloader.rs | 3 - rust/lancedb/src/dataloader/permutation.rs | 292 +------------ .../src/dataloader/permutation/builder.rs | 326 +++++++++++++++ .../src/dataloader/permutation/reader.rs | 384 ++++++++++++++++++ .../dataloader/{ => permutation}/shuffle.rs | 2 +- .../src/dataloader/{ => permutation}/split.rs | 4 +- .../src/dataloader/{ => permutation}/util.rs | 0 rust/lancedb/src/lib.rs | 3 +- rust/lancedb/src/table.rs | 3 + rust/lancedb/src/test_connection.rs | 126 ------ rust/lancedb/src/test_utils.rs | 5 + rust/lancedb/src/test_utils/connection.rs | 120 ++++++ rust/lancedb/src/test_utils/datagen.rs | 55 +++ 24 files changed, 974 insertions(+), 546 deletions(-) create mode 100644 rust/lancedb/src/dataloader/permutation/builder.rs create mode 100644 rust/lancedb/src/dataloader/permutation/reader.rs rename rust/lancedb/src/dataloader/{ => permutation}/shuffle.rs (99%) rename rust/lancedb/src/dataloader/{ => permutation}/split.rs (99%) rename rust/lancedb/src/dataloader/{ => permutation}/util.rs (100%) delete mode 100644 rust/lancedb/src/test_connection.rs create mode 100644 rust/lancedb/src/test_utils.rs create mode 100644 rust/lancedb/src/test_utils/connection.rs create mode 100644 rust/lancedb/src/test_utils/datagen.rs diff --git a/Cargo.lock b/Cargo.lock index 58a3a41a..37229f05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4697,6 +4697,7 @@ dependencies = [ "arrow-ipc", "arrow-ord", "arrow-schema", + "arrow-select", "async-openai", "async-trait", "aws-config", diff --git a/Cargo.toml b/Cargo.toml index 70ee3f2a..0f33c5b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ arrow-data = "56.2" arrow-ipc = "56.2" arrow-ord = "56.2" arrow-schema = "56.2" +arrow-select = "56.2" arrow-cast = "56.2" async-trait = "0" datafusion = { version = "50.1", default-features = false } diff --git a/docs/src/js/functions/permutationBuilder.md b/docs/src/js/functions/permutationBuilder.md index 63226d66..b7592d41 100644 --- a/docs/src/js/functions/permutationBuilder.md +++ b/docs/src/js/functions/permutationBuilder.md @@ -7,7 +7,7 @@ # Function: permutationBuilder() ```ts -function permutationBuilder(table, destTableName): PermutationBuilder +function permutationBuilder(table): PermutationBuilder ``` Create a permutation builder for the given table. @@ -17,9 +17,6 @@ Create a permutation builder for the given table. * **table**: [`Table`](../classes/Table.md) The source table to create a permutation from -* **destTableName**: `string` - The name for the destination permutation table - ## Returns [`PermutationBuilder`](../classes/PermutationBuilder.md) diff --git a/nodejs/__test__/permutation.test.ts b/nodejs/__test__/permutation.test.ts index be57e5d9..7a6db57b 100644 --- a/nodejs/__test__/permutation.test.ts +++ b/nodejs/__test__/permutation.test.ts @@ -38,23 +38,22 @@ describe("PermutationBuilder", () => { }); test("should create permutation builder", () => { - const builder = permutationBuilder(table, "permutation_table"); + const builder = permutationBuilder(table); expect(builder).toBeDefined(); }); test("should execute basic permutation", async () => { - const builder = permutationBuilder(table, "permutation_table"); + const builder = permutationBuilder(table); const permutationTable = await builder.execute(); expect(permutationTable).toBeDefined(); - expect(permutationTable.name).toBe("permutation_table"); const rowCount = await permutationTable.countRows(); expect(rowCount).toBe(10); }); test("should create permutation with random splits", async () => { - const builder = permutationBuilder(table, "permutation_table").splitRandom({ + const builder = permutationBuilder(table).splitRandom({ ratios: [1.0], seed: 42, }); @@ -65,7 +64,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with percentage splits", async () => { - const builder = permutationBuilder(table, "permutation_table").splitRandom({ + const builder = permutationBuilder(table).splitRandom({ ratios: [0.3, 0.7], seed: 42, }); @@ -84,7 +83,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with count splits", async () => { - const builder = permutationBuilder(table, "permutation_table").splitRandom({ + const builder = permutationBuilder(table).splitRandom({ counts: [3, 7], seed: 42, }); @@ -102,7 +101,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with hash splits", async () => { - const builder = permutationBuilder(table, "permutation_table").splitHash({ + const builder = permutationBuilder(table).splitHash({ columns: ["id"], splitWeights: [50, 50], discardWeight: 0, @@ -122,10 +121,9 @@ describe("PermutationBuilder", () => { }); test("should create permutation with sequential splits", async () => { - const builder = permutationBuilder( - table, - "permutation_table", - ).splitSequential({ ratios: [0.5, 0.5] }); + const builder = permutationBuilder(table).splitSequential({ + ratios: [0.5, 0.5], + }); const permutationTable = await builder.execute(); const rowCount = await permutationTable.countRows(); @@ -140,10 +138,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with calculated splits", async () => { - const builder = permutationBuilder( - table, - "permutation_table", - ).splitCalculated("id % 2"); + const builder = permutationBuilder(table).splitCalculated("id % 2"); const permutationTable = await builder.execute(); const rowCount = await permutationTable.countRows(); @@ -159,7 +154,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with shuffle", async () => { - const builder = permutationBuilder(table, "permutation_table").shuffle({ + const builder = permutationBuilder(table).shuffle({ seed: 42, }); @@ -169,7 +164,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with shuffle and clump size", async () => { - const builder = permutationBuilder(table, "permutation_table").shuffle({ + const builder = permutationBuilder(table).shuffle({ seed: 42, clumpSize: 2, }); @@ -180,9 +175,7 @@ describe("PermutationBuilder", () => { }); test("should create permutation with filter", async () => { - const builder = permutationBuilder(table, "permutation_table").filter( - "value > 50", - ); + const builder = permutationBuilder(table).filter("value > 50"); const permutationTable = await builder.execute(); const rowCount = await permutationTable.countRows(); @@ -190,7 +183,7 @@ describe("PermutationBuilder", () => { }); test("should chain multiple operations", async () => { - const builder = permutationBuilder(table, "permutation_table") + const builder = permutationBuilder(table) .filter("value <= 80") .splitRandom({ ratios: [0.5, 0.5], seed: 42 }) .shuffle({ seed: 123 }); @@ -209,7 +202,7 @@ describe("PermutationBuilder", () => { }); test("should throw error for invalid split arguments", () => { - const builder = permutationBuilder(table, "permutation_table"); + const builder = permutationBuilder(table); // Test no arguments provided expect(() => builder.splitRandom({})).toThrow( @@ -223,7 +216,7 @@ describe("PermutationBuilder", () => { }); test("should throw error when builder is consumed", async () => { - const builder = permutationBuilder(table, "permutation_table"); + const builder = permutationBuilder(table); // Execute once await builder.execute(); diff --git a/nodejs/lancedb/permutation.ts b/nodejs/lancedb/permutation.ts index 98406505..8fb8b508 100644 --- a/nodejs/lancedb/permutation.ts +++ b/nodejs/lancedb/permutation.ts @@ -161,7 +161,6 @@ export class PermutationBuilder { * Create a permutation builder for the given table. * * @param table - The source table to create a permutation from - * @param destTableName - The name for the destination permutation table * @returns A PermutationBuilder instance * @example * ```ts @@ -172,17 +171,13 @@ export class PermutationBuilder { * const trainingTable = await builder.execute(); * ``` */ -export function permutationBuilder( - table: Table, - destTableName: string, -): PermutationBuilder { +export function permutationBuilder(table: Table): PermutationBuilder { // Extract the inner native table from the TypeScript wrapper const localTable = table as LocalTable; // Access inner through type assertion since it's private const nativeBuilder = nativePermutationBuilder( // biome-ignore lint/suspicious/noExplicitAny: need access to private variable (localTable as any).inner, - destTableName, ); return new PermutationBuilder(nativeBuilder); } diff --git a/nodejs/src/permutation.rs b/nodejs/src/permutation.rs index 706a1de7..c569020b 100644 --- a/nodejs/src/permutation.rs +++ b/nodejs/src/permutation.rs @@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex}; use crate::{error::NapiErrorExt, table::Table}; use lancedb::dataloader::{ - permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, - split::{SplitSizes, SplitStrategy}, + permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, + permutation::split::{SplitSizes, SplitStrategy}, }; use napi_derive::napi; @@ -40,7 +40,6 @@ pub struct ShuffleOptions { pub struct PermutationBuilderState { pub builder: Option, - pub dest_table_name: String, } #[napi] @@ -49,11 +48,10 @@ pub struct PermutationBuilder { } impl PermutationBuilder { - pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self { + pub fn new(builder: LancePermutationBuilder) -> Self { Self { state: Arc::new(Mutex::new(PermutationBuilderState { builder: Some(builder), - dest_table_name, })), } } @@ -191,32 +189,26 @@ impl PermutationBuilder { /// Execute the permutation builder and create the table #[napi] pub async fn execute(&self) -> napi::Result { - let (builder, dest_table_name) = { + let builder = { let mut state = self.state.lock().unwrap(); - let builder = state + state .builder .take() - .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?; - - let dest_table_name = std::mem::take(&mut state.dest_table_name); - (builder, dest_table_name) + .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))? }; - let table = builder.build(&dest_table_name).await.default_error()?; + let table = builder.build().await.default_error()?; Ok(Table::new(table)) } } /// Create a permutation builder for the given table #[napi] -pub fn permutation_builder( - table: &crate::table::Table, - dest_table_name: String, -) -> napi::Result { - use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder; +pub fn permutation_builder(table: &crate::table::Table) -> napi::Result { + use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder; let inner_table = table.inner_ref()?.clone(); let inner_builder = LancePermutationBuilder::new(inner_table); - Ok(PermutationBuilder::new(inner_builder, dest_table_name)) + Ok(PermutationBuilder::new(inner_builder)) } diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index bd8aa610..bafaa0eb 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -8,8 +8,8 @@ from typing import Optional class PermutationBuilder: - def __init__(self, table: LanceTable, dest_table_name: str): - self._async = async_permutation_builder(table, dest_table_name) + def __init__(self, table: LanceTable): + self._async = async_permutation_builder(table) def select(self, projections: dict[str, str]) -> "PermutationBuilder": self._async.select(projections) @@ -68,5 +68,5 @@ class PermutationBuilder: return LOOP.run(do_execute()) -def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder: - return PermutationBuilder(table, dest_table_name) +def permutation_builder(table: LanceTable) -> PermutationBuilder: + return PermutationBuilder(table) diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 95cd21c0..7fbf2cc4 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -12,11 +12,7 @@ def test_split_random_ratios(mem_db): tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) - permutation_tbl = ( - permutation_builder(tbl, "test_permutation") - .split_random(ratios=[0.3, 0.7]) - .execute() - ) + permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute() # Check that the table was created and has data assert permutation_tbl.count_rows() == 100 @@ -38,11 +34,7 @@ def test_split_random_counts(mem_db): tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) - permutation_tbl = ( - permutation_builder(tbl, "test_permutation") - .split_random(counts=[20, 30]) - .execute() - ) + permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute() # Check that we have exactly the requested counts assert permutation_tbl.count_rows() == 50 @@ -58,9 +50,7 @@ def test_split_random_fixed(mem_db): tbl = mem_db.create_table( "test_table", pa.table({"x": range(100), "y": range(100)}) ) - permutation_tbl = ( - permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute() - ) + permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute() # Check that we have 4 splits with 25 rows each assert permutation_tbl.count_rows() == 100 @@ -78,17 +68,9 @@ def test_split_random_with_seed(mem_db): tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)})) # Create two identical permutations with same seed - perm1 = ( - permutation_builder(tbl, "perm1") - .split_random(ratios=[0.6, 0.4], seed=42) - .execute() - ) + perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute() - perm2 = ( - permutation_builder(tbl, "perm2") - .split_random(ratios=[0.6, 0.4], seed=42) - .execute() - ) + perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute() # Results should be identical data1 = perm1.search(None).to_arrow().to_pydict() @@ -112,7 +94,7 @@ def test_split_hash(mem_db): ) permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .execute() ) @@ -133,7 +115,7 @@ def test_split_hash(mem_db): # Hash splits should be deterministic - same category should go to same split # Let's verify by creating another permutation and checking consistency perm2 = ( - permutation_builder(tbl, "test_permutation2") + permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .execute() ) @@ -150,7 +132,7 @@ def test_split_hash_with_discard(mem_db): ) permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50% .execute() ) @@ -168,9 +150,7 @@ def test_split_sequential(mem_db): ) permutation_tbl = ( - permutation_builder(tbl, "test_permutation") - .split_sequential(counts=[30, 40]) - .execute() + permutation_builder(tbl).split_sequential(counts=[30, 40]).execute() ) assert permutation_tbl.count_rows() == 70 @@ -194,7 +174,7 @@ def test_split_calculated(mem_db): ) permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .split_calculated("id % 3") # Split based on id modulo 3 .execute() ) @@ -216,23 +196,21 @@ def test_split_error_cases(mem_db): # Test split_random with no parameters with pytest.raises(Exception): - permutation_builder(tbl, "error1").split_random().execute() + permutation_builder(tbl).split_random().execute() # Test split_random with multiple parameters with pytest.raises(Exception): - permutation_builder(tbl, "error2").split_random( + permutation_builder(tbl).split_random( ratios=[0.5, 0.5], counts=[5, 5] ).execute() # Test split_sequential with no parameters with pytest.raises(Exception): - permutation_builder(tbl, "error3").split_sequential().execute() + permutation_builder(tbl).split_sequential().execute() # Test split_sequential with multiple parameters with pytest.raises(Exception): - permutation_builder(tbl, "error4").split_sequential( - ratios=[0.5, 0.5], fixed=2 - ).execute() + permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute() def test_shuffle_no_seed(mem_db): @@ -242,7 +220,7 @@ def test_shuffle_no_seed(mem_db): ) # Create a permutation with shuffling (no seed) - permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute() + permutation_tbl = permutation_builder(tbl).shuffle().execute() assert permutation_tbl.count_rows() == 100 @@ -262,9 +240,9 @@ def test_shuffle_with_seed(mem_db): ) # Create two identical permutations with same shuffle seed - perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute() + perm1 = permutation_builder(tbl).shuffle(seed=42).execute() - perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute() + perm2 = permutation_builder(tbl).shuffle(seed=42).execute() # Results should be identical due to same seed data1 = perm1.search(None).to_arrow().to_pydict() @@ -282,7 +260,7 @@ def test_shuffle_with_clump_size(mem_db): # Create a permutation with shuffling using clumps permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .shuffle(clump_size=10) # 10-row clumps .execute() ) @@ -304,19 +282,9 @@ def test_shuffle_different_seeds(mem_db): ) # Create two permutations with different shuffle seeds - perm1 = ( - permutation_builder(tbl, "perm1") - .split_random(fixed=2) - .shuffle(seed=42) - .execute() - ) + perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute() - perm2 = ( - permutation_builder(tbl, "perm2") - .split_random(fixed=2) - .shuffle(seed=123) - .execute() - ) + perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute() # Results should be different due to different seeds data1 = perm1.search(None).to_arrow().to_pydict() @@ -341,7 +309,7 @@ def test_shuffle_combined_with_splits(mem_db): # Test shuffle with random splits perm_random = ( - permutation_builder(tbl, "perm_random") + permutation_builder(tbl) .split_random(ratios=[0.6, 0.4], seed=42) .shuffle(seed=123, clump_size=None) .execute() @@ -349,7 +317,7 @@ def test_shuffle_combined_with_splits(mem_db): # Test shuffle with hash splits perm_hash = ( - permutation_builder(tbl, "perm_hash") + permutation_builder(tbl) .split_hash(["category"], [1, 1], discard_weight=0) .shuffle(seed=456, clump_size=5) .execute() @@ -357,7 +325,7 @@ def test_shuffle_combined_with_splits(mem_db): # Test shuffle with sequential splits perm_sequential = ( - permutation_builder(tbl, "perm_sequential") + permutation_builder(tbl) .split_sequential(counts=[40, 35]) .shuffle(seed=789, clump_size=None) .execute() @@ -384,7 +352,7 @@ def test_no_shuffle_maintains_order(mem_db): # Create permutation without shuffle (should maintain some order) permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .split_sequential(counts=[25, 25]) # Sequential maintains order .execute() ) @@ -405,9 +373,7 @@ def test_filter_basic(mem_db): ) # Filter to only include rows where id < 50 - permutation_tbl = ( - permutation_builder(tbl, "test_permutation").filter("id < 50").execute() - ) + permutation_tbl = permutation_builder(tbl).filter("id < 50").execute() assert permutation_tbl.count_rows() == 50 @@ -433,7 +399,7 @@ def test_filter_with_splits(mem_db): # Filter to only category A and B, then split permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .filter("category IN ('A', 'B')") .split_random(ratios=[0.5, 0.5]) .execute() @@ -465,7 +431,7 @@ def test_filter_with_shuffle(mem_db): # Filter and shuffle permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .filter("category IN ('A', 'C')") .shuffle(seed=42) .execute() @@ -488,7 +454,7 @@ def test_filter_empty_result(mem_db): # Filter that matches nothing permutation_tbl = ( - permutation_builder(tbl, "test_permutation") + permutation_builder(tbl) .filter("value > 100") # No values > 100 in our data .execute() ) diff --git a/python/src/permutation.rs b/python/src/permutation.rs index 38da3f82..a8d6b4ee 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex}; use crate::{error::PythonErrorExt, table::Table}; use lancedb::dataloader::{ - permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, - split::{SplitSizes, SplitStrategy}, + permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, + permutation::split::{SplitSizes, SplitStrategy}, }; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut, @@ -16,10 +16,7 @@ use pyo3_async_runtimes::tokio::future_into_py; /// Create a permutation builder for the given table #[pyo3::pyfunction] -pub fn async_permutation_builder( - table: Bound<'_, PyAny>, - dest_table_name: String, -) -> PyResult { +pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult { let table = table.getattr("_inner")?.downcast_into::
()?; let inner_table = table.borrow().inner_ref()?.clone(); let inner_builder = LancePermutationBuilder::new(inner_table); @@ -27,14 +24,12 @@ pub fn async_permutation_builder( Ok(PyAsyncPermutationBuilder { state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState { builder: Some(inner_builder), - dest_table_name, })), }) } struct PyAsyncPermutationBuilderState { builder: Option, - dest_table_name: String, } #[pyclass(name = "AsyncPermutationBuilder")] @@ -167,10 +162,8 @@ impl PyAsyncPermutationBuilder { .take() .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; - let dest_table_name = std::mem::take(&mut state.dest_table_name); - future_into_py(slf.py(), async move { - let table = builder.build(&dest_table_name).await.infer_error()?; + let table = builder.build().await.infer_error()?; Ok(Table::new(table)) }) } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 3f83e496..159d1619 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -16,6 +16,7 @@ arrow = { workspace = true } arrow-array = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } +arrow-select = { workspace = true } arrow-ord = { workspace = true } arrow-cast = { workspace = true } arrow-ipc.workspace = true diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index c8e3496d..decaa05d 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -1182,7 +1182,7 @@ mod tests { use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; use crate::query::QueryBase; use crate::query::{ExecutableQuery, QueryExecutionOptions}; - use crate::test_connection::test_utils::new_test_connection; + use crate::test_utils::connection::new_test_connection; use arrow::compute::concat_batches; use arrow_array::RecordBatchReader; use arrow_schema::{DataType, Field, Schema}; diff --git a/rust/lancedb/src/dataloader.rs b/rust/lancedb/src/dataloader.rs index cbb7f037..dd3e4a7a 100644 --- a/rust/lancedb/src/dataloader.rs +++ b/rust/lancedb/src/dataloader.rs @@ -2,6 +2,3 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors pub mod permutation; -pub mod shuffle; -pub mod split; -pub mod util; diff --git a/rust/lancedb/src/dataloader/permutation.rs b/rust/lancedb/src/dataloader/permutation.rs index 09a39d93..a6973d1a 100644 --- a/rust/lancedb/src/dataloader/permutation.rs +++ b/rust/lancedb/src/dataloader/permutation.rs @@ -7,288 +7,12 @@ //! The permutation table only stores the split ids and row ids. It is not a materialized copy of //! the underlying data and can be very lightweight. //! -//! Building a permutation table should be fairly quick and memory efficient, even for billions or -//! trillions of rows. +//! Building a permutation table should be fairly quick (it is an O(N) operation where N is +//! the number of rows in the base table) and memory efficient, even for billions or trillions +//! of rows. -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder}; -use datafusion_expr::col; -use futures::TryStreamExt; -use lance_datafusion::exec::SessionContextExt; - -use crate::{ - arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream}, - dataloader::{ - shuffle::{Shuffler, ShufflerConfig}, - split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, - util::{rename_column, TemporaryDirectory}, - }, - query::{ExecutableQuery, QueryBase}, - Connection, Error, Result, Table, -}; - -/// Configuration for creating a permutation table -#[derive(Debug, Default)] -pub struct PermutationConfig { - /// Splitting configuration - pub split_strategy: SplitStrategy, - /// Shuffle strategy - pub shuffle_strategy: ShuffleStrategy, - /// Optional filter to apply to the base table - pub filter: Option, - /// Directory to use for temporary files - pub temp_dir: TemporaryDirectory, -} - -/// Strategy for shuffling the data. -#[derive(Debug, Clone)] -pub enum ShuffleStrategy { - /// The data is randomly shuffled - /// - /// A seed can be provided to make the shuffle deterministic. - /// - /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows. - /// This decreases the overall randomization but can improve I/O performance when reading from - /// cloud storage. - /// - /// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This - /// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence - /// the performance of the model. Note: shuffling within clumps can still be done at read time but - /// this will only provide a local shuffle and not a global shuffle. - Random { - seed: Option, - clump_size: Option, - }, - /// The data is not shuffled - /// - /// This is useful for debugging and testing. - None, -} - -impl Default for ShuffleStrategy { - fn default() -> Self { - Self::None - } -} - -/// Builder for creating a permutation table. -/// -/// A permutation table is a table that stores split assignments and a shuffled order of rows. This -/// can be used to create a -pub struct PermutationBuilder { - config: PermutationConfig, - base_table: Table, -} - -impl PermutationBuilder { - pub fn new(base_table: Table) -> Self { - Self { - config: PermutationConfig::default(), - base_table, - } - } - - /// Configures the strategy for assigning rows to splits. - /// - /// For example, it is common to create a test/train split of the data. Splits can also be used - /// to limit the number of rows. For example, to only use 10% of the data in a permutation you can - /// create a single split with 10% of the data. - /// - /// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across - /// multiple processes and multiple nodes. - /// - /// The default is a single split that contains all rows. - pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self { - self.config.split_strategy = split_strategy; - self - } - - /// Configures the strategy for shuffling the data. - /// - /// The default is to shuffle the data randomly at row-level granularity (no shard size) and - /// with a random seed. - pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self { - self.config.shuffle_strategy = shuffle_strategy; - self - } - - /// Configures a filter to apply to the base table. - /// - /// Only rows matching the filter will be included in the permutation. - pub fn with_filter(mut self, filter: String) -> Self { - self.config.filter = Some(filter); - self - } - - /// Configures the directory to use for temporary files. - /// - /// The default is to use the operating system's default temporary directory. - pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self { - self.config.temp_dir = temp_dir; - self - } - - async fn sort_by_split_id( - &self, - data: SendableRecordBatchStream, - ) -> Result { - let ctx = SessionContext::new_with_config_rt( - SessionConfig::default(), - RuntimeEnvBuilder::new() - .with_memory_limit(100 * 1024 * 1024, 1.0) - .with_disk_manager_builder( - DiskManagerBuilder::default() - .with_mode(self.config.temp_dir.to_disk_manager_mode()), - ) - .build_arc() - .unwrap(), - ); - let df = ctx - .read_one_shot(data.into_df_stream()) - .map_err(|e| Error::Other { - message: format!("Failed to setup sort by split id: {}", e), - source: Some(e.into()), - })?; - let df_stream = df - .sort_by(vec![col(SPLIT_ID_COLUMN)]) - .map_err(|e| Error::Other { - message: format!("Failed to plan sort by split id: {}", e), - source: Some(e.into()), - })? - .execute_stream() - .await - .map_err(|e| Error::Other { - message: format!("Failed to sort by split id: {}", e), - source: Some(e.into()), - })?; - - let schema = df_stream.schema(); - let stream = df_stream.map_err(|e| Error::Other { - message: format!("Failed to execute sort by split id: {}", e), - source: Some(e.into()), - }); - Ok(Box::pin(SimpleRecordBatchStream { schema, stream })) - } - - /// Builds the permutation table and stores it in the given database. - pub async fn build(self, dest_table_name: &str) -> Result
{ - // First pass, apply filter and load row ids - let mut rows = self.base_table.query().with_row_id(); - - if let Some(filter) = &self.config.filter { - rows = rows.only_if(filter); - } - - let splitter = Splitter::new( - self.config.temp_dir.clone(), - self.config.split_strategy.clone(), - ); - - let mut needs_sort = !splitter.orders_by_split_id(); - - // Might need to load additional columns to calculate splits (e.g. hash columns or calculated - // split id) - rows = splitter.project(rows); - - let num_rows = self - .base_table - .count_rows(self.config.filter.clone()) - .await? as u64; - - // Apply splits - let rows = rows.execute().await?; - let split_data = splitter.apply(rows, num_rows).await?; - - // Shuffle data if requested - let shuffled = match self.config.shuffle_strategy { - ShuffleStrategy::None => split_data, - ShuffleStrategy::Random { seed, clump_size } => { - let shuffler = Shuffler::new(ShufflerConfig { - seed, - clump_size, - temp_dir: self.config.temp_dir.clone(), - max_rows_per_file: 10 * 1024 * 1024, - }); - shuffler.shuffle(split_data, num_rows).await? - } - }; - - // We want the final permutation to be sorted by the split id. If we shuffled or if - // the split was not assigned sequentially then we need to sort the data. - needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None); - - let sorted = if needs_sort { - self.sort_by_split_id(shuffled).await? - } else { - shuffled - }; - - // Rename _rowid to row_id - let renamed = rename_column(sorted, "_rowid", "row_id")?; - - // Create permutation table - let conn = Connection::new( - self.base_table.database().clone(), - self.base_table.embedding_registry().clone(), - ); - conn.create_table_streaming(dest_table_name, renamed) - .execute() - .await - } -} - -#[cfg(test)] -mod tests { - use arrow::datatypes::Int32Type; - use lance_datagen::{BatchCount, RowCount}; - - use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes}; - - use super::*; - - #[tokio::test] - async fn test_permutation_builder() { - let temp_dir = tempfile::tempdir().unwrap(); - - let db = connect(temp_dir.path().to_str().unwrap()) - .execute() - .await - .unwrap(); - - let initial_data = lance_datagen::gen_batch() - .col("some_value", lance_datagen::array::step::()) - .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); - let data_table = db - .create_table_streaming("mytbl", initial_data) - .execute() - .await - .unwrap(); - - let permutation_table = PermutationBuilder::new(data_table) - .with_filter("some_value > 57".to_string()) - .with_split_strategy(SplitStrategy::Random { - seed: Some(42), - sizes: SplitSizes::Percentages(vec![0.05, 0.30]), - }) - .build("permutation") - .await - .unwrap(); - - // Potentially brittle seed-dependent values below - assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330); - assert_eq!( - permutation_table - .count_rows(Some("split_id = 0".to_string())) - .await - .unwrap(), - 47 - ); - assert_eq!( - permutation_table - .count_rows(Some("split_id = 1".to_string())) - .await - .unwrap(), - 283 - ); - } -} +pub mod builder; +pub mod reader; +pub mod shuffle; +pub mod split; +pub mod util; diff --git a/rust/lancedb/src/dataloader/permutation/builder.rs b/rust/lancedb/src/dataloader/permutation/builder.rs new file mode 100644 index 00000000..ea1cad62 --- /dev/null +++ b/rust/lancedb/src/dataloader/permutation/builder.rs @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder}; +use datafusion_expr::col; +use futures::TryStreamExt; +use lance_core::ROW_ID; +use lance_datafusion::exec::SessionContextExt; + +use crate::{ + arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream}, + connect, + database::{CreateTableData, CreateTableRequest, Database}, + dataloader::permutation::{ + shuffle::{Shuffler, ShufflerConfig}, + split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, + util::{rename_column, TemporaryDirectory}, + }, + query::{ExecutableQuery, QueryBase}, + Error, Result, Table, +}; + +pub const SRC_ROW_ID_COL: &str = "row_id"; + +/// Where to store the permutation table +#[derive(Debug, Clone, Default)] +enum PermutationDestination { + /// The permutation table is a temporary table in memory + #[default] + Temporary, + /// The permutation table is a permanent table in a database + Permanent(Arc, String), +} + +/// Configuration for creating a permutation table +#[derive(Debug, Default)] +pub struct PermutationConfig { + /// Splitting configuration + split_strategy: SplitStrategy, + /// Shuffle strategy + shuffle_strategy: ShuffleStrategy, + /// Optional filter to apply to the base table + filter: Option, + /// Directory to use for temporary files + temp_dir: TemporaryDirectory, + /// Destination + destination: PermutationDestination, +} + +/// Strategy for shuffling the data. +#[derive(Debug, Clone)] +pub enum ShuffleStrategy { + /// The data is randomly shuffled + /// + /// A seed can be provided to make the shuffle deterministic. + /// + /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows. + /// This decreases the overall randomization but can improve I/O performance when reading from + /// cloud storage. + /// + /// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This + /// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence + /// the performance of the model. Note: shuffling within clumps can still be done at read time but + /// this will only provide a local shuffle and not a global shuffle. + Random { + seed: Option, + clump_size: Option, + }, + /// The data is not shuffled + /// + /// This is useful for debugging and testing. + None, +} + +impl Default for ShuffleStrategy { + fn default() -> Self { + Self::None + } +} + +/// Builder for creating a permutation table. +/// +/// A permutation table is a table that stores split assignments and a shuffled order of rows. This +/// can be used to create a permutation reader that reads rows in the order defined by the permutation. +/// +/// The permutation table is not a materialized copy of the underlying data and can be very lightweight. +/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that +/// stores just row id and split id. +pub struct PermutationBuilder { + config: PermutationConfig, + base_table: Table, +} + +impl PermutationBuilder { + pub fn new(base_table: Table) -> Self { + Self { + config: PermutationConfig::default(), + base_table, + } + } + + /// Configures the strategy for assigning rows to splits. + /// + /// For example, it is common to create a test/train split of the data. Splits can also be used + /// to limit the number of rows. For example, to only use 10% of the data in a permutation you can + /// create a single split with 10% of the data. + /// + /// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across + /// multiple processes and multiple nodes. + /// + /// The default is a single split that contains all rows. + pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self { + self.config.split_strategy = split_strategy; + self + } + + /// Configures the strategy for shuffling the data. + /// + /// The default is to shuffle the data randomly at row-level granularity (no clump size) and + /// with a random seed. + pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self { + self.config.shuffle_strategy = shuffle_strategy; + self + } + + /// Configures a filter to apply to the base table. + /// + /// Only rows matching the filter will be included in the permutation. + pub fn with_filter(mut self, filter: String) -> Self { + self.config.filter = Some(filter); + self + } + + /// Configures the directory to use for temporary files. + /// + /// The default is to use the operating system's default temporary directory. + pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self { + self.config.temp_dir = temp_dir; + self + } + + /// Stores the permutation as a table in a database + /// + /// By default, the permutation is stored in memory. If this method is called then + /// the permutation will be stored as a table in the given database. + pub fn persist(mut self, database: Arc, table_name: String) -> Self { + self.config.destination = PermutationDestination::Permanent(database, table_name); + self + } + + async fn sort_by_split_id( + &self, + data: SendableRecordBatchStream, + ) -> Result { + let ctx = SessionContext::new_with_config_rt( + SessionConfig::default(), + RuntimeEnvBuilder::new() + .with_memory_limit(100 * 1024 * 1024, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default() + .with_mode(self.config.temp_dir.to_disk_manager_mode()), + ) + .build_arc() + .unwrap(), + ); + let df = ctx + .read_one_shot(data.into_df_stream()) + .map_err(|e| Error::Other { + message: format!("Failed to setup sort by split id: {}", e), + source: Some(e.into()), + })?; + let df_stream = df + .sort_by(vec![col(SPLIT_ID_COLUMN)]) + .map_err(|e| Error::Other { + message: format!("Failed to plan sort by split id: {}", e), + source: Some(e.into()), + })? + .execute_stream() + .await + .map_err(|e| Error::Other { + message: format!("Failed to sort by split id: {}", e), + source: Some(e.into()), + })?; + + let schema = df_stream.schema(); + let stream = df_stream.map_err(|e| Error::Other { + message: format!("Failed to execute sort by split id: {}", e), + source: Some(e.into()), + }); + Ok(Box::pin(SimpleRecordBatchStream { schema, stream })) + } + + /// Builds the permutation table and stores it in the given database. + pub async fn build(self) -> Result
{ + // First pass, apply filter and load row ids + let mut rows = self.base_table.query().with_row_id(); + + if let Some(filter) = &self.config.filter { + rows = rows.only_if(filter); + } + + let splitter = Splitter::new( + self.config.temp_dir.clone(), + self.config.split_strategy.clone(), + ); + + let mut needs_sort = !splitter.orders_by_split_id(); + + // Might need to load additional columns to calculate splits (e.g. hash columns or calculated + // split id) + rows = splitter.project(rows); + + let num_rows = self + .base_table + .count_rows(self.config.filter.clone()) + .await? as u64; + + // Apply splits + let rows = rows.execute().await?; + let split_data = splitter.apply(rows, num_rows).await?; + + // Shuffle data if requested + let shuffled = match self.config.shuffle_strategy { + ShuffleStrategy::None => split_data, + ShuffleStrategy::Random { seed, clump_size } => { + let shuffler = Shuffler::new(ShufflerConfig { + seed, + clump_size, + temp_dir: self.config.temp_dir.clone(), + max_rows_per_file: 10 * 1024 * 1024, + }); + shuffler.shuffle(split_data, num_rows).await? + } + }; + + // We want the final permutation to be sorted by the split id. If we shuffled or if + // the split was not assigned sequentially then we need to sort the data. + needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None); + + let sorted = if needs_sort { + self.sort_by_split_id(shuffled).await? + } else { + shuffled + }; + + // Rename _rowid to row_id + let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?; + + let (name, database) = match &self.config.destination { + PermutationDestination::Permanent(database, table_name) => { + (table_name.as_str(), database.clone()) + } + PermutationDestination::Temporary => { + let conn = connect("memory:///").execute().await?; + ("permutation", conn.database().clone()) + } + }; + + let create_table_request = + CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed)); + + let table = database.create_table(create_table_request).await?; + Ok(Table::new(table, database)) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Int32Type; + use lance_datagen::{BatchCount, RowCount}; + + use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes}; + + use super::*; + + #[tokio::test] + async fn test_permutation_builder() { + let temp_dir = tempfile::tempdir().unwrap(); + + let db = connect(temp_dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + + let initial_data = lance_datagen::gen_batch() + .col("some_value", lance_datagen::array::step::()) + .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); + let data_table = db + .create_table_streaming("mytbl", initial_data) + .execute() + .await + .unwrap(); + + let permutation_table = PermutationBuilder::new(data_table.clone()) + .with_filter("some_value > 57".to_string()) + .with_split_strategy(SplitStrategy::Random { + seed: Some(42), + sizes: SplitSizes::Percentages(vec![0.05, 0.30]), + }) + .build() + .await + .unwrap(); + + println!("permutation_table: {:?}", permutation_table); + + // Potentially brittle seed-dependent values below + assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330); + assert_eq!( + permutation_table + .count_rows(Some("split_id = 0".to_string())) + .await + .unwrap(), + 47 + ); + assert_eq!( + permutation_table + .count_rows(Some("split_id = 1".to_string())) + .await + .unwrap(), + 283 + ); + } +} diff --git a/rust/lancedb/src/dataloader/permutation/reader.rs b/rust/lancedb/src/dataloader/permutation/reader.rs new file mode 100644 index 00000000..d68d88a6 --- /dev/null +++ b/rust/lancedb/src/dataloader/permutation/reader.rs @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Row ID-based views for LanceDB tables +//! +//! This module provides functionality for creating views that are based on specific row IDs. +//! The `IdView` allows you to create a virtual table that contains only +//! the rows from a source table that correspond to row IDs stored in a separate table. + +use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}; +use crate::dataloader::permutation::builder::SRC_ROW_ID_COL; +use crate::dataloader::permutation::split::SPLIT_ID_COLUMN; +use crate::error::Error; +use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select}; +use crate::table::{AnyQuery, BaseTable}; +use crate::Result; +use arrow::array::AsArray; +use arrow::datatypes::UInt64Type; +use arrow_array::{RecordBatch, UInt64Array}; +use futures::{StreamExt, TryStreamExt}; +use lance::arrow::RecordBatchExt; +use lance::dataset::scanner::DatasetRecordBatchStream; +use lance::error::LanceOptionExt; +use lance_core::ROW_ID; +use std::collections::HashMap; +use std::sync::Arc; + +/// Reads a permutation of a source table based on row IDs stored in a separate table +pub struct PermutationReader { + base_table: Arc, + permutation_table: Arc, +} + +impl std::fmt::Debug for PermutationReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "PermutationReader(base={}, permutation={})", + self.base_table.name(), + self.permutation_table.name(), + ) + } +} + +impl PermutationReader { + /// Create a new PermutationReader + pub async fn try_new( + base_table: Arc, + permutation_table: Arc, + ) -> Result { + let schema = permutation_table.schema().await?; + if schema.column_with_name(SRC_ROW_ID_COL).is_none() { + return Err(Error::InvalidInput { + message: "Permutation table must contain a column named row_id".to_string(), + }); + } + if schema.column_with_name(SPLIT_ID_COLUMN).is_none() { + return Err(Error::InvalidInput { + message: "Permutation table must contain a column named split_id".to_string(), + }); + } + Ok(Self { + base_table, + permutation_table, + }) + } + + fn is_sorted_already<'a, T: Iterator>(iter: T) -> bool { + for (expected, idx) in iter.enumerate() { + if *idx != expected as u64 { + return false; + } + } + true + } + + async fn load_batch( + base_table: &Arc, + row_ids: RecordBatch, + selection: Select, + has_row_id: bool, + ) -> Result { + let num_rows = row_ids.num_rows(); + let row_ids = row_ids + .column(0) + .as_primitive_opt::() + .expect_ok()? + .values(); + + let filter = format!( + "_rowid in ({})", + row_ids + .iter() + .map(|o| o.to_string()) + .collect::>() + .join(",") + ); + + let base_query = QueryRequest { + filter: Some(QueryFilter::Sql(filter)), + select: selection, + with_row_id: true, + ..Default::default() + }; + + let mut data = base_table + .query( + &AnyQuery::Query(base_query), + QueryExecutionOptions { + max_batch_length: num_rows as u32, + ..Default::default() + }, + ) + .await?; + + let Some(batch) = data.try_next().await? else { + return Err(Error::InvalidInput { + message: "Base table returned no batches".to_string(), + }); + }; + if data.try_next().await?.is_some() { + return Err(Error::InvalidInput { + message: "Base table returned more than one batch".to_string(), + }); + } + + if batch.num_rows() != num_rows { + return Err(Error::InvalidInput { + message: "Base table returned different number of rows than the number of row IDs" + .to_string(), + }); + } + + // There is no guarantee the result order will match the order provided + // so may need to restore order + let actual_row_ids = batch + .column_by_name(ROW_ID) + .expect_ok()? + .as_primitive_opt::() + .expect_ok()? + .values(); + + // Map from row id to order in batch, used to restore original ordering + let ordering = actual_row_ids + .iter() + .copied() + .enumerate() + .map(|(i, o)| (o, i as u64)) + .collect::>(); + + let desired_idx_order = row_ids + .iter() + .map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from)) + .collect::>>()?; + + let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) { + // Fast path if already sorted, important as data may be large and + // re-ordering could be expensive + batch + } else { + let desired_idx_order = UInt64Array::from(desired_idx_order); + + arrow_select::take::take_record_batch(&batch, &desired_idx_order)? + }; + + if has_row_id { + Ok(ordered_batch) + } else { + // The user didn't ask for row id, we needed it for ordering the data, but now we drop it + Ok(ordered_batch.drop_column(ROW_ID)?) + } + } + + async fn row_ids_to_batches( + base_table: Arc, + row_ids: DatasetRecordBatchStream, + selection: Select, + ) -> Result { + let has_row_id = Self::has_row_id(&selection)?; + let mut stream = row_ids + .map_err(Error::from) + .try_filter_map(move |batch| { + let selection = selection.clone(); + let base_table = base_table.clone(); + async move { + Self::load_batch(&base_table, batch, selection, has_row_id) + .await + .map(Some) + } + }) + .boxed(); + + // Need to read out first batch to get schema + let Some(first_batch) = stream.try_next().await? else { + return Err(Error::InvalidInput { + message: "Permutation was empty".to_string(), + }); + }; + let schema = first_batch.schema(); + + let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream); + + Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema))) + } + + fn has_row_id(selection: &Select) -> Result { + match selection { + Select::All => { + // _rowid is a system column and is not included in Select::All + Ok(false) + } + Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())), + Select::Dynamic(columns) => { + for column in columns { + if column.0 == ROW_ID { + if column.1 == ROW_ID { + return Ok(true); + } else { + return Err(Error::InvalidInput { + message: format!( + "Dynamic column {} cannot be used to select _rowid", + column.1 + ), + }); + } + } + } + Ok(false) + } + } + } + + pub async fn read_split( + &self, + split: u64, + selection: Select, + execution_options: QueryExecutionOptions, + ) -> Result { + let row_ids = self + .permutation_table + .query( + &AnyQuery::Query(QueryRequest { + select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]), + filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))), + ..Default::default() + }), + execution_options, + ) + .await?; + + Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Int32Type; + use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema}; + use lance_datagen::{BatchCount, RowCount}; + use rand::seq::SliceRandom; + + use crate::{ + arrow::SendableRecordBatchStream, + query::{ExecutableQuery, QueryBase}, + test_utils::datagen::{virtual_table, LanceDbDatagenExt}, + Table, + }; + + use super::*; + + async fn collect_from_stream( + mut stream: SendableRecordBatchStream, + column: &str, + ) -> Vec { + let mut row_ids = Vec::new(); + while let Some(batch) = stream.try_next().await.unwrap() { + let col_idx = batch.schema().index_of(column).unwrap(); + row_ids.extend(batch.column(col_idx).as_primitive::().values().to_vec()); + } + row_ids + } + + async fn collect_column(table: &Table, column: &str) -> Vec { + collect_from_stream::( + table + .query() + .select(Select::Columns(vec![column.to_string()])) + .execute() + .await + .unwrap(), + column, + ) + .await + } + + #[tokio::test] + async fn test_permutation_reader() { + let base_table = lance_datagen::gen_batch() + .col("idx", lance_datagen::array::step::()) + .col("other_col", lance_datagen::array::step::()) + .into_mem_table("tbl", RowCount::from(9), BatchCount::from(1)) + .await; + + let mut row_ids = collect_column::(&base_table, "_rowid").await; + row_ids.shuffle(&mut rand::rng()); + // Put the last two rows in split 1 + let split_ids = UInt64Array::from_iter_values( + std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)), + ); + let permutation_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("row_id", DataType::UInt64, false), + Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false), + ])), + vec![ + Arc::new(UInt64Array::from(row_ids.clone())), + Arc::new(split_ids), + ], + ) + .unwrap(); + let row_ids_table = virtual_table("row_ids", &permutation_batch).await; + + let reader = PermutationReader::try_new( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + ) + .await + .unwrap(); + + // Read split 0 + let mut stream = reader + .read_split( + 0, + Select::All, + QueryExecutionOptions { + max_batch_length: 3, + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(stream.schema(), base_table.schema().await.unwrap()); + + let check_batch = async |stream: &mut SendableRecordBatchStream, + expected_values: &[u64]| { + let batch = stream.try_next().await.unwrap().unwrap(); + assert_eq!(batch.num_rows(), expected_values.len()); + assert_eq!( + batch.column(0).as_primitive::().values(), + &expected_values + .iter() + .map(|o| *o as i32) + .collect::>() + ); + assert_eq!( + batch.column(1).as_primitive::().values(), + &expected_values + ); + }; + + check_batch(&mut stream, &row_ids[0..3]).await; + check_batch(&mut stream, &row_ids[3..6]).await; + check_batch(&mut stream, &row_ids[6..7]).await; + assert!(stream.try_next().await.unwrap().is_none()); + + // Read split 1 + let mut stream = reader + .read_split( + 1, + Select::All, + QueryExecutionOptions { + max_batch_length: 3, + ..Default::default() + }, + ) + .await + .unwrap(); + + check_batch(&mut stream, &row_ids[7..9]).await; + assert!(stream.try_next().await.unwrap().is_none()); + } +} diff --git a/rust/lancedb/src/dataloader/shuffle.rs b/rust/lancedb/src/dataloader/permutation/shuffle.rs similarity index 99% rename from rust/lancedb/src/dataloader/shuffle.rs rename to rust/lancedb/src/dataloader/permutation/shuffle.rs index e06affc7..2c2b10fc 100644 --- a/rust/lancedb/src/dataloader/shuffle.rs +++ b/rust/lancedb/src/dataloader/permutation/shuffle.rs @@ -22,7 +22,7 @@ use rand::{seq::SliceRandom, Rng, RngCore}; use crate::{ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, - dataloader::util::{non_crypto_rng, TemporaryDirectory}, + dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory}, Error, Result, }; diff --git a/rust/lancedb/src/dataloader/split.rs b/rust/lancedb/src/dataloader/permutation/split.rs similarity index 99% rename from rust/lancedb/src/dataloader/split.rs rename to rust/lancedb/src/dataloader/permutation/split.rs index 32775488..d01db4c4 100644 --- a/rust/lancedb/src/dataloader/split.rs +++ b/rust/lancedb/src/dataloader/permutation/split.rs @@ -18,8 +18,8 @@ use lance::arrow::SchemaExt; use crate::{ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, dataloader::{ - shuffle::{Shuffler, ShufflerConfig}, - util::TemporaryDirectory, + permutation::shuffle::{Shuffler, ShufflerConfig}, + permutation::util::TemporaryDirectory, }, query::{Query, QueryBase, Select}, Error, Result, diff --git a/rust/lancedb/src/dataloader/util.rs b/rust/lancedb/src/dataloader/permutation/util.rs similarity index 100% rename from rust/lancedb/src/dataloader/util.rs rename to rust/lancedb/src/dataloader/permutation/util.rs diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 75159e04..5d5272a2 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -207,7 +207,8 @@ pub mod query; pub mod remote; pub mod rerankers; pub mod table; -pub mod test_connection; +#[cfg(test)] +pub mod test_utils; pub mod utils; use std::fmt::Display; diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index d632644c..6cfea7dd 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -511,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { /// Get the namespace of the table. fn namespace(&self) -> &[String]; /// Get the id of the table + /// + /// This is the namespace of the table concatenated with the name + /// separated by a dot (".") fn id(&self) -> &str; /// Get the arrow [Schema] of the table. async fn schema(&self) -> Result; diff --git a/rust/lancedb/src/test_connection.rs b/rust/lancedb/src/test_connection.rs deleted file mode 100644 index 2afd41ca..00000000 --- a/rust/lancedb/src/test_connection.rs +++ /dev/null @@ -1,126 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The LanceDB Authors - -//! Functions for testing connections. - -#[cfg(test)] -pub mod test_utils { - use regex::Regex; - use std::env; - use std::io::{BufRead, BufReader}; - use std::process::{Child, ChildStdout, Command, Stdio}; - - use crate::{connect, Connection}; - use anyhow::{bail, Result}; - use tempfile::{tempdir, TempDir}; - - pub struct TestConnection { - pub uri: String, - pub connection: Connection, - _temp_dir: Option, - _process: Option, - } - - struct TestProcess { - child: Child, - } - - impl Drop for TestProcess { - #[allow(unused_must_use)] - fn drop(&mut self) { - self.child.kill(); - } - } - - pub async fn new_test_connection() -> Result { - match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") { - Ok(script_path) => new_remote_connection(&script_path).await, - Err(_e) => new_local_connection().await, - } - } - - async fn new_remote_connection(script_path: &str) -> Result { - let temp_dir = tempdir()?; - let data_path = temp_dir.path().to_str().unwrap().to_string(); - let child_result = Command::new(script_path) - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .arg(data_path.clone()) - .spawn(); - if child_result.is_err() { - bail!(format!( - "Unable to run {}: {:?}", - script_path, - child_result.err() - )); - } - let mut process = TestProcess { - child: child_result.unwrap(), - }; - let stdout = BufReader::new(process.child.stdout.take().unwrap()); - let port = read_process_port(stdout)?; - let uri = "db://test"; - let host_override = format!("http://localhost:{}", port); - let connection = create_new_connection(uri, &host_override).await?; - Ok(TestConnection { - uri: uri.to_string(), - connection, - _temp_dir: Some(temp_dir), - _process: Some(process), - }) - } - - fn read_process_port(mut stdout: BufReader) -> Result { - let mut line = String::new(); - let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap(); - loop { - let result = stdout.read_line(&mut line); - if let Err(err) = result { - bail!(format!( - "read_process_port: error while reading from process output: {}", - err - )); - } else if result.unwrap() == 0 { - bail!("read_process_port: hit EOF before reading port from process output."); - } - if re.is_match(&line) { - let caps = re.captures(&line).unwrap(); - return Ok(caps[1].to_string()); - } - } - } - - #[cfg(feature = "remote")] - async fn create_new_connection( - uri: &str, - host_override: &str, - ) -> crate::error::Result { - connect(uri) - .region("us-east-1") - .api_key("sk_localtest") - .host_override(host_override) - .execute() - .await - } - - #[cfg(not(feature = "remote"))] - async fn create_new_connection( - _uri: &str, - _host_override: &str, - ) -> crate::error::Result { - panic!("remote feature not supported"); - } - - async fn new_local_connection() -> Result { - let temp_dir = tempdir()?; - let uri = temp_dir.path().to_str().unwrap(); - let connection = connect(uri).execute().await?; - Ok(TestConnection { - uri: uri.to_string(), - connection, - _temp_dir: Some(temp_dir), - _process: None, - }) - } -} diff --git a/rust/lancedb/src/test_utils.rs b/rust/lancedb/src/test_utils.rs new file mode 100644 index 00000000..daf749bc --- /dev/null +++ b/rust/lancedb/src/test_utils.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +pub mod connection; +pub mod datagen; diff --git a/rust/lancedb/src/test_utils/connection.rs b/rust/lancedb/src/test_utils/connection.rs new file mode 100644 index 00000000..2811f05f --- /dev/null +++ b/rust/lancedb/src/test_utils/connection.rs @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Functions for testing connections. + +use regex::Regex; +use std::env; +use std::io::{BufRead, BufReader}; +use std::process::{Child, ChildStdout, Command, Stdio}; + +use crate::{connect, Connection}; +use anyhow::{bail, Result}; +use tempfile::{tempdir, TempDir}; + +pub struct TestConnection { + pub uri: String, + pub connection: Connection, + _temp_dir: Option, + _process: Option, +} + +struct TestProcess { + child: Child, +} + +impl Drop for TestProcess { + #[allow(unused_must_use)] + fn drop(&mut self) { + self.child.kill(); + } +} + +pub async fn new_test_connection() -> Result { + match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") { + Ok(script_path) => new_remote_connection(&script_path).await, + Err(_e) => new_local_connection().await, + } +} + +async fn new_remote_connection(script_path: &str) -> Result { + let temp_dir = tempdir()?; + let data_path = temp_dir.path().to_str().unwrap().to_string(); + let child_result = Command::new(script_path) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg(data_path.clone()) + .spawn(); + if child_result.is_err() { + bail!(format!( + "Unable to run {}: {:?}", + script_path, + child_result.err() + )); + } + let mut process = TestProcess { + child: child_result.unwrap(), + }; + let stdout = BufReader::new(process.child.stdout.take().unwrap()); + let port = read_process_port(stdout)?; + let uri = "db://test"; + let host_override = format!("http://localhost:{}", port); + let connection = create_new_connection(uri, &host_override).await?; + Ok(TestConnection { + uri: uri.to_string(), + connection, + _temp_dir: Some(temp_dir), + _process: Some(process), + }) +} + +fn read_process_port(mut stdout: BufReader) -> Result { + let mut line = String::new(); + let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap(); + loop { + let result = stdout.read_line(&mut line); + if let Err(err) = result { + bail!(format!( + "read_process_port: error while reading from process output: {}", + err + )); + } else if result.unwrap() == 0 { + bail!("read_process_port: hit EOF before reading port from process output."); + } + if re.is_match(&line) { + let caps = re.captures(&line).unwrap(); + return Ok(caps[1].to_string()); + } + } +} + +#[cfg(feature = "remote")] +async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result { + connect(uri) + .region("us-east-1") + .api_key("sk_localtest") + .host_override(host_override) + .execute() + .await +} + +#[cfg(not(feature = "remote"))] +async fn create_new_connection( + _uri: &str, + _host_override: &str, +) -> crate::error::Result { + panic!("remote feature not supported"); +} + +async fn new_local_connection() -> Result { + let temp_dir = tempdir()?; + let uri = temp_dir.path().to_str().unwrap(); + let connection = connect(uri).execute().await?; + Ok(TestConnection { + uri: uri.to_string(), + connection, + _temp_dir: Some(temp_dir), + _process: None, + }) +} diff --git a/rust/lancedb/src/test_utils/datagen.rs b/rust/lancedb/src/test_utils/datagen.rs new file mode 100644 index 00000000..15b79e5d --- /dev/null +++ b/rust/lancedb/src/test_utils/datagen.rs @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use arrow_array::RecordBatch; +use futures::TryStreamExt; +use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount}; + +use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + connect, Error, Table, +}; + +#[async_trait::async_trait] +pub trait LanceDbDatagenExt { + async fn into_mem_table( + self, + table_name: &str, + rows_per_batch: RowCount, + num_batches: BatchCount, + ) -> Table; +} + +#[async_trait::async_trait] +impl LanceDbDatagenExt for BatchGeneratorBuilder { + async fn into_mem_table( + self, + table_name: &str, + rows_per_batch: RowCount, + num_batches: BatchCount, + ) -> Table { + let (stream, schema) = self.into_reader_stream(rows_per_batch, num_batches); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new( + stream.map_err(Error::from), + schema, + )); + let db = connect("memory:///").execute().await.unwrap(); + db.create_table_streaming(table_name, stream) + .execute() + .await + .unwrap() + } +} + +pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table { + let schema = values.schema(); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new( + futures::stream::once(std::future::ready(Ok(values.clone()))), + schema, + )); + let db = connect("memory:///").execute().await.unwrap(); + db.create_table_streaming(name, stream) + .execute() + .await + .unwrap() +}