diff --git a/Cargo.lock b/Cargo.lock index b3fc570b..23182c70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4702,6 +4702,7 @@ dependencies = [ name = "lancedb" version = "0.22.2" dependencies = [ + "ahash", "anyhow", "arrow", "arrow-array", @@ -4737,8 +4738,11 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "lance", + "lance-core", "lance-datafusion", + "lance-datagen", "lance-encoding", + "lance-file", "lance-index", "lance-io", "lance-linalg", @@ -4764,6 +4768,7 @@ dependencies = [ "serde_with", "snafu", "tempfile", + "test-log", "tokenizers", "tokio", "url", @@ -8237,6 +8242,28 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 825df5dc..2b1afe8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,9 @@ rust-version = "1.78.0" [workspace.dependencies] lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"] } +lance-core = "=0.38.2" +lance-datagen = "=0.38.2" +lance-file = "=0.38.2" lance-io = { "version" = "=0.38.2", default-features = false } lance-index = "=0.38.2" lance-linalg = "=0.38.2" @@ -24,6 +27,7 @@ lance-testing = "=0.38.2" lance-datafusion = "=0.38.2" lance-encoding = "=0.38.2" lance-namespace = "0.0.18" +ahash = "0.8" # Note that this one does not include pyarrow arrow = { version = "56.2", optional = false } arrow-array = "56.2" @@ -48,6 +52,7 @@ log = "0.4" moka = { version = "0.12", features = ["future"] } object_store = "0.12.0" pin-project = "1.0.7" +rand = "0.9" snafu = "0.8" url = "2" num-traits = "0.2" diff --git a/docs/src/js/classes/PermutationBuilder.md b/docs/src/js/classes/PermutationBuilder.md new file mode 100644 index 00000000..aa2437c9 --- /dev/null +++ b/docs/src/js/classes/PermutationBuilder.md @@ -0,0 +1,220 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / PermutationBuilder + +# Class: PermutationBuilder + +A PermutationBuilder for creating data permutations with splits, shuffling, and filtering. + +This class provides a TypeScript wrapper around the native Rust PermutationBuilder, +offering methods to configure data splits, shuffling, and filtering before executing +the permutation to create a new table. + +## Methods + +### execute() + +```ts +execute(): Promise +``` + +Execute the permutation and create the destination table. + +#### Returns + +`Promise`<[`Table`](Table.md)> + +A Promise that resolves to the new Table instance + +#### Example + +```ts +const permutationTable = await builder.execute(); +console.log(`Created table: ${permutationTable.name}`); +``` + +*** + +### filter() + +```ts +filter(filter): PermutationBuilder +``` + +Configure filtering for the permutation. + +#### Parameters + +* **filter**: `string` + SQL filter expression + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +builder.filter("age > 18 AND status = 'active'"); +``` + +*** + +### shuffle() + +```ts +shuffle(options): PermutationBuilder +``` + +Configure shuffling for the permutation. + +#### Parameters + +* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md) + Configuration for shuffling + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +// Basic shuffle +builder.shuffle({ seed: 42 }); + +// Shuffle with clump size +builder.shuffle({ seed: 42, clumpSize: 10 }); +``` + +*** + +### splitCalculated() + +```ts +splitCalculated(calculation): PermutationBuilder +``` + +Configure calculated splits for the permutation. + +#### Parameters + +* **calculation**: `string` + SQL expression for calculating splits + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +builder.splitCalculated("user_id % 3"); +``` + +*** + +### splitHash() + +```ts +splitHash(options): PermutationBuilder +``` + +Configure hash-based splits for the permutation. + +#### Parameters + +* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md) + Configuration for hash-based splitting + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +builder.splitHash({ + columns: ["user_id"], + splitWeights: [70, 30], + discardWeight: 0 +}); +``` + +*** + +### splitRandom() + +```ts +splitRandom(options): PermutationBuilder +``` + +Configure random splits for the permutation. + +#### Parameters + +* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md) + Configuration for random splitting + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +// Split by ratios +builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 }); + +// Split by counts +builder.splitRandom({ counts: [1000, 500], seed: 42 }); + +// Split with fixed size +builder.splitRandom({ fixed: 100, seed: 42 }); +``` + +*** + +### splitSequential() + +```ts +splitSequential(options): PermutationBuilder +``` + +Configure sequential splits for the permutation. + +#### Parameters + +* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md) + Configuration for sequential splitting + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +// Split by ratios +builder.splitSequential({ ratios: [0.8, 0.2] }); + +// Split by counts +builder.splitSequential({ counts: [800, 200] }); + +// Split with fixed size +builder.splitSequential({ fixed: 1000 }); +``` diff --git a/docs/src/js/functions/permutationBuilder.md b/docs/src/js/functions/permutationBuilder.md new file mode 100644 index 00000000..63226d66 --- /dev/null +++ b/docs/src/js/functions/permutationBuilder.md @@ -0,0 +1,37 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / permutationBuilder + +# Function: permutationBuilder() + +```ts +function permutationBuilder(table, destTableName): PermutationBuilder +``` + +Create a permutation builder for the given table. + +## Parameters + +* **table**: [`Table`](../classes/Table.md) + The source table to create a permutation from + +* **destTableName**: `string` + The name for the destination permutation table + +## Returns + +[`PermutationBuilder`](../classes/PermutationBuilder.md) + +A PermutationBuilder instance + +## Example + +```ts +const builder = permutationBuilder(sourceTable, "training_data") + .splitRandom({ ratios: [0.8, 0.2], seed: 42 }) + .shuffle({ seed: 123 }); + +const trainingTable = await builder.execute(); +``` diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 757e47e9..462e6a99 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -28,6 +28,7 @@ - [MultiMatchQuery](classes/MultiMatchQuery.md) - [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md) - [OAuthHeaderProvider](classes/OAuthHeaderProvider.md) +- [PermutationBuilder](classes/PermutationBuilder.md) - [PhraseQuery](classes/PhraseQuery.md) - [Query](classes/Query.md) - [QueryBase](classes/QueryBase.md) @@ -76,6 +77,10 @@ - [QueryExecutionOptions](interfaces/QueryExecutionOptions.md) - [RemovalStats](interfaces/RemovalStats.md) - [RetryConfig](interfaces/RetryConfig.md) +- [ShuffleOptions](interfaces/ShuffleOptions.md) +- [SplitHashOptions](interfaces/SplitHashOptions.md) +- [SplitRandomOptions](interfaces/SplitRandomOptions.md) +- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md) - [TableNamesOptions](interfaces/TableNamesOptions.md) - [TableStatistics](interfaces/TableStatistics.md) - [TimeoutConfig](interfaces/TimeoutConfig.md) @@ -103,3 +108,4 @@ - [connect](functions/connect.md) - [makeArrowTable](functions/makeArrowTable.md) - [packBits](functions/packBits.md) +- [permutationBuilder](functions/permutationBuilder.md) diff --git a/docs/src/js/interfaces/ShuffleOptions.md b/docs/src/js/interfaces/ShuffleOptions.md new file mode 100644 index 00000000..02d298c7 --- /dev/null +++ b/docs/src/js/interfaces/ShuffleOptions.md @@ -0,0 +1,23 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / ShuffleOptions + +# Interface: ShuffleOptions + +## Properties + +### clumpSize? + +```ts +optional clumpSize: number; +``` + +*** + +### seed? + +```ts +optional seed: number; +``` diff --git a/docs/src/js/interfaces/SplitHashOptions.md b/docs/src/js/interfaces/SplitHashOptions.md new file mode 100644 index 00000000..53cbae8e --- /dev/null +++ b/docs/src/js/interfaces/SplitHashOptions.md @@ -0,0 +1,31 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / SplitHashOptions + +# Interface: SplitHashOptions + +## Properties + +### columns + +```ts +columns: string[]; +``` + +*** + +### discardWeight? + +```ts +optional discardWeight: number; +``` + +*** + +### splitWeights + +```ts +splitWeights: number[]; +``` diff --git a/docs/src/js/interfaces/SplitRandomOptions.md b/docs/src/js/interfaces/SplitRandomOptions.md new file mode 100644 index 00000000..66430b6c --- /dev/null +++ b/docs/src/js/interfaces/SplitRandomOptions.md @@ -0,0 +1,39 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / SplitRandomOptions + +# Interface: SplitRandomOptions + +## Properties + +### counts? + +```ts +optional counts: number[]; +``` + +*** + +### fixed? + +```ts +optional fixed: number; +``` + +*** + +### ratios? + +```ts +optional ratios: number[]; +``` + +*** + +### seed? + +```ts +optional seed: number; +``` diff --git a/docs/src/js/interfaces/SplitSequentialOptions.md b/docs/src/js/interfaces/SplitSequentialOptions.md new file mode 100644 index 00000000..6397c191 --- /dev/null +++ b/docs/src/js/interfaces/SplitSequentialOptions.md @@ -0,0 +1,31 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / SplitSequentialOptions + +# Interface: SplitSequentialOptions + +## Properties + +### counts? + +```ts +optional counts: number[]; +``` + +*** + +### fixed? + +```ts +optional fixed: number; +``` + +*** + +### ratios? + +```ts +optional ratios: number[]; +``` diff --git a/nodejs/__test__/permutation.test.ts b/nodejs/__test__/permutation.test.ts new file mode 100644 index 00000000..be57e5d9 --- /dev/null +++ b/nodejs/__test__/permutation.test.ts @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import * as tmp from "tmp"; +import { Table, connect, permutationBuilder } from "../lancedb"; +import { makeArrowTable } from "../lancedb/arrow"; + +describe("PermutationBuilder", () => { + let tmpDir: tmp.DirResult; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const db = await connect(tmpDir.name); + + // Create test data + const data = makeArrowTable( + [ + { id: 1, value: 10 }, + { id: 2, value: 20 }, + { id: 3, value: 30 }, + { id: 4, value: 40 }, + { id: 5, value: 50 }, + { id: 6, value: 60 }, + { id: 7, value: 70 }, + { id: 8, value: 80 }, + { id: 9, value: 90 }, + { id: 10, value: 100 }, + ], + { vectorColumns: {} }, + ); + + table = await db.createTable("test_table", data); + }); + + afterEach(() => { + tmpDir.removeCallback(); + }); + + test("should create permutation builder", () => { + const builder = permutationBuilder(table, "permutation_table"); + expect(builder).toBeDefined(); + }); + + test("should execute basic permutation", async () => { + const builder = permutationBuilder(table, "permutation_table"); + const permutationTable = await builder.execute(); + + expect(permutationTable).toBeDefined(); + expect(permutationTable.name).toBe("permutation_table"); + + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + }); + + test("should create permutation with random splits", async () => { + const builder = permutationBuilder(table, "permutation_table").splitRandom({ + ratios: [1.0], + seed: 42, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + }); + + test("should create permutation with percentage splits", async () => { + const builder = permutationBuilder(table, "permutation_table").splitRandom({ + ratios: [0.3, 0.7], + seed: 42, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Check split distribution + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should create permutation with count splits", async () => { + const builder = permutationBuilder(table, "permutation_table").splitRandom({ + counts: [3, 7], + seed: 42, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Check split distribution + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBe(3); + expect(split1Count).toBe(7); + }); + + test("should create permutation with hash splits", async () => { + const builder = permutationBuilder(table, "permutation_table").splitHash({ + columns: ["id"], + splitWeights: [50, 50], + discardWeight: 0, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Check that splits exist + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should create permutation with sequential splits", async () => { + const builder = permutationBuilder( + table, + "permutation_table", + ).splitSequential({ ratios: [0.5, 0.5] }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Check split distribution - sequential should give exactly 5 and 5 + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBe(5); + expect(split1Count).toBe(5); + }); + + test("should create permutation with calculated splits", async () => { + const builder = permutationBuilder( + table, + "permutation_table", + ).splitCalculated("id % 2"); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Check split distribution + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should create permutation with shuffle", async () => { + const builder = permutationBuilder(table, "permutation_table").shuffle({ + seed: 42, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + }); + + test("should create permutation with shuffle and clump size", async () => { + const builder = permutationBuilder(table, "permutation_table").shuffle({ + seed: 42, + clumpSize: 2, + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + }); + + test("should create permutation with filter", async () => { + const builder = permutationBuilder(table, "permutation_table").filter( + "value > 50", + ); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100 + }); + + test("should chain multiple operations", async () => { + const builder = permutationBuilder(table, "permutation_table") + .filter("value <= 80") + .splitRandom({ ratios: [0.5, 0.5], seed: 42 }) + .shuffle({ seed: 123 }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80 + + // Check split distribution + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(8); + }); + + test("should throw error for invalid split arguments", () => { + const builder = permutationBuilder(table, "permutation_table"); + + // Test no arguments provided + expect(() => builder.splitRandom({})).toThrow( + "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + ); + + // Test multiple arguments provided + expect(() => + builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }), + ).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided"); + }); + + test("should throw error when builder is consumed", async () => { + const builder = permutationBuilder(table, "permutation_table"); + + // Execute once + await builder.execute(); + + // Should throw error on second execution + await expect(builder.execute()).rejects.toThrow("Builder already consumed"); + }); +}); diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 57069221..3ef1a76f 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -43,6 +43,10 @@ export { DeleteResult, DropColumnsResult, UpdateResult, + SplitRandomOptions, + SplitHashOptions, + SplitSequentialOptions, + ShuffleOptions, } from "./native.js"; export { @@ -111,6 +115,7 @@ export { export { MergeInsertBuilder, WriteExecutionOptions } from "./merge"; export * as embedding from "./embedding"; +export { permutationBuilder, PermutationBuilder } from "./permutation"; export * as rerankers from "./rerankers"; export { SchemaLike, diff --git a/nodejs/lancedb/permutation.ts b/nodejs/lancedb/permutation.ts new file mode 100644 index 00000000..98406505 --- /dev/null +++ b/nodejs/lancedb/permutation.ts @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { + PermutationBuilder as NativePermutationBuilder, + Table as NativeTable, + ShuffleOptions, + SplitHashOptions, + SplitRandomOptions, + SplitSequentialOptions, + permutationBuilder as nativePermutationBuilder, +} from "./native.js"; +import { LocalTable, Table } from "./table"; + +/** + * A PermutationBuilder for creating data permutations with splits, shuffling, and filtering. + * + * This class provides a TypeScript wrapper around the native Rust PermutationBuilder, + * offering methods to configure data splits, shuffling, and filtering before executing + * the permutation to create a new table. + */ +export class PermutationBuilder { + private inner: NativePermutationBuilder; + + /** + * @hidden + */ + constructor(inner: NativePermutationBuilder) { + this.inner = inner; + } + + /** + * Configure random splits for the permutation. + * + * @param options - Configuration for random splitting + * @returns A new PermutationBuilder instance + * @example + * ```ts + * // Split by ratios + * builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 }); + * + * // Split by counts + * builder.splitRandom({ counts: [1000, 500], seed: 42 }); + * + * // Split with fixed size + * builder.splitRandom({ fixed: 100, seed: 42 }); + * ``` + */ + splitRandom(options: SplitRandomOptions): PermutationBuilder { + const newInner = this.inner.splitRandom(options); + return new PermutationBuilder(newInner); + } + + /** + * Configure hash-based splits for the permutation. + * + * @param options - Configuration for hash-based splitting + * @returns A new PermutationBuilder instance + * @example + * ```ts + * builder.splitHash({ + * columns: ["user_id"], + * splitWeights: [70, 30], + * discardWeight: 0 + * }); + * ``` + */ + splitHash(options: SplitHashOptions): PermutationBuilder { + const newInner = this.inner.splitHash(options); + return new PermutationBuilder(newInner); + } + + /** + * Configure sequential splits for the permutation. + * + * @param options - Configuration for sequential splitting + * @returns A new PermutationBuilder instance + * @example + * ```ts + * // Split by ratios + * builder.splitSequential({ ratios: [0.8, 0.2] }); + * + * // Split by counts + * builder.splitSequential({ counts: [800, 200] }); + * + * // Split with fixed size + * builder.splitSequential({ fixed: 1000 }); + * ``` + */ + splitSequential(options: SplitSequentialOptions): PermutationBuilder { + const newInner = this.inner.splitSequential(options); + return new PermutationBuilder(newInner); + } + + /** + * Configure calculated splits for the permutation. + * + * @param calculation - SQL expression for calculating splits + * @returns A new PermutationBuilder instance + * @example + * ```ts + * builder.splitCalculated("user_id % 3"); + * ``` + */ + splitCalculated(calculation: string): PermutationBuilder { + const newInner = this.inner.splitCalculated(calculation); + return new PermutationBuilder(newInner); + } + + /** + * Configure shuffling for the permutation. + * + * @param options - Configuration for shuffling + * @returns A new PermutationBuilder instance + * @example + * ```ts + * // Basic shuffle + * builder.shuffle({ seed: 42 }); + * + * // Shuffle with clump size + * builder.shuffle({ seed: 42, clumpSize: 10 }); + * ``` + */ + shuffle(options: ShuffleOptions): PermutationBuilder { + const newInner = this.inner.shuffle(options); + return new PermutationBuilder(newInner); + } + + /** + * Configure filtering for the permutation. + * + * @param filter - SQL filter expression + * @returns A new PermutationBuilder instance + * @example + * ```ts + * builder.filter("age > 18 AND status = 'active'"); + * ``` + */ + filter(filter: string): PermutationBuilder { + const newInner = this.inner.filter(filter); + return new PermutationBuilder(newInner); + } + + /** + * Execute the permutation and create the destination table. + * + * @returns A Promise that resolves to the new Table instance + * @example + * ```ts + * const permutationTable = await builder.execute(); + * console.log(`Created table: ${permutationTable.name}`); + * ``` + */ + async execute(): Promise
{ + const nativeTable: NativeTable = await this.inner.execute(); + return new LocalTable(nativeTable); + } +} + +/** + * Create a permutation builder for the given table. + * + * @param table - The source table to create a permutation from + * @param destTableName - The name for the destination permutation table + * @returns A PermutationBuilder instance + * @example + * ```ts + * const builder = permutationBuilder(sourceTable, "training_data") + * .splitRandom({ ratios: [0.8, 0.2], seed: 42 }) + * .shuffle({ seed: 123 }); + * + * const trainingTable = await builder.execute(); + * ``` + */ +export function permutationBuilder( + table: Table, + destTableName: string, +): PermutationBuilder { + // Extract the inner native table from the TypeScript wrapper + const localTable = table as LocalTable; + // Access inner through type assertion since it's private + const nativeBuilder = nativePermutationBuilder( + // biome-ignore lint/suspicious/noExplicitAny: need access to private variable + (localTable as any).inner, + destTableName, + ); + return new PermutationBuilder(nativeBuilder); +} diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index e11e9278..df1898e2 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -12,6 +12,7 @@ mod header; mod index; mod iterator; pub mod merge; +pub mod permutation; mod query; pub mod remote; mod rerankers; diff --git a/nodejs/src/permutation.rs b/nodejs/src/permutation.rs new file mode 100644 index 00000000..706a1de7 --- /dev/null +++ b/nodejs/src/permutation.rs @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::{Arc, Mutex}; + +use crate::{error::NapiErrorExt, table::Table}; +use lancedb::dataloader::{ + permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, + split::{SplitSizes, SplitStrategy}, +}; +use napi_derive::napi; + +#[napi(object)] +pub struct SplitRandomOptions { + pub ratios: Option>, + pub counts: Option>, + pub fixed: Option, + pub seed: Option, +} + +#[napi(object)] +pub struct SplitHashOptions { + pub columns: Vec, + pub split_weights: Vec, + pub discard_weight: Option, +} + +#[napi(object)] +pub struct SplitSequentialOptions { + pub ratios: Option>, + pub counts: Option>, + pub fixed: Option, +} + +#[napi(object)] +pub struct ShuffleOptions { + pub seed: Option, + pub clump_size: Option, +} + +pub struct PermutationBuilderState { + pub builder: Option, + pub dest_table_name: String, +} + +#[napi] +pub struct PermutationBuilder { + state: Arc>, +} + +impl PermutationBuilder { + pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self { + Self { + state: Arc::new(Mutex::new(PermutationBuilderState { + builder: Some(builder), + dest_table_name, + })), + } + } +} + +impl PermutationBuilder { + fn modify( + &self, + func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder, + ) -> napi::Result { + let mut state = self.state.lock().unwrap(); + let builder = state + .builder + .take() + .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?; + state.builder = Some(func(builder)); + Ok(Self { + state: self.state.clone(), + }) + } +} + +#[napi] +impl PermutationBuilder { + /// Configure random splits + #[napi] + pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result { + // Check that exactly one split type is provided + let split_args_count = [ + options.ratios.is_some(), + options.counts.is_some(), + options.fixed.is_some(), + ] + .iter() + .filter(|&&x| x) + .count(); + + if split_args_count != 1 { + return Err(napi::Error::from_reason( + "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + )); + } + + let sizes = if let Some(ratios) = options.ratios { + SplitSizes::Percentages(ratios) + } else if let Some(counts) = options.counts { + SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect()) + } else if let Some(fixed) = options.fixed { + SplitSizes::Fixed(fixed as u64) + } else { + unreachable!("One of the split arguments must be provided"); + }; + + let seed = options.seed.map(|s| s as u64); + + self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes })) + } + + /// Configure hash-based splits + #[napi] + pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result { + let split_weights = options + .split_weights + .into_iter() + .map(|w| w as u64) + .collect(); + let discard_weight = options.discard_weight.unwrap_or(0) as u64; + + self.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Hash { + columns: options.columns, + split_weights, + discard_weight, + }) + }) + } + + /// Configure sequential splits + #[napi] + pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result { + // Check that exactly one split type is provided + let split_args_count = [ + options.ratios.is_some(), + options.counts.is_some(), + options.fixed.is_some(), + ] + .iter() + .filter(|&&x| x) + .count(); + + if split_args_count != 1 { + return Err(napi::Error::from_reason( + "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + )); + } + + let sizes = if let Some(ratios) = options.ratios { + SplitSizes::Percentages(ratios) + } else if let Some(counts) = options.counts { + SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect()) + } else if let Some(fixed) = options.fixed { + SplitSizes::Fixed(fixed as u64) + } else { + unreachable!("One of the split arguments must be provided"); + }; + + self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes })) + } + + /// Configure calculated splits + #[napi] + pub fn split_calculated(&self, calculation: String) -> napi::Result { + self.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Calculated { calculation }) + }) + } + + /// Configure shuffling + #[napi] + pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result { + let seed = options.seed.map(|s| s as u64); + let clump_size = options.clump_size.map(|c| c as u64); + + self.modify(|builder| { + builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size }) + }) + } + + /// Configure filtering + #[napi] + pub fn filter(&self, filter: String) -> napi::Result { + self.modify(|builder| builder.with_filter(filter)) + } + + /// Execute the permutation builder and create the table + #[napi] + pub async fn execute(&self) -> napi::Result
{ + let (builder, dest_table_name) = { + let mut state = self.state.lock().unwrap(); + let builder = state + .builder + .take() + .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?; + + let dest_table_name = std::mem::take(&mut state.dest_table_name); + (builder, dest_table_name) + }; + + let table = builder.build(&dest_table_name).await.default_error()?; + Ok(Table::new(table)) + } +} + +/// Create a permutation builder for the given table +#[napi] +pub fn permutation_builder( + table: &crate::table::Table, + dest_table_name: String, +) -> napi::Result { + use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder; + + let inner_table = table.inner_ref()?.clone(); + let inner_builder = LancePermutationBuilder::new(inner_table); + + Ok(PermutationBuilder::new(inner_builder, dest_table_name)) +} diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 1272b95c..b1f037fe 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -26,7 +26,7 @@ pub struct Table { } impl Table { - fn inner_ref(&self) -> napi::Result<&LanceDbTable> { + pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> { self.inner .as_ref() .ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name))) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 68900a92..378c4a09 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -296,3 +296,34 @@ class AlterColumnsResult: class DropColumnsResult: version: int + +class AsyncPermutationBuilder: + def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ... + def split_random( + self, + *, + ratios: Optional[List[float]] = None, + counts: Optional[List[int]] = None, + fixed: Optional[int] = None, + seed: Optional[int] = None, + ) -> "AsyncPermutationBuilder": ... + def split_hash( + self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0 + ) -> "AsyncPermutationBuilder": ... + def split_sequential( + self, + *, + ratios: Optional[List[float]] = None, + counts: Optional[List[int]] = None, + fixed: Optional[int] = None, + ) -> "AsyncPermutationBuilder": ... + def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ... + def shuffle( + self, seed: Optional[int], clump_size: Optional[int] + ) -> "AsyncPermutationBuilder": ... + def filter(self, filter: str) -> "AsyncPermutationBuilder": ... + async def execute(self) -> Table: ... + +def async_permutation_builder( + table: Table, dest_table_name: str +) -> AsyncPermutationBuilder: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 5d83c89e..dac5b1cd 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import abstractmethod +from datetime import timedelta from pathlib import Path import sys from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union @@ -40,7 +41,6 @@ import deprecation if TYPE_CHECKING: import pyarrow as pa from .pydantic import LanceModel - from datetime import timedelta from ._lancedb import Connection as LanceDbConnection from .common import DATA, URI @@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection): read_consistency_interval: Optional[timedelta] = None, storage_options: Optional[Dict[str, str]] = None, session: Optional[Session] = None, + _inner: Optional[LanceDbConnection] = None, ): + if _inner is not None: + self._conn = _inner + return + if not isinstance(uri, Path): scheme = get_uri_scheme(uri) is_local = isinstance(uri, Path) or scheme == "file" @@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection): uri = Path(uri) uri = uri.expanduser().absolute() Path(uri).mkdir(parents=True, exist_ok=True) - self._uri = str(uri) - self._entered = False - self.read_consistency_interval = read_consistency_interval - self.storage_options = storage_options - self.session = session if read_consistency_interval is not None: read_consistency_interval_secs = read_consistency_interval.total_seconds() @@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection): session, ) + # TODO: It would be nice if we didn't store self.storage_options but it is + # currently used by the LanceTable.to_lance method. This doesn't _really_ + # work because some paths like LanceDBConnection.from_inner will lose the + # storage_options. Also, this class really shouldn't be holding any state + # beyond _conn. + self.storage_options = storage_options self._conn = AsyncConnection(LOOP.run(do_connect())) + @property + def read_consistency_interval(self) -> Optional[timedelta]: + return LOOP.run(self._conn.get_read_consistency_interval()) + + @property + def session(self) -> Optional[Session]: + return self._conn.session + + @property + def uri(self) -> str: + return self._conn.uri + + @classmethod + def from_inner(cls, inner: LanceDbConnection): + return cls(None, _inner=inner) + def __repr__(self) -> str: - val = f"{self.__class__.__name__}(uri={self._uri!r}" + val = f"{self.__class__.__name__}(uri={self._conn.uri!r}" if self.read_consistency_interval is not None: val += f", read_consistency_interval={repr(self.read_consistency_interval)}" val += ")" @@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection): conn = AsyncConnection(await lancedb_connect(self.uri)) return await conn.table_names(start_after=start_after, limit=limit) + @property + def _inner(self) -> LanceDbConnection: + return self._conn._inner + @override def list_namespaces( self, @@ -856,6 +882,13 @@ class AsyncConnection(object): def uri(self) -> str: return self._inner.uri + async def get_read_consistency_interval(self) -> Optional[timedelta]: + interval_secs = await self._inner.get_read_consistency_interval() + if interval_secs is not None: + return timedelta(seconds=interval_secs) + else: + return None + async def list_namespaces( self, namespace: List[str] = [], diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py new file mode 100644 index 00000000..bd8aa610 --- /dev/null +++ b/python/python/lancedb/permutation.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from ._lancedb import async_permutation_builder +from .table import LanceTable +from .background_loop import LOOP +from typing import Optional + + +class PermutationBuilder: + def __init__(self, table: LanceTable, dest_table_name: str): + self._async = async_permutation_builder(table, dest_table_name) + + def select(self, projections: dict[str, str]) -> "PermutationBuilder": + self._async.select(projections) + return self + + def split_random( + self, + *, + ratios: Optional[list[float]] = None, + counts: Optional[list[int]] = None, + fixed: Optional[int] = None, + seed: Optional[int] = None, + ) -> "PermutationBuilder": + self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed) + return self + + def split_hash( + self, + columns: list[str], + split_weights: list[int], + *, + discard_weight: Optional[int] = None, + ) -> "PermutationBuilder": + self._async.split_hash(columns, split_weights, discard_weight=discard_weight) + return self + + def split_sequential( + self, + *, + ratios: Optional[list[float]] = None, + counts: Optional[list[int]] = None, + fixed: Optional[int] = None, + ) -> "PermutationBuilder": + self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed) + return self + + def split_calculated(self, calculation: str) -> "PermutationBuilder": + self._async.split_calculated(calculation) + return self + + def shuffle( + self, *, seed: Optional[int] = None, clump_size: Optional[int] = None + ) -> "PermutationBuilder": + self._async.shuffle(seed=seed, clump_size=clump_size) + return self + + def filter(self, filter: str) -> "PermutationBuilder": + self._async.filter(filter) + return self + + def execute(self) -> LanceTable: + async def do_execute(): + inner_tbl = await self._async.execute() + return LanceTable.from_inner(inner_tbl) + + return LOOP.run(do_execute()) + + +def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder: + return PermutationBuilder(table, dest_table_name) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 7ee0bf01..6749064d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -74,6 +74,7 @@ from .index import lang_mapping if TYPE_CHECKING: + from .db import LanceDBConnection from ._lancedb import ( Table as LanceDBTable, OptimizeStats, @@ -88,7 +89,6 @@ if TYPE_CHECKING: MergeResult, UpdateResult, ) - from .db import LanceDBConnection from .index import IndexConfig import pandas import PIL @@ -1707,22 +1707,38 @@ class LanceTable(Table): namespace: List[str] = [], storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + _async: AsyncTable = None, ): self._conn = connection self._namespace = namespace - self._table = LOOP.run( - connection._conn.open_table( - name, - namespace=namespace, - storage_options=storage_options, - index_cache_size=index_cache_size, + if _async is not None: + self._table = _async + else: + self._table = LOOP.run( + connection._conn.open_table( + name, + namespace=namespace, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) ) - ) @property def name(self) -> str: return self._table.name + @classmethod + def from_inner(cls, tbl: LanceDBTable): + from .db import LanceDBConnection + + async_tbl = AsyncTable(tbl) + conn = LanceDBConnection.from_inner(tbl.database()) + return cls( + conn, + async_tbl.name, + _async=async_tbl, + ) + @classmethod def open(cls, db, name, *, namespace: List[str] = [], **kwargs): tbl = cls(db, name, namespace=namespace, **kwargs) @@ -2756,6 +2772,10 @@ class LanceTable(Table): self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) ) + @property + def _inner(self) -> LanceDBTable: + return self._table._inner + @deprecation.deprecated( deprecated_in="0.21.0", current_version=__version__, diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py new file mode 100644 index 00000000..95cd21c0 --- /dev/null +++ b/python/python/tests/test_permutation.py @@ -0,0 +1,496 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import pyarrow as pa +import pytest + +from lancedb.permutation import permutation_builder + + +def test_split_random_ratios(mem_db): + """Test random splitting with ratios.""" + tbl = mem_db.create_table( + "test_table", pa.table({"x": range(100), "y": range(100)}) + ) + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_random(ratios=[0.3, 0.7]) + .execute() + ) + + # Check that the table was created and has data + assert permutation_tbl.count_rows() == 100 + + # Check that split_id column exists and has correct values + data = permutation_tbl.search(None).to_arrow().to_pydict() + split_ids = data["split_id"] + assert set(split_ids) == {0, 1} + + # Check approximate split sizes (allowing for rounding) + split_0_count = split_ids.count(0) + split_1_count = split_ids.count(1) + assert 25 <= split_0_count <= 35 # ~30% ± tolerance + assert 65 <= split_1_count <= 75 # ~70% ± tolerance + + +def test_split_random_counts(mem_db): + """Test random splitting with absolute counts.""" + tbl = mem_db.create_table( + "test_table", pa.table({"x": range(100), "y": range(100)}) + ) + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_random(counts=[20, 30]) + .execute() + ) + + # Check that we have exactly the requested counts + assert permutation_tbl.count_rows() == 50 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + split_ids = data["split_id"] + assert split_ids.count(0) == 20 + assert split_ids.count(1) == 30 + + +def test_split_random_fixed(mem_db): + """Test random splitting with fixed number of splits.""" + tbl = mem_db.create_table( + "test_table", pa.table({"x": range(100), "y": range(100)}) + ) + permutation_tbl = ( + permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute() + ) + + # Check that we have 4 splits with 25 rows each + assert permutation_tbl.count_rows() == 100 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + split_ids = data["split_id"] + assert set(split_ids) == {0, 1, 2, 3} + + for split_id in range(4): + assert split_ids.count(split_id) == 25 + + +def test_split_random_with_seed(mem_db): + """Test that seeded random splits are reproducible.""" + tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)})) + + # Create two identical permutations with same seed + perm1 = ( + permutation_builder(tbl, "perm1") + .split_random(ratios=[0.6, 0.4], seed=42) + .execute() + ) + + perm2 = ( + permutation_builder(tbl, "perm2") + .split_random(ratios=[0.6, 0.4], seed=42) + .execute() + ) + + # Results should be identical + data1 = perm1.search(None).to_arrow().to_pydict() + data2 = perm2.search(None).to_arrow().to_pydict() + + assert data1["row_id"] == data2["row_id"] + assert data1["split_id"] == data2["split_id"] + + +def test_split_hash(mem_db): + """Test hash-based splitting.""" + tbl = mem_db.create_table( + "test_table", + pa.table( + { + "id": range(100), + "category": (["A", "B", "C"] * 34)[:100], # Repeating pattern + "value": range(100), + } + ), + ) + + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_hash(["category"], [1, 1], discard_weight=0) + .execute() + ) + + # Should have all 100 rows (no discard) + assert permutation_tbl.count_rows() == 100 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + split_ids = data["split_id"] + assert set(split_ids) == {0, 1} + + # Verify that each split has roughly 50 rows (allowing for hash variance) + split_0_count = split_ids.count(0) + split_1_count = split_ids.count(1) + assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution + assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution + + # Hash splits should be deterministic - same category should go to same split + # Let's verify by creating another permutation and checking consistency + perm2 = ( + permutation_builder(tbl, "test_permutation2") + .split_hash(["category"], [1, 1], discard_weight=0) + .execute() + ) + + data2 = perm2.search(None).to_arrow().to_pydict() + assert data["split_id"] == data2["split_id"] # Should be identical + + +def test_split_hash_with_discard(mem_db): + """Test hash-based splitting with discard weight.""" + tbl = mem_db.create_table( + "test_table", + pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}), + ) + + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50% + .execute() + ) + + # Should have fewer than 100 rows due to discard + row_count = permutation_tbl.count_rows() + assert row_count < 100 + assert row_count > 0 # But not empty + + +def test_split_sequential(mem_db): + """Test sequential splitting.""" + tbl = mem_db.create_table( + "test_table", pa.table({"x": range(100), "y": range(100)}) + ) + + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_sequential(counts=[30, 40]) + .execute() + ) + + assert permutation_tbl.count_rows() == 70 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + split_ids = data["split_id"] + + # Sequential should maintain order + assert row_ids == sorted(row_ids) + + # First 30 should be split 0, next 40 should be split 1 + assert split_ids[:30] == [0] * 30 + assert split_ids[30:] == [1] * 40 + + +def test_split_calculated(mem_db): + """Test calculated splitting.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(100), "value": range(100)}) + ) + + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_calculated("id % 3") # Split based on id modulo 3 + .execute() + ) + + assert permutation_tbl.count_rows() == 100 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + split_ids = data["split_id"] + + # Verify the calculation: each row's split_id should equal row_id % 3 + for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)): + assert split_id == row_id % 3 + + +def test_split_error_cases(mem_db): + """Test error handling for invalid split parameters.""" + tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)})) + + # Test split_random with no parameters + with pytest.raises(Exception): + permutation_builder(tbl, "error1").split_random().execute() + + # Test split_random with multiple parameters + with pytest.raises(Exception): + permutation_builder(tbl, "error2").split_random( + ratios=[0.5, 0.5], counts=[5, 5] + ).execute() + + # Test split_sequential with no parameters + with pytest.raises(Exception): + permutation_builder(tbl, "error3").split_sequential().execute() + + # Test split_sequential with multiple parameters + with pytest.raises(Exception): + permutation_builder(tbl, "error4").split_sequential( + ratios=[0.5, 0.5], fixed=2 + ).execute() + + +def test_shuffle_no_seed(mem_db): + """Test shuffling without a seed.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(100), "value": range(100)}) + ) + + # Create a permutation with shuffling (no seed) + permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute() + + assert permutation_tbl.count_rows() == 100 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + + # Row IDs should not be in sequential order due to shuffling + # This is probabilistic but with 100 rows, it's extremely unlikely they'd stay + # in order + assert row_ids != list(range(100)) + + +def test_shuffle_with_seed(mem_db): + """Test that shuffling with a seed is reproducible.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(50), "value": range(50)}) + ) + + # Create two identical permutations with same shuffle seed + perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute() + + perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute() + + # Results should be identical due to same seed + data1 = perm1.search(None).to_arrow().to_pydict() + data2 = perm2.search(None).to_arrow().to_pydict() + + assert data1["row_id"] == data2["row_id"] + assert data1["split_id"] == data2["split_id"] + + +def test_shuffle_with_clump_size(mem_db): + """Test shuffling with clump size.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(100), "value": range(100)}) + ) + + # Create a permutation with shuffling using clumps + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .shuffle(clump_size=10) # 10-row clumps + .execute() + ) + + assert permutation_tbl.count_rows() == 100 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + + for i in range(10): + start = row_ids[i * 10] + assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10)) + + +def test_shuffle_different_seeds(mem_db): + """Test that different seeds produce different shuffle orders.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(50), "value": range(50)}) + ) + + # Create two permutations with different shuffle seeds + perm1 = ( + permutation_builder(tbl, "perm1") + .split_random(fixed=2) + .shuffle(seed=42) + .execute() + ) + + perm2 = ( + permutation_builder(tbl, "perm2") + .split_random(fixed=2) + .shuffle(seed=123) + .execute() + ) + + # Results should be different due to different seeds + data1 = perm1.search(None).to_arrow().to_pydict() + data2 = perm2.search(None).to_arrow().to_pydict() + + # Row order should be different + assert data1["row_id"] != data2["row_id"] + + +def test_shuffle_combined_with_splits(mem_db): + """Test shuffling combined with different split strategies.""" + tbl = mem_db.create_table( + "test_table", + pa.table( + { + "id": range(100), + "category": (["A", "B", "C"] * 34)[:100], + "value": range(100), + } + ), + ) + + # Test shuffle with random splits + perm_random = ( + permutation_builder(tbl, "perm_random") + .split_random(ratios=[0.6, 0.4], seed=42) + .shuffle(seed=123, clump_size=None) + .execute() + ) + + # Test shuffle with hash splits + perm_hash = ( + permutation_builder(tbl, "perm_hash") + .split_hash(["category"], [1, 1], discard_weight=0) + .shuffle(seed=456, clump_size=5) + .execute() + ) + + # Test shuffle with sequential splits + perm_sequential = ( + permutation_builder(tbl, "perm_sequential") + .split_sequential(counts=[40, 35]) + .shuffle(seed=789, clump_size=None) + .execute() + ) + + # Verify all permutations work and have expected properties + assert perm_random.count_rows() == 100 + assert perm_hash.count_rows() == 100 + assert perm_sequential.count_rows() == 75 + + # Verify shuffle affected the order + data_random = perm_random.search(None).to_arrow().to_pydict() + data_sequential = perm_sequential.search(None).to_arrow().to_pydict() + + assert data_random["row_id"] != list(range(100)) + assert data_sequential["row_id"] != list(range(75)) + + +def test_no_shuffle_maintains_order(mem_db): + """Test that not calling shuffle maintains the original order.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(50), "value": range(50)}) + ) + + # Create permutation without shuffle (should maintain some order) + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .split_sequential(counts=[25, 25]) # Sequential maintains order + .execute() + ) + + assert permutation_tbl.count_rows() == 50 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + + # With sequential splits and no shuffle, should maintain order + assert row_ids == list(range(50)) + + +def test_filter_basic(mem_db): + """Test basic filtering functionality.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(100), "value": range(100, 200)}) + ) + + # Filter to only include rows where id < 50 + permutation_tbl = ( + permutation_builder(tbl, "test_permutation").filter("id < 50").execute() + ) + + assert permutation_tbl.count_rows() == 50 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + + # All row_ids should be less than 50 + assert all(row_id < 50 for row_id in row_ids) + + +def test_filter_with_splits(mem_db): + """Test filtering combined with split strategies.""" + tbl = mem_db.create_table( + "test_table", + pa.table( + { + "id": range(100), + "category": (["A", "B", "C"] * 34)[:100], + "value": range(100), + } + ), + ) + + # Filter to only category A and B, then split + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .filter("category IN ('A', 'B')") + .split_random(ratios=[0.5, 0.5]) + .execute() + ) + + # Should have fewer than 100 rows due to filtering + row_count = permutation_tbl.count_rows() + assert row_count == 67 + + data = permutation_tbl.search(None).to_arrow().to_pydict() + categories = data["category"] + + # All categories should be A or B + assert all(cat in ["A", "B"] for cat in categories) + + +def test_filter_with_shuffle(mem_db): + """Test filtering combined with shuffling.""" + tbl = mem_db.create_table( + "test_table", + pa.table( + { + "id": range(100), + "category": (["A", "B", "C", "D"] * 25)[:100], + "value": range(100), + } + ), + ) + + # Filter and shuffle + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .filter("category IN ('A', 'C')") + .shuffle(seed=42) + .execute() + ) + + row_count = permutation_tbl.count_rows() + assert row_count == 50 # Should have 50 rows (A and C categories) + + data = permutation_tbl.search(None).to_arrow().to_pydict() + row_ids = data["row_id"] + + assert row_ids != sorted(row_ids) + + +def test_filter_empty_result(mem_db): + """Test filtering that results in empty set.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + + # Filter that matches nothing + permutation_tbl = ( + permutation_builder(tbl, "test_permutation") + .filter("value > 100") # No values > 100 in our data + .execute() + ) + + assert permutation_tbl.count_rows() == 0 diff --git a/python/src/connection.rs b/python/src/connection.rs index 13782dca..67553074 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -4,7 +4,10 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; -use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode}; +use lancedb::{ + connection::Connection as LanceConnection, + database::{CreateTableMode, ReadConsistency}, +}; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, @@ -23,7 +26,7 @@ impl Connection { Self { inner: Some(inner) } } - fn get_inner(&self) -> PyResult<&LanceConnection> { + pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> { self.inner .as_ref() .ok_or_else(|| PyRuntimeError::new_err("Connection is closed")) @@ -63,6 +66,18 @@ impl Connection { self.get_inner().map(|inner| inner.uri().to_string()) } + #[pyo3(signature = ())] + pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.get_inner()?.clone(); + future_into_py(self_.py(), async move { + Ok(match inner.read_consistency().await.infer_error()? { + ReadConsistency::Manual => None, + ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()), + ReadConsistency::Strong => Some(0.0_f64), + }) + }) + } + #[pyo3(signature = (namespace=vec![], start_after=None, limit=None))] pub fn table_names( self_: PyRef<'_, Self>, diff --git a/python/src/lib.rs b/python/src/lib.rs index 636c176f..6f20c99e 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -5,6 +5,7 @@ use arrow::RecordBatchStream; use connection::{connect, Connection}; use env_logger::Env; use index::IndexConfig; +use permutation::PyAsyncPermutationBuilder; use pyo3::{ pymodule, types::{PyModule, PyModuleMethods}, @@ -22,6 +23,7 @@ pub mod connection; pub mod error; pub mod header; pub mod index; +pub mod permutation; pub mod query; pub mod session; pub mod table; @@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; + m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?; m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) diff --git a/python/src/permutation.rs b/python/src/permutation.rs new file mode 100644 index 00000000..38da3f82 --- /dev/null +++ b/python/src/permutation.rs @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::{Arc, Mutex}; + +use crate::{error::PythonErrorExt, table::Table}; +use lancedb::dataloader::{ + permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, + split::{SplitSizes, SplitStrategy}, +}; +use pyo3::{ + exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut, + PyResult, +}; +use pyo3_async_runtimes::tokio::future_into_py; + +/// Create a permutation builder for the given table +#[pyo3::pyfunction] +pub fn async_permutation_builder( + table: Bound<'_, PyAny>, + dest_table_name: String, +) -> PyResult { + let table = table.getattr("_inner")?.downcast_into::
()?; + let inner_table = table.borrow().inner_ref()?.clone(); + let inner_builder = LancePermutationBuilder::new(inner_table); + + Ok(PyAsyncPermutationBuilder { + state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState { + builder: Some(inner_builder), + dest_table_name, + })), + }) +} + +struct PyAsyncPermutationBuilderState { + builder: Option, + dest_table_name: String, +} + +#[pyclass(name = "AsyncPermutationBuilder")] +pub struct PyAsyncPermutationBuilder { + state: Arc>, +} + +impl PyAsyncPermutationBuilder { + fn modify( + &self, + func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder, + ) -> PyResult { + let mut state = self.state.lock().unwrap(); + let builder = state + .builder + .take() + .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; + state.builder = Some(func(builder)); + Ok(Self { + state: self.state.clone(), + }) + } +} + +#[pymethods] +impl PyAsyncPermutationBuilder { + #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))] + pub fn split_random( + slf: PyRefMut<'_, Self>, + ratios: Option>, + counts: Option>, + fixed: Option, + seed: Option, + ) -> PyResult { + // Check that exactly one split type is provided + let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] + .iter() + .filter(|&&x| x) + .count(); + + if split_args_count != 1 { + return Err(pyo3::exceptions::PyValueError::new_err( + "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + )); + } + + let sizes = if let Some(ratios) = ratios { + SplitSizes::Percentages(ratios) + } else if let Some(counts) = counts { + SplitSizes::Counts(counts) + } else if let Some(fixed) = fixed { + SplitSizes::Fixed(fixed) + } else { + unreachable!("One of the split arguments must be provided"); + }; + + slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes })) + } + + #[pyo3(signature = (columns, split_weights, *, discard_weight=0))] + pub fn split_hash( + slf: PyRefMut<'_, Self>, + columns: Vec, + split_weights: Vec, + discard_weight: u64, + ) -> PyResult { + slf.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Hash { + columns, + split_weights, + discard_weight, + }) + }) + } + + #[pyo3(signature = (*, ratios=None, counts=None, fixed=None))] + pub fn split_sequential( + slf: PyRefMut<'_, Self>, + ratios: Option>, + counts: Option>, + fixed: Option, + ) -> PyResult { + // Check that exactly one split type is provided + let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] + .iter() + .filter(|&&x| x) + .count(); + + if split_args_count != 1 { + return Err(pyo3::exceptions::PyValueError::new_err( + "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + )); + } + + let sizes = if let Some(ratios) = ratios { + SplitSizes::Percentages(ratios) + } else if let Some(counts) = counts { + SplitSizes::Counts(counts) + } else if let Some(fixed) = fixed { + SplitSizes::Fixed(fixed) + } else { + unreachable!("One of the split arguments must be provided"); + }; + + slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes })) + } + + pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult { + slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation })) + } + + pub fn shuffle( + slf: PyRefMut<'_, Self>, + seed: Option, + clump_size: Option, + ) -> PyResult { + slf.modify(|builder| { + builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size }) + }) + } + + pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult { + slf.modify(|builder| builder.with_filter(filter)) + } + + pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult> { + let mut state = slf.state.lock().unwrap(); + let builder = state + .builder + .take() + .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; + + let dest_table_name = std::mem::take(&mut state.dest_table_name); + + future_into_py(slf.py(), async move { + let table = builder.build(&dest_table_name).await.infer_error()?; + Ok(Table::new(table)) + }) + } +} diff --git a/python/src/table.rs b/python/src/table.rs index f9f7f995..2097909a 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ + connection::Connection, error::PythonErrorExt, index::{extract_index_params, IndexConfig}, query::{Query, TakeQuery}, @@ -249,7 +250,7 @@ impl Table { } impl Table { - fn inner_ref(&self) -> PyResult<&LanceDbTable> { + pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> { self.inner .as_ref() .ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name))) @@ -272,6 +273,13 @@ impl Table { self.inner.take(); } + pub fn database(&self) -> PyResult { + let inner = self.inner_ref()?.clone(); + let inner_connection = + lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone()); + Ok(Connection::new(inner_connection)) + } + pub fn schema(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 97934a34..aee78336 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ahash = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-data = { workspace = true } @@ -24,12 +25,16 @@ datafusion-common.workspace = true datafusion-execution.workspace = true datafusion-expr.workspace = true datafusion-physical-plan.workspace = true +datafusion.workspace = true object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } lazy_static.workspace = true lance = { workspace = true } +lance-core = { workspace = true } lance-datafusion.workspace = true +lance-datagen = { workspace = true } +lance-file = { workspace = true } lance-io = { workspace = true } lance-index = { workspace = true } lance-table = { workspace = true } @@ -46,11 +51,13 @@ bytes = "1" futures.workspace = true num-traits.workspace = true url.workspace = true +rand.workspace = true regex.workspace = true serde = { version = "^1" } serde_json = { version = "1" } async-openai = { version = "0.20.0", optional = true } serde_with = { version = "3.8.1" } +tempfile = "3.5.0" aws-sdk-bedrockruntime = { version = "1.27.0", optional = true } # For remote feature reqwest = { version = "0.12.0", default-features = false, features = [ @@ -61,9 +68,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [ "macos-system-configuration", "stream", ], optional = true } -rand = { version = "0.9", features = ["small_rng"], optional = true } http = { version = "1", optional = true } # Matching what is in reqwest -uuid = { version = "1.7.0", features = ["v4"], optional = true } +uuid = { version = "1.7.0", features = ["v4"] } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [ @@ -84,7 +90,6 @@ bytemuck_derive.workspace = true [dev-dependencies] anyhow = "1" tempfile = "3.5.0" -rand = { version = "0.9", features = ["small_rng"] } random_word = { version = "0.4.3", features = ["en"] } uuid = { version = "1.7.0", features = ["v4"] } walkdir = "2" @@ -96,6 +101,7 @@ aws-smithy-runtime = { version = "1.9.1" } datafusion.workspace = true http-body = "1" # Matching reqwest rstest = "0.23.0" +test-log = "0.2" [features] @@ -105,7 +111,7 @@ oss = ["lance/oss", "lance-io/oss"] gcs = ["lance/gcp", "lance-io/gcp"] azure = ["lance/azure", "lance-io/azure"] dynamodb = ["lance/dynamodb", "aws"] -remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"] +remote = ["dep:reqwest", "dep:http"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] bedrock = ["dep:aws-sdk-bedrockruntime"] diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index 133f5d1f..df6e51a3 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -7,6 +7,7 @@ pub use arrow_schema; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use futures::{Stream, StreamExt, TryStreamExt}; +use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount}; #[cfg(feature = "polars")] use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame}; @@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream { } } +pub trait LanceDbDatagenExt { + fn into_ldb_stream( + self, + batch_size: RowCount, + num_batches: BatchCount, + ) -> SendableRecordBatchStream; +} + +impl LanceDbDatagenExt for BatchGeneratorBuilder { + fn into_ldb_stream( + self, + batch_size: RowCount, + num_batches: BatchCount, + ) -> SendableRecordBatchStream { + let (stream, schema) = self.into_reader_stream(batch_size, num_batches); + let stream = stream.map_err(|err| Error::Arrow { source: err }); + Box::pin(SimpleRecordBatchStream::new(stream, schema)) + } +} + #[cfg(feature = "polars")] /// An iterator of record batches formed from a Polars DataFrame. pub struct PolarsDataFrameRecordBatchReader { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 66e24161..c8e3496d 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -19,7 +19,7 @@ use crate::database::listing::{ use crate::database::{ CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest, - OpenTableRequest, TableNamesRequest, + OpenTableRequest, ReadConsistency, TableNamesRequest, }; use crate::embeddings::{ EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings, @@ -152,6 +152,7 @@ impl CreateTableBuilder { let request = self.into_request()?; Ok(Table::new_with_embedding_registry( parent.create_table(request).await?, + parent, embedding_registry, )) } @@ -211,9 +212,9 @@ impl CreateTableBuilder { /// Execute the create table operation pub async fn execute(self) -> Result
{ - Ok(Table::new( - self.parent.clone().create_table(self.request).await?, - )) + let parent = self.parent.clone(); + let table = parent.create_table(self.request).await?; + Ok(Table::new(table, parent)) } } @@ -462,8 +463,10 @@ impl OpenTableBuilder { /// Open the table pub async fn execute(self) -> Result
{ + let table = self.parent.open_table(self.request).await?; Ok(Table::new_with_embedding_registry( - self.parent.clone().open_table(self.request).await?, + table, + self.parent, self.embedding_registry, )) } @@ -519,16 +522,15 @@ impl CloneTableBuilder { /// Execute the clone operation pub async fn execute(self) -> Result
{ - Ok(Table::new( - self.parent.clone().clone_table(self.request).await?, - )) + let parent = self.parent.clone(); + let table = parent.clone_table(self.request).await?; + Ok(Table::new(table, parent)) } } /// A connection to LanceDB #[derive(Clone)] pub struct Connection { - uri: String, internal: Arc, embedding_registry: Arc, } @@ -540,9 +542,19 @@ impl std::fmt::Display for Connection { } impl Connection { + pub fn new( + internal: Arc, + embedding_registry: Arc, + ) -> Self { + Self { + internal, + embedding_registry, + } + } + /// Get the URI of the connection pub fn uri(&self) -> &str { - self.uri.as_str() + self.internal.uri() } /// Get access to the underlying database @@ -675,6 +687,11 @@ impl Connection { .await } + /// Get the read consistency of the connection + pub async fn read_consistency(&self) -> Result { + self.internal.read_consistency().await + } + /// Drop a table in the database. /// /// # Arguments @@ -973,7 +990,6 @@ impl ConnectBuilder { )?); Ok(Connection { internal, - uri: self.request.uri, embedding_registry: self .embedding_registry .unwrap_or_else(|| Arc::new(MemoryRegistry::new())), @@ -996,7 +1012,6 @@ impl ConnectBuilder { let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?); Ok(Connection { internal, - uri: self.request.uri, embedding_registry: self .embedding_registry .unwrap_or_else(|| Arc::new(MemoryRegistry::new())), @@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder { Ok(Connection { internal, - uri: format!("namespace://{}", self.ns_impl), embedding_registry: self .embedding_registry .unwrap_or_else(|| Arc::new(MemoryRegistry::new())), @@ -1139,7 +1153,6 @@ mod test_utils { let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler)); Self { internal, - uri: "db://test".to_string(), embedding_registry: Arc::new(MemoryRegistry::new()), } } @@ -1156,7 +1169,6 @@ mod test_utils { )); Self { internal, - uri: "db://test".to_string(), embedding_registry: Arc::new(MemoryRegistry::new()), } } @@ -1187,7 +1199,7 @@ mod tests { #[tokio::test] async fn test_connect() { let tc = new_test_connection().await.unwrap(); - assert_eq!(tc.connection.uri, tc.uri); + assert_eq!(tc.connection.uri(), tc.uri); } #[cfg(not(windows))] @@ -1208,7 +1220,7 @@ mod tests { .await .unwrap(); - assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string()); + assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string()); } #[tokio::test] diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index e5233920..7816518e 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use arrow_array::RecordBatchReader; use async_trait::async_trait; @@ -213,6 +214,20 @@ impl CloneTableRequest { } } +/// How long until a change is reflected from one Table instance to another +/// +/// Tables are always internally consistent. If a write method is called on +/// a table instance it will be immediately visible in that same table instance. +pub enum ReadConsistency { + /// Changes will not be automatically propagated until the checkout_latest + /// method is called on the target table + Manual, + /// Changes will be propagated automatically within the given duration + Eventual(Duration), + /// Changes are immediately visible in target tables + Strong, +} + /// The `Database` trait defines the interface for database implementations. /// /// A database is responsible for managing tables and their metadata. @@ -220,6 +235,10 @@ impl CloneTableRequest { pub trait Database: Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static { + /// Get the uri of the database + fn uri(&self) -> &str; + /// Get the read consistency of the database + async fn read_consistency(&self) -> Result; /// List immediate child namespace names in the given namespace async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result>; /// Create a new namespace diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index 5aa7ac7b..bc1af880 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem; use snafu::ResultExt; use crate::connection::ConnectRequest; +use crate::database::ReadConsistency; use crate::error::{CreateDirSnafu, Error, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; use crate::table::NativeTable; @@ -598,6 +599,22 @@ impl Database for ListingDatabase { Ok(Vec::new()) } + fn uri(&self) -> &str { + &self.uri + } + + async fn read_consistency(&self) -> Result { + if let Some(read_consistency_inverval) = self.read_consistency_interval { + if read_consistency_inverval.is_zero() { + Ok(ReadConsistency::Strong) + } else { + Ok(ReadConsistency::Eventual(read_consistency_inverval)) + } + } else { + Ok(ReadConsistency::Manual) + } + } + async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> { Err(Error::NotSupported { message: "Namespace operations are not supported for listing database".into(), @@ -1249,7 +1266,8 @@ mod tests { ) .unwrap(); - let source_table_obj = Table::new(source_table.clone()); + let db = Arc::new(db); + let source_table_obj = Table::new(source_table.clone(), db.clone()); source_table_obj .add(Box::new(arrow_array::RecordBatchIterator::new( vec![Ok(batch2)], @@ -1320,7 +1338,8 @@ mod tests { .unwrap(); // Create a tag for the current version - let source_table_obj = Table::new(source_table.clone()); + let db = Arc::new(db); + let source_table_obj = Table::new(source_table.clone(), db.clone()); let mut tags = source_table_obj.tags().await.unwrap(); tags.create("v1.0", source_table.version().await.unwrap()) .await @@ -1336,7 +1355,7 @@ mod tests { ) .unwrap(); - let source_table_obj = Table::new(source_table.clone()); + let source_table_obj = Table::new(source_table.clone(), db.clone()); source_table_obj .add(Box::new(arrow_array::RecordBatchIterator::new( vec![Ok(batch2)], @@ -1432,7 +1451,8 @@ mod tests { ) .unwrap(); - let cloned_table_obj = Table::new(cloned_table.clone()); + let db = Arc::new(db); + let cloned_table_obj = Table::new(cloned_table.clone(), db.clone()); cloned_table_obj .add(Box::new(arrow_array::RecordBatchIterator::new( vec![Ok(batch_clone)], @@ -1452,7 +1472,7 @@ mod tests { ) .unwrap(); - let source_table_obj = Table::new(source_table.clone()); + let source_table_obj = Table::new(source_table.clone(), db); source_table_obj .add(Box::new(arrow_array::RecordBatchIterator::new( vec![Ok(batch_source)], @@ -1495,6 +1515,7 @@ mod tests { .unwrap(); // Add more data to create new versions + let db = Arc::new(db); for i in 0..3 { let batch = RecordBatch::try_new( schema.clone(), @@ -1502,7 +1523,7 @@ mod tests { ) .unwrap(); - let source_table_obj = Table::new(source_table.clone()); + let source_table_obj = Table::new(source_table.clone(), db.clone()); source_table_obj .add(Box::new(arrow_array::RecordBatchIterator::new( vec![Ok(batch)], diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 97f8be39..37d928f8 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -16,9 +16,9 @@ use lance_namespace::{ LanceNamespace, }; -use crate::connection::ConnectRequest; use crate::database::listing::ListingDatabase; use crate::error::{Error, Result}; +use crate::{connection::ConnectRequest, database::ReadConsistency}; use super::{ BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest, @@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase { read_consistency_interval: Option, // Optional session for object stores and caching session: Option>, + // database URI + uri: String, } impl LanceNamespaceDatabase { @@ -57,6 +59,7 @@ impl LanceNamespaceDatabase { storage_options, read_consistency_interval, session, + uri: format!("namespace://{}", ns_impl), }) } @@ -130,6 +133,22 @@ impl std::fmt::Display for LanceNamespaceDatabase { #[async_trait] impl Database for LanceNamespaceDatabase { + fn uri(&self) -> &str { + &self.uri + } + + async fn read_consistency(&self) -> Result { + if let Some(read_consistency_inverval) = self.read_consistency_interval { + if read_consistency_inverval.is_zero() { + Ok(ReadConsistency::Strong) + } else { + Ok(ReadConsistency::Eventual(read_consistency_inverval)) + } + } else { + Ok(ReadConsistency::Manual) + } + } + async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result> { let ns_request = ListNamespacesRequest { id: if request.namespace.is_empty() { diff --git a/rust/lancedb/src/dataloader.rs b/rust/lancedb/src/dataloader.rs new file mode 100644 index 00000000..cbb7f037 --- /dev/null +++ b/rust/lancedb/src/dataloader.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +pub mod permutation; +pub mod shuffle; +pub mod split; +pub mod util; diff --git a/rust/lancedb/src/dataloader/permutation.rs b/rust/lancedb/src/dataloader/permutation.rs new file mode 100644 index 00000000..09a39d93 --- /dev/null +++ b/rust/lancedb/src/dataloader/permutation.rs @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table. +//! +//! A permutation view can apply a filter, divide the data into splits, and shuffle the data. +//! The permutation table only stores the split ids and row ids. It is not a materialized copy of +//! the underlying data and can be very lightweight. +//! +//! Building a permutation table should be fairly quick and memory efficient, even for billions or +//! trillions of rows. + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder}; +use datafusion_expr::col; +use futures::TryStreamExt; +use lance_datafusion::exec::SessionContextExt; + +use crate::{ + arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream}, + dataloader::{ + shuffle::{Shuffler, ShufflerConfig}, + split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, + util::{rename_column, TemporaryDirectory}, + }, + query::{ExecutableQuery, QueryBase}, + Connection, Error, Result, Table, +}; + +/// Configuration for creating a permutation table +#[derive(Debug, Default)] +pub struct PermutationConfig { + /// Splitting configuration + pub split_strategy: SplitStrategy, + /// Shuffle strategy + pub shuffle_strategy: ShuffleStrategy, + /// Optional filter to apply to the base table + pub filter: Option, + /// Directory to use for temporary files + pub temp_dir: TemporaryDirectory, +} + +/// Strategy for shuffling the data. +#[derive(Debug, Clone)] +pub enum ShuffleStrategy { + /// The data is randomly shuffled + /// + /// A seed can be provided to make the shuffle deterministic. + /// + /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows. + /// This decreases the overall randomization but can improve I/O performance when reading from + /// cloud storage. + /// + /// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This + /// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence + /// the performance of the model. Note: shuffling within clumps can still be done at read time but + /// this will only provide a local shuffle and not a global shuffle. + Random { + seed: Option, + clump_size: Option, + }, + /// The data is not shuffled + /// + /// This is useful for debugging and testing. + None, +} + +impl Default for ShuffleStrategy { + fn default() -> Self { + Self::None + } +} + +/// Builder for creating a permutation table. +/// +/// A permutation table is a table that stores split assignments and a shuffled order of rows. This +/// can be used to create a +pub struct PermutationBuilder { + config: PermutationConfig, + base_table: Table, +} + +impl PermutationBuilder { + pub fn new(base_table: Table) -> Self { + Self { + config: PermutationConfig::default(), + base_table, + } + } + + /// Configures the strategy for assigning rows to splits. + /// + /// For example, it is common to create a test/train split of the data. Splits can also be used + /// to limit the number of rows. For example, to only use 10% of the data in a permutation you can + /// create a single split with 10% of the data. + /// + /// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across + /// multiple processes and multiple nodes. + /// + /// The default is a single split that contains all rows. + pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self { + self.config.split_strategy = split_strategy; + self + } + + /// Configures the strategy for shuffling the data. + /// + /// The default is to shuffle the data randomly at row-level granularity (no shard size) and + /// with a random seed. + pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self { + self.config.shuffle_strategy = shuffle_strategy; + self + } + + /// Configures a filter to apply to the base table. + /// + /// Only rows matching the filter will be included in the permutation. + pub fn with_filter(mut self, filter: String) -> Self { + self.config.filter = Some(filter); + self + } + + /// Configures the directory to use for temporary files. + /// + /// The default is to use the operating system's default temporary directory. + pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self { + self.config.temp_dir = temp_dir; + self + } + + async fn sort_by_split_id( + &self, + data: SendableRecordBatchStream, + ) -> Result { + let ctx = SessionContext::new_with_config_rt( + SessionConfig::default(), + RuntimeEnvBuilder::new() + .with_memory_limit(100 * 1024 * 1024, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default() + .with_mode(self.config.temp_dir.to_disk_manager_mode()), + ) + .build_arc() + .unwrap(), + ); + let df = ctx + .read_one_shot(data.into_df_stream()) + .map_err(|e| Error::Other { + message: format!("Failed to setup sort by split id: {}", e), + source: Some(e.into()), + })?; + let df_stream = df + .sort_by(vec![col(SPLIT_ID_COLUMN)]) + .map_err(|e| Error::Other { + message: format!("Failed to plan sort by split id: {}", e), + source: Some(e.into()), + })? + .execute_stream() + .await + .map_err(|e| Error::Other { + message: format!("Failed to sort by split id: {}", e), + source: Some(e.into()), + })?; + + let schema = df_stream.schema(); + let stream = df_stream.map_err(|e| Error::Other { + message: format!("Failed to execute sort by split id: {}", e), + source: Some(e.into()), + }); + Ok(Box::pin(SimpleRecordBatchStream { schema, stream })) + } + + /// Builds the permutation table and stores it in the given database. + pub async fn build(self, dest_table_name: &str) -> Result
{ + // First pass, apply filter and load row ids + let mut rows = self.base_table.query().with_row_id(); + + if let Some(filter) = &self.config.filter { + rows = rows.only_if(filter); + } + + let splitter = Splitter::new( + self.config.temp_dir.clone(), + self.config.split_strategy.clone(), + ); + + let mut needs_sort = !splitter.orders_by_split_id(); + + // Might need to load additional columns to calculate splits (e.g. hash columns or calculated + // split id) + rows = splitter.project(rows); + + let num_rows = self + .base_table + .count_rows(self.config.filter.clone()) + .await? as u64; + + // Apply splits + let rows = rows.execute().await?; + let split_data = splitter.apply(rows, num_rows).await?; + + // Shuffle data if requested + let shuffled = match self.config.shuffle_strategy { + ShuffleStrategy::None => split_data, + ShuffleStrategy::Random { seed, clump_size } => { + let shuffler = Shuffler::new(ShufflerConfig { + seed, + clump_size, + temp_dir: self.config.temp_dir.clone(), + max_rows_per_file: 10 * 1024 * 1024, + }); + shuffler.shuffle(split_data, num_rows).await? + } + }; + + // We want the final permutation to be sorted by the split id. If we shuffled or if + // the split was not assigned sequentially then we need to sort the data. + needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None); + + let sorted = if needs_sort { + self.sort_by_split_id(shuffled).await? + } else { + shuffled + }; + + // Rename _rowid to row_id + let renamed = rename_column(sorted, "_rowid", "row_id")?; + + // Create permutation table + let conn = Connection::new( + self.base_table.database().clone(), + self.base_table.embedding_registry().clone(), + ); + conn.create_table_streaming(dest_table_name, renamed) + .execute() + .await + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Int32Type; + use lance_datagen::{BatchCount, RowCount}; + + use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes}; + + use super::*; + + #[tokio::test] + async fn test_permutation_builder() { + let temp_dir = tempfile::tempdir().unwrap(); + + let db = connect(temp_dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + + let initial_data = lance_datagen::gen_batch() + .col("some_value", lance_datagen::array::step::()) + .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); + let data_table = db + .create_table_streaming("mytbl", initial_data) + .execute() + .await + .unwrap(); + + let permutation_table = PermutationBuilder::new(data_table) + .with_filter("some_value > 57".to_string()) + .with_split_strategy(SplitStrategy::Random { + seed: Some(42), + sizes: SplitSizes::Percentages(vec![0.05, 0.30]), + }) + .build("permutation") + .await + .unwrap(); + + // Potentially brittle seed-dependent values below + assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330); + assert_eq!( + permutation_table + .count_rows(Some("split_id = 0".to_string())) + .await + .unwrap(), + 47 + ); + assert_eq!( + permutation_table + .count_rows(Some("split_id = 1".to_string())) + .await + .unwrap(), + 283 + ); + } +} diff --git a/rust/lancedb/src/dataloader/shuffle.rs b/rust/lancedb/src/dataloader/shuffle.rs new file mode 100644 index 00000000..e06affc7 --- /dev/null +++ b/rust/lancedb/src/dataloader/shuffle.rs @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::{Arc, Mutex}; + +use arrow::compute::concat_batches; +use arrow_array::{RecordBatch, UInt64Array}; +use futures::{StreamExt, TryStreamExt}; +use lance::io::ObjectStore; +use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt}; +use lance_encoding::decoder::DecoderPlugins; +use lance_file::v2::{ + reader::{FileReader, FileReaderOptions}, + writer::{FileWriter, FileWriterOptions}, +}; +use lance_index::scalar::IndexReader; +use lance_io::{ + scheduler::{ScanScheduler, SchedulerConfig}, + utils::CachedFileSize, +}; +use rand::{seq::SliceRandom, Rng, RngCore}; + +use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + dataloader::util::{non_crypto_rng, TemporaryDirectory}, + Error, Result, +}; + +#[derive(Debug, Clone)] +pub struct ShufflerConfig { + /// An optional seed to make the shuffle deterministic + pub seed: Option, + /// The maximum number of rows to write to a single file + /// + /// The shuffler will need to hold at least this many rows in memory. Setting this value + /// extremely large could cause the shuffler to use a lot of memory (depending on row size). + /// + /// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file + /// writers in memory. Each of these will consume some amount of data for column write buffers. + /// So setting this value too small could _also_ cause the shuffler to use a lot of memory and + /// open file handles. + pub max_rows_per_file: u64, + /// The temporary directory to use for writing files + pub temp_dir: TemporaryDirectory, + /// The size of the clumps to shuffle within + /// + /// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows. + /// This decreases the overall randomization but can improve I/O performance when reading from + /// cloud storage. + pub clump_size: Option, +} + +impl Default for ShufflerConfig { + fn default() -> Self { + Self { + max_rows_per_file: 1024 * 1024, + seed: Option::default(), + temp_dir: TemporaryDirectory::default(), + clump_size: None, + } + } +} + +/// A shuffler that can shuffle a stream of record batches +/// +/// To do this the stream is consumed and written to temporary files. A new stream is returned +/// which returns the shuffled data from the temporary files. +/// +/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not +/// write any files and will instead perform an in-memory shuffle. +/// +/// The number of rows in the input stream must be known in advance. +pub struct Shuffler { + config: ShufflerConfig, + id: String, +} + +impl Shuffler { + pub fn new(config: ShufflerConfig) -> Self { + let id = uuid::Uuid::new_v4().to_string(); + Self { config, id } + } + + /// Shuffles a single batch of data in memory + fn shuffle_batch( + batch: &RecordBatch, + rng: &mut dyn RngCore, + clump_size: u64, + ) -> Result { + let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size); + let mut indices = (0..num_clumps).collect::>(); + indices.shuffle(rng); + let indices = if clump_size == 1 { + UInt64Array::from(indices) + } else { + UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| { + if clump_index == num_clumps - 1 { + clump_index * clump_size..batch.num_rows() as u64 + } else { + clump_index * clump_size..(clump_index + 1) * clump_size + } + })) + }; + Ok(arrow::compute::take_record_batch(batch, &indices)?) + } + + async fn in_memory_shuffle( + &self, + data: SendableRecordBatchStream, + mut rng: Box, + ) -> Result { + let schema = data.schema(); + let batches = data.try_collect::>().await?; + let batch = concat_batches(&schema, &batches)?; + let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?; + log::debug!("Shuffle job {}: in-memory shuffle complete", self.id); + Ok(Box::pin(SimpleRecordBatchStream::new( + futures::stream::once(async move { Ok(shuffled) }), + schema, + ))) + } + + async fn do_shuffle( + &self, + mut data: SendableRecordBatchStream, + num_rows: u64, + mut rng: Box, + ) -> Result { + let num_files = num_rows.div_ceil(self.config.max_rows_per_file); + + let temp_dir = self.config.temp_dir.create_temp_dir()?; + let tmp_dir = temp_dir.path().to_path_buf(); + + let clump_size = self.config.clump_size.unwrap_or(1); + if clump_size == 0 { + return Err(Error::InvalidInput { + message: "clump size must be greater than 0".to_string(), + }); + } + + let object_store = ObjectStore::local(); + let arrow_schema = data.schema(); + let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?; + + // Create file writers + let mut file_writers = Vec::with_capacity(num_files as usize); + for file_index in 0..num_files { + let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id)); + let path = + object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other { + message: format!("Failed to create temporary file: {}", err), + source: None, + })?; + let object_writer = object_store.create(&path).await?; + let writer = + FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?; + file_writers.push(writer); + } + + let mut num_rows_seen = 0; + + // Randomly distribute clumps to files + while let Some(batch) = data.try_next().await? { + num_rows_seen += batch.num_rows() as u64; + let is_last = num_rows_seen == num_rows; + if num_rows_seen > num_rows { + return Err(Error::Runtime { + message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen), + }); + } + // This is kind of an annoying limitation but if we allow runt clumps from batches then + // clumps will get unaligned and we will mess up the clumps when we do the in-memory + // shuffle step. If this is a problem we can probably figure out a better way to do this. + if !is_last && batch.num_rows() as u64 % clump_size != 0 { + return Err(Error::Runtime { + message: format!( + "Expected batch size ({}) to be divisible by clump size ({})", + batch.num_rows(), + clump_size + ), + }); + } + let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size); + let mut batch_offsets_for_files = + vec![Vec::::with_capacity(batch.num_rows()); num_files as usize]; + // Partition the batch randomly and write to the appropriate accumulator + for clump_offset in 0..num_clumps { + let clump_start = clump_offset * clump_size; + let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start); + let clump_end = clump_start + num_rows_in_clump; + let file_index = rng.random_range(0..num_files); + batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end); + } + for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() { + if batch_offsets.is_empty() { + continue; + } + let indices = UInt64Array::from(batch_offsets); + let partition = arrow::compute::take_record_batch(&batch, &indices)?; + file_writers[file_index].write_batch(&partition).await?; + } + } + + // Finish writing files + for (file_idx, mut writer) in file_writers.into_iter().enumerate() { + let num_written = writer.finish().await?; + log::debug!( + "Shuffle job {}: wrote {} rows to file {}", + self.id, + num_written, + file_idx + ); + } + + let scheduler_config = SchedulerConfig::max_bandwidth(&object_store); + let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config); + let job_id = self.id.clone(); + let rng = Arc::new(Mutex::new(rng)); + + // Second pass, read each file as a single batch and shuffle + let stream = futures::stream::iter(0..num_files) + .then(move |file_index| { + let scan_scheduler = scan_scheduler.clone(); + let rng = rng.clone(); + let tmp_dir = tmp_dir.clone(); + let job_id = job_id.clone(); + async move { + let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id)); + let path = object_store::path::Path::from_absolute_path(path).unwrap(); + let file_scheduler = scan_scheduler + .open_file(&path, &CachedFileSize::unknown()) + .await?; + let reader = FileReader::try_open( + file_scheduler, + None, + Arc::::default(), + &LanceCache::no_cache(), + FileReaderOptions::default(), + ) + .await?; + // Need to read the entire file in a single batch for in-memory shuffling + let batch = reader.read_record_batch(0, reader.num_rows()).await?; + let mut rng = rng.lock().unwrap(); + Self::shuffle_batch(&batch, &mut rng, clump_size) + } + }) + .finally(move || drop(temp_dir)) + .boxed(); + + Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema))) + } + + pub async fn shuffle( + self, + data: SendableRecordBatchStream, + num_rows: u64, + ) -> Result { + log::debug!( + "Shuffle job {}: shuffling {} rows and {} columns", + self.id, + num_rows, + data.schema().fields.len() + ); + let rng = non_crypto_rng(&self.config.seed); + + if num_rows < self.config.max_rows_per_file { + return self.in_memory_shuffle(data, rng).await; + } + + self.do_shuffle(data, num_rows, rng).await + } +} + +#[cfg(test)] +mod tests { + use crate::arrow::LanceDbDatagenExt; + + use super::*; + use arrow::{array::AsArray, datatypes::Int32Type}; + use datafusion::prelude::SessionContext; + use datafusion_expr::col; + use futures::TryStreamExt; + use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed}; + use rand::{rngs::SmallRng, SeedableRng}; + + fn test_gen() -> BatchGeneratorBuilder { + lance_datagen::gen_batch() + .with_seed(Seed::from(42)) + .col("id", lance_datagen::array::step::()) + .col( + "name", + lance_datagen::array::rand_utf8(ByteCount::from(10), false), + ) + } + + fn create_test_batch(size: RowCount) -> RecordBatch { + test_gen().into_batch_rows(size).unwrap() + } + + fn create_test_stream( + num_batches: BatchCount, + batch_size: RowCount, + ) -> SendableRecordBatchStream { + test_gen().into_ldb_stream(batch_size, num_batches) + } + + #[test] + fn test_shuffle_batch_deterministic() { + let batch = create_test_batch(RowCount::from(10)); + let mut rng1 = SmallRng::seed_from_u64(42); + let mut rng2 = SmallRng::seed_from_u64(42); + + let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap(); + let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap(); + + // Same seed should produce same shuffle + assert_eq!(shuffled1, shuffled2); + } + + #[test] + fn test_shuffle_with_clumps() { + let batch = create_test_batch(RowCount::from(10)); + let mut rng = SmallRng::seed_from_u64(42); + let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap(); + let values = shuffled.column(0).as_primitive::(); + + let mut iter = values.into_iter().map(|o| o.unwrap()); + let mut frag_seen = false; + let mut clumps_seen = 0; + while let Some(first) = iter.next() { + // 9 is the last value and not a full clump + if first != 9 { + // Otherwise we should have a full clump + let second = iter.next().unwrap(); + let third = iter.next().unwrap(); + assert_eq!(first + 1, second); + assert_eq!(first + 2, third); + clumps_seen += 1; + } else { + frag_seen = true; + } + } + assert_eq!(clumps_seen, 3); + assert!(frag_seen); + } + + async fn sort_batch(batch: RecordBatch) -> RecordBatch { + let ctx = SessionContext::new(); + let df = ctx.read_batch(batch).unwrap(); + let sorted = df.sort_by(vec![col("id")]).unwrap(); + let batches = sorted.collect().await.unwrap(); + let schema = batches[0].schema(); + concat_batches(&schema, &batches).unwrap() + } + + #[tokio::test] + async fn test_shuffle_batch_preserves_data() { + let batch = create_test_batch(RowCount::from(100)); + let mut rng = SmallRng::seed_from_u64(42); + + let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap(); + + assert_ne!(shuffled, batch); + + let sorted = sort_batch(shuffled).await; + + assert_eq!(sorted, batch); + } + + #[test] + fn test_shuffle_batch_empty() { + let batch = create_test_batch(RowCount::from(0)); + let mut rng = SmallRng::seed_from_u64(42); + + let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap(); + assert_eq!(shuffled.num_rows(), 0); + } + + #[tokio::test] + async fn test_in_memory_shuffle() { + let config = ShufflerConfig { + temp_dir: TemporaryDirectory::None, + ..Default::default() + }; + let shuffler = Shuffler::new(config); + + let stream = create_test_stream(BatchCount::from(5), RowCount::from(20)); + + let result_stream = shuffler.shuffle(stream, 100).await.unwrap(); + let result_batches: Vec = result_stream.try_collect().await.unwrap(); + + assert_eq!(result_batches.len(), 1); + let result_batch = result_batches.into_iter().next().unwrap(); + + let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20)) + .try_collect::>() + .await + .unwrap(); + let schema = unshuffled_batches[0].schema(); + let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap(); + + let sorted = sort_batch(result_batch).await; + + assert_eq!(unshuffled_batch, sorted); + } + + #[tokio::test] + async fn test_external_shuffle() { + let config = ShufflerConfig { + max_rows_per_file: 100, + ..Default::default() + }; + let shuffler = Shuffler::new(config); + + let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000)); + + let result_stream = shuffler.shuffle(stream, 5000).await.unwrap(); + let result_batches: Vec = result_stream.try_collect().await.unwrap(); + + let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000)) + .try_collect::>() + .await + .unwrap(); + let schema = unshuffled_batches[0].schema(); + let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap(); + + assert_eq!(result_batches.len(), 50); + let result_batch = concat_batches(&schema, &result_batches).unwrap(); + + let sorted = sort_batch(result_batch).await; + + assert_eq!(unshuffled_batch, sorted); + } + + #[test_log::test(tokio::test)] + async fn test_external_clump_shuffle() { + let config = ShufflerConfig { + max_rows_per_file: 100, + clump_size: Some(30), + ..Default::default() + }; + let shuffler = Shuffler::new(config); + + // Batch size (900) must be multiple of clump size (30) + let stream = create_test_stream(BatchCount::from(5), RowCount::from(900)); + let schema = stream.schema(); + + // Remove 10 rows from the last batch to simulate ending with partial clump + let mut batches = stream.try_collect::>().await.unwrap(); + let last_index = batches.len() - 1; + let sliced_last = batches[last_index].slice(0, 890); + batches[last_index] = sliced_last; + + let stream = Box::pin(SimpleRecordBatchStream::new( + futures::stream::iter(batches).map(Ok).boxed(), + schema.clone(), + )); + + let result_stream = shuffler.shuffle(stream, 4490).await.unwrap(); + let result_batches: Vec = result_stream.try_collect().await.unwrap(); + let result_batch = concat_batches(&schema, &result_batches).unwrap(); + + let ids = result_batch.column(0).as_primitive::(); + let mut iter = ids.into_iter().map(|o| o.unwrap()); + while let Some(first) = iter.next() { + let rows_left_in_clump = if first == 4470 { 19 } else { 29 }; + let mut expected_next = first + 1; + for _ in 0..rows_left_in_clump { + let next = iter.next().unwrap(); + assert_eq!(next, expected_next); + expected_next += 1; + } + } + } +} diff --git a/rust/lancedb/src/dataloader/split.rs b/rust/lancedb/src/dataloader/split.rs new file mode 100644 index 00000000..fd9abfb8 --- /dev/null +++ b/rust/lancedb/src/dataloader/split.rs @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::{ + iter, + sync::{ + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, + Arc, + }, +}; + +use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::hash_utils::create_hashes; +use futures::{StreamExt, TryStreamExt}; +use lance::arrow::SchemaExt; + +use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + dataloader::{ + shuffle::{Shuffler, ShufflerConfig}, + util::TemporaryDirectory, + }, + query::{Query, QueryBase, Select}, + Error, Result, +}; + +pub const SPLIT_ID_COLUMN: &str = "split_id"; + +/// Strategy for assigning rows to splits +#[derive(Debug, Clone)] +pub enum SplitStrategy { + /// All rows will have split id 0 + NoSplit, + /// Rows will be randomly assigned to splits + /// + /// A seed can be provided to make the assignment deterministic. + Random { + seed: Option, + sizes: SplitSizes, + }, + /// Rows will be assigned to splits based on the values in the specified columns. + /// + /// This will ensure rows are always assigned to the same split if the given columns do not change. + /// + /// The `split_weights` are used to determine the approximate number of rows in each split. This + /// controls how we divide up the u64 hash space. However, it does not guarantee any particular division + /// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split + /// regardless of the weights. + /// + /// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your + /// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set + /// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight + /// to 85). If you set discard_weight to 0 then all rows will be assigned to a split. + Hash { + columns: Vec, + split_weights: Vec, + discard_weight: u64, + }, + /// Rows will be assigned to splits sequentially. + /// + /// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc. + /// + /// This is mainly useful for debugging and testing. + Sequential { sizes: SplitSizes }, + /// Rows will be assigned to splits based on a calculation of one or more columns. + /// + /// This is useful when the splits already exist in the base table. + /// + /// The provided `calculation` should be an SQL statement that returns an integer value between + /// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration). + /// + /// If this strategy is used then the counts/percentages in the SplitSizes are ignored. + Calculated { calculation: String }, +} + +// The default is not to split the data +// +// All data will be assigned to a single split. +impl Default for SplitStrategy { + fn default() -> Self { + Self::NoSplit + } +} + +impl SplitStrategy { + pub fn validate(&self, num_rows: u64) -> Result<()> { + match self { + Self::NoSplit => Ok(()), + Self::Random { sizes, .. } => sizes.validate(num_rows), + Self::Hash { + split_weights, + columns, + .. + } => { + if columns.is_empty() { + return Err(Error::InvalidInput { + message: "Hash strategy requires at least one column".to_string(), + }); + } + if split_weights.is_empty() { + return Err(Error::InvalidInput { + message: "Hash strategy requires at least one split weight".to_string(), + }); + } + if split_weights.iter().any(|w| *w == 0) { + return Err(Error::InvalidInput { + message: "Split weights must be greater than 0".to_string(), + }); + } + Ok(()) + } + Self::Sequential { sizes } => sizes.validate(num_rows), + Self::Calculated { .. } => Ok(()), + } + } +} + +pub struct Splitter { + temp_dir: TemporaryDirectory, + strategy: SplitStrategy, +} + +impl Splitter { + pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self { + Self { temp_dir, strategy } + } + + fn sequential_split_id( + num_rows: u64, + split_sizes: &[u64], + split_index: &AtomicUsize, + counter_in_split: &AtomicU64, + exhausted: &AtomicBool, + ) -> UInt64Array { + let mut split_ids = Vec::::with_capacity(num_rows as usize); + + while split_ids.len() < num_rows as usize { + let split_id = split_index.load(Ordering::Relaxed); + let counter = counter_in_split.load(Ordering::Relaxed); + + let split_size = split_sizes[split_id]; + let remaining_in_split = split_size - counter; + + let remaining_in_batch = num_rows - split_ids.len() as u64; + + let mut done = false; + let rows_to_add = if remaining_in_batch < remaining_in_split { + counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed); + remaining_in_batch + } else { + split_index.fetch_add(1, Ordering::Relaxed); + counter_in_split.store(0, Ordering::Relaxed); + if split_id == split_sizes.len() - 1 { + exhausted.store(true, Ordering::Relaxed); + done = true; + } + remaining_in_split + }; + + split_ids.extend(iter::repeat_n(split_id as u64, rows_to_add as usize)); + if done { + // Quit early if we've run out of splits + break; + } + } + + UInt64Array::from(split_ids) + } + + async fn apply_sequential( + &self, + source: SendableRecordBatchStream, + num_rows: u64, + sizes: &SplitSizes, + ) -> Result { + let split_sizes = sizes.to_counts(num_rows); + + let split_index = AtomicUsize::new(0); + let counter_in_split = AtomicU64::new(0); + let exhausted = AtomicBool::new(false); + + let schema = source.schema(); + + let new_schema = Arc::new(schema.try_with_column(Field::new( + SPLIT_ID_COLUMN, + DataType::UInt64, + false, + ))?); + + let new_schema_clone = new_schema.clone(); + let stream = source.filter_map(move |batch| { + let batch = match batch { + Ok(batch) => batch, + Err(e) => { + return std::future::ready(Some(Err(e))); + } + }; + + if exhausted.load(Ordering::Relaxed) { + return std::future::ready(None); + } + + let split_ids = Self::sequential_split_id( + batch.num_rows() as u64, + &split_sizes, + &split_index, + &counter_in_split, + &exhausted, + ); + + let mut arrays = batch.columns().to_vec(); + // This can happen if we exhaust all splits in the middle of a batch + if split_ids.len() < batch.num_rows() { + arrays = arrays + .iter() + .map(|arr| arr.slice(0, split_ids.len())) + .collect(); + } + arrays.push(Arc::new(split_ids)); + + std::future::ready(Some(Ok( + RecordBatch::try_new(new_schema.clone(), arrays).unwrap() + ))) + }); + + Ok(Box::pin(SimpleRecordBatchStream::new( + stream, + new_schema_clone, + ))) + } + + fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array { + let arrays = batch + .columns() + .iter() + // Don't hash the last column which should always be the row id + .take(batch.columns().len() - 1) + .cloned() + .collect::>(); + let mut hashes = vec![0; batch.num_rows()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + create_hashes(&arrays, &random_state, &mut hashes).unwrap(); + // As an example, let's assume the weights are 1, 2. Our total weight is 3. + // + // Our thresholds are [1, 3] + // Our modulo output will be 0, 1, or 2. + // + // thresholds.binary_search(0) => Err(0) => 0 + // thresholds.binary_search(1) => Ok(0) => 1 + // thresholds.binary_search(2) => Err(1) => 1 + let split_ids = hashes + .iter() + .map(|h| { + let h = h % total_weight; + let split_id = match thresholds.binary_search(&h) { + Ok(i) => (i + 1) as u64, + Err(i) => i as u64, + }; + if split_id == thresholds.len() as u64 { + // If we're at the last threshold then we discard the row (indicated by setting + // the split_id to null) + None + } else { + Some(split_id) + } + }) + .collect::>(); + UInt64Array::from(split_ids) + } + + async fn apply_hash( + &self, + source: SendableRecordBatchStream, + weights: &[u64], + discard_weight: u64, + ) -> Result { + let row_id_index = source.schema().fields.len() - 1; + let new_schema = Arc::new(Schema::new(vec![ + source.schema().field(row_id_index).clone(), + Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false), + ])); + + let total_weight = weights.iter().sum::() + discard_weight; + // Thresholds are the cumulative sum of the weights + let mut offset = 0; + let thresholds = weights + .iter() + .map(|w| { + let value = offset + w; + offset = value; + value + }) + .collect::>(); + + let new_schema_clone = new_schema.clone(); + let stream = source.map_ok(move |batch| { + let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight); + + if split_ids.null_count() > 0 { + let is_valid = split_ids.nulls().unwrap().inner(); + let is_valid_mask = BooleanArray::new(is_valid.clone(), None); + let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap(); + let row_ids = batch.column(row_id_index); + let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap(); + RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap() + } else { + RecordBatch::try_new( + new_schema.clone(), + vec![batch.column(row_id_index).clone(), Arc::new(split_ids)], + ) + .unwrap() + } + }); + + Ok(Box::pin(SimpleRecordBatchStream::new( + stream, + new_schema_clone, + ))) + } + + pub async fn apply( + &self, + source: SendableRecordBatchStream, + num_rows: u64, + ) -> Result { + self.strategy.validate(num_rows)?; + + match &self.strategy { + // For consistency, even if no-split, we still give a split id column of all 0s + SplitStrategy::NoSplit => { + self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows])) + .await + } + SplitStrategy::Random { seed, sizes } => { + let shuffler = Shuffler::new(ShufflerConfig { + seed: *seed, + // In this case we are only shuffling row ids so we can use a large max_rows_per_file + max_rows_per_file: 10 * 1024 * 1024, + temp_dir: self.temp_dir.clone(), + clump_size: None, + }); + + let shuffled = shuffler.shuffle(source, num_rows).await?; + + self.apply_sequential(shuffled, num_rows, sizes).await + } + SplitStrategy::Sequential { sizes } => { + self.apply_sequential(source, num_rows, sizes).await + } + // Nothing to do, split is calculated in projection + SplitStrategy::Calculated { .. } => Ok(source), + SplitStrategy::Hash { + split_weights, + discard_weight, + .. + } => { + self.apply_hash(source, split_weights, *discard_weight) + .await + } + } + } + + pub fn project(&self, query: Query) -> Query { + match &self.strategy { + SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![( + SPLIT_ID_COLUMN.to_string(), + calculation.clone(), + )])), + SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())), + _ => query, + } + } + + pub fn orders_by_split_id(&self) -> bool { + match &self.strategy { + SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true, + SplitStrategy::NoSplit + | SplitStrategy::Sequential { .. } + // It may be strange but for random we shuffle and then assign splits so the result is + // sorted by split id + | SplitStrategy::Random { .. } => false, + } + } +} + +/// Split configuration - either percentages or absolute counts +/// +/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows) +/// the remaining rows will not be included in the permutation. +/// +/// The default implementation assigns all rows to a single split. +#[derive(Debug, Clone)] +pub enum SplitSizes { + /// Percentage splits (must sum to <= 1.0) + /// + /// The number of rows in each split is the nearest integer to the percentage multiplied by + /// the total number of rows. + Percentages(Vec), + /// Absolute row counts per split + /// + /// If the dataset doesn't contain enough matching rows to fill all splits then an error + /// will be raised. + Counts(Vec), + /// Divides data into a fixed number of splits + /// + /// Will divide the data evenly. + /// + /// If the number of rows is not divisible by the number of splits then the rows per split + /// is rounded down. + Fixed(u64), +} + +impl Default for SplitSizes { + fn default() -> Self { + Self::Percentages(vec![1.0]) + } +} + +impl SplitSizes { + pub fn validate(&self, num_rows: u64) -> Result<()> { + match self { + Self::Percentages(percentages) => { + for percentage in percentages { + if *percentage < 0.0 || *percentage > 1.0 { + return Err(Error::InvalidInput { + message: "Split percentages must be between 0.0 and 1.0".to_string(), + }); + } + if percentage * (num_rows as f64) < 1.0 { + return Err(Error::InvalidInput { + message: format!( + "One of the splits has {}% of {} rows which rounds to 0 rows", + percentage * 100.0, + num_rows + ), + }); + } + } + if percentages.iter().sum::() > 1.0 { + return Err(Error::InvalidInput { + message: "Split percentages must sum to 1.0 or less".to_string(), + }); + } + } + Self::Counts(counts) => { + if counts.iter().sum::() > num_rows { + return Err(Error::InvalidInput { + message: format!( + "Split counts specified {} rows but only {} are available", + counts.iter().sum::(), + num_rows + ), + }); + } + if counts.iter().any(|c| *c == 0) { + return Err(Error::InvalidInput { + message: "Split counts must be greater than 0".to_string(), + }); + } + } + Self::Fixed(num_splits) => { + if *num_splits > num_rows { + return Err(Error::InvalidInput { + message: format!( + "Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.", + *num_splits, num_rows + ), + }); + } + if (num_rows / num_splits) == 0 { + return Err(Error::InvalidInput { + message: format!( + "Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.", + *num_splits, num_rows + ), + }); + } + } + } + Ok(()) + } + + pub fn to_counts(&self, num_rows: u64) -> Vec { + let sizes = match self { + Self::Percentages(percentages) => { + let mut percentage_sum = 0.0_f64; + let mut counts = percentages + .iter() + .map(|p| { + let count = (p * (num_rows as f64)).round() as u64; + percentage_sum += p; + count + }) + .collect::>(); + let sum = counts.iter().sum::(); + + let is_basically_one = + (num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5; + + // If the sum of percentages is close to 1.0 then rounding errors can add up + // to more or less than num_rows + // + // Drop items from buckets until we have the correct number of rows + let mut excess = sum as i64 - num_rows as i64; + let mut drop_idx = 0; + while excess > 0 { + if counts[drop_idx] > 0 { + counts[drop_idx] -= 1; + excess -= 1; + } + drop_idx += 1; + if drop_idx == counts.len() { + drop_idx = 0; + } + } + + // On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_ + // rows due to rounding errors + let mut add_idx = 0; + while is_basically_one && excess < 0 { + counts[add_idx] += 1; + add_idx += 1; + excess += 1; + if add_idx == counts.len() { + add_idx = 0; + } + } + + counts + } + Self::Counts(counts) => counts.clone(), + Self::Fixed(num_splits) => { + let rows_per_split = num_rows / *num_splits; + vec![rows_per_split; *num_splits as usize] + } + }; + + assert!(sizes.iter().sum::() <= num_rows); + + sizes + } +} + +#[cfg(test)] +mod tests { + use crate::arrow::LanceDbDatagenExt; + + use super::*; + use arrow::{ + array::AsArray, + compute::concat_batches, + datatypes::{Int32Type, UInt64Type}, + }; + use arrow_array::Int32Array; + use futures::TryStreamExt; + use lance_datagen::{BatchCount, ByteCount, RowCount, Seed}; + use std::sync::Arc; + + const ID_COLUMN: &str = "id"; + + #[test] + fn test_split_sizes_percentages_validation() { + // Valid percentages + let sizes = SplitSizes::Percentages(vec![0.7, 0.3]); + assert!(sizes.validate(100).is_ok()); + + // Sum > 1.0 + let sizes = SplitSizes::Percentages(vec![0.7, 0.4]); + assert!(sizes.validate(100).is_err()); + + // Negative percentage + let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]); + assert!(sizes.validate(100).is_err()); + + // Percentage > 1.0 + let sizes = SplitSizes::Percentages(vec![1.5]); + assert!(sizes.validate(100).is_err()); + + // Percentage rounds to 0 rows + let sizes = SplitSizes::Percentages(vec![0.001]); + assert!(sizes.validate(100).is_err()); + } + + #[test] + fn test_split_sizes_counts_validation() { + // Valid counts + let sizes = SplitSizes::Counts(vec![30, 70]); + assert!(sizes.validate(100).is_ok()); + + // Sum > num_rows + let sizes = SplitSizes::Counts(vec![60, 50]); + assert!(sizes.validate(100).is_err()); + + // Counts are 0 + let sizes = SplitSizes::Counts(vec![0, 100]); + assert!(sizes.validate(100).is_err()); + } + + #[test] + fn test_split_sizes_fixed_validation() { + // Valid fixed splits + let sizes = SplitSizes::Fixed(5); + assert!(sizes.validate(100).is_ok()); + + // More splits than rows + let sizes = SplitSizes::Fixed(150); + assert!(sizes.validate(100).is_err()); + } + + #[test] + fn test_split_sizes_to_sizes_percentages() { + let sizes = SplitSizes::Percentages(vec![0.3, 0.7]); + let result = sizes.to_counts(100); + assert_eq!(result, vec![30, 70]); + + // Test rounding + let sizes = SplitSizes::Percentages(vec![0.3, 0.41]); + let result = sizes.to_counts(70); + assert_eq!(result, vec![21, 29]); + } + + #[test] + fn test_split_sizes_to_sizes_fixed() { + let sizes = SplitSizes::Fixed(4); + let result = sizes.to_counts(100); + assert_eq!(result, vec![25, 25, 25, 25]); + + // Test with remainder + let sizes = SplitSizes::Fixed(3); + let result = sizes.to_counts(10); + assert_eq!(result, vec![3, 3, 3]); + } + + fn test_data() -> SendableRecordBatchStream { + lance_datagen::gen_batch() + .with_seed(Seed::from(42)) + .col(ID_COLUMN, lance_datagen::array::step::()) + .into_ldb_stream(RowCount::from(10), BatchCount::from(5)) + } + + async fn verify_splitter( + splitter: Splitter, + data: SendableRecordBatchStream, + num_rows: u64, + expected_split_sizes: &[u64], + row_ids_in_order: bool, + ) { + let split_batches = splitter + .apply(data, num_rows) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let schema = split_batches[0].schema(); + let split_batch = concat_batches(&schema, &split_batches).unwrap(); + + let total_split_sizes = expected_split_sizes.iter().sum::(); + + assert_eq!(split_batch.num_rows(), total_split_sizes as usize); + let mut expected = Vec::with_capacity(total_split_sizes as usize); + for (i, size) in expected_split_sizes.iter().enumerate() { + expected.extend(iter::repeat_n(i as u64, *size as usize)); + } + let expected = Arc::new(UInt64Array::from(expected)) as Arc; + + assert_eq!(&expected, split_batch.column(1)); + + let expected_row_ids = + Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc; + if row_ids_in_order { + assert_eq!(&expected_row_ids, split_batch.column(0)); + } else { + assert_ne!(&expected_row_ids, split_batch.column(0)); + } + } + + #[tokio::test] + async fn test_fixed_sequential_split() { + let splitter = Splitter::new( + // Sequential splitting doesn't need a temp dir + TemporaryDirectory::None, + SplitStrategy::Sequential { + sizes: SplitSizes::Fixed(3), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await; + } + + #[tokio::test] + async fn test_fixed_random_split() { + let splitter = Splitter::new( + TemporaryDirectory::None, + SplitStrategy::Random { + seed: Some(42), + sizes: SplitSizes::Fixed(3), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await; + } + + #[tokio::test] + async fn test_counts_sequential_split() { + let splitter = Splitter::new( + // Sequential splitting doesn't need a temp dir + TemporaryDirectory::None, + SplitStrategy::Sequential { + sizes: SplitSizes::Counts(vec![5, 15, 10]), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await; + } + + #[tokio::test] + async fn test_counts_random_split() { + let splitter = Splitter::new( + TemporaryDirectory::None, + SplitStrategy::Random { + seed: Some(42), + sizes: SplitSizes::Counts(vec![5, 15, 10]), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await; + } + + #[tokio::test] + async fn test_percentages_sequential_split() { + let splitter = Splitter::new( + // Sequential splitting doesn't need a temp dir + TemporaryDirectory::None, + SplitStrategy::Sequential { + sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await; + } + + #[tokio::test] + async fn test_percentages_random_split() { + let splitter = Splitter::new( + TemporaryDirectory::None, + SplitStrategy::Random { + seed: Some(42), + sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]), + }, + ); + + verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await; + } + + #[tokio::test] + async fn test_hash_split() { + let data = lance_datagen::gen_batch() + .with_seed(Seed::from(42)) + .col( + "hash1", + lance_datagen::array::rand_utf8(ByteCount::from(10), false), + ) + .col("hash2", lance_datagen::array::step::()) + .col(ID_COLUMN, lance_datagen::array::step::()) + .into_ldb_stream(RowCount::from(10), BatchCount::from(5)); + + let splitter = Splitter::new( + TemporaryDirectory::None, + SplitStrategy::Hash { + columns: vec!["hash1".to_string(), "hash2".to_string()], + split_weights: vec![1, 2], + discard_weight: 1, + }, + ); + + let split_batches = splitter + .apply(data, 10) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let schema = split_batches[0].schema(); + let split_batch = concat_batches(&schema, &split_batches).unwrap(); + + // These assertions are all based on fixed seed in data generation but they match + // up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1) + + // 14 rows (28%) are discarded because discard_weight is 1 + assert_eq!(split_batch.num_rows(), 36); + assert_eq!(split_batch.num_columns(), 2); + + let split_ids = split_batch.column(1).as_primitive::().values(); + let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count(); + let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count(); + + assert_eq!(num_in_split_0, 11); // 22% + assert_eq!(num_in_split_1, 25); // 50% + } +} diff --git a/rust/lancedb/src/dataloader/util.rs b/rust/lancedb/src/dataloader/util.rs new file mode 100644 index 00000000..9e457112 --- /dev/null +++ b/rust/lancedb/src/dataloader/util.rs @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::{path::PathBuf, sync::Arc}; + +use arrow_array::RecordBatch; +use arrow_schema::{Fields, Schema}; +use datafusion_execution::disk_manager::DiskManagerMode; +use futures::TryStreamExt; +use rand::{rngs::SmallRng, RngCore, SeedableRng}; +use tempfile::TempDir; + +use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + Error, Result, +}; + +/// Directory to use for temporary files +#[derive(Debug, Clone, Default)] +pub enum TemporaryDirectory { + /// Use the operating system's default temporary directory (e.g. /tmp) + #[default] + OsDefault, + /// Use the specified directory (must be an absolute path) + Specific(PathBuf), + /// If spilling is required, then error out + None, +} + +impl TemporaryDirectory { + pub fn create_temp_dir(&self) -> Result { + match self { + Self::OsDefault => tempfile::tempdir(), + Self::Specific(path) => tempfile::Builder::default().tempdir_in(path), + Self::None => { + return Err(Error::Runtime { + message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(), + }); + } + } + .map_err(|err| Error::Other { + message: "Failed to create temporary directory".to_string(), + source: Some(err.into()), + }) + } + + pub fn to_disk_manager_mode(&self) -> DiskManagerMode { + match self { + Self::OsDefault => DiskManagerMode::OsTmpDirectory, + Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]), + Self::None => DiskManagerMode::Disabled, + } + } +} + +pub fn non_crypto_rng(seed: &Option) -> Box { + Box::new( + seed.as_ref() + .map(|seed| SmallRng::seed_from_u64(*seed)) + .unwrap_or_else(SmallRng::from_os_rng), + ) +} + +pub fn rename_column( + stream: SendableRecordBatchStream, + old_name: &str, + new_name: &str, +) -> Result { + let schema = stream.schema(); + let field_index = schema.index_of(old_name)?; + + let new_fields = schema + .fields + .iter() + .cloned() + .enumerate() + .map(|(idx, f)| { + if idx == field_index { + Arc::new(f.as_ref().clone().with_name(new_name)) + } else { + f + } + }) + .collect::(); + let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone())); + let new_schema_clone = new_schema.clone(); + + let renamed_stream = stream.and_then(move |batch| { + let renamed_batch = + RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from); + std::future::ready(renamed_batch) + }); + + Ok(Box::pin(SimpleRecordBatchStream::new( + renamed_stream, + new_schema_clone, + ))) +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 9637cf39..75159e04 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -194,6 +194,7 @@ pub mod arrow; pub mod connection; pub mod data; pub mod database; +pub mod dataloader; pub mod embeddings; pub mod error; pub mod index; diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index e69fd54e..2f56a592 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -16,7 +16,7 @@ use tokio::task::spawn_blocking; use crate::database::{ CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest, - OpenTableRequest, TableNamesRequest, + OpenTableRequest, ReadConsistency, TableNamesRequest, }; use crate::error::Result; use crate::table::BaseTable; @@ -189,6 +189,7 @@ struct ListTablesResponse { pub struct RemoteDatabase { client: RestfulLanceDbClient, table_cache: Cache>>, + uri: String, } impl RemoteDatabase { @@ -217,6 +218,7 @@ impl RemoteDatabase { Ok(Self { client, table_cache, + uri: uri.to_owned(), }) } } @@ -238,6 +240,7 @@ mod test_utils { Self { client, table_cache: Cache::new(0), + uri: "http://localhost".to_string(), } } @@ -250,6 +253,7 @@ mod test_utils { Self { client, table_cache: Cache::new(0), + uri: "http://localhost".to_string(), } } } @@ -315,6 +319,17 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String { #[async_trait] impl Database for RemoteDatabase { + fn uri(&self) -> &str { + &self.uri + } + + async fn read_consistency(&self) -> Result { + Err(Error::NotSupported { + message: "Getting the read consistency of a remote database is not yet supported" + .to_string(), + }) + } + async fn table_names(&self, request: TableNamesRequest) -> Result> { let mut req = if !request.namespace.is_empty() { let namespace_id = diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 60c601f1..d632644c 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -50,6 +50,7 @@ use std::sync::Arc; use crate::arrow::IntoArrow; use crate::connection::NoData; +use crate::database::Database; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::error::{Error, Result}; use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex}; @@ -611,9 +612,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { /// A Table is a collection of strong typed Rows. /// /// The type of the each row is defined in Apache Arrow [Schema]. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Table { inner: Arc, + database: Arc, embedding_registry: Arc, } @@ -631,11 +633,13 @@ mod test_utils { { let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( name.into(), - handler, + handler.clone(), None, )); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler)); Self { inner, + database, // Registry is unused. embedding_registry: Arc::new(MemoryRegistry::new()), } @@ -651,11 +655,13 @@ mod test_utils { { let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( name.into(), - handler, + handler.clone(), Some(version), )); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler)); Self { inner, + database, // Registry is unused. embedding_registry: Arc::new(MemoryRegistry::new()), } @@ -670,9 +676,10 @@ impl std::fmt::Display for Table { } impl Table { - pub fn new(inner: Arc) -> Self { + pub fn new(inner: Arc, database: Arc) -> Self { Self { inner, + database, embedding_registry: Arc::new(MemoryRegistry::new()), } } @@ -681,12 +688,22 @@ impl Table { &self.inner } + pub fn database(&self) -> &Arc { + &self.database + } + + pub fn embedding_registry(&self) -> &Arc { + &self.embedding_registry + } + pub(crate) fn new_with_embedding_registry( inner: Arc, + database: Arc, embedding_registry: Arc, ) -> Self { Self { inner, + database, embedding_registry, } } @@ -1416,12 +1433,6 @@ impl Tags for NativeTags { } } -impl From for Table { - fn from(table: NativeTable) -> Self { - Self::new(Arc::new(table)) - } -} - pub trait NativeTableExt { /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. fn as_native(&self) -> Option<&NativeTable>;