feat: add a permutation reader that can read a permutation view (#2712)

This adds a rust permutation builder. In the next PR I will have python
bindings and integration with pytorch.
This commit is contained in:
Weston Pace
2025-10-17 05:00:23 -07:00
committed by GitHub
parent a70ff04bc9
commit 4cfcd95320
24 changed files with 974 additions and 546 deletions

1
Cargo.lock generated
View File

@@ -4697,6 +4697,7 @@ dependencies = [
"arrow-ipc",
"arrow-ord",
"arrow-schema",
"arrow-select",
"async-openai",
"async-trait",
"aws-config",

View File

@@ -35,6 +35,7 @@ arrow-data = "56.2"
arrow-ipc = "56.2"
arrow-ord = "56.2"
arrow-schema = "56.2"
arrow-select = "56.2"
arrow-cast = "56.2"
async-trait = "0"
datafusion = { version = "50.1", default-features = false }

View File

@@ -7,7 +7,7 @@
# Function: permutationBuilder()
```ts
function permutationBuilder(table, destTableName): PermutationBuilder
function permutationBuilder(table): PermutationBuilder
```
Create a permutation builder for the given table.
@@ -17,9 +17,6 @@ Create a permutation builder for the given table.
* **table**: [`Table`](../classes/Table.md)
The source table to create a permutation from
* **destTableName**: `string`
The name for the destination permutation table
## Returns
[`PermutationBuilder`](../classes/PermutationBuilder.md)

View File

@@ -38,23 +38,22 @@ describe("PermutationBuilder", () => {
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
expect(permutationTable.name).toBe("permutation_table");
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
ratios: [1.0],
seed: 42,
});
@@ -65,7 +64,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
@@ -84,7 +83,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
counts: [3, 7],
seed: 42,
});
@@ -102,7 +101,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitHash({
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
@@ -122,10 +121,9 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitSequential({ ratios: [0.5, 0.5] });
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -140,10 +138,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitCalculated("id % 2");
const builder = permutationBuilder(table).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -159,7 +154,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
const builder = permutationBuilder(table).shuffle({
seed: 42,
});
@@ -169,7 +164,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
const builder = permutationBuilder(table).shuffle({
seed: 42,
clumpSize: 2,
});
@@ -180,9 +175,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table, "permutation_table").filter(
"value > 50",
);
const builder = permutationBuilder(table).filter("value > 50");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -190,7 +183,7 @@ describe("PermutationBuilder", () => {
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table, "permutation_table")
const builder = permutationBuilder(table)
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
@@ -209,7 +202,7 @@ describe("PermutationBuilder", () => {
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
@@ -223,7 +216,7 @@ describe("PermutationBuilder", () => {
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
// Execute once
await builder.execute();

View File

@@ -161,7 +161,6 @@ export class PermutationBuilder {
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @param destTableName - The name for the destination permutation table
* @returns A PermutationBuilder instance
* @example
* ```ts
@@ -172,17 +171,13 @@ export class PermutationBuilder {
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(
table: Table,
destTableName: string,
): PermutationBuilder {
export function permutationBuilder(table: Table): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
destTableName,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
@@ -40,7 +40,6 @@ pub struct ShuffleOptions {
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
@@ -49,11 +48,10 @@ pub struct PermutationBuilder {
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
pub fn new(builder: LancePermutationBuilder) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
@@ -191,32 +189,26 @@ impl PermutationBuilder {
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let builder = {
let mut state = self.state.lock().unwrap();
let builder = state
state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
};
let table = builder.build(&dest_table_name).await.default_error()?;
let table = builder.build().await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
Ok(PermutationBuilder::new(inner_builder))
}

View File

@@ -8,8 +8,8 @@ from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable, dest_table_name: str):
self._async = async_permutation_builder(table, dest_table_name)
def __init__(self, table: LanceTable):
self._async = async_permutation_builder(table)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
@@ -68,5 +68,5 @@ class PermutationBuilder:
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
return PermutationBuilder(table, dest_table_name)
def permutation_builder(table: LanceTable) -> PermutationBuilder:
return PermutationBuilder(table)

View File

@@ -12,11 +12,7 @@ def test_split_random_ratios(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(ratios=[0.3, 0.7])
.execute()
)
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
@@ -38,11 +34,7 @@ def test_split_random_counts(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(counts=[20, 30])
.execute()
)
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
@@ -58,9 +50,7 @@ def test_split_random_fixed(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
)
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
@@ -78,17 +68,9 @@ def test_split_random_with_seed(mem_db):
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -112,7 +94,7 @@ def test_split_hash(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
@@ -133,7 +115,7 @@ def test_split_hash(mem_db):
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl, "test_permutation2")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
@@ -150,7 +132,7 @@ def test_split_hash_with_discard(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
@@ -168,9 +150,7 @@ def test_split_sequential(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[30, 40])
.execute()
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
)
assert permutation_tbl.count_rows() == 70
@@ -194,7 +174,7 @@ def test_split_calculated(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
@@ -216,23 +196,21 @@ def test_split_error_cases(mem_db):
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error1").split_random().execute()
permutation_builder(tbl).split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error2").split_random(
permutation_builder(tbl).split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error3").split_sequential().execute()
permutation_builder(tbl).split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error4").split_sequential(
ratios=[0.5, 0.5], fixed=2
).execute()
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
def test_shuffle_no_seed(mem_db):
@@ -242,7 +220,7 @@ def test_shuffle_no_seed(mem_db):
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
permutation_tbl = permutation_builder(tbl).shuffle().execute()
assert permutation_tbl.count_rows() == 100
@@ -262,9 +240,9 @@ def test_shuffle_with_seed(mem_db):
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -282,7 +260,7 @@ def test_shuffle_with_clump_size(mem_db):
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
@@ -304,19 +282,9 @@ def test_shuffle_different_seeds(mem_db):
)
# Create two permutations with different shuffle seeds
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(fixed=2)
.shuffle(seed=42)
.execute()
)
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(fixed=2)
.shuffle(seed=123)
.execute()
)
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -341,7 +309,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl, "perm_random")
permutation_builder(tbl)
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
@@ -349,7 +317,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl, "perm_hash")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
@@ -357,7 +325,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl, "perm_sequential")
permutation_builder(tbl)
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
@@ -384,7 +352,7 @@ def test_no_shuffle_maintains_order(mem_db):
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
@@ -405,9 +373,7 @@ def test_filter_basic(mem_db):
)
# Filter to only include rows where id < 50
permutation_tbl = (
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
)
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
assert permutation_tbl.count_rows() == 50
@@ -433,7 +399,7 @@ def test_filter_with_splits(mem_db):
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
@@ -465,7 +431,7 @@ def test_filter_with_shuffle(mem_db):
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
@@ -488,7 +454,7 @@ def test_filter_empty_result(mem_db):
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("value > 100") # No values > 100 in our data
.execute()
)

View File

@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
@@ -16,10 +16,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(
table: Bound<'_, PyAny>,
dest_table_name: String,
) -> PyResult<PyAsyncPermutationBuilder> {
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
@@ -27,14 +24,12 @@ pub fn async_permutation_builder(
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
dest_table_name,
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
dest_table_name: String,
}
#[pyclass(name = "AsyncPermutationBuilder")]
@@ -167,10 +162,8 @@ impl PyAsyncPermutationBuilder {
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
future_into_py(slf.py(), async move {
let table = builder.build(&dest_table_name).await.infer_error()?;
let table = builder.build().await.infer_error()?;
Ok(Table::new(table))
})
}

View File

@@ -16,6 +16,7 @@ arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true

View File

@@ -1182,7 +1182,7 @@ mod tests {
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
use crate::test_connection::test_utils::new_test_connection;
use crate::test_utils::connection::new_test_connection;
use arrow::compute::concat_batches;
use arrow_array::RecordBatchReader;
use arrow_schema::{DataType, Field, Schema};

View File

@@ -2,6 +2,3 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod permutation;
pub mod shuffle;
pub mod split;
pub mod util;

View File

@@ -7,288 +7,12 @@
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
//! the underlying data and can be very lightweight.
//!
//! Building a permutation table should be fairly quick and memory efficient, even for billions or
//! trillions of rows.
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
//! the number of rows in the base table) and memory efficient, even for billions or trillions
//! of rows.
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datafusion::exec::SessionContextExt;
use crate::{
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
Connection, Error, Result, Table,
};
/// Configuration for creating a permutation table
#[derive(Debug, Default)]
pub struct PermutationConfig {
/// Splitting configuration
pub split_strategy: SplitStrategy,
/// Shuffle strategy
pub shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
pub filter: Option<String>,
/// Directory to use for temporary files
pub temp_dir: TemporaryDirectory,
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
/// A seed can be provided to make the shuffle deterministic.
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
///
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
/// this will only provide a local shuffle and not a global shuffle.
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
/// The data is not shuffled
///
/// This is useful for debugging and testing.
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
/// can be used to create a
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
/// Configures the strategy for assigning rows to splits.
///
/// For example, it is common to create a test/train split of the data. Splits can also be used
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
/// create a single split with 10% of the data.
///
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
self.config.split_strategy = split_strategy;
self
}
/// Configures the strategy for shuffling the data.
///
/// The default is to shuffle the data randomly at row-level granularity (no shard size) and
/// with a random seed.
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
/// Configures a filter to apply to the base table.
///
/// Only rows matching the filter will be included in the permutation.
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
/// Configures the directory to use for temporary files.
///
/// The default is to use the operating system's default temporary directory.
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self, dest_table_name: &str) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
// split id)
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
// Apply splits
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
// Shuffle data if requested
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
// We want the final permutation to be sorted by the split id. If we shuffled or if
// the split was not assigned sequentially then we need to sort the data.
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
// Rename _rowid to row_id
let renamed = rename_column(sorted, "_rowid", "row_id")?;
// Create permutation table
let conn = Connection::new(
self.base_table.database().clone(),
self.base_table.embedding_registry().clone(),
);
conn.create_table_streaming(dest_table_name, renamed)
.execute()
.await
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table)
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.build("permutation")
.await
.unwrap();
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}
pub mod builder;
pub mod reader;
pub mod shuffle;
pub mod split;
pub mod util;

View File

@@ -0,0 +1,326 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_core::ROW_ID;
use lance_datafusion::exec::SessionContextExt;
use crate::{
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
connect,
database::{CreateTableData, CreateTableRequest, Database},
dataloader::permutation::{
shuffle::{Shuffler, ShufflerConfig},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
Error, Result, Table,
};
pub const SRC_ROW_ID_COL: &str = "row_id";
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
/// The permutation table is a temporary table in memory
#[default]
Temporary,
/// The permutation table is a permanent table in a database
Permanent(Arc<dyn Database>, String),
}
/// Configuration for creating a permutation table
#[derive(Debug, Default)]
pub struct PermutationConfig {
/// Splitting configuration
split_strategy: SplitStrategy,
/// Shuffle strategy
shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
filter: Option<String>,
/// Directory to use for temporary files
temp_dir: TemporaryDirectory,
/// Destination
destination: PermutationDestination,
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
/// A seed can be provided to make the shuffle deterministic.
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
///
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
/// this will only provide a local shuffle and not a global shuffle.
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
/// The data is not shuffled
///
/// This is useful for debugging and testing.
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
/// can be used to create a permutation reader that reads rows in the order defined by the permutation.
///
/// The permutation table is not a materialized copy of the underlying data and can be very lightweight.
/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that
/// stores just row id and split id.
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
/// Configures the strategy for assigning rows to splits.
///
/// For example, it is common to create a test/train split of the data. Splits can also be used
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
/// create a single split with 10% of the data.
///
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
self.config.split_strategy = split_strategy;
self
}
/// Configures the strategy for shuffling the data.
///
/// The default is to shuffle the data randomly at row-level granularity (no clump size) and
/// with a random seed.
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
/// Configures a filter to apply to the base table.
///
/// Only rows matching the filter will be included in the permutation.
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
/// Configures the directory to use for temporary files.
///
/// The default is to use the operating system's default temporary directory.
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
/// Stores the permutation as a table in a database
///
/// By default, the permutation is stored in memory. If this method is called then
/// the permutation will be stored as a table in the given database.
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
self.config.destination = PermutationDestination::Permanent(database, table_name);
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
// split id)
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
// Apply splits
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
// Shuffle data if requested
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
// We want the final permutation to be sorted by the split id. If we shuffled or if
// the split was not assigned sequentially then we need to sort the data.
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
// Rename _rowid to row_id
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
let (name, database) = match &self.config.destination {
PermutationDestination::Permanent(database, table_name) => {
(table_name.as_str(), database.clone())
}
PermutationDestination::Temporary => {
let conn = connect("memory:///").execute().await?;
("permutation", conn.database().clone())
}
};
let create_table_request =
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
let table = database.create_table(create_table_request).await?;
Ok(Table::new(table, database))
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.build()
.await
.unwrap();
println!("permutation_table: {:?}", permutation_table);
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}

View File

@@ -0,0 +1,384 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Row ID-based views for LanceDB tables
//!
//! This module provides functionality for creating views that are based on specific row IDs.
//! The `IdView` allows you to create a virtual table that contains only
//! the rows from a source table that correspond to row IDs stored in a separate table.
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
use crate::error::Error;
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
use crate::table::{AnyQuery, BaseTable};
use crate::Result;
use arrow::array::AsArray;
use arrow::datatypes::UInt64Type;
use arrow_array::{RecordBatch, UInt64Array};
use futures::{StreamExt, TryStreamExt};
use lance::arrow::RecordBatchExt;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::error::LanceOptionExt;
use lance_core::ROW_ID;
use std::collections::HashMap;
use std::sync::Arc;
/// Reads a permutation of a source table based on row IDs stored in a separate table
pub struct PermutationReader {
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
}
impl std::fmt::Debug for PermutationReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PermutationReader(base={}, permutation={})",
self.base_table.name(),
self.permutation_table.name(),
)
}
}
impl PermutationReader {
/// Create a new PermutationReader
pub async fn try_new(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
) -> Result<Self> {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
Ok(Self {
base_table,
permutation_table,
})
}
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
for (expected, idx) in iter.enumerate() {
if *idx != expected as u64 {
return false;
}
}
true
}
async fn load_batch(
base_table: &Arc<dyn BaseTable>,
row_ids: RecordBatch,
selection: Select,
has_row_id: bool,
) -> Result<RecordBatch> {
let num_rows = row_ids.num_rows();
let row_ids = row_ids
.column(0)
.as_primitive_opt::<UInt64Type>()
.expect_ok()?
.values();
let filter = format!(
"_rowid in ({})",
row_ids
.iter()
.map(|o| o.to_string())
.collect::<Vec<_>>()
.join(",")
);
let base_query = QueryRequest {
filter: Some(QueryFilter::Sql(filter)),
select: selection,
with_row_id: true,
..Default::default()
};
let mut data = base_table
.query(
&AnyQuery::Query(base_query),
QueryExecutionOptions {
max_batch_length: num_rows as u32,
..Default::default()
},
)
.await?;
let Some(batch) = data.try_next().await? else {
return Err(Error::InvalidInput {
message: "Base table returned no batches".to_string(),
});
};
if data.try_next().await?.is_some() {
return Err(Error::InvalidInput {
message: "Base table returned more than one batch".to_string(),
});
}
if batch.num_rows() != num_rows {
return Err(Error::InvalidInput {
message: "Base table returned different number of rows than the number of row IDs"
.to_string(),
});
}
// There is no guarantee the result order will match the order provided
// so may need to restore order
let actual_row_ids = batch
.column_by_name(ROW_ID)
.expect_ok()?
.as_primitive_opt::<UInt64Type>()
.expect_ok()?
.values();
// Map from row id to order in batch, used to restore original ordering
let ordering = actual_row_ids
.iter()
.copied()
.enumerate()
.map(|(i, o)| (o, i as u64))
.collect::<HashMap<_, _>>();
let desired_idx_order = row_ids
.iter()
.map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from))
.collect::<Result<Vec<_>>>()?;
let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) {
// Fast path if already sorted, important as data may be large and
// re-ordering could be expensive
batch
} else {
let desired_idx_order = UInt64Array::from(desired_idx_order);
arrow_select::take::take_record_batch(&batch, &desired_idx_order)?
};
if has_row_id {
Ok(ordered_batch)
} else {
// The user didn't ask for row id, we needed it for ordering the data, but now we drop it
Ok(ordered_batch.drop_column(ROW_ID)?)
}
}
async fn row_ids_to_batches(
base_table: Arc<dyn BaseTable>,
row_ids: DatasetRecordBatchStream,
selection: Select,
) -> Result<SendableRecordBatchStream> {
let has_row_id = Self::has_row_id(&selection)?;
let mut stream = row_ids
.map_err(Error::from)
.try_filter_map(move |batch| {
let selection = selection.clone();
let base_table = base_table.clone();
async move {
Self::load_batch(&base_table, batch, selection, has_row_id)
.await
.map(Some)
}
})
.boxed();
// Need to read out first batch to get schema
let Some(first_batch) = stream.try_next().await? else {
return Err(Error::InvalidInput {
message: "Permutation was empty".to_string(),
});
};
let schema = first_batch.schema();
let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream);
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
}
fn has_row_id(selection: &Select) -> Result<bool> {
match selection {
Select::All => {
// _rowid is a system column and is not included in Select::All
Ok(false)
}
Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())),
Select::Dynamic(columns) => {
for column in columns {
if column.0 == ROW_ID {
if column.1 == ROW_ID {
return Ok(true);
} else {
return Err(Error::InvalidInput {
message: format!(
"Dynamic column {} cannot be used to select _rowid",
column.1
),
});
}
}
}
Ok(false)
}
}
}
pub async fn read_split(
&self,
split: u64,
selection: Select,
execution_options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let row_ids = self
.permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
..Default::default()
}),
execution_options,
)
.await?;
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use lance_datagen::{BatchCount, RowCount};
use rand::seq::SliceRandom;
use crate::{
arrow::SendableRecordBatchStream,
query::{ExecutableQuery, QueryBase},
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
Table,
};
use super::*;
async fn collect_from_stream<T: ArrowPrimitiveType>(
mut stream: SendableRecordBatchStream,
column: &str,
) -> Vec<T::Native> {
let mut row_ids = Vec::new();
while let Some(batch) = stream.try_next().await.unwrap() {
let col_idx = batch.schema().index_of(column).unwrap();
row_ids.extend(batch.column(col_idx).as_primitive::<T>().values().to_vec());
}
row_ids
}
async fn collect_column<T: ArrowPrimitiveType>(table: &Table, column: &str) -> Vec<T::Native> {
collect_from_stream::<T>(
table
.query()
.select(Select::Columns(vec![column.to_string()]))
.execute()
.await
.unwrap(),
column,
)
.await
}
#[tokio::test]
async fn test_permutation_reader() {
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.col("other_col", lance_datagen::array::step::<UInt64Type>())
.into_mem_table("tbl", RowCount::from(9), BatchCount::from(1))
.await;
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
row_ids.shuffle(&mut rand::rng());
// Put the last two rows in split 1
let split_ids = UInt64Array::from_iter_values(
std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)),
);
let permutation_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("row_id", DataType::UInt64, false),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
])),
vec![
Arc::new(UInt64Array::from(row_ids.clone())),
Arc::new(split_ids),
],
)
.unwrap();
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
let reader = PermutationReader::try_new(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
)
.await
.unwrap();
// Read split 0
let mut stream = reader
.read_split(
0,
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(stream.schema(), base_table.schema().await.unwrap());
let check_batch = async |stream: &mut SendableRecordBatchStream,
expected_values: &[u64]| {
let batch = stream.try_next().await.unwrap().unwrap();
assert_eq!(batch.num_rows(), expected_values.len());
assert_eq!(
batch.column(0).as_primitive::<Int32Type>().values(),
&expected_values
.iter()
.map(|o| *o as i32)
.collect::<Vec<_>>()
);
assert_eq!(
batch.column(1).as_primitive::<UInt64Type>().values(),
&expected_values
);
};
check_batch(&mut stream, &row_ids[0..3]).await;
check_batch(&mut stream, &row_ids[3..6]).await;
check_batch(&mut stream, &row_ids[6..7]).await;
assert!(stream.try_next().await.unwrap().is_none());
// Read split 1
let mut stream = reader
.read_split(
1,
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
..Default::default()
},
)
.await
.unwrap();
check_batch(&mut stream, &row_ids[7..9]).await;
assert!(stream.try_next().await.unwrap().is_none());
}
}

View File

@@ -22,7 +22,7 @@ use rand::{seq::SliceRandom, Rng, RngCore};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::util::{non_crypto_rng, TemporaryDirectory},
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
Error, Result,
};

View File

@@ -18,8 +18,8 @@ use lance::arrow::SchemaExt;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
util::TemporaryDirectory,
permutation::shuffle::{Shuffler, ShufflerConfig},
permutation::util::TemporaryDirectory,
},
query::{Query, QueryBase, Select},
Error, Result,

View File

@@ -207,7 +207,8 @@ pub mod query;
pub mod remote;
pub mod rerankers;
pub mod table;
pub mod test_connection;
#[cfg(test)]
pub mod test_utils;
pub mod utils;
use std::fmt::Display;

View File

@@ -511,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// Get the namespace of the table.
fn namespace(&self) -> &[String];
/// Get the id of the table
///
/// This is the namespace of the table concatenated with the name
/// separated by a dot (".")
fn id(&self) -> &str;
/// Get the arrow [Schema] of the table.
async fn schema(&self) -> Result<SchemaRef>;

View File

@@ -1,126 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Functions for testing connections.
#[cfg(test)]
pub mod test_utils {
use regex::Regex;
use std::env;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use crate::{connect, Connection};
use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
struct TestProcess {
child: Child,
}
impl Drop for TestProcess {
#[allow(unused_must_use)]
fn drop(&mut self) {
self.child.kill();
}
}
pub async fn new_test_connection() -> Result<TestConnection> {
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
Ok(script_path) => new_remote_connection(&script_path).await,
Err(_e) => new_local_connection().await,
}
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
let child_result = Command::new(script_path)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.arg(data_path.clone())
.spawn();
if child_result.is_err() {
bail!(format!(
"Unable to run {}: {:?}",
script_path,
child_result.err()
));
}
let mut process = TestProcess {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let port = read_process_port(stdout)?;
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(
uri: &str,
host_override: &str,
) -> crate::error::Result<Connection> {
connect(uri)
.region("us-east-1")
.api_key("sk_localtest")
.host_override(host_override)
.execute()
.await
}
#[cfg(not(feature = "remote"))]
async fn create_new_connection(
_uri: &str,
_host_override: &str,
) -> crate::error::Result<Connection> {
panic!("remote feature not supported");
}
async fn new_local_connection() -> Result<TestConnection> {
let temp_dir = tempdir()?;
let uri = temp_dir.path().to_str().unwrap();
let connection = connect(uri).execute().await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: None,
})
}
}

View File

@@ -0,0 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod connection;
pub mod datagen;

View File

@@ -0,0 +1,120 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Functions for testing connections.
use regex::Regex;
use std::env;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use crate::{connect, Connection};
use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
struct TestProcess {
child: Child,
}
impl Drop for TestProcess {
#[allow(unused_must_use)]
fn drop(&mut self) {
self.child.kill();
}
}
pub async fn new_test_connection() -> Result<TestConnection> {
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
Ok(script_path) => new_remote_connection(&script_path).await,
Err(_e) => new_local_connection().await,
}
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
let child_result = Command::new(script_path)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.arg(data_path.clone())
.spawn();
if child_result.is_err() {
bail!(format!(
"Unable to run {}: {:?}",
script_path,
child_result.err()
));
}
let mut process = TestProcess {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let port = read_process_port(stdout)?;
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
connect(uri)
.region("us-east-1")
.api_key("sk_localtest")
.host_override(host_override)
.execute()
.await
}
#[cfg(not(feature = "remote"))]
async fn create_new_connection(
_uri: &str,
_host_override: &str,
) -> crate::error::Result<Connection> {
panic!("remote feature not supported");
}
async fn new_local_connection() -> Result<TestConnection> {
let temp_dir = tempdir()?;
let uri = temp_dir.path().to_str().unwrap();
let connection = connect(uri).execute().await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: None,
})
}

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_array::RecordBatch;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
connect, Error, Table,
};
#[async_trait::async_trait]
pub trait LanceDbDatagenExt {
async fn into_mem_table(
self,
table_name: &str,
rows_per_batch: RowCount,
num_batches: BatchCount,
) -> Table;
}
#[async_trait::async_trait]
impl LanceDbDatagenExt for BatchGeneratorBuilder {
async fn into_mem_table(
self,
table_name: &str,
rows_per_batch: RowCount,
num_batches: BatchCount,
) -> Table {
let (stream, schema) = self.into_reader_stream(rows_per_batch, num_batches);
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
stream.map_err(Error::from),
schema,
));
let db = connect("memory:///").execute().await.unwrap();
db.create_table_streaming(table_name, stream)
.execute()
.await
.unwrap()
}
}
pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table {
let schema = values.schema();
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
futures::stream::once(std::future::ready(Ok(values.clone()))),
schema,
));
let db = connect("memory:///").execute().await.unwrap();
db.create_table_streaming(name, stream)
.execute()
.await
.unwrap()
}