diff --git a/Cargo.lock b/Cargo.lock
index b3fc570b..23182c70 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4702,6 +4702,7 @@ dependencies = [
name = "lancedb"
version = "0.22.2"
dependencies = [
+ "ahash",
"anyhow",
"arrow",
"arrow-array",
@@ -4737,8 +4738,11 @@ dependencies = [
"http 1.3.1",
"http-body 1.0.1",
"lance",
+ "lance-core",
"lance-datafusion",
+ "lance-datagen",
"lance-encoding",
+ "lance-file",
"lance-index",
"lance-io",
"lance-linalg",
@@ -4764,6 +4768,7 @@ dependencies = [
"serde_with",
"snafu",
"tempfile",
+ "test-log",
"tokenizers",
"tokio",
"url",
@@ -8237,6 +8242,28 @@ dependencies = [
"windows-sys 0.61.2",
]
+[[package]]
+name = "test-log"
+version = "0.2.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
+dependencies = [
+ "env_logger",
+ "test-log-macros",
+ "tracing-subscriber",
+]
+
+[[package]]
+name = "test-log-macros"
+version = "0.2.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.106",
+]
+
[[package]]
name = "thiserror"
version = "1.0.69"
diff --git a/Cargo.toml b/Cargo.toml
index 825df5dc..2b1afe8e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,9 @@ rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"] }
+lance-core = "=0.38.2"
+lance-datagen = "=0.38.2"
+lance-file = "=0.38.2"
lance-io = { "version" = "=0.38.2", default-features = false }
lance-index = "=0.38.2"
lance-linalg = "=0.38.2"
@@ -24,6 +27,7 @@ lance-testing = "=0.38.2"
lance-datafusion = "=0.38.2"
lance-encoding = "=0.38.2"
lance-namespace = "0.0.18"
+ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
@@ -48,6 +52,7 @@ log = "0.4"
moka = { version = "0.12", features = ["future"] }
object_store = "0.12.0"
pin-project = "1.0.7"
+rand = "0.9"
snafu = "0.8"
url = "2"
num-traits = "0.2"
diff --git a/docs/src/js/classes/PermutationBuilder.md b/docs/src/js/classes/PermutationBuilder.md
new file mode 100644
index 00000000..aa2437c9
--- /dev/null
+++ b/docs/src/js/classes/PermutationBuilder.md
@@ -0,0 +1,220 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / PermutationBuilder
+
+# Class: PermutationBuilder
+
+A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
+
+This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
+offering methods to configure data splits, shuffling, and filtering before executing
+the permutation to create a new table.
+
+## Methods
+
+### execute()
+
+```ts
+execute(): Promise
+```
+
+Execute the permutation and create the destination table.
+
+#### Returns
+
+`Promise`<[`Table`](Table.md)>
+
+A Promise that resolves to the new Table instance
+
+#### Example
+
+```ts
+const permutationTable = await builder.execute();
+console.log(`Created table: ${permutationTable.name}`);
+```
+
+***
+
+### filter()
+
+```ts
+filter(filter): PermutationBuilder
+```
+
+Configure filtering for the permutation.
+
+#### Parameters
+
+* **filter**: `string`
+ SQL filter expression
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+builder.filter("age > 18 AND status = 'active'");
+```
+
+***
+
+### shuffle()
+
+```ts
+shuffle(options): PermutationBuilder
+```
+
+Configure shuffling for the permutation.
+
+#### Parameters
+
+* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
+ Configuration for shuffling
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+// Basic shuffle
+builder.shuffle({ seed: 42 });
+
+// Shuffle with clump size
+builder.shuffle({ seed: 42, clumpSize: 10 });
+```
+
+***
+
+### splitCalculated()
+
+```ts
+splitCalculated(calculation): PermutationBuilder
+```
+
+Configure calculated splits for the permutation.
+
+#### Parameters
+
+* **calculation**: `string`
+ SQL expression for calculating splits
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+builder.splitCalculated("user_id % 3");
+```
+
+***
+
+### splitHash()
+
+```ts
+splitHash(options): PermutationBuilder
+```
+
+Configure hash-based splits for the permutation.
+
+#### Parameters
+
+* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
+ Configuration for hash-based splitting
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+builder.splitHash({
+ columns: ["user_id"],
+ splitWeights: [70, 30],
+ discardWeight: 0
+});
+```
+
+***
+
+### splitRandom()
+
+```ts
+splitRandom(options): PermutationBuilder
+```
+
+Configure random splits for the permutation.
+
+#### Parameters
+
+* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
+ Configuration for random splitting
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+// Split by ratios
+builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
+
+// Split by counts
+builder.splitRandom({ counts: [1000, 500], seed: 42 });
+
+// Split with fixed size
+builder.splitRandom({ fixed: 100, seed: 42 });
+```
+
+***
+
+### splitSequential()
+
+```ts
+splitSequential(options): PermutationBuilder
+```
+
+Configure sequential splits for the permutation.
+
+#### Parameters
+
+* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
+ Configuration for sequential splitting
+
+#### Returns
+
+[`PermutationBuilder`](PermutationBuilder.md)
+
+A new PermutationBuilder instance
+
+#### Example
+
+```ts
+// Split by ratios
+builder.splitSequential({ ratios: [0.8, 0.2] });
+
+// Split by counts
+builder.splitSequential({ counts: [800, 200] });
+
+// Split with fixed size
+builder.splitSequential({ fixed: 1000 });
+```
diff --git a/docs/src/js/functions/permutationBuilder.md b/docs/src/js/functions/permutationBuilder.md
new file mode 100644
index 00000000..63226d66
--- /dev/null
+++ b/docs/src/js/functions/permutationBuilder.md
@@ -0,0 +1,37 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / permutationBuilder
+
+# Function: permutationBuilder()
+
+```ts
+function permutationBuilder(table, destTableName): PermutationBuilder
+```
+
+Create a permutation builder for the given table.
+
+## Parameters
+
+* **table**: [`Table`](../classes/Table.md)
+ The source table to create a permutation from
+
+* **destTableName**: `string`
+ The name for the destination permutation table
+
+## Returns
+
+[`PermutationBuilder`](../classes/PermutationBuilder.md)
+
+A PermutationBuilder instance
+
+## Example
+
+```ts
+const builder = permutationBuilder(sourceTable, "training_data")
+ .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
+ .shuffle({ seed: 123 });
+
+const trainingTable = await builder.execute();
+```
diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md
index 757e47e9..462e6a99 100644
--- a/docs/src/js/globals.md
+++ b/docs/src/js/globals.md
@@ -28,6 +28,7 @@
- [MultiMatchQuery](classes/MultiMatchQuery.md)
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
+- [PermutationBuilder](classes/PermutationBuilder.md)
- [PhraseQuery](classes/PhraseQuery.md)
- [Query](classes/Query.md)
- [QueryBase](classes/QueryBase.md)
@@ -76,6 +77,10 @@
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
+- [ShuffleOptions](interfaces/ShuffleOptions.md)
+- [SplitHashOptions](interfaces/SplitHashOptions.md)
+- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
+- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [TableStatistics](interfaces/TableStatistics.md)
- [TimeoutConfig](interfaces/TimeoutConfig.md)
@@ -103,3 +108,4 @@
- [connect](functions/connect.md)
- [makeArrowTable](functions/makeArrowTable.md)
- [packBits](functions/packBits.md)
+- [permutationBuilder](functions/permutationBuilder.md)
diff --git a/docs/src/js/interfaces/ShuffleOptions.md b/docs/src/js/interfaces/ShuffleOptions.md
new file mode 100644
index 00000000..02d298c7
--- /dev/null
+++ b/docs/src/js/interfaces/ShuffleOptions.md
@@ -0,0 +1,23 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / ShuffleOptions
+
+# Interface: ShuffleOptions
+
+## Properties
+
+### clumpSize?
+
+```ts
+optional clumpSize: number;
+```
+
+***
+
+### seed?
+
+```ts
+optional seed: number;
+```
diff --git a/docs/src/js/interfaces/SplitHashOptions.md b/docs/src/js/interfaces/SplitHashOptions.md
new file mode 100644
index 00000000..53cbae8e
--- /dev/null
+++ b/docs/src/js/interfaces/SplitHashOptions.md
@@ -0,0 +1,31 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / SplitHashOptions
+
+# Interface: SplitHashOptions
+
+## Properties
+
+### columns
+
+```ts
+columns: string[];
+```
+
+***
+
+### discardWeight?
+
+```ts
+optional discardWeight: number;
+```
+
+***
+
+### splitWeights
+
+```ts
+splitWeights: number[];
+```
diff --git a/docs/src/js/interfaces/SplitRandomOptions.md b/docs/src/js/interfaces/SplitRandomOptions.md
new file mode 100644
index 00000000..66430b6c
--- /dev/null
+++ b/docs/src/js/interfaces/SplitRandomOptions.md
@@ -0,0 +1,39 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / SplitRandomOptions
+
+# Interface: SplitRandomOptions
+
+## Properties
+
+### counts?
+
+```ts
+optional counts: number[];
+```
+
+***
+
+### fixed?
+
+```ts
+optional fixed: number;
+```
+
+***
+
+### ratios?
+
+```ts
+optional ratios: number[];
+```
+
+***
+
+### seed?
+
+```ts
+optional seed: number;
+```
diff --git a/docs/src/js/interfaces/SplitSequentialOptions.md b/docs/src/js/interfaces/SplitSequentialOptions.md
new file mode 100644
index 00000000..6397c191
--- /dev/null
+++ b/docs/src/js/interfaces/SplitSequentialOptions.md
@@ -0,0 +1,31 @@
+[**@lancedb/lancedb**](../README.md) • **Docs**
+
+***
+
+[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
+
+# Interface: SplitSequentialOptions
+
+## Properties
+
+### counts?
+
+```ts
+optional counts: number[];
+```
+
+***
+
+### fixed?
+
+```ts
+optional fixed: number;
+```
+
+***
+
+### ratios?
+
+```ts
+optional ratios: number[];
+```
diff --git a/nodejs/__test__/permutation.test.ts b/nodejs/__test__/permutation.test.ts
new file mode 100644
index 00000000..be57e5d9
--- /dev/null
+++ b/nodejs/__test__/permutation.test.ts
@@ -0,0 +1,234 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+import * as tmp from "tmp";
+import { Table, connect, permutationBuilder } from "../lancedb";
+import { makeArrowTable } from "../lancedb/arrow";
+
+describe("PermutationBuilder", () => {
+ let tmpDir: tmp.DirResult;
+ let table: Table;
+
+ beforeEach(async () => {
+ tmpDir = tmp.dirSync({ unsafeCleanup: true });
+ const db = await connect(tmpDir.name);
+
+ // Create test data
+ const data = makeArrowTable(
+ [
+ { id: 1, value: 10 },
+ { id: 2, value: 20 },
+ { id: 3, value: 30 },
+ { id: 4, value: 40 },
+ { id: 5, value: 50 },
+ { id: 6, value: 60 },
+ { id: 7, value: 70 },
+ { id: 8, value: 80 },
+ { id: 9, value: 90 },
+ { id: 10, value: 100 },
+ ],
+ { vectorColumns: {} },
+ );
+
+ table = await db.createTable("test_table", data);
+ });
+
+ afterEach(() => {
+ tmpDir.removeCallback();
+ });
+
+ test("should create permutation builder", () => {
+ const builder = permutationBuilder(table, "permutation_table");
+ expect(builder).toBeDefined();
+ });
+
+ test("should execute basic permutation", async () => {
+ const builder = permutationBuilder(table, "permutation_table");
+ const permutationTable = await builder.execute();
+
+ expect(permutationTable).toBeDefined();
+ expect(permutationTable.name).toBe("permutation_table");
+
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+ });
+
+ test("should create permutation with random splits", async () => {
+ const builder = permutationBuilder(table, "permutation_table").splitRandom({
+ ratios: [1.0],
+ seed: 42,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+ });
+
+ test("should create permutation with percentage splits", async () => {
+ const builder = permutationBuilder(table, "permutation_table").splitRandom({
+ ratios: [0.3, 0.7],
+ seed: 42,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+
+ // Check split distribution
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBeGreaterThan(0);
+ expect(split1Count).toBeGreaterThan(0);
+ expect(split0Count + split1Count).toBe(10);
+ });
+
+ test("should create permutation with count splits", async () => {
+ const builder = permutationBuilder(table, "permutation_table").splitRandom({
+ counts: [3, 7],
+ seed: 42,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+
+ // Check split distribution
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBe(3);
+ expect(split1Count).toBe(7);
+ });
+
+ test("should create permutation with hash splits", async () => {
+ const builder = permutationBuilder(table, "permutation_table").splitHash({
+ columns: ["id"],
+ splitWeights: [50, 50],
+ discardWeight: 0,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+
+ // Check that splits exist
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBeGreaterThan(0);
+ expect(split1Count).toBeGreaterThan(0);
+ expect(split0Count + split1Count).toBe(10);
+ });
+
+ test("should create permutation with sequential splits", async () => {
+ const builder = permutationBuilder(
+ table,
+ "permutation_table",
+ ).splitSequential({ ratios: [0.5, 0.5] });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+
+ // Check split distribution - sequential should give exactly 5 and 5
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBe(5);
+ expect(split1Count).toBe(5);
+ });
+
+ test("should create permutation with calculated splits", async () => {
+ const builder = permutationBuilder(
+ table,
+ "permutation_table",
+ ).splitCalculated("id % 2");
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+
+ // Check split distribution
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBeGreaterThan(0);
+ expect(split1Count).toBeGreaterThan(0);
+ expect(split0Count + split1Count).toBe(10);
+ });
+
+ test("should create permutation with shuffle", async () => {
+ const builder = permutationBuilder(table, "permutation_table").shuffle({
+ seed: 42,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+ });
+
+ test("should create permutation with shuffle and clump size", async () => {
+ const builder = permutationBuilder(table, "permutation_table").shuffle({
+ seed: 42,
+ clumpSize: 2,
+ });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(10);
+ });
+
+ test("should create permutation with filter", async () => {
+ const builder = permutationBuilder(table, "permutation_table").filter(
+ "value > 50",
+ );
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
+ });
+
+ test("should chain multiple operations", async () => {
+ const builder = permutationBuilder(table, "permutation_table")
+ .filter("value <= 80")
+ .splitRandom({ ratios: [0.5, 0.5], seed: 42 })
+ .shuffle({ seed: 123 });
+
+ const permutationTable = await builder.execute();
+ const rowCount = await permutationTable.countRows();
+ expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
+
+ // Check split distribution
+ const split0Count = await permutationTable.countRows("split_id = 0");
+ const split1Count = await permutationTable.countRows("split_id = 1");
+
+ expect(split0Count).toBeGreaterThan(0);
+ expect(split1Count).toBeGreaterThan(0);
+ expect(split0Count + split1Count).toBe(8);
+ });
+
+ test("should throw error for invalid split arguments", () => {
+ const builder = permutationBuilder(table, "permutation_table");
+
+ // Test no arguments provided
+ expect(() => builder.splitRandom({})).toThrow(
+ "Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
+ );
+
+ // Test multiple arguments provided
+ expect(() =>
+ builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
+ ).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
+ });
+
+ test("should throw error when builder is consumed", async () => {
+ const builder = permutationBuilder(table, "permutation_table");
+
+ // Execute once
+ await builder.execute();
+
+ // Should throw error on second execution
+ await expect(builder.execute()).rejects.toThrow("Builder already consumed");
+ });
+});
diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts
index 57069221..3ef1a76f 100644
--- a/nodejs/lancedb/index.ts
+++ b/nodejs/lancedb/index.ts
@@ -43,6 +43,10 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
+ SplitRandomOptions,
+ SplitHashOptions,
+ SplitSequentialOptions,
+ ShuffleOptions,
} from "./native.js";
export {
@@ -111,6 +115,7 @@ export {
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";
+export { permutationBuilder, PermutationBuilder } from "./permutation";
export * as rerankers from "./rerankers";
export {
SchemaLike,
diff --git a/nodejs/lancedb/permutation.ts b/nodejs/lancedb/permutation.ts
new file mode 100644
index 00000000..98406505
--- /dev/null
+++ b/nodejs/lancedb/permutation.ts
@@ -0,0 +1,188 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+import {
+ PermutationBuilder as NativePermutationBuilder,
+ Table as NativeTable,
+ ShuffleOptions,
+ SplitHashOptions,
+ SplitRandomOptions,
+ SplitSequentialOptions,
+ permutationBuilder as nativePermutationBuilder,
+} from "./native.js";
+import { LocalTable, Table } from "./table";
+
+/**
+ * A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
+ *
+ * This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
+ * offering methods to configure data splits, shuffling, and filtering before executing
+ * the permutation to create a new table.
+ */
+export class PermutationBuilder {
+ private inner: NativePermutationBuilder;
+
+ /**
+ * @hidden
+ */
+ constructor(inner: NativePermutationBuilder) {
+ this.inner = inner;
+ }
+
+ /**
+ * Configure random splits for the permutation.
+ *
+ * @param options - Configuration for random splitting
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * // Split by ratios
+ * builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
+ *
+ * // Split by counts
+ * builder.splitRandom({ counts: [1000, 500], seed: 42 });
+ *
+ * // Split with fixed size
+ * builder.splitRandom({ fixed: 100, seed: 42 });
+ * ```
+ */
+ splitRandom(options: SplitRandomOptions): PermutationBuilder {
+ const newInner = this.inner.splitRandom(options);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Configure hash-based splits for the permutation.
+ *
+ * @param options - Configuration for hash-based splitting
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * builder.splitHash({
+ * columns: ["user_id"],
+ * splitWeights: [70, 30],
+ * discardWeight: 0
+ * });
+ * ```
+ */
+ splitHash(options: SplitHashOptions): PermutationBuilder {
+ const newInner = this.inner.splitHash(options);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Configure sequential splits for the permutation.
+ *
+ * @param options - Configuration for sequential splitting
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * // Split by ratios
+ * builder.splitSequential({ ratios: [0.8, 0.2] });
+ *
+ * // Split by counts
+ * builder.splitSequential({ counts: [800, 200] });
+ *
+ * // Split with fixed size
+ * builder.splitSequential({ fixed: 1000 });
+ * ```
+ */
+ splitSequential(options: SplitSequentialOptions): PermutationBuilder {
+ const newInner = this.inner.splitSequential(options);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Configure calculated splits for the permutation.
+ *
+ * @param calculation - SQL expression for calculating splits
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * builder.splitCalculated("user_id % 3");
+ * ```
+ */
+ splitCalculated(calculation: string): PermutationBuilder {
+ const newInner = this.inner.splitCalculated(calculation);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Configure shuffling for the permutation.
+ *
+ * @param options - Configuration for shuffling
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * // Basic shuffle
+ * builder.shuffle({ seed: 42 });
+ *
+ * // Shuffle with clump size
+ * builder.shuffle({ seed: 42, clumpSize: 10 });
+ * ```
+ */
+ shuffle(options: ShuffleOptions): PermutationBuilder {
+ const newInner = this.inner.shuffle(options);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Configure filtering for the permutation.
+ *
+ * @param filter - SQL filter expression
+ * @returns A new PermutationBuilder instance
+ * @example
+ * ```ts
+ * builder.filter("age > 18 AND status = 'active'");
+ * ```
+ */
+ filter(filter: string): PermutationBuilder {
+ const newInner = this.inner.filter(filter);
+ return new PermutationBuilder(newInner);
+ }
+
+ /**
+ * Execute the permutation and create the destination table.
+ *
+ * @returns A Promise that resolves to the new Table instance
+ * @example
+ * ```ts
+ * const permutationTable = await builder.execute();
+ * console.log(`Created table: ${permutationTable.name}`);
+ * ```
+ */
+ async execute(): Promise {
+ const nativeTable: NativeTable = await this.inner.execute();
+ return new LocalTable(nativeTable);
+ }
+}
+
+/**
+ * Create a permutation builder for the given table.
+ *
+ * @param table - The source table to create a permutation from
+ * @param destTableName - The name for the destination permutation table
+ * @returns A PermutationBuilder instance
+ * @example
+ * ```ts
+ * const builder = permutationBuilder(sourceTable, "training_data")
+ * .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
+ * .shuffle({ seed: 123 });
+ *
+ * const trainingTable = await builder.execute();
+ * ```
+ */
+export function permutationBuilder(
+ table: Table,
+ destTableName: string,
+): PermutationBuilder {
+ // Extract the inner native table from the TypeScript wrapper
+ const localTable = table as LocalTable;
+ // Access inner through type assertion since it's private
+ const nativeBuilder = nativePermutationBuilder(
+ // biome-ignore lint/suspicious/noExplicitAny: need access to private variable
+ (localTable as any).inner,
+ destTableName,
+ );
+ return new PermutationBuilder(nativeBuilder);
+}
diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs
index e11e9278..df1898e2 100644
--- a/nodejs/src/lib.rs
+++ b/nodejs/src/lib.rs
@@ -12,6 +12,7 @@ mod header;
mod index;
mod iterator;
pub mod merge;
+pub mod permutation;
mod query;
pub mod remote;
mod rerankers;
diff --git a/nodejs/src/permutation.rs b/nodejs/src/permutation.rs
new file mode 100644
index 00000000..706a1de7
--- /dev/null
+++ b/nodejs/src/permutation.rs
@@ -0,0 +1,222 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+use std::sync::{Arc, Mutex};
+
+use crate::{error::NapiErrorExt, table::Table};
+use lancedb::dataloader::{
+ permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
+ split::{SplitSizes, SplitStrategy},
+};
+use napi_derive::napi;
+
+#[napi(object)]
+pub struct SplitRandomOptions {
+ pub ratios: Option>,
+ pub counts: Option>,
+ pub fixed: Option,
+ pub seed: Option,
+}
+
+#[napi(object)]
+pub struct SplitHashOptions {
+ pub columns: Vec,
+ pub split_weights: Vec,
+ pub discard_weight: Option,
+}
+
+#[napi(object)]
+pub struct SplitSequentialOptions {
+ pub ratios: Option>,
+ pub counts: Option>,
+ pub fixed: Option,
+}
+
+#[napi(object)]
+pub struct ShuffleOptions {
+ pub seed: Option,
+ pub clump_size: Option,
+}
+
+pub struct PermutationBuilderState {
+ pub builder: Option,
+ pub dest_table_name: String,
+}
+
+#[napi]
+pub struct PermutationBuilder {
+ state: Arc>,
+}
+
+impl PermutationBuilder {
+ pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
+ Self {
+ state: Arc::new(Mutex::new(PermutationBuilderState {
+ builder: Some(builder),
+ dest_table_name,
+ })),
+ }
+ }
+}
+
+impl PermutationBuilder {
+ fn modify(
+ &self,
+ func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
+ ) -> napi::Result {
+ let mut state = self.state.lock().unwrap();
+ let builder = state
+ .builder
+ .take()
+ .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
+ state.builder = Some(func(builder));
+ Ok(Self {
+ state: self.state.clone(),
+ })
+ }
+}
+
+#[napi]
+impl PermutationBuilder {
+ /// Configure random splits
+ #[napi]
+ pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result {
+ // Check that exactly one split type is provided
+ let split_args_count = [
+ options.ratios.is_some(),
+ options.counts.is_some(),
+ options.fixed.is_some(),
+ ]
+ .iter()
+ .filter(|&&x| x)
+ .count();
+
+ if split_args_count != 1 {
+ return Err(napi::Error::from_reason(
+ "Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
+ ));
+ }
+
+ let sizes = if let Some(ratios) = options.ratios {
+ SplitSizes::Percentages(ratios)
+ } else if let Some(counts) = options.counts {
+ SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
+ } else if let Some(fixed) = options.fixed {
+ SplitSizes::Fixed(fixed as u64)
+ } else {
+ unreachable!("One of the split arguments must be provided");
+ };
+
+ let seed = options.seed.map(|s| s as u64);
+
+ self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
+ }
+
+ /// Configure hash-based splits
+ #[napi]
+ pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result {
+ let split_weights = options
+ .split_weights
+ .into_iter()
+ .map(|w| w as u64)
+ .collect();
+ let discard_weight = options.discard_weight.unwrap_or(0) as u64;
+
+ self.modify(|builder| {
+ builder.with_split_strategy(SplitStrategy::Hash {
+ columns: options.columns,
+ split_weights,
+ discard_weight,
+ })
+ })
+ }
+
+ /// Configure sequential splits
+ #[napi]
+ pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result {
+ // Check that exactly one split type is provided
+ let split_args_count = [
+ options.ratios.is_some(),
+ options.counts.is_some(),
+ options.fixed.is_some(),
+ ]
+ .iter()
+ .filter(|&&x| x)
+ .count();
+
+ if split_args_count != 1 {
+ return Err(napi::Error::from_reason(
+ "Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
+ ));
+ }
+
+ let sizes = if let Some(ratios) = options.ratios {
+ SplitSizes::Percentages(ratios)
+ } else if let Some(counts) = options.counts {
+ SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
+ } else if let Some(fixed) = options.fixed {
+ SplitSizes::Fixed(fixed as u64)
+ } else {
+ unreachable!("One of the split arguments must be provided");
+ };
+
+ self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
+ }
+
+ /// Configure calculated splits
+ #[napi]
+ pub fn split_calculated(&self, calculation: String) -> napi::Result {
+ self.modify(|builder| {
+ builder.with_split_strategy(SplitStrategy::Calculated { calculation })
+ })
+ }
+
+ /// Configure shuffling
+ #[napi]
+ pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result {
+ let seed = options.seed.map(|s| s as u64);
+ let clump_size = options.clump_size.map(|c| c as u64);
+
+ self.modify(|builder| {
+ builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
+ })
+ }
+
+ /// Configure filtering
+ #[napi]
+ pub fn filter(&self, filter: String) -> napi::Result {
+ self.modify(|builder| builder.with_filter(filter))
+ }
+
+ /// Execute the permutation builder and create the table
+ #[napi]
+ pub async fn execute(&self) -> napi::Result {
+ let (builder, dest_table_name) = {
+ let mut state = self.state.lock().unwrap();
+ let builder = state
+ .builder
+ .take()
+ .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
+
+ let dest_table_name = std::mem::take(&mut state.dest_table_name);
+ (builder, dest_table_name)
+ };
+
+ let table = builder.build(&dest_table_name).await.default_error()?;
+ Ok(Table::new(table))
+ }
+}
+
+/// Create a permutation builder for the given table
+#[napi]
+pub fn permutation_builder(
+ table: &crate::table::Table,
+ dest_table_name: String,
+) -> napi::Result {
+ use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
+
+ let inner_table = table.inner_ref()?.clone();
+ let inner_builder = LancePermutationBuilder::new(inner_table);
+
+ Ok(PermutationBuilder::new(inner_builder, dest_table_name))
+}
diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs
index 1272b95c..b1f037fe 100644
--- a/nodejs/src/table.rs
+++ b/nodejs/src/table.rs
@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
- fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
+ pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi
index 68900a92..378c4a09 100644
--- a/python/python/lancedb/_lancedb.pyi
+++ b/python/python/lancedb/_lancedb.pyi
@@ -296,3 +296,34 @@ class AlterColumnsResult:
class DropColumnsResult:
version: int
+
+class AsyncPermutationBuilder:
+ def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
+ def split_random(
+ self,
+ *,
+ ratios: Optional[List[float]] = None,
+ counts: Optional[List[int]] = None,
+ fixed: Optional[int] = None,
+ seed: Optional[int] = None,
+ ) -> "AsyncPermutationBuilder": ...
+ def split_hash(
+ self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
+ ) -> "AsyncPermutationBuilder": ...
+ def split_sequential(
+ self,
+ *,
+ ratios: Optional[List[float]] = None,
+ counts: Optional[List[int]] = None,
+ fixed: Optional[int] = None,
+ ) -> "AsyncPermutationBuilder": ...
+ def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
+ def shuffle(
+ self, seed: Optional[int], clump_size: Optional[int]
+ ) -> "AsyncPermutationBuilder": ...
+ def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
+ async def execute(self) -> Table: ...
+
+def async_permutation_builder(
+ table: Table, dest_table_name: str
+) -> AsyncPermutationBuilder: ...
diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py
index 5d83c89e..dac5b1cd 100644
--- a/python/python/lancedb/db.py
+++ b/python/python/lancedb/db.py
@@ -5,6 +5,7 @@
from __future__ import annotations
from abc import abstractmethod
+from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
@@ -40,7 +41,6 @@ import deprecation
if TYPE_CHECKING:
import pyarrow as pa
from .pydantic import LanceModel
- from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
from .common import DATA, URI
@@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection):
read_consistency_interval: Optional[timedelta] = None,
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
+ _inner: Optional[LanceDbConnection] = None,
):
+ if _inner is not None:
+ self._conn = _inner
+ return
+
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
- self._uri = str(uri)
- self._entered = False
- self.read_consistency_interval = read_consistency_interval
- self.storage_options = storage_options
- self.session = session
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
@@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection):
session,
)
+ # TODO: It would be nice if we didn't store self.storage_options but it is
+ # currently used by the LanceTable.to_lance method. This doesn't _really_
+ # work because some paths like LanceDBConnection.from_inner will lose the
+ # storage_options. Also, this class really shouldn't be holding any state
+ # beyond _conn.
+ self.storage_options = storage_options
self._conn = AsyncConnection(LOOP.run(do_connect()))
+ @property
+ def read_consistency_interval(self) -> Optional[timedelta]:
+ return LOOP.run(self._conn.get_read_consistency_interval())
+
+ @property
+ def session(self) -> Optional[Session]:
+ return self._conn.session
+
+ @property
+ def uri(self) -> str:
+ return self._conn.uri
+
+ @classmethod
+ def from_inner(cls, inner: LanceDbConnection):
+ return cls(None, _inner=inner)
+
def __repr__(self) -> str:
- val = f"{self.__class__.__name__}(uri={self._uri!r}"
+ val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
+ @property
+ def _inner(self) -> LanceDbConnection:
+ return self._conn._inner
+
@override
def list_namespaces(
self,
@@ -856,6 +882,13 @@ class AsyncConnection(object):
def uri(self) -> str:
return self._inner.uri
+ async def get_read_consistency_interval(self) -> Optional[timedelta]:
+ interval_secs = await self._inner.get_read_consistency_interval()
+ if interval_secs is not None:
+ return timedelta(seconds=interval_secs)
+ else:
+ return None
+
async def list_namespaces(
self,
namespace: List[str] = [],
diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py
new file mode 100644
index 00000000..bd8aa610
--- /dev/null
+++ b/python/python/lancedb/permutation.py
@@ -0,0 +1,72 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+from ._lancedb import async_permutation_builder
+from .table import LanceTable
+from .background_loop import LOOP
+from typing import Optional
+
+
+class PermutationBuilder:
+ def __init__(self, table: LanceTable, dest_table_name: str):
+ self._async = async_permutation_builder(table, dest_table_name)
+
+ def select(self, projections: dict[str, str]) -> "PermutationBuilder":
+ self._async.select(projections)
+ return self
+
+ def split_random(
+ self,
+ *,
+ ratios: Optional[list[float]] = None,
+ counts: Optional[list[int]] = None,
+ fixed: Optional[int] = None,
+ seed: Optional[int] = None,
+ ) -> "PermutationBuilder":
+ self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
+ return self
+
+ def split_hash(
+ self,
+ columns: list[str],
+ split_weights: list[int],
+ *,
+ discard_weight: Optional[int] = None,
+ ) -> "PermutationBuilder":
+ self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
+ return self
+
+ def split_sequential(
+ self,
+ *,
+ ratios: Optional[list[float]] = None,
+ counts: Optional[list[int]] = None,
+ fixed: Optional[int] = None,
+ ) -> "PermutationBuilder":
+ self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
+ return self
+
+ def split_calculated(self, calculation: str) -> "PermutationBuilder":
+ self._async.split_calculated(calculation)
+ return self
+
+ def shuffle(
+ self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
+ ) -> "PermutationBuilder":
+ self._async.shuffle(seed=seed, clump_size=clump_size)
+ return self
+
+ def filter(self, filter: str) -> "PermutationBuilder":
+ self._async.filter(filter)
+ return self
+
+ def execute(self) -> LanceTable:
+ async def do_execute():
+ inner_tbl = await self._async.execute()
+ return LanceTable.from_inner(inner_tbl)
+
+ return LOOP.run(do_execute())
+
+
+def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
+ return PermutationBuilder(table, dest_table_name)
diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py
index 7ee0bf01..6749064d 100644
--- a/python/python/lancedb/table.py
+++ b/python/python/lancedb/table.py
@@ -74,6 +74,7 @@ from .index import lang_mapping
if TYPE_CHECKING:
+ from .db import LanceDBConnection
from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
MergeResult,
UpdateResult,
)
- from .db import LanceDBConnection
from .index import IndexConfig
import pandas
import PIL
@@ -1707,22 +1707,38 @@ class LanceTable(Table):
namespace: List[str] = [],
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
+ _async: AsyncTable = None,
):
self._conn = connection
self._namespace = namespace
- self._table = LOOP.run(
- connection._conn.open_table(
- name,
- namespace=namespace,
- storage_options=storage_options,
- index_cache_size=index_cache_size,
+ if _async is not None:
+ self._table = _async
+ else:
+ self._table = LOOP.run(
+ connection._conn.open_table(
+ name,
+ namespace=namespace,
+ storage_options=storage_options,
+ index_cache_size=index_cache_size,
+ )
)
- )
@property
def name(self) -> str:
return self._table.name
+ @classmethod
+ def from_inner(cls, tbl: LanceDBTable):
+ from .db import LanceDBConnection
+
+ async_tbl = AsyncTable(tbl)
+ conn = LanceDBConnection.from_inner(tbl.database())
+ return cls(
+ conn,
+ async_tbl.name,
+ _async=async_tbl,
+ )
+
@classmethod
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
tbl = cls(db, name, namespace=namespace, **kwargs)
@@ -2756,6 +2772,10 @@ class LanceTable(Table):
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
+ @property
+ def _inner(self) -> LanceDBTable:
+ return self._table._inner
+
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,
diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py
new file mode 100644
index 00000000..95cd21c0
--- /dev/null
+++ b/python/python/tests/test_permutation.py
@@ -0,0 +1,496 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+import pyarrow as pa
+import pytest
+
+from lancedb.permutation import permutation_builder
+
+
+def test_split_random_ratios(mem_db):
+ """Test random splitting with ratios."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"x": range(100), "y": range(100)})
+ )
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_random(ratios=[0.3, 0.7])
+ .execute()
+ )
+
+ # Check that the table was created and has data
+ assert permutation_tbl.count_rows() == 100
+
+ # Check that split_id column exists and has correct values
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ split_ids = data["split_id"]
+ assert set(split_ids) == {0, 1}
+
+ # Check approximate split sizes (allowing for rounding)
+ split_0_count = split_ids.count(0)
+ split_1_count = split_ids.count(1)
+ assert 25 <= split_0_count <= 35 # ~30% ± tolerance
+ assert 65 <= split_1_count <= 75 # ~70% ± tolerance
+
+
+def test_split_random_counts(mem_db):
+ """Test random splitting with absolute counts."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"x": range(100), "y": range(100)})
+ )
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_random(counts=[20, 30])
+ .execute()
+ )
+
+ # Check that we have exactly the requested counts
+ assert permutation_tbl.count_rows() == 50
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ split_ids = data["split_id"]
+ assert split_ids.count(0) == 20
+ assert split_ids.count(1) == 30
+
+
+def test_split_random_fixed(mem_db):
+ """Test random splitting with fixed number of splits."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"x": range(100), "y": range(100)})
+ )
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
+ )
+
+ # Check that we have 4 splits with 25 rows each
+ assert permutation_tbl.count_rows() == 100
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ split_ids = data["split_id"]
+ assert set(split_ids) == {0, 1, 2, 3}
+
+ for split_id in range(4):
+ assert split_ids.count(split_id) == 25
+
+
+def test_split_random_with_seed(mem_db):
+ """Test that seeded random splits are reproducible."""
+ tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
+
+ # Create two identical permutations with same seed
+ perm1 = (
+ permutation_builder(tbl, "perm1")
+ .split_random(ratios=[0.6, 0.4], seed=42)
+ .execute()
+ )
+
+ perm2 = (
+ permutation_builder(tbl, "perm2")
+ .split_random(ratios=[0.6, 0.4], seed=42)
+ .execute()
+ )
+
+ # Results should be identical
+ data1 = perm1.search(None).to_arrow().to_pydict()
+ data2 = perm2.search(None).to_arrow().to_pydict()
+
+ assert data1["row_id"] == data2["row_id"]
+ assert data1["split_id"] == data2["split_id"]
+
+
+def test_split_hash(mem_db):
+ """Test hash-based splitting."""
+ tbl = mem_db.create_table(
+ "test_table",
+ pa.table(
+ {
+ "id": range(100),
+ "category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
+ "value": range(100),
+ }
+ ),
+ )
+
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_hash(["category"], [1, 1], discard_weight=0)
+ .execute()
+ )
+
+ # Should have all 100 rows (no discard)
+ assert permutation_tbl.count_rows() == 100
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ split_ids = data["split_id"]
+ assert set(split_ids) == {0, 1}
+
+ # Verify that each split has roughly 50 rows (allowing for hash variance)
+ split_0_count = split_ids.count(0)
+ split_1_count = split_ids.count(1)
+ assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
+ assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
+
+ # Hash splits should be deterministic - same category should go to same split
+ # Let's verify by creating another permutation and checking consistency
+ perm2 = (
+ permutation_builder(tbl, "test_permutation2")
+ .split_hash(["category"], [1, 1], discard_weight=0)
+ .execute()
+ )
+
+ data2 = perm2.search(None).to_arrow().to_pydict()
+ assert data["split_id"] == data2["split_id"] # Should be identical
+
+
+def test_split_hash_with_discard(mem_db):
+ """Test hash-based splitting with discard weight."""
+ tbl = mem_db.create_table(
+ "test_table",
+ pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
+ )
+
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
+ .execute()
+ )
+
+ # Should have fewer than 100 rows due to discard
+ row_count = permutation_tbl.count_rows()
+ assert row_count < 100
+ assert row_count > 0 # But not empty
+
+
+def test_split_sequential(mem_db):
+ """Test sequential splitting."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"x": range(100), "y": range(100)})
+ )
+
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_sequential(counts=[30, 40])
+ .execute()
+ )
+
+ assert permutation_tbl.count_rows() == 70
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+ split_ids = data["split_id"]
+
+ # Sequential should maintain order
+ assert row_ids == sorted(row_ids)
+
+ # First 30 should be split 0, next 40 should be split 1
+ assert split_ids[:30] == [0] * 30
+ assert split_ids[30:] == [1] * 40
+
+
+def test_split_calculated(mem_db):
+ """Test calculated splitting."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(100), "value": range(100)})
+ )
+
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_calculated("id % 3") # Split based on id modulo 3
+ .execute()
+ )
+
+ assert permutation_tbl.count_rows() == 100
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+ split_ids = data["split_id"]
+
+ # Verify the calculation: each row's split_id should equal row_id % 3
+ for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
+ assert split_id == row_id % 3
+
+
+def test_split_error_cases(mem_db):
+ """Test error handling for invalid split parameters."""
+ tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
+
+ # Test split_random with no parameters
+ with pytest.raises(Exception):
+ permutation_builder(tbl, "error1").split_random().execute()
+
+ # Test split_random with multiple parameters
+ with pytest.raises(Exception):
+ permutation_builder(tbl, "error2").split_random(
+ ratios=[0.5, 0.5], counts=[5, 5]
+ ).execute()
+
+ # Test split_sequential with no parameters
+ with pytest.raises(Exception):
+ permutation_builder(tbl, "error3").split_sequential().execute()
+
+ # Test split_sequential with multiple parameters
+ with pytest.raises(Exception):
+ permutation_builder(tbl, "error4").split_sequential(
+ ratios=[0.5, 0.5], fixed=2
+ ).execute()
+
+
+def test_shuffle_no_seed(mem_db):
+ """Test shuffling without a seed."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(100), "value": range(100)})
+ )
+
+ # Create a permutation with shuffling (no seed)
+ permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
+
+ assert permutation_tbl.count_rows() == 100
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+
+ # Row IDs should not be in sequential order due to shuffling
+ # This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
+ # in order
+ assert row_ids != list(range(100))
+
+
+def test_shuffle_with_seed(mem_db):
+ """Test that shuffling with a seed is reproducible."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(50), "value": range(50)})
+ )
+
+ # Create two identical permutations with same shuffle seed
+ perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
+
+ perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
+
+ # Results should be identical due to same seed
+ data1 = perm1.search(None).to_arrow().to_pydict()
+ data2 = perm2.search(None).to_arrow().to_pydict()
+
+ assert data1["row_id"] == data2["row_id"]
+ assert data1["split_id"] == data2["split_id"]
+
+
+def test_shuffle_with_clump_size(mem_db):
+ """Test shuffling with clump size."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(100), "value": range(100)})
+ )
+
+ # Create a permutation with shuffling using clumps
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .shuffle(clump_size=10) # 10-row clumps
+ .execute()
+ )
+
+ assert permutation_tbl.count_rows() == 100
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+
+ for i in range(10):
+ start = row_ids[i * 10]
+ assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
+
+
+def test_shuffle_different_seeds(mem_db):
+ """Test that different seeds produce different shuffle orders."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(50), "value": range(50)})
+ )
+
+ # Create two permutations with different shuffle seeds
+ perm1 = (
+ permutation_builder(tbl, "perm1")
+ .split_random(fixed=2)
+ .shuffle(seed=42)
+ .execute()
+ )
+
+ perm2 = (
+ permutation_builder(tbl, "perm2")
+ .split_random(fixed=2)
+ .shuffle(seed=123)
+ .execute()
+ )
+
+ # Results should be different due to different seeds
+ data1 = perm1.search(None).to_arrow().to_pydict()
+ data2 = perm2.search(None).to_arrow().to_pydict()
+
+ # Row order should be different
+ assert data1["row_id"] != data2["row_id"]
+
+
+def test_shuffle_combined_with_splits(mem_db):
+ """Test shuffling combined with different split strategies."""
+ tbl = mem_db.create_table(
+ "test_table",
+ pa.table(
+ {
+ "id": range(100),
+ "category": (["A", "B", "C"] * 34)[:100],
+ "value": range(100),
+ }
+ ),
+ )
+
+ # Test shuffle with random splits
+ perm_random = (
+ permutation_builder(tbl, "perm_random")
+ .split_random(ratios=[0.6, 0.4], seed=42)
+ .shuffle(seed=123, clump_size=None)
+ .execute()
+ )
+
+ # Test shuffle with hash splits
+ perm_hash = (
+ permutation_builder(tbl, "perm_hash")
+ .split_hash(["category"], [1, 1], discard_weight=0)
+ .shuffle(seed=456, clump_size=5)
+ .execute()
+ )
+
+ # Test shuffle with sequential splits
+ perm_sequential = (
+ permutation_builder(tbl, "perm_sequential")
+ .split_sequential(counts=[40, 35])
+ .shuffle(seed=789, clump_size=None)
+ .execute()
+ )
+
+ # Verify all permutations work and have expected properties
+ assert perm_random.count_rows() == 100
+ assert perm_hash.count_rows() == 100
+ assert perm_sequential.count_rows() == 75
+
+ # Verify shuffle affected the order
+ data_random = perm_random.search(None).to_arrow().to_pydict()
+ data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
+
+ assert data_random["row_id"] != list(range(100))
+ assert data_sequential["row_id"] != list(range(75))
+
+
+def test_no_shuffle_maintains_order(mem_db):
+ """Test that not calling shuffle maintains the original order."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(50), "value": range(50)})
+ )
+
+ # Create permutation without shuffle (should maintain some order)
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .split_sequential(counts=[25, 25]) # Sequential maintains order
+ .execute()
+ )
+
+ assert permutation_tbl.count_rows() == 50
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+
+ # With sequential splits and no shuffle, should maintain order
+ assert row_ids == list(range(50))
+
+
+def test_filter_basic(mem_db):
+ """Test basic filtering functionality."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(100), "value": range(100, 200)})
+ )
+
+ # Filter to only include rows where id < 50
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
+ )
+
+ assert permutation_tbl.count_rows() == 50
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+
+ # All row_ids should be less than 50
+ assert all(row_id < 50 for row_id in row_ids)
+
+
+def test_filter_with_splits(mem_db):
+ """Test filtering combined with split strategies."""
+ tbl = mem_db.create_table(
+ "test_table",
+ pa.table(
+ {
+ "id": range(100),
+ "category": (["A", "B", "C"] * 34)[:100],
+ "value": range(100),
+ }
+ ),
+ )
+
+ # Filter to only category A and B, then split
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .filter("category IN ('A', 'B')")
+ .split_random(ratios=[0.5, 0.5])
+ .execute()
+ )
+
+ # Should have fewer than 100 rows due to filtering
+ row_count = permutation_tbl.count_rows()
+ assert row_count == 67
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ categories = data["category"]
+
+ # All categories should be A or B
+ assert all(cat in ["A", "B"] for cat in categories)
+
+
+def test_filter_with_shuffle(mem_db):
+ """Test filtering combined with shuffling."""
+ tbl = mem_db.create_table(
+ "test_table",
+ pa.table(
+ {
+ "id": range(100),
+ "category": (["A", "B", "C", "D"] * 25)[:100],
+ "value": range(100),
+ }
+ ),
+ )
+
+ # Filter and shuffle
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .filter("category IN ('A', 'C')")
+ .shuffle(seed=42)
+ .execute()
+ )
+
+ row_count = permutation_tbl.count_rows()
+ assert row_count == 50 # Should have 50 rows (A and C categories)
+
+ data = permutation_tbl.search(None).to_arrow().to_pydict()
+ row_ids = data["row_id"]
+
+ assert row_ids != sorted(row_ids)
+
+
+def test_filter_empty_result(mem_db):
+ """Test filtering that results in empty set."""
+ tbl = mem_db.create_table(
+ "test_table", pa.table({"id": range(10), "value": range(10)})
+ )
+
+ # Filter that matches nothing
+ permutation_tbl = (
+ permutation_builder(tbl, "test_permutation")
+ .filter("value > 100") # No values > 100 in our data
+ .execute()
+ )
+
+ assert permutation_tbl.count_rows() == 0
diff --git a/python/src/connection.rs b/python/src/connection.rs
index 13782dca..67553074 100644
--- a/python/src/connection.rs
+++ b/python/src/connection.rs
@@ -4,7 +4,10 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
-use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
+use lancedb::{
+ connection::Connection as LanceConnection,
+ database::{CreateTableMode, ReadConsistency},
+};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -23,7 +26,7 @@ impl Connection {
Self { inner: Some(inner) }
}
- fn get_inner(&self) -> PyResult<&LanceConnection> {
+ pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -63,6 +66,18 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
+ #[pyo3(signature = ())]
+ pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult> {
+ let inner = self_.get_inner()?.clone();
+ future_into_py(self_.py(), async move {
+ Ok(match inner.read_consistency().await.infer_error()? {
+ ReadConsistency::Manual => None,
+ ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
+ ReadConsistency::Strong => Some(0.0_f64),
+ })
+ })
+ }
+
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 636c176f..6f20c99e 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
+use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -22,6 +23,7 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
+pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
+ m.add_class::()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
+ m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
diff --git a/python/src/permutation.rs b/python/src/permutation.rs
new file mode 100644
index 00000000..38da3f82
--- /dev/null
+++ b/python/src/permutation.rs
@@ -0,0 +1,177 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+use std::sync::{Arc, Mutex};
+
+use crate::{error::PythonErrorExt, table::Table};
+use lancedb::dataloader::{
+ permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
+ split::{SplitSizes, SplitStrategy},
+};
+use pyo3::{
+ exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
+ PyResult,
+};
+use pyo3_async_runtimes::tokio::future_into_py;
+
+/// Create a permutation builder for the given table
+#[pyo3::pyfunction]
+pub fn async_permutation_builder(
+ table: Bound<'_, PyAny>,
+ dest_table_name: String,
+) -> PyResult {
+ let table = table.getattr("_inner")?.downcast_into::()?;
+ let inner_table = table.borrow().inner_ref()?.clone();
+ let inner_builder = LancePermutationBuilder::new(inner_table);
+
+ Ok(PyAsyncPermutationBuilder {
+ state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
+ builder: Some(inner_builder),
+ dest_table_name,
+ })),
+ })
+}
+
+struct PyAsyncPermutationBuilderState {
+ builder: Option,
+ dest_table_name: String,
+}
+
+#[pyclass(name = "AsyncPermutationBuilder")]
+pub struct PyAsyncPermutationBuilder {
+ state: Arc>,
+}
+
+impl PyAsyncPermutationBuilder {
+ fn modify(
+ &self,
+ func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
+ ) -> PyResult {
+ let mut state = self.state.lock().unwrap();
+ let builder = state
+ .builder
+ .take()
+ .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
+ state.builder = Some(func(builder));
+ Ok(Self {
+ state: self.state.clone(),
+ })
+ }
+}
+
+#[pymethods]
+impl PyAsyncPermutationBuilder {
+ #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
+ pub fn split_random(
+ slf: PyRefMut<'_, Self>,
+ ratios: Option>,
+ counts: Option>,
+ fixed: Option,
+ seed: Option,
+ ) -> PyResult {
+ // Check that exactly one split type is provided
+ let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
+ .iter()
+ .filter(|&&x| x)
+ .count();
+
+ if split_args_count != 1 {
+ return Err(pyo3::exceptions::PyValueError::new_err(
+ "Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
+ ));
+ }
+
+ let sizes = if let Some(ratios) = ratios {
+ SplitSizes::Percentages(ratios)
+ } else if let Some(counts) = counts {
+ SplitSizes::Counts(counts)
+ } else if let Some(fixed) = fixed {
+ SplitSizes::Fixed(fixed)
+ } else {
+ unreachable!("One of the split arguments must be provided");
+ };
+
+ slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
+ }
+
+ #[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
+ pub fn split_hash(
+ slf: PyRefMut<'_, Self>,
+ columns: Vec,
+ split_weights: Vec,
+ discard_weight: u64,
+ ) -> PyResult {
+ slf.modify(|builder| {
+ builder.with_split_strategy(SplitStrategy::Hash {
+ columns,
+ split_weights,
+ discard_weight,
+ })
+ })
+ }
+
+ #[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
+ pub fn split_sequential(
+ slf: PyRefMut<'_, Self>,
+ ratios: Option>,
+ counts: Option>,
+ fixed: Option,
+ ) -> PyResult {
+ // Check that exactly one split type is provided
+ let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
+ .iter()
+ .filter(|&&x| x)
+ .count();
+
+ if split_args_count != 1 {
+ return Err(pyo3::exceptions::PyValueError::new_err(
+ "Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
+ ));
+ }
+
+ let sizes = if let Some(ratios) = ratios {
+ SplitSizes::Percentages(ratios)
+ } else if let Some(counts) = counts {
+ SplitSizes::Counts(counts)
+ } else if let Some(fixed) = fixed {
+ SplitSizes::Fixed(fixed)
+ } else {
+ unreachable!("One of the split arguments must be provided");
+ };
+
+ slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
+ }
+
+ pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult {
+ slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
+ }
+
+ pub fn shuffle(
+ slf: PyRefMut<'_, Self>,
+ seed: Option,
+ clump_size: Option,
+ ) -> PyResult {
+ slf.modify(|builder| {
+ builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
+ })
+ }
+
+ pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult {
+ slf.modify(|builder| builder.with_filter(filter))
+ }
+
+ pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult> {
+ let mut state = slf.state.lock().unwrap();
+ let builder = state
+ .builder
+ .take()
+ .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
+
+ let dest_table_name = std::mem::take(&mut state.dest_table_name);
+
+ future_into_py(slf.py(), async move {
+ let table = builder.build(&dest_table_name).await.infer_error()?;
+ Ok(Table::new(table))
+ })
+ }
+}
diff --git a/python/src/table.rs b/python/src/table.rs
index f9f7f995..2097909a 100644
--- a/python/src/table.rs
+++ b/python/src/table.rs
@@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
+ connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -249,7 +250,7 @@ impl Table {
}
impl Table {
- fn inner_ref(&self) -> PyResult<&LanceDbTable> {
+ pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -272,6 +273,13 @@ impl Table {
self.inner.take();
}
+ pub fn database(&self) -> PyResult {
+ let inner = self.inner_ref()?.clone();
+ let inner_connection =
+ lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
+ Ok(Connection::new(inner_connection))
+ }
+
pub fn schema(self_: PyRef<'_, Self>) -> PyResult> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml
index 97934a34..aee78336 100644
--- a/rust/lancedb/Cargo.toml
+++ b/rust/lancedb/Cargo.toml
@@ -11,6 +11,7 @@ rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
@@ -24,12 +25,16 @@ datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-plan.workspace = true
+datafusion.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
+lance-core = { workspace = true }
lance-datafusion.workspace = true
+lance-datagen = { workspace = true }
+lance-file = { workspace = true }
lance-io = { workspace = true }
lance-index = { workspace = true }
lance-table = { workspace = true }
@@ -46,11 +51,13 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
+rand.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
+tempfile = "3.5.0"
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
# For remote feature
reqwest = { version = "0.12.0", default-features = false, features = [
@@ -61,9 +68,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"macos-system-configuration",
"stream",
], optional = true }
-rand = { version = "0.9", features = ["small_rng"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
-uuid = { version = "1.7.0", features = ["v4"], optional = true }
+uuid = { version = "1.7.0", features = ["v4"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
@@ -84,7 +90,6 @@ bytemuck_derive.workspace = true
[dev-dependencies]
anyhow = "1"
tempfile = "3.5.0"
-rand = { version = "0.9", features = ["small_rng"] }
random_word = { version = "0.4.3", features = ["en"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
@@ -96,6 +101,7 @@ aws-smithy-runtime = { version = "1.9.1" }
datafusion.workspace = true
http-body = "1" # Matching reqwest
rstest = "0.23.0"
+test-log = "0.2"
[features]
@@ -105,7 +111,7 @@ oss = ["lance/oss", "lance-io/oss"]
gcs = ["lance/gcp", "lance-io/gcp"]
azure = ["lance/azure", "lance-io/azure"]
dynamodb = ["lance/dynamodb", "aws"]
-remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
+remote = ["dep:reqwest", "dep:http"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]
diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs
index 133f5d1f..df6e51a3 100644
--- a/rust/lancedb/src/arrow.rs
+++ b/rust/lancedb/src/arrow.rs
@@ -7,6 +7,7 @@ pub use arrow_schema;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
+use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
@@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
}
}
+pub trait LanceDbDatagenExt {
+ fn into_ldb_stream(
+ self,
+ batch_size: RowCount,
+ num_batches: BatchCount,
+ ) -> SendableRecordBatchStream;
+}
+
+impl LanceDbDatagenExt for BatchGeneratorBuilder {
+ fn into_ldb_stream(
+ self,
+ batch_size: RowCount,
+ num_batches: BatchCount,
+ ) -> SendableRecordBatchStream {
+ let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
+ let stream = stream.map_err(|err| Error::Arrow { source: err });
+ Box::pin(SimpleRecordBatchStream::new(stream, schema))
+ }
+}
+
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {
diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs
index 66e24161..c8e3496d 100644
--- a/rust/lancedb/src/connection.rs
+++ b/rust/lancedb/src/connection.rs
@@ -19,7 +19,7 @@ use crate::database::listing::{
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
- OpenTableRequest, TableNamesRequest,
+ OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
@@ -152,6 +152,7 @@ impl CreateTableBuilder {
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
+ parent,
embedding_registry,
))
}
@@ -211,9 +212,9 @@ impl CreateTableBuilder {
/// Execute the create table operation
pub async fn execute(self) -> Result {
- Ok(Table::new(
- self.parent.clone().create_table(self.request).await?,
- ))
+ let parent = self.parent.clone();
+ let table = parent.create_table(self.request).await?;
+ Ok(Table::new(table, parent))
}
}
@@ -462,8 +463,10 @@ impl OpenTableBuilder {
/// Open the table
pub async fn execute(self) -> Result {
+ let table = self.parent.open_table(self.request).await?;
Ok(Table::new_with_embedding_registry(
- self.parent.clone().open_table(self.request).await?,
+ table,
+ self.parent,
self.embedding_registry,
))
}
@@ -519,16 +522,15 @@ impl CloneTableBuilder {
/// Execute the clone operation
pub async fn execute(self) -> Result {
- Ok(Table::new(
- self.parent.clone().clone_table(self.request).await?,
- ))
+ let parent = self.parent.clone();
+ let table = parent.clone_table(self.request).await?;
+ Ok(Table::new(table, parent))
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
- uri: String,
internal: Arc,
embedding_registry: Arc,
}
@@ -540,9 +542,19 @@ impl std::fmt::Display for Connection {
}
impl Connection {
+ pub fn new(
+ internal: Arc,
+ embedding_registry: Arc,
+ ) -> Self {
+ Self {
+ internal,
+ embedding_registry,
+ }
+ }
+
/// Get the URI of the connection
pub fn uri(&self) -> &str {
- self.uri.as_str()
+ self.internal.uri()
}
/// Get access to the underlying database
@@ -675,6 +687,11 @@ impl Connection {
.await
}
+ /// Get the read consistency of the connection
+ pub async fn read_consistency(&self) -> Result {
+ self.internal.read_consistency().await
+ }
+
/// Drop a table in the database.
///
/// # Arguments
@@ -973,7 +990,6 @@ impl ConnectBuilder {
)?);
Ok(Connection {
internal,
- uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -996,7 +1012,6 @@ impl ConnectBuilder {
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
Ok(Connection {
internal,
- uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder {
Ok(Connection {
internal,
- uri: format!("namespace://{}", self.ns_impl),
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1139,7 +1153,6 @@ mod test_utils {
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
- uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1156,7 +1169,6 @@ mod test_utils {
));
Self {
internal,
- uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1187,7 +1199,7 @@ mod tests {
#[tokio::test]
async fn test_connect() {
let tc = new_test_connection().await.unwrap();
- assert_eq!(tc.connection.uri, tc.uri);
+ assert_eq!(tc.connection.uri(), tc.uri);
}
#[cfg(not(windows))]
@@ -1208,7 +1220,7 @@ mod tests {
.await
.unwrap();
- assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
+ assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
}
#[tokio::test]
diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs
index e5233920..7816518e 100644
--- a/rust/lancedb/src/database.rs
+++ b/rust/lancedb/src/database.rs
@@ -16,6 +16,7 @@
use std::collections::HashMap;
use std::sync::Arc;
+use std::time::Duration;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
@@ -213,6 +214,20 @@ impl CloneTableRequest {
}
}
+/// How long until a change is reflected from one Table instance to another
+///
+/// Tables are always internally consistent. If a write method is called on
+/// a table instance it will be immediately visible in that same table instance.
+pub enum ReadConsistency {
+ /// Changes will not be automatically propagated until the checkout_latest
+ /// method is called on the target table
+ Manual,
+ /// Changes will be propagated automatically within the given duration
+ Eventual(Duration),
+ /// Changes are immediately visible in target tables
+ Strong,
+}
+
/// The `Database` trait defines the interface for database implementations.
///
/// A database is responsible for managing tables and their metadata.
@@ -220,6 +235,10 @@ impl CloneTableRequest {
pub trait Database:
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
{
+ /// Get the uri of the database
+ fn uri(&self) -> &str;
+ /// Get the read consistency of the database
+ async fn read_consistency(&self) -> Result;
/// List immediate child namespace names in the given namespace
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result>;
/// Create a new namespace
diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs
index 5aa7ac7b..bc1af880 100644
--- a/rust/lancedb/src/database/listing.rs
+++ b/rust/lancedb/src/database/listing.rs
@@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem;
use snafu::ResultExt;
use crate::connection::ConnectRequest;
+use crate::database::ReadConsistency;
use crate::error::{CreateDirSnafu, Error, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::NativeTable;
@@ -598,6 +599,22 @@ impl Database for ListingDatabase {
Ok(Vec::new())
}
+ fn uri(&self) -> &str {
+ &self.uri
+ }
+
+ async fn read_consistency(&self) -> Result {
+ if let Some(read_consistency_inverval) = self.read_consistency_interval {
+ if read_consistency_inverval.is_zero() {
+ Ok(ReadConsistency::Strong)
+ } else {
+ Ok(ReadConsistency::Eventual(read_consistency_inverval))
+ }
+ } else {
+ Ok(ReadConsistency::Manual)
+ }
+ }
+
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
Err(Error::NotSupported {
message: "Namespace operations are not supported for listing database".into(),
@@ -1249,7 +1266,8 @@ mod tests {
)
.unwrap();
- let source_table_obj = Table::new(source_table.clone());
+ let db = Arc::new(db);
+ let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1320,7 +1338,8 @@ mod tests {
.unwrap();
// Create a tag for the current version
- let source_table_obj = Table::new(source_table.clone());
+ let db = Arc::new(db);
+ let source_table_obj = Table::new(source_table.clone(), db.clone());
let mut tags = source_table_obj.tags().await.unwrap();
tags.create("v1.0", source_table.version().await.unwrap())
.await
@@ -1336,7 +1355,7 @@ mod tests {
)
.unwrap();
- let source_table_obj = Table::new(source_table.clone());
+ let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1432,7 +1451,8 @@ mod tests {
)
.unwrap();
- let cloned_table_obj = Table::new(cloned_table.clone());
+ let db = Arc::new(db);
+ let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
cloned_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_clone)],
@@ -1452,7 +1472,7 @@ mod tests {
)
.unwrap();
- let source_table_obj = Table::new(source_table.clone());
+ let source_table_obj = Table::new(source_table.clone(), db);
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_source)],
@@ -1495,6 +1515,7 @@ mod tests {
.unwrap();
// Add more data to create new versions
+ let db = Arc::new(db);
for i in 0..3 {
let batch = RecordBatch::try_new(
schema.clone(),
@@ -1502,7 +1523,7 @@ mod tests {
)
.unwrap();
- let source_table_obj = Table::new(source_table.clone());
+ let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch)],
diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs
index 97f8be39..37d928f8 100644
--- a/rust/lancedb/src/database/namespace.rs
+++ b/rust/lancedb/src/database/namespace.rs
@@ -16,9 +16,9 @@ use lance_namespace::{
LanceNamespace,
};
-use crate::connection::ConnectRequest;
use crate::database::listing::ListingDatabase;
use crate::error::{Error, Result};
+use crate::{connection::ConnectRequest, database::ReadConsistency};
use super::{
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
@@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase {
read_consistency_interval: Option,
// Optional session for object stores and caching
session: Option>,
+ // database URI
+ uri: String,
}
impl LanceNamespaceDatabase {
@@ -57,6 +59,7 @@ impl LanceNamespaceDatabase {
storage_options,
read_consistency_interval,
session,
+ uri: format!("namespace://{}", ns_impl),
})
}
@@ -130,6 +133,22 @@ impl std::fmt::Display for LanceNamespaceDatabase {
#[async_trait]
impl Database for LanceNamespaceDatabase {
+ fn uri(&self) -> &str {
+ &self.uri
+ }
+
+ async fn read_consistency(&self) -> Result {
+ if let Some(read_consistency_inverval) = self.read_consistency_interval {
+ if read_consistency_inverval.is_zero() {
+ Ok(ReadConsistency::Strong)
+ } else {
+ Ok(ReadConsistency::Eventual(read_consistency_inverval))
+ }
+ } else {
+ Ok(ReadConsistency::Manual)
+ }
+ }
+
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result> {
let ns_request = ListNamespacesRequest {
id: if request.namespace.is_empty() {
diff --git a/rust/lancedb/src/dataloader.rs b/rust/lancedb/src/dataloader.rs
new file mode 100644
index 00000000..cbb7f037
--- /dev/null
+++ b/rust/lancedb/src/dataloader.rs
@@ -0,0 +1,7 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+pub mod permutation;
+pub mod shuffle;
+pub mod split;
+pub mod util;
diff --git a/rust/lancedb/src/dataloader/permutation.rs b/rust/lancedb/src/dataloader/permutation.rs
new file mode 100644
index 00000000..09a39d93
--- /dev/null
+++ b/rust/lancedb/src/dataloader/permutation.rs
@@ -0,0 +1,294 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
+//!
+//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
+//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
+//! the underlying data and can be very lightweight.
+//!
+//! Building a permutation table should be fairly quick and memory efficient, even for billions or
+//! trillions of rows.
+
+use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
+use datafusion_expr::col;
+use futures::TryStreamExt;
+use lance_datafusion::exec::SessionContextExt;
+
+use crate::{
+ arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
+ dataloader::{
+ shuffle::{Shuffler, ShufflerConfig},
+ split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
+ util::{rename_column, TemporaryDirectory},
+ },
+ query::{ExecutableQuery, QueryBase},
+ Connection, Error, Result, Table,
+};
+
+/// Configuration for creating a permutation table
+#[derive(Debug, Default)]
+pub struct PermutationConfig {
+ /// Splitting configuration
+ pub split_strategy: SplitStrategy,
+ /// Shuffle strategy
+ pub shuffle_strategy: ShuffleStrategy,
+ /// Optional filter to apply to the base table
+ pub filter: Option,
+ /// Directory to use for temporary files
+ pub temp_dir: TemporaryDirectory,
+}
+
+/// Strategy for shuffling the data.
+#[derive(Debug, Clone)]
+pub enum ShuffleStrategy {
+ /// The data is randomly shuffled
+ ///
+ /// A seed can be provided to make the shuffle deterministic.
+ ///
+ /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
+ /// This decreases the overall randomization but can improve I/O performance when reading from
+ /// cloud storage.
+ ///
+ /// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
+ /// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
+ /// the performance of the model. Note: shuffling within clumps can still be done at read time but
+ /// this will only provide a local shuffle and not a global shuffle.
+ Random {
+ seed: Option,
+ clump_size: Option,
+ },
+ /// The data is not shuffled
+ ///
+ /// This is useful for debugging and testing.
+ None,
+}
+
+impl Default for ShuffleStrategy {
+ fn default() -> Self {
+ Self::None
+ }
+}
+
+/// Builder for creating a permutation table.
+///
+/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
+/// can be used to create a
+pub struct PermutationBuilder {
+ config: PermutationConfig,
+ base_table: Table,
+}
+
+impl PermutationBuilder {
+ pub fn new(base_table: Table) -> Self {
+ Self {
+ config: PermutationConfig::default(),
+ base_table,
+ }
+ }
+
+ /// Configures the strategy for assigning rows to splits.
+ ///
+ /// For example, it is common to create a test/train split of the data. Splits can also be used
+ /// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
+ /// create a single split with 10% of the data.
+ ///
+ /// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
+ /// multiple processes and multiple nodes.
+ ///
+ /// The default is a single split that contains all rows.
+ pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
+ self.config.split_strategy = split_strategy;
+ self
+ }
+
+ /// Configures the strategy for shuffling the data.
+ ///
+ /// The default is to shuffle the data randomly at row-level granularity (no shard size) and
+ /// with a random seed.
+ pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
+ self.config.shuffle_strategy = shuffle_strategy;
+ self
+ }
+
+ /// Configures a filter to apply to the base table.
+ ///
+ /// Only rows matching the filter will be included in the permutation.
+ pub fn with_filter(mut self, filter: String) -> Self {
+ self.config.filter = Some(filter);
+ self
+ }
+
+ /// Configures the directory to use for temporary files.
+ ///
+ /// The default is to use the operating system's default temporary directory.
+ pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
+ self.config.temp_dir = temp_dir;
+ self
+ }
+
+ async fn sort_by_split_id(
+ &self,
+ data: SendableRecordBatchStream,
+ ) -> Result {
+ let ctx = SessionContext::new_with_config_rt(
+ SessionConfig::default(),
+ RuntimeEnvBuilder::new()
+ .with_memory_limit(100 * 1024 * 1024, 1.0)
+ .with_disk_manager_builder(
+ DiskManagerBuilder::default()
+ .with_mode(self.config.temp_dir.to_disk_manager_mode()),
+ )
+ .build_arc()
+ .unwrap(),
+ );
+ let df = ctx
+ .read_one_shot(data.into_df_stream())
+ .map_err(|e| Error::Other {
+ message: format!("Failed to setup sort by split id: {}", e),
+ source: Some(e.into()),
+ })?;
+ let df_stream = df
+ .sort_by(vec![col(SPLIT_ID_COLUMN)])
+ .map_err(|e| Error::Other {
+ message: format!("Failed to plan sort by split id: {}", e),
+ source: Some(e.into()),
+ })?
+ .execute_stream()
+ .await
+ .map_err(|e| Error::Other {
+ message: format!("Failed to sort by split id: {}", e),
+ source: Some(e.into()),
+ })?;
+
+ let schema = df_stream.schema();
+ let stream = df_stream.map_err(|e| Error::Other {
+ message: format!("Failed to execute sort by split id: {}", e),
+ source: Some(e.into()),
+ });
+ Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
+ }
+
+ /// Builds the permutation table and stores it in the given database.
+ pub async fn build(self, dest_table_name: &str) -> Result {
+ // First pass, apply filter and load row ids
+ let mut rows = self.base_table.query().with_row_id();
+
+ if let Some(filter) = &self.config.filter {
+ rows = rows.only_if(filter);
+ }
+
+ let splitter = Splitter::new(
+ self.config.temp_dir.clone(),
+ self.config.split_strategy.clone(),
+ );
+
+ let mut needs_sort = !splitter.orders_by_split_id();
+
+ // Might need to load additional columns to calculate splits (e.g. hash columns or calculated
+ // split id)
+ rows = splitter.project(rows);
+
+ let num_rows = self
+ .base_table
+ .count_rows(self.config.filter.clone())
+ .await? as u64;
+
+ // Apply splits
+ let rows = rows.execute().await?;
+ let split_data = splitter.apply(rows, num_rows).await?;
+
+ // Shuffle data if requested
+ let shuffled = match self.config.shuffle_strategy {
+ ShuffleStrategy::None => split_data,
+ ShuffleStrategy::Random { seed, clump_size } => {
+ let shuffler = Shuffler::new(ShufflerConfig {
+ seed,
+ clump_size,
+ temp_dir: self.config.temp_dir.clone(),
+ max_rows_per_file: 10 * 1024 * 1024,
+ });
+ shuffler.shuffle(split_data, num_rows).await?
+ }
+ };
+
+ // We want the final permutation to be sorted by the split id. If we shuffled or if
+ // the split was not assigned sequentially then we need to sort the data.
+ needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
+
+ let sorted = if needs_sort {
+ self.sort_by_split_id(shuffled).await?
+ } else {
+ shuffled
+ };
+
+ // Rename _rowid to row_id
+ let renamed = rename_column(sorted, "_rowid", "row_id")?;
+
+ // Create permutation table
+ let conn = Connection::new(
+ self.base_table.database().clone(),
+ self.base_table.embedding_registry().clone(),
+ );
+ conn.create_table_streaming(dest_table_name, renamed)
+ .execute()
+ .await
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use arrow::datatypes::Int32Type;
+ use lance_datagen::{BatchCount, RowCount};
+
+ use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
+
+ use super::*;
+
+ #[tokio::test]
+ async fn test_permutation_builder() {
+ let temp_dir = tempfile::tempdir().unwrap();
+
+ let db = connect(temp_dir.path().to_str().unwrap())
+ .execute()
+ .await
+ .unwrap();
+
+ let initial_data = lance_datagen::gen_batch()
+ .col("some_value", lance_datagen::array::step::())
+ .into_ldb_stream(RowCount::from(100), BatchCount::from(10));
+ let data_table = db
+ .create_table_streaming("mytbl", initial_data)
+ .execute()
+ .await
+ .unwrap();
+
+ let permutation_table = PermutationBuilder::new(data_table)
+ .with_filter("some_value > 57".to_string())
+ .with_split_strategy(SplitStrategy::Random {
+ seed: Some(42),
+ sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
+ })
+ .build("permutation")
+ .await
+ .unwrap();
+
+ // Potentially brittle seed-dependent values below
+ assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
+ assert_eq!(
+ permutation_table
+ .count_rows(Some("split_id = 0".to_string()))
+ .await
+ .unwrap(),
+ 47
+ );
+ assert_eq!(
+ permutation_table
+ .count_rows(Some("split_id = 1".to_string()))
+ .await
+ .unwrap(),
+ 283
+ );
+ }
+}
diff --git a/rust/lancedb/src/dataloader/shuffle.rs b/rust/lancedb/src/dataloader/shuffle.rs
new file mode 100644
index 00000000..e06affc7
--- /dev/null
+++ b/rust/lancedb/src/dataloader/shuffle.rs
@@ -0,0 +1,475 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+use std::sync::{Arc, Mutex};
+
+use arrow::compute::concat_batches;
+use arrow_array::{RecordBatch, UInt64Array};
+use futures::{StreamExt, TryStreamExt};
+use lance::io::ObjectStore;
+use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt};
+use lance_encoding::decoder::DecoderPlugins;
+use lance_file::v2::{
+ reader::{FileReader, FileReaderOptions},
+ writer::{FileWriter, FileWriterOptions},
+};
+use lance_index::scalar::IndexReader;
+use lance_io::{
+ scheduler::{ScanScheduler, SchedulerConfig},
+ utils::CachedFileSize,
+};
+use rand::{seq::SliceRandom, Rng, RngCore};
+
+use crate::{
+ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
+ dataloader::util::{non_crypto_rng, TemporaryDirectory},
+ Error, Result,
+};
+
+#[derive(Debug, Clone)]
+pub struct ShufflerConfig {
+ /// An optional seed to make the shuffle deterministic
+ pub seed: Option,
+ /// The maximum number of rows to write to a single file
+ ///
+ /// The shuffler will need to hold at least this many rows in memory. Setting this value
+ /// extremely large could cause the shuffler to use a lot of memory (depending on row size).
+ ///
+ /// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file
+ /// writers in memory. Each of these will consume some amount of data for column write buffers.
+ /// So setting this value too small could _also_ cause the shuffler to use a lot of memory and
+ /// open file handles.
+ pub max_rows_per_file: u64,
+ /// The temporary directory to use for writing files
+ pub temp_dir: TemporaryDirectory,
+ /// The size of the clumps to shuffle within
+ ///
+ /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
+ /// This decreases the overall randomization but can improve I/O performance when reading from
+ /// cloud storage.
+ pub clump_size: Option,
+}
+
+impl Default for ShufflerConfig {
+ fn default() -> Self {
+ Self {
+ max_rows_per_file: 1024 * 1024,
+ seed: Option::default(),
+ temp_dir: TemporaryDirectory::default(),
+ clump_size: None,
+ }
+ }
+}
+
+/// A shuffler that can shuffle a stream of record batches
+///
+/// To do this the stream is consumed and written to temporary files. A new stream is returned
+/// which returns the shuffled data from the temporary files.
+///
+/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not
+/// write any files and will instead perform an in-memory shuffle.
+///
+/// The number of rows in the input stream must be known in advance.
+pub struct Shuffler {
+ config: ShufflerConfig,
+ id: String,
+}
+
+impl Shuffler {
+ pub fn new(config: ShufflerConfig) -> Self {
+ let id = uuid::Uuid::new_v4().to_string();
+ Self { config, id }
+ }
+
+ /// Shuffles a single batch of data in memory
+ fn shuffle_batch(
+ batch: &RecordBatch,
+ rng: &mut dyn RngCore,
+ clump_size: u64,
+ ) -> Result {
+ let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
+ let mut indices = (0..num_clumps).collect::>();
+ indices.shuffle(rng);
+ let indices = if clump_size == 1 {
+ UInt64Array::from(indices)
+ } else {
+ UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| {
+ if clump_index == num_clumps - 1 {
+ clump_index * clump_size..batch.num_rows() as u64
+ } else {
+ clump_index * clump_size..(clump_index + 1) * clump_size
+ }
+ }))
+ };
+ Ok(arrow::compute::take_record_batch(batch, &indices)?)
+ }
+
+ async fn in_memory_shuffle(
+ &self,
+ data: SendableRecordBatchStream,
+ mut rng: Box,
+ ) -> Result {
+ let schema = data.schema();
+ let batches = data.try_collect::>().await?;
+ let batch = concat_batches(&schema, &batches)?;
+ let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?;
+ log::debug!("Shuffle job {}: in-memory shuffle complete", self.id);
+ Ok(Box::pin(SimpleRecordBatchStream::new(
+ futures::stream::once(async move { Ok(shuffled) }),
+ schema,
+ )))
+ }
+
+ async fn do_shuffle(
+ &self,
+ mut data: SendableRecordBatchStream,
+ num_rows: u64,
+ mut rng: Box,
+ ) -> Result {
+ let num_files = num_rows.div_ceil(self.config.max_rows_per_file);
+
+ let temp_dir = self.config.temp_dir.create_temp_dir()?;
+ let tmp_dir = temp_dir.path().to_path_buf();
+
+ let clump_size = self.config.clump_size.unwrap_or(1);
+ if clump_size == 0 {
+ return Err(Error::InvalidInput {
+ message: "clump size must be greater than 0".to_string(),
+ });
+ }
+
+ let object_store = ObjectStore::local();
+ let arrow_schema = data.schema();
+ let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?;
+
+ // Create file writers
+ let mut file_writers = Vec::with_capacity(num_files as usize);
+ for file_index in 0..num_files {
+ let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id));
+ let path =
+ object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other {
+ message: format!("Failed to create temporary file: {}", err),
+ source: None,
+ })?;
+ let object_writer = object_store.create(&path).await?;
+ let writer =
+ FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?;
+ file_writers.push(writer);
+ }
+
+ let mut num_rows_seen = 0;
+
+ // Randomly distribute clumps to files
+ while let Some(batch) = data.try_next().await? {
+ num_rows_seen += batch.num_rows() as u64;
+ let is_last = num_rows_seen == num_rows;
+ if num_rows_seen > num_rows {
+ return Err(Error::Runtime {
+ message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen),
+ });
+ }
+ // This is kind of an annoying limitation but if we allow runt clumps from batches then
+ // clumps will get unaligned and we will mess up the clumps when we do the in-memory
+ // shuffle step. If this is a problem we can probably figure out a better way to do this.
+ if !is_last && batch.num_rows() as u64 % clump_size != 0 {
+ return Err(Error::Runtime {
+ message: format!(
+ "Expected batch size ({}) to be divisible by clump size ({})",
+ batch.num_rows(),
+ clump_size
+ ),
+ });
+ }
+ let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
+ let mut batch_offsets_for_files =
+ vec![Vec::::with_capacity(batch.num_rows()); num_files as usize];
+ // Partition the batch randomly and write to the appropriate accumulator
+ for clump_offset in 0..num_clumps {
+ let clump_start = clump_offset * clump_size;
+ let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start);
+ let clump_end = clump_start + num_rows_in_clump;
+ let file_index = rng.random_range(0..num_files);
+ batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end);
+ }
+ for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() {
+ if batch_offsets.is_empty() {
+ continue;
+ }
+ let indices = UInt64Array::from(batch_offsets);
+ let partition = arrow::compute::take_record_batch(&batch, &indices)?;
+ file_writers[file_index].write_batch(&partition).await?;
+ }
+ }
+
+ // Finish writing files
+ for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
+ let num_written = writer.finish().await?;
+ log::debug!(
+ "Shuffle job {}: wrote {} rows to file {}",
+ self.id,
+ num_written,
+ file_idx
+ );
+ }
+
+ let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
+ let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
+ let job_id = self.id.clone();
+ let rng = Arc::new(Mutex::new(rng));
+
+ // Second pass, read each file as a single batch and shuffle
+ let stream = futures::stream::iter(0..num_files)
+ .then(move |file_index| {
+ let scan_scheduler = scan_scheduler.clone();
+ let rng = rng.clone();
+ let tmp_dir = tmp_dir.clone();
+ let job_id = job_id.clone();
+ async move {
+ let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id));
+ let path = object_store::path::Path::from_absolute_path(path).unwrap();
+ let file_scheduler = scan_scheduler
+ .open_file(&path, &CachedFileSize::unknown())
+ .await?;
+ let reader = FileReader::try_open(
+ file_scheduler,
+ None,
+ Arc::::default(),
+ &LanceCache::no_cache(),
+ FileReaderOptions::default(),
+ )
+ .await?;
+ // Need to read the entire file in a single batch for in-memory shuffling
+ let batch = reader.read_record_batch(0, reader.num_rows()).await?;
+ let mut rng = rng.lock().unwrap();
+ Self::shuffle_batch(&batch, &mut rng, clump_size)
+ }
+ })
+ .finally(move || drop(temp_dir))
+ .boxed();
+
+ Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema)))
+ }
+
+ pub async fn shuffle(
+ self,
+ data: SendableRecordBatchStream,
+ num_rows: u64,
+ ) -> Result {
+ log::debug!(
+ "Shuffle job {}: shuffling {} rows and {} columns",
+ self.id,
+ num_rows,
+ data.schema().fields.len()
+ );
+ let rng = non_crypto_rng(&self.config.seed);
+
+ if num_rows < self.config.max_rows_per_file {
+ return self.in_memory_shuffle(data, rng).await;
+ }
+
+ self.do_shuffle(data, num_rows, rng).await
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::arrow::LanceDbDatagenExt;
+
+ use super::*;
+ use arrow::{array::AsArray, datatypes::Int32Type};
+ use datafusion::prelude::SessionContext;
+ use datafusion_expr::col;
+ use futures::TryStreamExt;
+ use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
+ use rand::{rngs::SmallRng, SeedableRng};
+
+ fn test_gen() -> BatchGeneratorBuilder {
+ lance_datagen::gen_batch()
+ .with_seed(Seed::from(42))
+ .col("id", lance_datagen::array::step::())
+ .col(
+ "name",
+ lance_datagen::array::rand_utf8(ByteCount::from(10), false),
+ )
+ }
+
+ fn create_test_batch(size: RowCount) -> RecordBatch {
+ test_gen().into_batch_rows(size).unwrap()
+ }
+
+ fn create_test_stream(
+ num_batches: BatchCount,
+ batch_size: RowCount,
+ ) -> SendableRecordBatchStream {
+ test_gen().into_ldb_stream(batch_size, num_batches)
+ }
+
+ #[test]
+ fn test_shuffle_batch_deterministic() {
+ let batch = create_test_batch(RowCount::from(10));
+ let mut rng1 = SmallRng::seed_from_u64(42);
+ let mut rng2 = SmallRng::seed_from_u64(42);
+
+ let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap();
+ let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap();
+
+ // Same seed should produce same shuffle
+ assert_eq!(shuffled1, shuffled2);
+ }
+
+ #[test]
+ fn test_shuffle_with_clumps() {
+ let batch = create_test_batch(RowCount::from(10));
+ let mut rng = SmallRng::seed_from_u64(42);
+ let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap();
+ let values = shuffled.column(0).as_primitive::();
+
+ let mut iter = values.into_iter().map(|o| o.unwrap());
+ let mut frag_seen = false;
+ let mut clumps_seen = 0;
+ while let Some(first) = iter.next() {
+ // 9 is the last value and not a full clump
+ if first != 9 {
+ // Otherwise we should have a full clump
+ let second = iter.next().unwrap();
+ let third = iter.next().unwrap();
+ assert_eq!(first + 1, second);
+ assert_eq!(first + 2, third);
+ clumps_seen += 1;
+ } else {
+ frag_seen = true;
+ }
+ }
+ assert_eq!(clumps_seen, 3);
+ assert!(frag_seen);
+ }
+
+ async fn sort_batch(batch: RecordBatch) -> RecordBatch {
+ let ctx = SessionContext::new();
+ let df = ctx.read_batch(batch).unwrap();
+ let sorted = df.sort_by(vec![col("id")]).unwrap();
+ let batches = sorted.collect().await.unwrap();
+ let schema = batches[0].schema();
+ concat_batches(&schema, &batches).unwrap()
+ }
+
+ #[tokio::test]
+ async fn test_shuffle_batch_preserves_data() {
+ let batch = create_test_batch(RowCount::from(100));
+ let mut rng = SmallRng::seed_from_u64(42);
+
+ let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
+
+ assert_ne!(shuffled, batch);
+
+ let sorted = sort_batch(shuffled).await;
+
+ assert_eq!(sorted, batch);
+ }
+
+ #[test]
+ fn test_shuffle_batch_empty() {
+ let batch = create_test_batch(RowCount::from(0));
+ let mut rng = SmallRng::seed_from_u64(42);
+
+ let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
+ assert_eq!(shuffled.num_rows(), 0);
+ }
+
+ #[tokio::test]
+ async fn test_in_memory_shuffle() {
+ let config = ShufflerConfig {
+ temp_dir: TemporaryDirectory::None,
+ ..Default::default()
+ };
+ let shuffler = Shuffler::new(config);
+
+ let stream = create_test_stream(BatchCount::from(5), RowCount::from(20));
+
+ let result_stream = shuffler.shuffle(stream, 100).await.unwrap();
+ let result_batches: Vec = result_stream.try_collect().await.unwrap();
+
+ assert_eq!(result_batches.len(), 1);
+ let result_batch = result_batches.into_iter().next().unwrap();
+
+ let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20))
+ .try_collect::>()
+ .await
+ .unwrap();
+ let schema = unshuffled_batches[0].schema();
+ let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
+
+ let sorted = sort_batch(result_batch).await;
+
+ assert_eq!(unshuffled_batch, sorted);
+ }
+
+ #[tokio::test]
+ async fn test_external_shuffle() {
+ let config = ShufflerConfig {
+ max_rows_per_file: 100,
+ ..Default::default()
+ };
+ let shuffler = Shuffler::new(config);
+
+ let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000));
+
+ let result_stream = shuffler.shuffle(stream, 5000).await.unwrap();
+ let result_batches: Vec = result_stream.try_collect().await.unwrap();
+
+ let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000))
+ .try_collect::>()
+ .await
+ .unwrap();
+ let schema = unshuffled_batches[0].schema();
+ let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
+
+ assert_eq!(result_batches.len(), 50);
+ let result_batch = concat_batches(&schema, &result_batches).unwrap();
+
+ let sorted = sort_batch(result_batch).await;
+
+ assert_eq!(unshuffled_batch, sorted);
+ }
+
+ #[test_log::test(tokio::test)]
+ async fn test_external_clump_shuffle() {
+ let config = ShufflerConfig {
+ max_rows_per_file: 100,
+ clump_size: Some(30),
+ ..Default::default()
+ };
+ let shuffler = Shuffler::new(config);
+
+ // Batch size (900) must be multiple of clump size (30)
+ let stream = create_test_stream(BatchCount::from(5), RowCount::from(900));
+ let schema = stream.schema();
+
+ // Remove 10 rows from the last batch to simulate ending with partial clump
+ let mut batches = stream.try_collect::>().await.unwrap();
+ let last_index = batches.len() - 1;
+ let sliced_last = batches[last_index].slice(0, 890);
+ batches[last_index] = sliced_last;
+
+ let stream = Box::pin(SimpleRecordBatchStream::new(
+ futures::stream::iter(batches).map(Ok).boxed(),
+ schema.clone(),
+ ));
+
+ let result_stream = shuffler.shuffle(stream, 4490).await.unwrap();
+ let result_batches: Vec = result_stream.try_collect().await.unwrap();
+ let result_batch = concat_batches(&schema, &result_batches).unwrap();
+
+ let ids = result_batch.column(0).as_primitive::();
+ let mut iter = ids.into_iter().map(|o| o.unwrap());
+ while let Some(first) = iter.next() {
+ let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
+ let mut expected_next = first + 1;
+ for _ in 0..rows_left_in_clump {
+ let next = iter.next().unwrap();
+ assert_eq!(next, expected_next);
+ expected_next += 1;
+ }
+ }
+ }
+}
diff --git a/rust/lancedb/src/dataloader/split.rs b/rust/lancedb/src/dataloader/split.rs
new file mode 100644
index 00000000..fd9abfb8
--- /dev/null
+++ b/rust/lancedb/src/dataloader/split.rs
@@ -0,0 +1,804 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+use std::{
+ iter,
+ sync::{
+ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
+ Arc,
+ },
+};
+
+use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
+use arrow_schema::{DataType, Field, Schema};
+use datafusion_common::hash_utils::create_hashes;
+use futures::{StreamExt, TryStreamExt};
+use lance::arrow::SchemaExt;
+
+use crate::{
+ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
+ dataloader::{
+ shuffle::{Shuffler, ShufflerConfig},
+ util::TemporaryDirectory,
+ },
+ query::{Query, QueryBase, Select},
+ Error, Result,
+};
+
+pub const SPLIT_ID_COLUMN: &str = "split_id";
+
+/// Strategy for assigning rows to splits
+#[derive(Debug, Clone)]
+pub enum SplitStrategy {
+ /// All rows will have split id 0
+ NoSplit,
+ /// Rows will be randomly assigned to splits
+ ///
+ /// A seed can be provided to make the assignment deterministic.
+ Random {
+ seed: Option,
+ sizes: SplitSizes,
+ },
+ /// Rows will be assigned to splits based on the values in the specified columns.
+ ///
+ /// This will ensure rows are always assigned to the same split if the given columns do not change.
+ ///
+ /// The `split_weights` are used to determine the approximate number of rows in each split. This
+ /// controls how we divide up the u64 hash space. However, it does not guarantee any particular division
+ /// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split
+ /// regardless of the weights.
+ ///
+ /// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your
+ /// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set
+ /// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight
+ /// to 85). If you set discard_weight to 0 then all rows will be assigned to a split.
+ Hash {
+ columns: Vec,
+ split_weights: Vec,
+ discard_weight: u64,
+ },
+ /// Rows will be assigned to splits sequentially.
+ ///
+ /// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc.
+ ///
+ /// This is mainly useful for debugging and testing.
+ Sequential { sizes: SplitSizes },
+ /// Rows will be assigned to splits based on a calculation of one or more columns.
+ ///
+ /// This is useful when the splits already exist in the base table.
+ ///
+ /// The provided `calculation` should be an SQL statement that returns an integer value between
+ /// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration).
+ ///
+ /// If this strategy is used then the counts/percentages in the SplitSizes are ignored.
+ Calculated { calculation: String },
+}
+
+// The default is not to split the data
+//
+// All data will be assigned to a single split.
+impl Default for SplitStrategy {
+ fn default() -> Self {
+ Self::NoSplit
+ }
+}
+
+impl SplitStrategy {
+ pub fn validate(&self, num_rows: u64) -> Result<()> {
+ match self {
+ Self::NoSplit => Ok(()),
+ Self::Random { sizes, .. } => sizes.validate(num_rows),
+ Self::Hash {
+ split_weights,
+ columns,
+ ..
+ } => {
+ if columns.is_empty() {
+ return Err(Error::InvalidInput {
+ message: "Hash strategy requires at least one column".to_string(),
+ });
+ }
+ if split_weights.is_empty() {
+ return Err(Error::InvalidInput {
+ message: "Hash strategy requires at least one split weight".to_string(),
+ });
+ }
+ if split_weights.iter().any(|w| *w == 0) {
+ return Err(Error::InvalidInput {
+ message: "Split weights must be greater than 0".to_string(),
+ });
+ }
+ Ok(())
+ }
+ Self::Sequential { sizes } => sizes.validate(num_rows),
+ Self::Calculated { .. } => Ok(()),
+ }
+ }
+}
+
+pub struct Splitter {
+ temp_dir: TemporaryDirectory,
+ strategy: SplitStrategy,
+}
+
+impl Splitter {
+ pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self {
+ Self { temp_dir, strategy }
+ }
+
+ fn sequential_split_id(
+ num_rows: u64,
+ split_sizes: &[u64],
+ split_index: &AtomicUsize,
+ counter_in_split: &AtomicU64,
+ exhausted: &AtomicBool,
+ ) -> UInt64Array {
+ let mut split_ids = Vec::::with_capacity(num_rows as usize);
+
+ while split_ids.len() < num_rows as usize {
+ let split_id = split_index.load(Ordering::Relaxed);
+ let counter = counter_in_split.load(Ordering::Relaxed);
+
+ let split_size = split_sizes[split_id];
+ let remaining_in_split = split_size - counter;
+
+ let remaining_in_batch = num_rows - split_ids.len() as u64;
+
+ let mut done = false;
+ let rows_to_add = if remaining_in_batch < remaining_in_split {
+ counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed);
+ remaining_in_batch
+ } else {
+ split_index.fetch_add(1, Ordering::Relaxed);
+ counter_in_split.store(0, Ordering::Relaxed);
+ if split_id == split_sizes.len() - 1 {
+ exhausted.store(true, Ordering::Relaxed);
+ done = true;
+ }
+ remaining_in_split
+ };
+
+ split_ids.extend(iter::repeat_n(split_id as u64, rows_to_add as usize));
+ if done {
+ // Quit early if we've run out of splits
+ break;
+ }
+ }
+
+ UInt64Array::from(split_ids)
+ }
+
+ async fn apply_sequential(
+ &self,
+ source: SendableRecordBatchStream,
+ num_rows: u64,
+ sizes: &SplitSizes,
+ ) -> Result {
+ let split_sizes = sizes.to_counts(num_rows);
+
+ let split_index = AtomicUsize::new(0);
+ let counter_in_split = AtomicU64::new(0);
+ let exhausted = AtomicBool::new(false);
+
+ let schema = source.schema();
+
+ let new_schema = Arc::new(schema.try_with_column(Field::new(
+ SPLIT_ID_COLUMN,
+ DataType::UInt64,
+ false,
+ ))?);
+
+ let new_schema_clone = new_schema.clone();
+ let stream = source.filter_map(move |batch| {
+ let batch = match batch {
+ Ok(batch) => batch,
+ Err(e) => {
+ return std::future::ready(Some(Err(e)));
+ }
+ };
+
+ if exhausted.load(Ordering::Relaxed) {
+ return std::future::ready(None);
+ }
+
+ let split_ids = Self::sequential_split_id(
+ batch.num_rows() as u64,
+ &split_sizes,
+ &split_index,
+ &counter_in_split,
+ &exhausted,
+ );
+
+ let mut arrays = batch.columns().to_vec();
+ // This can happen if we exhaust all splits in the middle of a batch
+ if split_ids.len() < batch.num_rows() {
+ arrays = arrays
+ .iter()
+ .map(|arr| arr.slice(0, split_ids.len()))
+ .collect();
+ }
+ arrays.push(Arc::new(split_ids));
+
+ std::future::ready(Some(Ok(
+ RecordBatch::try_new(new_schema.clone(), arrays).unwrap()
+ )))
+ });
+
+ Ok(Box::pin(SimpleRecordBatchStream::new(
+ stream,
+ new_schema_clone,
+ )))
+ }
+
+ fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array {
+ let arrays = batch
+ .columns()
+ .iter()
+ // Don't hash the last column which should always be the row id
+ .take(batch.columns().len() - 1)
+ .cloned()
+ .collect::>();
+ let mut hashes = vec![0; batch.num_rows()];
+ let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
+ create_hashes(&arrays, &random_state, &mut hashes).unwrap();
+ // As an example, let's assume the weights are 1, 2. Our total weight is 3.
+ //
+ // Our thresholds are [1, 3]
+ // Our modulo output will be 0, 1, or 2.
+ //
+ // thresholds.binary_search(0) => Err(0) => 0
+ // thresholds.binary_search(1) => Ok(0) => 1
+ // thresholds.binary_search(2) => Err(1) => 1
+ let split_ids = hashes
+ .iter()
+ .map(|h| {
+ let h = h % total_weight;
+ let split_id = match thresholds.binary_search(&h) {
+ Ok(i) => (i + 1) as u64,
+ Err(i) => i as u64,
+ };
+ if split_id == thresholds.len() as u64 {
+ // If we're at the last threshold then we discard the row (indicated by setting
+ // the split_id to null)
+ None
+ } else {
+ Some(split_id)
+ }
+ })
+ .collect::>();
+ UInt64Array::from(split_ids)
+ }
+
+ async fn apply_hash(
+ &self,
+ source: SendableRecordBatchStream,
+ weights: &[u64],
+ discard_weight: u64,
+ ) -> Result {
+ let row_id_index = source.schema().fields.len() - 1;
+ let new_schema = Arc::new(Schema::new(vec![
+ source.schema().field(row_id_index).clone(),
+ Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
+ ]));
+
+ let total_weight = weights.iter().sum::() + discard_weight;
+ // Thresholds are the cumulative sum of the weights
+ let mut offset = 0;
+ let thresholds = weights
+ .iter()
+ .map(|w| {
+ let value = offset + w;
+ offset = value;
+ value
+ })
+ .collect::>();
+
+ let new_schema_clone = new_schema.clone();
+ let stream = source.map_ok(move |batch| {
+ let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight);
+
+ if split_ids.null_count() > 0 {
+ let is_valid = split_ids.nulls().unwrap().inner();
+ let is_valid_mask = BooleanArray::new(is_valid.clone(), None);
+ let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap();
+ let row_ids = batch.column(row_id_index);
+ let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap();
+ RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap()
+ } else {
+ RecordBatch::try_new(
+ new_schema.clone(),
+ vec![batch.column(row_id_index).clone(), Arc::new(split_ids)],
+ )
+ .unwrap()
+ }
+ });
+
+ Ok(Box::pin(SimpleRecordBatchStream::new(
+ stream,
+ new_schema_clone,
+ )))
+ }
+
+ pub async fn apply(
+ &self,
+ source: SendableRecordBatchStream,
+ num_rows: u64,
+ ) -> Result {
+ self.strategy.validate(num_rows)?;
+
+ match &self.strategy {
+ // For consistency, even if no-split, we still give a split id column of all 0s
+ SplitStrategy::NoSplit => {
+ self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows]))
+ .await
+ }
+ SplitStrategy::Random { seed, sizes } => {
+ let shuffler = Shuffler::new(ShufflerConfig {
+ seed: *seed,
+ // In this case we are only shuffling row ids so we can use a large max_rows_per_file
+ max_rows_per_file: 10 * 1024 * 1024,
+ temp_dir: self.temp_dir.clone(),
+ clump_size: None,
+ });
+
+ let shuffled = shuffler.shuffle(source, num_rows).await?;
+
+ self.apply_sequential(shuffled, num_rows, sizes).await
+ }
+ SplitStrategy::Sequential { sizes } => {
+ self.apply_sequential(source, num_rows, sizes).await
+ }
+ // Nothing to do, split is calculated in projection
+ SplitStrategy::Calculated { .. } => Ok(source),
+ SplitStrategy::Hash {
+ split_weights,
+ discard_weight,
+ ..
+ } => {
+ self.apply_hash(source, split_weights, *discard_weight)
+ .await
+ }
+ }
+ }
+
+ pub fn project(&self, query: Query) -> Query {
+ match &self.strategy {
+ SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
+ SPLIT_ID_COLUMN.to_string(),
+ calculation.clone(),
+ )])),
+ SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
+ _ => query,
+ }
+ }
+
+ pub fn orders_by_split_id(&self) -> bool {
+ match &self.strategy {
+ SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true,
+ SplitStrategy::NoSplit
+ | SplitStrategy::Sequential { .. }
+ // It may be strange but for random we shuffle and then assign splits so the result is
+ // sorted by split id
+ | SplitStrategy::Random { .. } => false,
+ }
+ }
+}
+
+/// Split configuration - either percentages or absolute counts
+///
+/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows)
+/// the remaining rows will not be included in the permutation.
+///
+/// The default implementation assigns all rows to a single split.
+#[derive(Debug, Clone)]
+pub enum SplitSizes {
+ /// Percentage splits (must sum to <= 1.0)
+ ///
+ /// The number of rows in each split is the nearest integer to the percentage multiplied by
+ /// the total number of rows.
+ Percentages(Vec),
+ /// Absolute row counts per split
+ ///
+ /// If the dataset doesn't contain enough matching rows to fill all splits then an error
+ /// will be raised.
+ Counts(Vec),
+ /// Divides data into a fixed number of splits
+ ///
+ /// Will divide the data evenly.
+ ///
+ /// If the number of rows is not divisible by the number of splits then the rows per split
+ /// is rounded down.
+ Fixed(u64),
+}
+
+impl Default for SplitSizes {
+ fn default() -> Self {
+ Self::Percentages(vec![1.0])
+ }
+}
+
+impl SplitSizes {
+ pub fn validate(&self, num_rows: u64) -> Result<()> {
+ match self {
+ Self::Percentages(percentages) => {
+ for percentage in percentages {
+ if *percentage < 0.0 || *percentage > 1.0 {
+ return Err(Error::InvalidInput {
+ message: "Split percentages must be between 0.0 and 1.0".to_string(),
+ });
+ }
+ if percentage * (num_rows as f64) < 1.0 {
+ return Err(Error::InvalidInput {
+ message: format!(
+ "One of the splits has {}% of {} rows which rounds to 0 rows",
+ percentage * 100.0,
+ num_rows
+ ),
+ });
+ }
+ }
+ if percentages.iter().sum::() > 1.0 {
+ return Err(Error::InvalidInput {
+ message: "Split percentages must sum to 1.0 or less".to_string(),
+ });
+ }
+ }
+ Self::Counts(counts) => {
+ if counts.iter().sum::() > num_rows {
+ return Err(Error::InvalidInput {
+ message: format!(
+ "Split counts specified {} rows but only {} are available",
+ counts.iter().sum::(),
+ num_rows
+ ),
+ });
+ }
+ if counts.iter().any(|c| *c == 0) {
+ return Err(Error::InvalidInput {
+ message: "Split counts must be greater than 0".to_string(),
+ });
+ }
+ }
+ Self::Fixed(num_splits) => {
+ if *num_splits > num_rows {
+ return Err(Error::InvalidInput {
+ message: format!(
+ "Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
+ *num_splits, num_rows
+ ),
+ });
+ }
+ if (num_rows / num_splits) == 0 {
+ return Err(Error::InvalidInput {
+ message: format!(
+ "Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
+ *num_splits, num_rows
+ ),
+ });
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub fn to_counts(&self, num_rows: u64) -> Vec {
+ let sizes = match self {
+ Self::Percentages(percentages) => {
+ let mut percentage_sum = 0.0_f64;
+ let mut counts = percentages
+ .iter()
+ .map(|p| {
+ let count = (p * (num_rows as f64)).round() as u64;
+ percentage_sum += p;
+ count
+ })
+ .collect::>();
+ let sum = counts.iter().sum::();
+
+ let is_basically_one =
+ (num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5;
+
+ // If the sum of percentages is close to 1.0 then rounding errors can add up
+ // to more or less than num_rows
+ //
+ // Drop items from buckets until we have the correct number of rows
+ let mut excess = sum as i64 - num_rows as i64;
+ let mut drop_idx = 0;
+ while excess > 0 {
+ if counts[drop_idx] > 0 {
+ counts[drop_idx] -= 1;
+ excess -= 1;
+ }
+ drop_idx += 1;
+ if drop_idx == counts.len() {
+ drop_idx = 0;
+ }
+ }
+
+ // On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_
+ // rows due to rounding errors
+ let mut add_idx = 0;
+ while is_basically_one && excess < 0 {
+ counts[add_idx] += 1;
+ add_idx += 1;
+ excess += 1;
+ if add_idx == counts.len() {
+ add_idx = 0;
+ }
+ }
+
+ counts
+ }
+ Self::Counts(counts) => counts.clone(),
+ Self::Fixed(num_splits) => {
+ let rows_per_split = num_rows / *num_splits;
+ vec![rows_per_split; *num_splits as usize]
+ }
+ };
+
+ assert!(sizes.iter().sum::() <= num_rows);
+
+ sizes
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::arrow::LanceDbDatagenExt;
+
+ use super::*;
+ use arrow::{
+ array::AsArray,
+ compute::concat_batches,
+ datatypes::{Int32Type, UInt64Type},
+ };
+ use arrow_array::Int32Array;
+ use futures::TryStreamExt;
+ use lance_datagen::{BatchCount, ByteCount, RowCount, Seed};
+ use std::sync::Arc;
+
+ const ID_COLUMN: &str = "id";
+
+ #[test]
+ fn test_split_sizes_percentages_validation() {
+ // Valid percentages
+ let sizes = SplitSizes::Percentages(vec![0.7, 0.3]);
+ assert!(sizes.validate(100).is_ok());
+
+ // Sum > 1.0
+ let sizes = SplitSizes::Percentages(vec![0.7, 0.4]);
+ assert!(sizes.validate(100).is_err());
+
+ // Negative percentage
+ let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]);
+ assert!(sizes.validate(100).is_err());
+
+ // Percentage > 1.0
+ let sizes = SplitSizes::Percentages(vec![1.5]);
+ assert!(sizes.validate(100).is_err());
+
+ // Percentage rounds to 0 rows
+ let sizes = SplitSizes::Percentages(vec![0.001]);
+ assert!(sizes.validate(100).is_err());
+ }
+
+ #[test]
+ fn test_split_sizes_counts_validation() {
+ // Valid counts
+ let sizes = SplitSizes::Counts(vec![30, 70]);
+ assert!(sizes.validate(100).is_ok());
+
+ // Sum > num_rows
+ let sizes = SplitSizes::Counts(vec![60, 50]);
+ assert!(sizes.validate(100).is_err());
+
+ // Counts are 0
+ let sizes = SplitSizes::Counts(vec![0, 100]);
+ assert!(sizes.validate(100).is_err());
+ }
+
+ #[test]
+ fn test_split_sizes_fixed_validation() {
+ // Valid fixed splits
+ let sizes = SplitSizes::Fixed(5);
+ assert!(sizes.validate(100).is_ok());
+
+ // More splits than rows
+ let sizes = SplitSizes::Fixed(150);
+ assert!(sizes.validate(100).is_err());
+ }
+
+ #[test]
+ fn test_split_sizes_to_sizes_percentages() {
+ let sizes = SplitSizes::Percentages(vec![0.3, 0.7]);
+ let result = sizes.to_counts(100);
+ assert_eq!(result, vec![30, 70]);
+
+ // Test rounding
+ let sizes = SplitSizes::Percentages(vec![0.3, 0.41]);
+ let result = sizes.to_counts(70);
+ assert_eq!(result, vec![21, 29]);
+ }
+
+ #[test]
+ fn test_split_sizes_to_sizes_fixed() {
+ let sizes = SplitSizes::Fixed(4);
+ let result = sizes.to_counts(100);
+ assert_eq!(result, vec![25, 25, 25, 25]);
+
+ // Test with remainder
+ let sizes = SplitSizes::Fixed(3);
+ let result = sizes.to_counts(10);
+ assert_eq!(result, vec![3, 3, 3]);
+ }
+
+ fn test_data() -> SendableRecordBatchStream {
+ lance_datagen::gen_batch()
+ .with_seed(Seed::from(42))
+ .col(ID_COLUMN, lance_datagen::array::step::())
+ .into_ldb_stream(RowCount::from(10), BatchCount::from(5))
+ }
+
+ async fn verify_splitter(
+ splitter: Splitter,
+ data: SendableRecordBatchStream,
+ num_rows: u64,
+ expected_split_sizes: &[u64],
+ row_ids_in_order: bool,
+ ) {
+ let split_batches = splitter
+ .apply(data, num_rows)
+ .await
+ .unwrap()
+ .try_collect::>()
+ .await
+ .unwrap();
+
+ let schema = split_batches[0].schema();
+ let split_batch = concat_batches(&schema, &split_batches).unwrap();
+
+ let total_split_sizes = expected_split_sizes.iter().sum::();
+
+ assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
+ let mut expected = Vec::with_capacity(total_split_sizes as usize);
+ for (i, size) in expected_split_sizes.iter().enumerate() {
+ expected.extend(iter::repeat_n(i as u64, *size as usize));
+ }
+ let expected = Arc::new(UInt64Array::from(expected)) as Arc;
+
+ assert_eq!(&expected, split_batch.column(1));
+
+ let expected_row_ids =
+ Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc;
+ if row_ids_in_order {
+ assert_eq!(&expected_row_ids, split_batch.column(0));
+ } else {
+ assert_ne!(&expected_row_ids, split_batch.column(0));
+ }
+ }
+
+ #[tokio::test]
+ async fn test_fixed_sequential_split() {
+ let splitter = Splitter::new(
+ // Sequential splitting doesn't need a temp dir
+ TemporaryDirectory::None,
+ SplitStrategy::Sequential {
+ sizes: SplitSizes::Fixed(3),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await;
+ }
+
+ #[tokio::test]
+ async fn test_fixed_random_split() {
+ let splitter = Splitter::new(
+ TemporaryDirectory::None,
+ SplitStrategy::Random {
+ seed: Some(42),
+ sizes: SplitSizes::Fixed(3),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await;
+ }
+
+ #[tokio::test]
+ async fn test_counts_sequential_split() {
+ let splitter = Splitter::new(
+ // Sequential splitting doesn't need a temp dir
+ TemporaryDirectory::None,
+ SplitStrategy::Sequential {
+ sizes: SplitSizes::Counts(vec![5, 15, 10]),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await;
+ }
+
+ #[tokio::test]
+ async fn test_counts_random_split() {
+ let splitter = Splitter::new(
+ TemporaryDirectory::None,
+ SplitStrategy::Random {
+ seed: Some(42),
+ sizes: SplitSizes::Counts(vec![5, 15, 10]),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await;
+ }
+
+ #[tokio::test]
+ async fn test_percentages_sequential_split() {
+ let splitter = Splitter::new(
+ // Sequential splitting doesn't need a temp dir
+ TemporaryDirectory::None,
+ SplitStrategy::Sequential {
+ sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await;
+ }
+
+ #[tokio::test]
+ async fn test_percentages_random_split() {
+ let splitter = Splitter::new(
+ TemporaryDirectory::None,
+ SplitStrategy::Random {
+ seed: Some(42),
+ sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
+ },
+ );
+
+ verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await;
+ }
+
+ #[tokio::test]
+ async fn test_hash_split() {
+ let data = lance_datagen::gen_batch()
+ .with_seed(Seed::from(42))
+ .col(
+ "hash1",
+ lance_datagen::array::rand_utf8(ByteCount::from(10), false),
+ )
+ .col("hash2", lance_datagen::array::step::())
+ .col(ID_COLUMN, lance_datagen::array::step::())
+ .into_ldb_stream(RowCount::from(10), BatchCount::from(5));
+
+ let splitter = Splitter::new(
+ TemporaryDirectory::None,
+ SplitStrategy::Hash {
+ columns: vec!["hash1".to_string(), "hash2".to_string()],
+ split_weights: vec![1, 2],
+ discard_weight: 1,
+ },
+ );
+
+ let split_batches = splitter
+ .apply(data, 10)
+ .await
+ .unwrap()
+ .try_collect::>()
+ .await
+ .unwrap();
+
+ let schema = split_batches[0].schema();
+ let split_batch = concat_batches(&schema, &split_batches).unwrap();
+
+ // These assertions are all based on fixed seed in data generation but they match
+ // up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1)
+
+ // 14 rows (28%) are discarded because discard_weight is 1
+ assert_eq!(split_batch.num_rows(), 36);
+ assert_eq!(split_batch.num_columns(), 2);
+
+ let split_ids = split_batch.column(1).as_primitive::().values();
+ let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count();
+ let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count();
+
+ assert_eq!(num_in_split_0, 11); // 22%
+ assert_eq!(num_in_split_1, 25); // 50%
+ }
+}
diff --git a/rust/lancedb/src/dataloader/util.rs b/rust/lancedb/src/dataloader/util.rs
new file mode 100644
index 00000000..9e457112
--- /dev/null
+++ b/rust/lancedb/src/dataloader/util.rs
@@ -0,0 +1,98 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+use std::{path::PathBuf, sync::Arc};
+
+use arrow_array::RecordBatch;
+use arrow_schema::{Fields, Schema};
+use datafusion_execution::disk_manager::DiskManagerMode;
+use futures::TryStreamExt;
+use rand::{rngs::SmallRng, RngCore, SeedableRng};
+use tempfile::TempDir;
+
+use crate::{
+ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
+ Error, Result,
+};
+
+/// Directory to use for temporary files
+#[derive(Debug, Clone, Default)]
+pub enum TemporaryDirectory {
+ /// Use the operating system's default temporary directory (e.g. /tmp)
+ #[default]
+ OsDefault,
+ /// Use the specified directory (must be an absolute path)
+ Specific(PathBuf),
+ /// If spilling is required, then error out
+ None,
+}
+
+impl TemporaryDirectory {
+ pub fn create_temp_dir(&self) -> Result {
+ match self {
+ Self::OsDefault => tempfile::tempdir(),
+ Self::Specific(path) => tempfile::Builder::default().tempdir_in(path),
+ Self::None => {
+ return Err(Error::Runtime {
+ message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(),
+ });
+ }
+ }
+ .map_err(|err| Error::Other {
+ message: "Failed to create temporary directory".to_string(),
+ source: Some(err.into()),
+ })
+ }
+
+ pub fn to_disk_manager_mode(&self) -> DiskManagerMode {
+ match self {
+ Self::OsDefault => DiskManagerMode::OsTmpDirectory,
+ Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]),
+ Self::None => DiskManagerMode::Disabled,
+ }
+ }
+}
+
+pub fn non_crypto_rng(seed: &Option) -> Box {
+ Box::new(
+ seed.as_ref()
+ .map(|seed| SmallRng::seed_from_u64(*seed))
+ .unwrap_or_else(SmallRng::from_os_rng),
+ )
+}
+
+pub fn rename_column(
+ stream: SendableRecordBatchStream,
+ old_name: &str,
+ new_name: &str,
+) -> Result {
+ let schema = stream.schema();
+ let field_index = schema.index_of(old_name)?;
+
+ let new_fields = schema
+ .fields
+ .iter()
+ .cloned()
+ .enumerate()
+ .map(|(idx, f)| {
+ if idx == field_index {
+ Arc::new(f.as_ref().clone().with_name(new_name))
+ } else {
+ f
+ }
+ })
+ .collect::();
+ let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone()));
+ let new_schema_clone = new_schema.clone();
+
+ let renamed_stream = stream.and_then(move |batch| {
+ let renamed_batch =
+ RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from);
+ std::future::ready(renamed_batch)
+ });
+
+ Ok(Box::pin(SimpleRecordBatchStream::new(
+ renamed_stream,
+ new_schema_clone,
+ )))
+}
diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs
index 9637cf39..75159e04 100644
--- a/rust/lancedb/src/lib.rs
+++ b/rust/lancedb/src/lib.rs
@@ -194,6 +194,7 @@ pub mod arrow;
pub mod connection;
pub mod data;
pub mod database;
+pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod index;
diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs
index e69fd54e..2f56a592 100644
--- a/rust/lancedb/src/remote/db.rs
+++ b/rust/lancedb/src/remote/db.rs
@@ -16,7 +16,7 @@ use tokio::task::spawn_blocking;
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
- OpenTableRequest, TableNamesRequest,
+ OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::error::Result;
use crate::table::BaseTable;
@@ -189,6 +189,7 @@ struct ListTablesResponse {
pub struct RemoteDatabase {
client: RestfulLanceDbClient,
table_cache: Cache>>,
+ uri: String,
}
impl RemoteDatabase {
@@ -217,6 +218,7 @@ impl RemoteDatabase {
Ok(Self {
client,
table_cache,
+ uri: uri.to_owned(),
})
}
}
@@ -238,6 +240,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
+ uri: "http://localhost".to_string(),
}
}
@@ -250,6 +253,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
+ uri: "http://localhost".to_string(),
}
}
}
@@ -315,6 +319,17 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String {
#[async_trait]
impl Database for RemoteDatabase {
+ fn uri(&self) -> &str {
+ &self.uri
+ }
+
+ async fn read_consistency(&self) -> Result {
+ Err(Error::NotSupported {
+ message: "Getting the read consistency of a remote database is not yet supported"
+ .to_string(),
+ })
+ }
+
async fn table_names(&self, request: TableNamesRequest) -> Result> {
let mut req = if !request.namespace.is_empty() {
let namespace_id =
diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs
index 60c601f1..d632644c 100644
--- a/rust/lancedb/src/table.rs
+++ b/rust/lancedb/src/table.rs
@@ -50,6 +50,7 @@ use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
+use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
@@ -611,9 +612,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
-#[derive(Clone)]
+#[derive(Clone, Debug)]
pub struct Table {
inner: Arc,
+ database: Arc,
embedding_registry: Arc,
}
@@ -631,11 +633,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
- handler,
+ handler.clone(),
None,
));
+ let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
+ database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -651,11 +655,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
- handler,
+ handler.clone(),
Some(version),
));
+ let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
+ database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -670,9 +676,10 @@ impl std::fmt::Display for Table {
}
impl Table {
- pub fn new(inner: Arc) -> Self {
+ pub fn new(inner: Arc, database: Arc) -> Self {
Self {
inner,
+ database,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -681,12 +688,22 @@ impl Table {
&self.inner
}
+ pub fn database(&self) -> &Arc {
+ &self.database
+ }
+
+ pub fn embedding_registry(&self) -> &Arc {
+ &self.embedding_registry
+ }
+
pub(crate) fn new_with_embedding_registry(
inner: Arc,
+ database: Arc,
embedding_registry: Arc,
) -> Self {
Self {
inner,
+ database,
embedding_registry,
}
}
@@ -1416,12 +1433,6 @@ impl Tags for NativeTags {
}
}
-impl From for Table {
- fn from(table: NativeTable) -> Self {
- Self::new(Arc::new(table))
- }
-}
-
pub trait NativeTableExt {
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;