diff --git a/docs/src/js/classes/PermutationBuilder.md b/docs/src/js/classes/PermutationBuilder.md index aa2437c9..b293929b 100644 --- a/docs/src/js/classes/PermutationBuilder.md +++ b/docs/src/js/classes/PermutationBuilder.md @@ -64,6 +64,36 @@ builder.filter("age > 18 AND status = 'active'"); *** +### persist() + +```ts +persist(connection, tableName): PermutationBuilder +``` + +Configure the permutation to be persisted. + +#### Parameters + +* **connection**: [`Connection`](Connection.md) + The connection to persist the permutation to + +* **tableName**: `string` + The name of the table to create + +#### Returns + +[`PermutationBuilder`](PermutationBuilder.md) + +A new PermutationBuilder instance + +#### Example + +```ts +builder.persist(connection, "permutation_table"); +``` + +*** + ### shuffle() ```ts @@ -98,15 +128,15 @@ builder.shuffle({ seed: 42, clumpSize: 10 }); ### splitCalculated() ```ts -splitCalculated(calculation): PermutationBuilder +splitCalculated(options): PermutationBuilder ``` Configure calculated splits for the permutation. #### Parameters -* **calculation**: `string` - SQL expression for calculating splits +* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md) + Configuration for calculated splitting #### Returns diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index b3d61023..988f35dc 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -77,6 +77,7 @@ - [RemovalStats](interfaces/RemovalStats.md) - [RetryConfig](interfaces/RetryConfig.md) - [ShuffleOptions](interfaces/ShuffleOptions.md) +- [SplitCalculatedOptions](interfaces/SplitCalculatedOptions.md) - [SplitHashOptions](interfaces/SplitHashOptions.md) - [SplitRandomOptions](interfaces/SplitRandomOptions.md) - [SplitSequentialOptions](interfaces/SplitSequentialOptions.md) diff --git a/docs/src/js/interfaces/SplitCalculatedOptions.md b/docs/src/js/interfaces/SplitCalculatedOptions.md new file mode 100644 index 00000000..5132d648 --- /dev/null +++ b/docs/src/js/interfaces/SplitCalculatedOptions.md @@ -0,0 +1,23 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / SplitCalculatedOptions + +# Interface: SplitCalculatedOptions + +## Properties + +### calculation + +```ts +calculation: string; +``` + +*** + +### splitNames? + +```ts +optional splitNames: string[]; +``` diff --git a/docs/src/js/interfaces/SplitHashOptions.md b/docs/src/js/interfaces/SplitHashOptions.md index 53cbae8e..b0ea571c 100644 --- a/docs/src/js/interfaces/SplitHashOptions.md +++ b/docs/src/js/interfaces/SplitHashOptions.md @@ -24,6 +24,14 @@ optional discardWeight: number; *** +### splitNames? + +```ts +optional splitNames: string[]; +``` + +*** + ### splitWeights ```ts diff --git a/docs/src/js/interfaces/SplitRandomOptions.md b/docs/src/js/interfaces/SplitRandomOptions.md index 66430b6c..f2607b98 100644 --- a/docs/src/js/interfaces/SplitRandomOptions.md +++ b/docs/src/js/interfaces/SplitRandomOptions.md @@ -37,3 +37,11 @@ optional ratios: number[]; ```ts optional seed: number; ``` + +*** + +### splitNames? + +```ts +optional splitNames: string[]; +``` diff --git a/docs/src/js/interfaces/SplitSequentialOptions.md b/docs/src/js/interfaces/SplitSequentialOptions.md index 6397c191..db5cda61 100644 --- a/docs/src/js/interfaces/SplitSequentialOptions.md +++ b/docs/src/js/interfaces/SplitSequentialOptions.md @@ -29,3 +29,11 @@ optional fixed: number; ```ts optional ratios: number[]; ``` + +*** + +### splitNames? + +```ts +optional splitNames: string[]; +``` diff --git a/nodejs/__test__/permutation.test.ts b/nodejs/__test__/permutation.test.ts index 7a6db57b..31f5a3d0 100644 --- a/nodejs/__test__/permutation.test.ts +++ b/nodejs/__test__/permutation.test.ts @@ -138,7 +138,9 @@ describe("PermutationBuilder", () => { }); test("should create permutation with calculated splits", async () => { - const builder = permutationBuilder(table).splitCalculated("id % 2"); + const builder = permutationBuilder(table).splitCalculated({ + calculation: "id % 2", + }); const permutationTable = await builder.execute(); const rowCount = await permutationTable.countRows(); @@ -224,4 +226,146 @@ describe("PermutationBuilder", () => { // Should throw error on second execution await expect(builder.execute()).rejects.toThrow("Builder already consumed"); }); + + test("should accept custom split names with random splits", async () => { + const builder = permutationBuilder(table).splitRandom({ + ratios: [0.3, 0.7], + seed: 42, + splitNames: ["train", "test"], + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Split names are provided but split_id is still numeric (0, 1, etc.) + // The names are metadata that can be used by higher-level APIs + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should accept custom split names with hash splits", async () => { + const builder = permutationBuilder(table).splitHash({ + columns: ["id"], + splitWeights: [50, 50], + discardWeight: 0, + splitNames: ["set_a", "set_b"], + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Split names are provided but split_id is still numeric + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should accept custom split names with sequential splits", async () => { + const builder = permutationBuilder(table).splitSequential({ + ratios: [0.5, 0.5], + splitNames: ["first", "second"], + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Split names are provided but split_id is still numeric + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBe(5); + expect(split1Count).toBe(5); + }); + + test("should accept custom split names with calculated splits", async () => { + const builder = permutationBuilder(table).splitCalculated({ + calculation: "id % 2", + splitNames: ["even", "odd"], + }); + + const permutationTable = await builder.execute(); + const rowCount = await permutationTable.countRows(); + expect(rowCount).toBe(10); + + // Split names are provided but split_id is still numeric + const split0Count = await permutationTable.countRows("split_id = 0"); + const split1Count = await permutationTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + }); + + test("should persist permutation to a new table", async () => { + const db = await connect(tmpDir.name); + const builder = permutationBuilder(table) + .splitRandom({ + ratios: [0.7, 0.3], + seed: 42, + splitNames: ["train", "validation"], + }) + .persist(db, "my_permutation"); + + // Execute the builder which will persist the table + const permutationTable = await builder.execute(); + + // Verify the persisted table exists and can be opened + const persistedTable = await db.openTable("my_permutation"); + expect(persistedTable).toBeDefined(); + + // Verify the persisted table has the correct number of rows + const rowCount = await persistedTable.countRows(); + expect(rowCount).toBe(10); + + // Verify splits exist (numeric split_id values) + const split0Count = await persistedTable.countRows("split_id = 0"); + const split1Count = await persistedTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(10); + + // Verify the table returned by execute is the same as the persisted one + const executedRowCount = await permutationTable.countRows(); + expect(executedRowCount).toBe(10); + }); + + test("should persist permutation with multiple operations", async () => { + const db = await connect(tmpDir.name); + const builder = permutationBuilder(table) + .filter("value > 30") + .splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] }) + .shuffle({ seed: 456 }) + .persist(db, "filtered_permutation"); + + // Execute the builder + const permutationTable = await builder.execute(); + + // Verify the persisted table + const persistedTable = await db.openTable("filtered_permutation"); + const rowCount = await persistedTable.countRows(); + expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100 + + // Verify splits exist (numeric split_id values) + const split0Count = await persistedTable.countRows("split_id = 0"); + const split1Count = await persistedTable.countRows("split_id = 1"); + + expect(split0Count).toBeGreaterThan(0); + expect(split1Count).toBeGreaterThan(0); + expect(split0Count + split1Count).toBe(7); + + // Verify the executed table matches + const executedRowCount = await permutationTable.countRows(); + expect(executedRowCount).toBe(7); + }); }); diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 3ef1a76f..c39ff817 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -43,6 +43,7 @@ export { DeleteResult, DropColumnsResult, UpdateResult, + SplitCalculatedOptions, SplitRandomOptions, SplitHashOptions, SplitSequentialOptions, diff --git a/nodejs/lancedb/permutation.ts b/nodejs/lancedb/permutation.ts index 8fb8b508..ce1c73c5 100644 --- a/nodejs/lancedb/permutation.ts +++ b/nodejs/lancedb/permutation.ts @@ -1,10 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +import { Connection, LocalConnection } from "./connection.js"; import { PermutationBuilder as NativePermutationBuilder, Table as NativeTable, ShuffleOptions, + SplitCalculatedOptions, SplitHashOptions, SplitRandomOptions, SplitSequentialOptions, @@ -29,6 +31,23 @@ export class PermutationBuilder { this.inner = inner; } + /** + * Configure the permutation to be persisted. + * + * @param connection - The connection to persist the permutation to + * @param tableName - The name of the table to create + * @returns A new PermutationBuilder instance + * @example + * ```ts + * builder.persist(connection, "permutation_table"); + * ``` + */ + persist(connection: Connection, tableName: string): PermutationBuilder { + const localConnection = connection as LocalConnection; + const newInner = this.inner.persist(localConnection.inner, tableName); + return new PermutationBuilder(newInner); + } + /** * Configure random splits for the permutation. * @@ -95,15 +114,15 @@ export class PermutationBuilder { /** * Configure calculated splits for the permutation. * - * @param calculation - SQL expression for calculating splits + * @param options - Configuration for calculated splitting * @returns A new PermutationBuilder instance * @example * ```ts * builder.splitCalculated("user_id % 3"); * ``` */ - splitCalculated(calculation: string): PermutationBuilder { - const newInner = this.inner.splitCalculated(calculation); + splitCalculated(options: SplitCalculatedOptions): PermutationBuilder { + const newInner = this.inner.splitCalculated(options); return new PermutationBuilder(newInner); } diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index c3c019d7..a117ab41 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::sync::Arc; -use lancedb::database::CreateTableMode; +use lancedb::database::{CreateTableMode, Database}; use napi::bindgen_prelude::*; use napi_derive::*; @@ -41,6 +41,10 @@ impl Connection { _ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))), } } + + pub fn database(&self) -> napi::Result> { + Ok(self.get_inner()?.database().clone()) + } } #[napi] diff --git a/nodejs/src/permutation.rs b/nodejs/src/permutation.rs index c569020b..43e2e1e8 100644 --- a/nodejs/src/permutation.rs +++ b/nodejs/src/permutation.rs @@ -16,6 +16,7 @@ pub struct SplitRandomOptions { pub counts: Option>, pub fixed: Option, pub seed: Option, + pub split_names: Option>, } #[napi(object)] @@ -23,6 +24,7 @@ pub struct SplitHashOptions { pub columns: Vec, pub split_weights: Vec, pub discard_weight: Option, + pub split_names: Option>, } #[napi(object)] @@ -30,6 +32,13 @@ pub struct SplitSequentialOptions { pub ratios: Option>, pub counts: Option>, pub fixed: Option, + pub split_names: Option>, +} + +#[napi(object)] +pub struct SplitCalculatedOptions { + pub calculation: String, + pub split_names: Option>, } #[napi(object)] @@ -76,6 +85,16 @@ impl PermutationBuilder { #[napi] impl PermutationBuilder { + #[napi] + pub fn persist( + &self, + connection: &crate::connection::Connection, + table_name: String, + ) -> napi::Result { + let database = connection.database()?; + self.modify(|builder| builder.persist(database, table_name)) + } + /// Configure random splits #[napi] pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result { @@ -107,7 +126,12 @@ impl PermutationBuilder { let seed = options.seed.map(|s| s as u64); - self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes })) + self.modify(|builder| { + builder.with_split_strategy( + SplitStrategy::Random { seed, sizes }, + options.split_names.clone(), + ) + }) } /// Configure hash-based splits @@ -120,12 +144,15 @@ impl PermutationBuilder { .collect(); let discard_weight = options.discard_weight.unwrap_or(0) as u64; - self.modify(|builder| { - builder.with_split_strategy(SplitStrategy::Hash { - columns: options.columns, - split_weights, - discard_weight, - }) + self.modify(move |builder| { + builder.with_split_strategy( + SplitStrategy::Hash { + columns: options.columns, + split_weights, + discard_weight, + }, + options.split_names, + ) }) } @@ -158,14 +185,21 @@ impl PermutationBuilder { unreachable!("One of the split arguments must be provided"); }; - self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes })) + self.modify(move |builder| { + builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names) + }) } /// Configure calculated splits #[napi] - pub fn split_calculated(&self, calculation: String) -> napi::Result { - self.modify(|builder| { - builder.with_split_strategy(SplitStrategy::Calculated { calculation }) + pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result { + self.modify(move |builder| { + builder.with_split_strategy( + SplitStrategy::Calculated { + calculation: options.calculation, + }, + options.split_names, + ) }) } diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 7f15be8a..a29f54a8 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection from .remote import ClientConfig from .remote.db import RemoteDBConnection from .schema import vector -from .table import AsyncTable +from .table import AsyncTable, Table from ._lancedb import Session from .namespace import connect_namespace, LanceNamespaceDBConnection @@ -233,6 +233,7 @@ __all__ = [ "LanceNamespaceDBConnection", "RemoteDBConnection", "Session", + "Table", "__version__", ] diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 6e4d7033..54843a76 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -340,3 +340,6 @@ def async_permutation_builder( table: Table, dest_table_name: str ) -> AsyncPermutationBuilder: ... def fts_query_to_json(query: Any) -> str: ... + +class PermutationReader: + def __init__(self, base_table: Table, permutation_table: Table): ... diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index bafaa0eb..2ed0d021 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -1,18 +1,63 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -from ._lancedb import async_permutation_builder +from deprecation import deprecated +from lancedb import AsyncConnection, DBConnection +import pyarrow as pa +import json + +from ._lancedb import async_permutation_builder, PermutationReader from .table import LanceTable from .background_loop import LOOP -from typing import Optional +from .util import batch_to_tensor +from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from lancedb.dependencies import pandas as pd, numpy as np, polars as pl class PermutationBuilder: + """ + A utility for creating a "permutation table" which is a table that defines an + ordering on a base table. + + The permutation table does not store the actual data. It only stores row + ids and split ids to define the ordering. The [Permutation] class can be used to + read the data from the base table in the order defined by the permutation table. + + Permutations can split, shuffle, and filter the data in the base table. + + A filter limits the rows that are included in the permutation. + Splits divide the data into subsets (for example, a test/train split, or K + different splits for cross-validation). + Shuffling randomizes the order of the rows in the permutation. + + Splits can optionally be named. If names are provided it will enable them to + be referenced by name in the future. If names are not provided then they can only + be referenced by their ordinal index. There is no requirement to name every split. + + By default, the permutation will be stored in memory and will be lost when the + program exits. To persist the permutation (for very large datasets or to share + the permutation across multiple workers) use the [persist](#persist) method to + create a permanent table. + """ + def __init__(self, table: LanceTable): + """ + Creates a new permutation builder for the given table. + + By default, the permutation builder will create a single split that contains all + rows in the same order as the base table. + """ self._async = async_permutation_builder(table) - def select(self, projections: dict[str, str]) -> "PermutationBuilder": - self._async.select(projections) + def persist( + self, database: Union[DBConnection, AsyncConnection], table_name: str + ) -> "PermutationBuilder": + """ + Persist the permutation to the given database. + """ + self._async.persist(database, table_name) return self def split_random( @@ -22,8 +67,38 @@ class PermutationBuilder: counts: Optional[list[int]] = None, fixed: Optional[int] = None, seed: Optional[int] = None, + split_names: Optional[list[str]] = None, ) -> "PermutationBuilder": - self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed) + """ + Configure random splits for the permutation. + + One of ratios, counts, or fixed must be provided. + + If ratios are provided, they will be used to determine the relative size of each + split. For example, if ratios are [0.3, 0.7] then the first split will contain + 30% of the rows and the second split will contain 70% of the rows. + + If counts are provided, they will be used to determine the absolute number of + rows in each split. For example, if counts are [100, 200] then the first split + will contain 100 rows and the second split will contain 200 rows. + + If fixed is provided, it will be used to determine the number of splits. + For example, if fixed is 3 then the permutation will be split evenly into 3 + splits. + + Rows will be randomly assigned to splits. The optional seed can be provided to + make the assignment deterministic. + + The optional split_names can be provided to name the splits. If not provided, + the splits can only be referenced by their index. + """ + self._async.split_random( + ratios=ratios, + counts=counts, + fixed=fixed, + seed=seed, + split_names=split_names, + ) return self def split_hash( @@ -32,8 +107,33 @@ class PermutationBuilder: split_weights: list[int], *, discard_weight: Optional[int] = None, + split_names: Optional[list[str]] = None, ) -> "PermutationBuilder": - self._async.split_hash(columns, split_weights, discard_weight=discard_weight) + """ + Configure hash-based splits for the permutation. + + First, a hash will be calculated over the specified columns. The splits weights + are then used to determine how many rows to assign to each split. For example, + if split weights are [1, 2] then the first split will contain 1/3 of the rows + and the second split will contain 2/3 of the rows. + + The optional discard weight can be provided to determine what percentage of rows + should be discarded. For example, if split weights are [1, 2] and discard + weight is 1 then 25% of the rows will be discarded. + + Hash-based splits are useful if you want the split to be more or less random but + you don't want the split assignments to change if rows are added or removed + from the table. + + The optional split_names can be provided to name the splits. If not provided, + the splits can only be referenced by their index. + """ + self._async.split_hash( + columns, + split_weights, + discard_weight=discard_weight, + split_names=split_names, + ) return self def split_sequential( @@ -42,25 +142,85 @@ class PermutationBuilder: ratios: Optional[list[float]] = None, counts: Optional[list[int]] = None, fixed: Optional[int] = None, + split_names: Optional[list[str]] = None, ) -> "PermutationBuilder": - self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed) + """ + Configure sequential splits for the permutation. + + One of ratios, counts, or fixed must be provided. + + If ratios are provided, they will be used to determine the relative size of each + split. For example, if ratios are [0.3, 0.7] then the first split will contain + 30% of the rows and the second split will contain 70% of the rows. + + If counts are provided, they will be used to determine the absolute number of + rows in each split. For example, if counts are [100, 200] then the first split + will contain 100 rows and the second split will contain 200 rows. + + If fixed is provided, it will be used to determine the number of splits. + For example, if fixed is 3 then the permutation will be split evenly into 3 + splits. + + Rows will be assigned to splits sequentially. The first N1 rows are assigned to + split 1, the next N2 rows are assigned to split 2, etc. + + The optional split_names can be provided to name the splits. If not provided, + the splits can only be referenced by their index. + """ + self._async.split_sequential( + ratios=ratios, counts=counts, fixed=fixed, split_names=split_names + ) return self - def split_calculated(self, calculation: str) -> "PermutationBuilder": - self._async.split_calculated(calculation) + def split_calculated( + self, calculation: str, split_names: Optional[list[str]] = None + ) -> "PermutationBuilder": + """ + Use pre-calculated splits for the permutation. + + The calculation should be an SQL statement that returns an integer value between + 0 and the number of splits - 1. For example, if you have 3 splits then the + calculation should return 0 for the first split, 1 for the second split, and 2 + for the third split. + + This can be used to implement any kind of user-defined split strategy. + + The optional split_names can be provided to name the splits. If not provided, + the splits can only be referenced by their index. + """ + self._async.split_calculated(calculation, split_names=split_names) return self def shuffle( self, *, seed: Optional[int] = None, clump_size: Optional[int] = None ) -> "PermutationBuilder": + """ + Randomly shuffle the rows in the permutation. + + An optional seed can be provided to make the shuffle deterministic. + + If a clump size is provided, then data will be shuffled as small "clumps" + of contiguous rows. This allows for a balance between randomization and + I/O performance. It can be useful when reading from cloud storage. + """ self._async.shuffle(seed=seed, clump_size=clump_size) return self def filter(self, filter: str) -> "PermutationBuilder": + """ + Configure a filter for the permutation. + + The filter should be an SQL statement that returns a boolean value for each row. + Only rows where the filter is true will be included in the permutation. + """ self._async.filter(filter) return self def execute(self) -> LanceTable: + """ + Execute the configuration and create the permutation table. + """ + async def do_execute(): inner_tbl = await self._async.execute() return LanceTable.from_inner(inner_tbl) @@ -70,3 +230,592 @@ class PermutationBuilder: def permutation_builder(table: LanceTable) -> PermutationBuilder: return PermutationBuilder(table) + + +class Permutations: + """ + A collection of permutations indexed by name or ordinal index. + + Splits are defined when the permutation is created. Splits can always be referenced + by their ordinal index. If names were provided when the permutation was created + then they can also be referenced by name. + + Each permutation or "split" is a view of a portion of the base table. For more + details see [Permutation]. + + Attributes + ---------- + base_table: LanceTable + The base table that the permutations are based on. + permutation_table: LanceTable + The permutation table that defines the splits. + split_names: list[str] + The names of the splits. + split_dict: dict[str, int] + A dictionary mapping split names to their ordinal index. + + Examples + -------- + >>> # Initial data + >>> import lancedb + >>> db = lancedb.connect("memory:///") + >>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(1000)]) + >>> # Create a permutation + >>> perm_tbl = ( + ... permutation_builder(tbl) + ... .split_random(ratios=[0.95, 0.05], split_names=["train", "test"]) + ... .shuffle() + ... .execute() + ... ) + >>> # Read the permutations + >>> permutations = Permutations(tbl, perm_tbl) + >>> permutations["train"] + + >>> permutations[0] + + >>> permutations.split_names + ['train', 'test'] + >>> permutations.split_dict + {'train': 0, 'test': 1} + """ + + def __init__(self, base_table: LanceTable, permutation_table: LanceTable): + self.base_table = base_table + self.permutation_table = permutation_table + + if permutation_table.schema.metadata is not None: + split_names = permutation_table.schema.metadata.get( + b"split_names", None + ).decode("utf-8") + if split_names is not None: + self.split_names = json.loads(split_names) + self.split_dict = { + name: idx for idx, name in enumerate(self.split_names) + } + else: + # No split names are defined in the permutation table + self.split_names = [] + self.split_dict = {} + else: + # No metadata is defined in the permutation table + self.split_names = [] + self.split_dict = {} + + def get_by_name(self, name: str) -> "Permutation": + """ + Get a permutation by name. + + If no split named `name` is found then an error will be raised. + """ + idx = self.split_dict.get(name, None) + if idx is None: + raise ValueError(f"No split named `{name}` found") + return self.get_by_index(idx) + + def get_by_index(self, index: int) -> "Permutation": + """ + Get a permutation by index. + """ + return Permutation.from_tables(self.base_table, self.permutation_table, index) + + def __getitem__(self, name: Union[str, int]) -> "Permutation": + if isinstance(name, str): + return self.get_by_name(name) + elif isinstance(name, int): + return self.get_by_index(name) + else: + raise TypeError(f"Invalid split name or index: {name}") + + +class Transforms: + """ + Namespace for common transformation functions + """ + + @staticmethod + def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]: + return batch.to_pydict() + + @staticmethod + def arrow2arrow(batch: pa.RecordBatch) -> pa.RecordBatch: + return batch + + @staticmethod + def arrow2numpy(batch: pa.RecordBatch) -> "np.ndarray": + return batch.to_pandas().to_numpy() + + @staticmethod + def arrow2pandas(batch: pa.RecordBatch) -> "pd.DataFrame": + return batch.to_pandas() + + @staticmethod + def arrow2polars() -> "pl.DataFrame": + import polars as pl + + def impl(batch: pa.RecordBatch) -> pl.DataFrame: + return pl.from_arrow(batch) + + return impl + + +# HuggingFace uses 10 which is pretty small +DEFAULT_BATCH_SIZE = 100 + + +class Permutation: + """ + A Permutation is a view of a dataset that can be used as input to model training + and evaluation. + + A Permutation fulfills the pytorch Dataset contract and is loosely modeled after the + huggingface Dataset so it should be easy to use with existing code. + + A permutation is not a "materialized view" or copy of the underlying data. It is + calculated on the fly from the base table. As a result, it is truly "lazy" and does + not require materializing the entire dataset in memory. + """ + + def __init__( + self, + reader: PermutationReader, + selection: dict[str, str], + batch_size: int, + transform_fn: Callable[pa.RecordBatch, Any], + ): + """ + Internal constructor. Use [from_tables](#from_tables) instead. + """ + assert reader is not None, "reader is required" + assert selection is not None, "selection is required" + self.reader = reader + self.selection = selection + self.transform_fn = transform_fn + self.batch_size = batch_size + + def _with_selection(self, selection: dict[str, str]) -> "Permutation": + """ + Creates a new permutation with the given selection + + Does not validation of the selection and it replaces it entirely. This is not + intended for public use. + """ + return Permutation(self.reader, selection, self.batch_size, self.transform_fn) + + def _with_reader(self, reader: PermutationReader) -> "Permutation": + """ + Creates a new permutation with the given reader + + This is an internal method and should not be used directly. + """ + return Permutation(reader, self.selection, self.batch_size, self.transform_fn) + + def with_batch_size(self, batch_size: int) -> "Permutation": + """ + Creates a new permutation with the given batch size + """ + return Permutation(self.reader, self.selection, batch_size, self.transform_fn) + + @classmethod + def identity(cls, table: LanceTable) -> "Permutation": + """ + Creates an identity permutation for the given table. + """ + return Permutation.from_tables(table, None, None) + + @classmethod + def from_tables( + cls, + base_table: LanceTable, + permutation_table: Optional[LanceTable] = None, + split: Optional[Union[str, int]] = None, + ) -> "Permutation": + """ + Creates a permutation from the given base table and permutation table. + + A permutation table identifies which rows, and in what order, the data should + be read from the base table. For more details see the [PermutationBuilder] + class. + + If no permutation table is provided, then the identity permutation will be + created. An identity permutation is a permutation that reads all rows in the + base table in the order they are stored. + + The split parameter identifies which split to use. If no split is provided + then the first split will be used. + """ + assert base_table is not None, "base_table is required" + if split is not None: + if permutation_table is None: + raise ValueError( + "Cannot create a permutation on split `{split}`" + " because no permutation table is provided" + ) + if isinstance(split, str): + if permutation_table.schema.metadata is None: + raise ValueError( + f"Cannot create a permutation on split `{split}`" + " because no split names are defined in the permutation table" + ) + split_names = permutation_table.schema.metadata.get( + b"split_names", None + ).decode("utf-8") + if split_names is None: + raise ValueError( + f"Cannot create a permutation on split `{split}`" + " because no split names are defined in the permutation table" + ) + split_names = json.loads(split_names) + try: + split = split_names.index(split) + except ValueError: + raise ValueError( + f"Cannot create a permutation on split `{split}`" + f" because split `{split}` is not defined in the " + "permutation table" + ) + elif isinstance(split, int): + split = split + else: + raise TypeError(f"Invalid split: {split}") + else: + split = 0 + + async def do_from_tables(): + reader = await PermutationReader.from_tables( + base_table, permutation_table, split + ) + schema = await reader.output_schema(None) + initial_selection = {name: name for name in schema.names} + return cls( + reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python + ) + + return LOOP.run(do_from_tables()) + + @property + def schema(self) -> pa.Schema: + async def do_output_schema(): + return await self.reader.output_schema(self.selection) + + return LOOP.run(do_output_schema()) + + @property + def num_columns(self) -> int: + """ + The number of columns in the permutation + """ + return len(self.schema) + + @property + def num_rows(self) -> int: + """ + The number of rows in the permutation + """ + return self.reader.count_rows() + + @property + def column_names(self) -> list[str]: + """ + The names of the columns in the permutation + """ + return self.schema.names + + @property + def shape(self) -> tuple[int, int]: + """ + The shape of the permutation + + This will return self.num_rows, self.num_columns + """ + return self.num_rows, self.num_columns + + def __len__(self) -> int: + """ + The number of rows in the permutation + + This is an alias for [num_rows][lancedb.permutation.Permutation.num_rows] + """ + return self.num_rows + + def unique(self, _column: str) -> list[Any]: + """ + Get the unique values in the given column + """ + raise Exception("unique is not yet implemented") + + def flatten(self) -> "Permutation": + """ + Flatten the permutation + + Each column with a struct type will be flattened into multiple columns. + + This flattening operation happens at read time as a post-processing step + so this call is cheap and no data is copied or modified in the underlying + dataset. + """ + raise Exception("flatten is not yet implemented") + + def remove_columns(self, columns: list[str]) -> "Permutation": + """ + Remove the given columns from the permutation + + Note: this does not actually modify the underlying dataset. It only changes + which columns are visible from this permutation. Also, this does not introduce + a post-processing step. Instead, we simply do not read those columns in the + first place. + + If any of the provided columns does not exist in the current permutation then it + will be ignored (no error is raised for missing columns) + + Returns a new permutation with the given columns removed. This does not modify + self. + """ + assert columns is not None, "columns is required" + + new_selection = { + name: value for name, value in self.selection.items() if name not in columns + } + + if len(new_selection) == 0: + raise ValueError("Cannot remove all columns") + + return self._with_selection(new_selection) + + def rename_column(self, old_name: str, new_name: str) -> "Permutation": + """ + Rename a column in the permutation + + If there is no column named old_name then an error will be raised + If there is already a column named new_name then an error will be raised + + Note: this does not actually modify the underlying dataset. It only changes + the name of the column that is visible from this permutation. This is a + post-processing step but done at the batch level and so it is very cheap. + No data will be copied. + """ + assert old_name is not None, "old_name is required" + assert new_name is not None, "new_name is required" + if old_name not in self.selection: + raise ValueError( + f"Cannot rename column `{old_name}` because it does not exist" + ) + if new_name in self.selection: + raise ValueError( + f"Cannot rename column `{old_name}` to `{new_name}` because a column " + "with that name already exists" + ) + new_selection = self.selection.copy() + new_selection[new_name] = new_selection[old_name] + del new_selection[old_name] + return self._with_selection(new_selection) + + def rename_columns(self, column_map: dict[str, str]) -> "Permutation": + """ + Rename the given columns in the permutation + + If any of the columns do not exist then an error will be raised + If any of the new names already exist then an error will be raised + + Note: this does not actually modify the underlying dataset. It only changes + the name of the column that is visible from this permutation. This is a + post-processing step but done at the batch level and so it is very cheap. + No data will be copied. + """ + assert column_map is not None, "column_map is required" + + new_permutation = self + for old_name, new_name in column_map.items(): + new_permutation = new_permutation.rename_column(old_name, new_name) + return new_permutation + + def select_columns(self, columns: list[str]) -> "Permutation": + """ + Select the given columns from the permutation + + This method refines the current selection, potentially removing columns. It + will not add back columns that were previously removed. + + If any of the columns do not exist then an error will be raised + + This does not introduce a post-processing step. It simply reduces the amount + of data we read. + """ + assert columns is not None, "columns is required" + if len(columns) == 0: + raise ValueError("Must select at least one column") + + new_selection = {} + for name in columns: + value = self.selection.get(name, None) + if value is None: + raise ValueError( + f"Cannot select column `{name}` because it does not exist" + ) + new_selection[name] = value + return self._with_selection(new_selection) + + def __iter__(self) -> Iterator[dict[str, Any]]: + """ + Iterate over the permutation + """ + return self.iter(self.batch_size, skip_last_batch=True) + + def iter( + self, batch_size: int, skip_last_batch: bool = False + ) -> Iterator[dict[str, Any]]: + """ + Iterate over the permutation in batches + + If skip_last_batch is True, the last batch will be skipped if it is not a + multiple of batch_size. + """ + + async def get_iter(): + return await self.reader.read(self.selection, batch_size=batch_size) + + async_iter = LOOP.run(get_iter()) + + async def get_next(): + return await async_iter.__anext__() + + try: + while True: + batch = LOOP.run(get_next()) + if batch.num_rows == batch_size or not skip_last_batch: + yield self.transform_fn(batch) + except StopAsyncIteration: + return + + def with_format( + self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"] + ) -> "Permutation": + """ + Set the format for batches + + If this method is not called, the "python" format will be used. + + The format can be one of: + - "numpy" - the batch will be a dict of numpy arrays (one per column) + - "python" - the batch will be a dict of lists (one per column) + - "pandas" - the batch will be a pandas DataFrame + - "arrow" - the batch will be a pyarrow RecordBatch + - "torch" - the batch will be a two dimensional torch tensor + - "polars" - the batch will be a polars DataFrame + + Conversion may or may not involve a data copy. Lance uses Arrow internally + and so it is able to zero-copy to the arrow and polars. + + Conversion to torch will be zero-copy but will only support a subset of data + types (numeric types). + + Conversion to numpy and/or pandas will typically be zero-copy for numeric + types. Conversion of strings, lists, and structs will require creating python + objects and this is not zero-copy. + + For custom formatting, use [with_transform](#with_transform) which overrides + this method. + """ + assert format is not None, "format is required" + if format == "python": + return self.with_transform(Transforms.arrow2python) + elif format == "numpy": + return self.with_transform(Transforms.arrow2numpy) + elif format == "pandas": + return self.with_transform(Transforms.arrow2pandas) + elif format == "arrow": + return self.with_transform(Transforms.arrow2arrow) + elif format == "torch": + return self.with_transform(batch_to_tensor) + elif format == "polars": + return self.with_transform(Transforms.arrow2polars()) + else: + raise ValueError(f"Invalid format: {format}") + + def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation": + """ + Set a custom transform for the permutation + + The transform is a callable that will be invoked with each record batch. The + return value will be used as the batch for iteration. + + Note: transforms are not invoked in parallel. This method is not a good place + for expensive operations such as image decoding. + """ + assert transform is not None, "transform is required" + return Permutation(self.reader, self.selection, self.batch_size, transform) + + def __getitem__(self, index: int) -> Any: + """ + Return a single row from the permutation + + The output will always be a python dictionary regardless of the format. + + This method is mostly useful for debugging and exploration. For actual + processing use [iter](#iter) or a torch data loader to perform batched + processing. + """ + pass + + @deprecated(details="Use with_skip instead") + def skip(self, skip: int) -> "Permutation": + """ + Skip the first `skip` rows of the permutation + + Note: this method returns a new permutation and does not modify `self` + It is provided for compatibility with the huggingface Dataset API. + + Use [with_skip](#with_skip) instead to avoid confusion. + """ + return self.with_skip(skip) + + def with_skip(self, skip: int) -> "Permutation": + """ + Skip the first `skip` rows of the permutation + """ + + async def do_with_skip(): + reader = await self.reader.with_offset(skip) + return self._with_reader(reader) + + return LOOP.run(do_with_skip()) + + @deprecated(details="Use with_take instead") + def take(self, limit: int) -> "Permutation": + """ + Limit the permutation to `limit` rows (following any `skip`) + + Note: this method returns a new permutation and does not modify `self` + It is provided for compatibility with the huggingface Dataset API. + + Use [with_take](#with_take) instead to avoid confusion. + """ + return self.with_take(limit) + + def with_take(self, limit: int) -> "Permutation": + """ + Limit the permutation to `limit` rows (following any `skip`) + """ + + async def do_with_take(): + reader = await self.reader.with_limit(limit) + return self._with_reader(reader) + + return LOOP.run(do_with_take()) + + @deprecated(details="Use with_repeat instead") + def repeat(self, times: int) -> "Permutation": + """ + Repeat the permutation `times` times + + Note: this method returns a new permutation and does not modify `self` + It is provided for compatibility with the huggingface Dataset API. + + Use [with_repeat](#with_repeat) instead to avoid confusion. + """ + return self.with_repeat(times) + + def with_repeat(self, times: int) -> "Permutation": + """ + Repeat the permutation `times` times + """ + raise Exception("with_repeat is not yet implemented") diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 64573f6b..8084cbd1 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -366,3 +366,56 @@ def add_note(base_exception: BaseException, note: str): ) else: raise ValueError("Cannot add note to exception") + + +def tbl_to_tensor(tbl: pa.Table): + """ + Convert a PyArrow Table to a PyTorch Tensor. + + Each column is converted to a tensor (using zero-copy via DLPack) + and the columns are then stacked into a single tensor. + + Fails if torch is not installed. + Fails if any column is more than one chunk. + Fails if a column's data type is not supported by PyTorch. + + Parameters + ---------- + tbl : pa.Table or pa.RecordBatch + The table or record batch to convert to a tensor. + + Returns + ------- + torch.Tensor: The tensor containing the columns of the table. + """ + torch = attempt_import_or_raise("torch", "torch") + + def to_tensor(col: pa.ChunkedArray): + if col.num_chunks > 1: + raise Exception("Single batch was too large to fit into a one-chunk table") + return torch.from_dlpack(col.chunk(0)) + + return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)]) + + +def batch_to_tensor(batch: pa.RecordBatch): + """ + Convert a PyArrow RecordBatch to a PyTorch Tensor. + + Each column is converted to a tensor (using zero-copy via DLPack) + and the columns are then stacked into a single tensor. + + Fails if torch is not installed. + Fails if a column's data type is not supported by PyTorch. + + Parameters + ---------- + batch : pa.RecordBatch + The record batch to convert to a tensor. + + Returns + ------- + torch.Tensor: The tensor containing the columns of the record batch. + """ + torch = attempt_import_or_raise("torch", "torch") + return torch.stack([torch.from_dlpack(col) for col in batch.columns]) diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 7fbf2cc4..fa74b273 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -2,9 +2,26 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import pyarrow as pa +import math import pytest -from lancedb.permutation import permutation_builder +from lancedb import DBConnection, Table, connect +from lancedb.permutation import Permutation, Permutations, permutation_builder + + +def test_permutation_persistence(tmp_path): + db = connect(tmp_path) + tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)})) + + permutation_tbl = ( + permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute() + ) + assert permutation_tbl.count_rows() == 100 + + re_open = db.open_table("test_permutation") + assert re_open.count_rows() == 100 + + assert permutation_tbl.to_arrow() == re_open.to_arrow() def test_split_random_ratios(mem_db): @@ -195,21 +212,33 @@ def test_split_error_cases(mem_db): tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)})) # Test split_random with no parameters - with pytest.raises(Exception): + with pytest.raises( + ValueError, + match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + ): permutation_builder(tbl).split_random().execute() # Test split_random with multiple parameters - with pytest.raises(Exception): + with pytest.raises( + ValueError, + match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + ): permutation_builder(tbl).split_random( ratios=[0.5, 0.5], counts=[5, 5] ).execute() # Test split_sequential with no parameters - with pytest.raises(Exception): + with pytest.raises( + ValueError, + match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + ): permutation_builder(tbl).split_sequential().execute() # Test split_sequential with multiple parameters - with pytest.raises(Exception): + with pytest.raises( + ValueError, + match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided", + ): permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute() @@ -460,3 +489,455 @@ def test_filter_empty_result(mem_db): ) assert permutation_tbl.count_rows() == 0 + + +@pytest.fixture +def mem_db() -> DBConnection: + return connect("memory:///") + + +@pytest.fixture +def some_table(mem_db: DBConnection) -> Table: + data = pa.table( + { + "id": range(1000), + "value": range(1000), + } + ) + return mem_db.create_table("some_table", data) + + +def test_no_split_names(some_table: Table): + perm_tbl = ( + permutation_builder(some_table).split_sequential(counts=[500, 500]).execute() + ) + permutations = Permutations(some_table, perm_tbl) + assert permutations.split_names == [] + assert permutations.split_dict == {} + assert permutations[0].num_rows == 500 + assert permutations[1].num_rows == 500 + + +@pytest.fixture +def some_perm_table(some_table: Table) -> Table: + return ( + permutation_builder(some_table) + .split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"]) + .shuffle(seed=42) + .execute() + ) + + +def test_nonexistent_split(some_table: Table, some_perm_table: Table): + # Reference by name and name does not exist + with pytest.raises(ValueError, match="split `nonexistent` is not defined"): + Permutation.from_tables(some_table, some_perm_table, "nonexistent") + + # Reference by ordinal and there are no rows + with pytest.raises(ValueError, match="No rows found"): + Permutation.from_tables(some_table, some_perm_table, 5) + + +def test_permutations(some_table: Table, some_perm_table: Table): + permutations = Permutations(some_table, some_perm_table) + assert permutations.split_names == ["train", "test"] + assert permutations.split_dict == {"train": 0, "test": 1} + assert permutations["train"].num_rows == 950 + assert permutations[0].num_rows == 950 + assert permutations["test"].num_rows == 50 + assert permutations[1].num_rows == 50 + + with pytest.raises(ValueError, match="No split named `nonexistent` found"): + permutations["nonexistent"] + with pytest.raises(ValueError, match="No rows found"): + permutations[5] + + +@pytest.fixture +def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation: + return Permutation.from_tables(some_table, some_perm_table) + + +def test_num_rows(some_permutation: Permutation): + assert some_permutation.num_rows == 950 + + +def test_num_columns(some_permutation: Permutation): + assert some_permutation.num_columns == 2 + + +def test_column_names(some_permutation: Permutation): + assert some_permutation.column_names == ["id", "value"] + + +def test_shape(some_permutation: Permutation): + assert some_permutation.shape == (950, 2) + + +def test_schema(some_permutation: Permutation): + assert some_permutation.schema == pa.schema( + [("id", pa.int64()), ("value", pa.int64())] + ) + + +def test_limit_offset(some_permutation: Permutation): + assert some_permutation.with_take(100).num_rows == 100 + assert some_permutation.with_skip(100).num_rows == 850 + assert some_permutation.with_take(100).with_skip(100).num_rows == 100 + + with pytest.raises(Exception): + some_permutation.with_take(1000000).num_rows + with pytest.raises(Exception): + some_permutation.with_skip(1000000).num_rows + with pytest.raises(Exception): + some_permutation.with_take(500).with_skip(500).num_rows + with pytest.raises(Exception): + some_permutation.with_skip(500).with_take(500).num_rows + + +def test_remove_columns(some_permutation: Permutation): + assert some_permutation.remove_columns(["value"]).schema == pa.schema( + [("id", pa.int64())] + ) + # Should not modify the original permutation + assert some_permutation.schema.names == ["id", "value"] + # Cannot remove all columns + with pytest.raises(ValueError, match="Cannot remove all columns"): + some_permutation.remove_columns(["id", "value"]) + + +def test_rename_column(some_permutation: Permutation): + assert some_permutation.rename_column("value", "new_value").schema == pa.schema( + [("id", pa.int64()), ("new_value", pa.int64())] + ) + # Should not modify the original permutation + assert some_permutation.schema.names == ["id", "value"] + # Cannot rename to an existing column + with pytest.raises( + ValueError, + match="a column with that name already exists", + ): + some_permutation.rename_column("value", "id") + # Cannot rename a non-existent column + with pytest.raises( + ValueError, + match="does not exist", + ): + some_permutation.rename_column("non_existent", "new_value") + + +def test_rename_columns(some_permutation: Permutation): + assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema( + [("id", pa.int64()), ("new_value", pa.int64())] + ) + # Should not modify the original permutation + assert some_permutation.schema.names == ["id", "value"] + # Cannot rename to an existing column + with pytest.raises(ValueError, match="a column with that name already exists"): + some_permutation.rename_columns({"value": "id"}) + + +def test_select_columns(some_permutation: Permutation): + assert some_permutation.select_columns(["id"]).schema == pa.schema( + [("id", pa.int64())] + ) + # Should not modify the original permutation + assert some_permutation.schema.names == ["id", "value"] + # Cannot select a non-existent column + with pytest.raises(ValueError, match="does not exist"): + some_permutation.select_columns(["non_existent"]) + # Empty selection is not allowed + with pytest.raises(ValueError, match="select at least one column"): + some_permutation.select_columns([]) + + +def test_iter_basic(some_permutation: Permutation): + """Test basic iteration with custom batch size.""" + batch_size = 100 + batches = list(some_permutation.iter(batch_size, skip_last_batch=False)) + + # Check that we got the expected number of batches + expected_batches = (950 + batch_size - 1) // batch_size # ceiling division + assert len(batches) == expected_batches + + # Check that all batches are dicts (default python format) + assert all(isinstance(batch, dict) for batch in batches) + + # Check that batches have the correct structure + for batch in batches: + assert "id" in batch + assert "value" in batch + assert isinstance(batch["id"], list) + assert isinstance(batch["value"], list) + + # Check that all batches except the last have the correct size + for batch in batches[:-1]: + assert len(batch["id"]) == batch_size + assert len(batch["value"]) == batch_size + + # Last batch might be smaller + assert len(batches[-1]["id"]) <= batch_size + + +def test_iter_skip_last_batch(some_permutation: Permutation): + """Test iteration with skip_last_batch=True.""" + batch_size = 300 + batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True)) + batches_without_skip = list( + some_permutation.iter(batch_size, skip_last_batch=False) + ) + + # With skip_last_batch=True, we should have fewer batches if the last one is partial + num_full_batches = 950 // batch_size + assert len(batches_with_skip) == num_full_batches + + # Without skip_last_batch, we should have one more batch if there's a remainder + if 950 % batch_size != 0: + assert len(batches_without_skip) == num_full_batches + 1 + # Last batch should be smaller + assert len(batches_without_skip[-1]["id"]) == 950 % batch_size + + # All batches with skip_last_batch should be full size + for batch in batches_with_skip: + assert len(batch["id"]) == batch_size + + +def test_iter_different_batch_sizes(some_permutation: Permutation): + """Test iteration with different batch sizes.""" + + # Test with small batch size + small_batches = list(some_permutation.iter(100, skip_last_batch=False)) + assert len(small_batches) == 10 # ceiling(950 / 100) + + # Test with large batch size + large_batches = list(some_permutation.iter(400, skip_last_batch=False)) + assert len(large_batches) == 3 # ceiling(950 / 400) + + # Test with batch size equal to total rows + single_batch = list(some_permutation.iter(950, skip_last_batch=False)) + assert len(single_batch) == 1 + assert len(single_batch[0]["id"]) == 950 + + # Test with batch size larger than total rows + oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False)) + assert len(oversized_batch) == 1 + assert len(oversized_batch[0]["id"]) == 950 + + +def test_dunder_iter(some_permutation: Permutation): + """Test the __iter__ method.""" + # __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True + batches = list(some_permutation) + + # With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches + assert len(batches) == 9 # ceiling(950 / 100) + + # All batches should be full size + for batch in batches: + assert len(batch["id"]) == 100 + assert len(batch["value"]) == 100 + + some_permutation = some_permutation.with_batch_size(400) + batches = list(some_permutation) + assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True + for batch in batches: + assert len(batch["id"]) == 400 + assert len(batch["value"]) == 400 + + +def test_iter_with_different_formats(some_permutation: Permutation): + """Test iteration with different output formats.""" + batch_size = 100 + + # Test with arrow format + arrow_perm = some_permutation.with_format("arrow") + arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False)) + assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches) + + # Test with python format (default) + python_perm = some_permutation.with_format("python") + python_batches = list(python_perm.iter(batch_size, skip_last_batch=False)) + assert all(isinstance(batch, dict) for batch in python_batches) + + # Test with pandas format + pandas_perm = some_permutation.with_format("pandas") + pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False)) + # Import pandas to check the type + import pandas as pd + + assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches) + + +def test_iter_with_column_selection(some_permutation: Permutation): + """Test iteration after column selection.""" + # Select only the id column + id_only = some_permutation.select_columns(["id"]) + batches = list(id_only.iter(100, skip_last_batch=False)) + + # Check that batches only contain the id column + for batch in batches: + assert "id" in batch + assert "value" not in batch + + +def test_iter_with_column_rename(some_permutation: Permutation): + """Test iteration after renaming columns.""" + renamed = some_permutation.rename_column("value", "data") + batches = list(renamed.iter(100, skip_last_batch=False)) + + # Check that batches have the renamed column + for batch in batches: + assert "id" in batch + assert "data" in batch + assert "value" not in batch + + +def test_iter_with_limit_offset(some_permutation: Permutation): + """Test iteration with limit and offset.""" + # Test with offset + offset_perm = some_permutation.with_skip(100) + offset_batches = list(offset_perm.iter(100, skip_last_batch=False)) + # Should have 850 rows (950 - 100) + expected_batches = math.ceil(850 / 100) + assert len(offset_batches) == expected_batches + + # Test with limit + limit_perm = some_permutation.with_take(500) + limit_batches = list(limit_perm.iter(100, skip_last_batch=False)) + # Should have 5 batches (500 / 100) + assert len(limit_batches) == 5 + + no_skip = some_permutation.iter(101, skip_last_batch=False) + row_100 = next(no_skip)["id"][100] + + # Test with both limit and offset + limited_perm = some_permutation.with_skip(100).with_take(300) + limited_batches = list(limited_perm.iter(100, skip_last_batch=False)) + # Should have 3 batches (300 / 100) + assert len(limited_batches) == 3 + assert limited_batches[0]["id"][0] == row_100 + + +def test_iter_empty_permutation(mem_db): + """Test iteration over an empty permutation.""" + # Create a table and filter it to be empty + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation_tbl = permutation_builder(tbl).filter("value > 100").execute() + with pytest.raises(ValueError, match="No rows found"): + Permutation.from_tables(tbl, permutation_tbl) + + +def test_iter_single_row(mem_db): + """Test iteration over a permutation with a single row.""" + tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]})) + permutation_tbl = permutation_builder(tbl).execute() + perm = Permutation.from_tables(tbl, permutation_tbl) + + # With skip_last_batch=False, should get one batch + batches = list(perm.iter(10, skip_last_batch=False)) + assert len(batches) == 1 + assert len(batches[0]["id"]) == 1 + + # With skip_last_batch=True, should skip the single row (since it's < batch_size) + batches_skip = list(perm.iter(10, skip_last_batch=True)) + assert len(batches_skip) == 0 + + +def test_identity_permutation(mem_db): + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation = Permutation.identity(tbl) + + assert permutation.num_rows == 10 + assert permutation.num_columns == 2 + + batches = list(permutation.iter(10, skip_last_batch=False)) + assert len(batches) == 1 + assert len(batches[0]["id"]) == 10 + assert len(batches[0]["value"]) == 10 + + permutation = permutation.remove_columns(["value"]) + assert permutation.num_columns == 1 + assert permutation.schema == pa.schema([("id", pa.int64())]) + assert permutation.column_names == ["id"] + assert permutation.shape == (10, 1) + + +def test_transform_fn(mem_db): + import numpy as np + import pandas as pd + import polars as pl + + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation = Permutation.identity(tbl) + + np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[ + 0 + ] + assert np_result.shape == (10, 2) + assert np_result.dtype == np.int64 + assert isinstance(np_result, np.ndarray) + + pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[ + 0 + ] + assert pd_result.shape == (10, 2) + assert pd_result.dtypes.tolist() == [np.int64, np.int64] + assert isinstance(pd_result, pd.DataFrame) + + pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[ + 0 + ] + assert pl_result.shape == (10, 2) + assert pl_result.dtypes == [pl.Int64, pl.Int64] + assert isinstance(pl_result, pl.DataFrame) + + py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[ + 0 + ] + assert len(py_result) == 2 + assert len(py_result["id"]) == 10 + assert len(py_result["value"]) == 10 + assert isinstance(py_result, dict) + + try: + import torch + + torch_result = list( + permutation.with_format("torch").iter(10, skip_last_batch=False) + )[0] + assert torch_result.shape == (2, 10) + assert torch_result.dtype == torch.int64 + assert isinstance(torch_result, torch.Tensor) + except ImportError: + # Skip check if torch is not installed + pass + + arrow_result = list( + permutation.with_format("arrow").iter(10, skip_last_batch=False) + )[0] + assert arrow_result.shape == (10, 2) + assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())]) + assert isinstance(arrow_result, pa.RecordBatch) + + +def test_custom_transform(mem_db): + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation = Permutation.identity(tbl) + + def transform(batch: pa.RecordBatch) -> pa.RecordBatch: + return batch.select(["id"]) + + transformed = permutation.with_transform(transform) + batches = list(transformed.iter(10, skip_last_batch=False)) + assert len(batches) == 1 + batch = batches[0] + + assert batch == pa.record_batch([range(10)], ["id"]) diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index 26c3ef5f..b883fdf9 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -3,19 +3,11 @@ import pyarrow as pa import pytest +from lancedb.util import tbl_to_tensor torch = pytest.importorskip("torch") -def tbl_to_tensor(tbl): - def to_tensor(col: pa.ChunkedArray): - if col.num_chunks > 1: - raise Exception("Single batch was too large to fit into a one-chunk table") - return torch.from_dlpack(col.chunk(0)) - - return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)]) - - def test_table_dataloader(mem_db): table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) dataloader = torch.utils.data.DataLoader( diff --git a/python/src/connection.rs b/python/src/connection.rs index 67553074..148ab100 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -6,7 +6,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; use lancedb::{ connection::Connection as LanceConnection, - database::{CreateTableMode, ReadConsistency}, + database::{CreateTableMode, Database, ReadConsistency}, }; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, @@ -42,6 +42,10 @@ impl Connection { _ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))), } } + + pub fn database(&self) -> PyResult> { + Ok(self.get_inner()?.database().clone()) + } } #[pymethods] diff --git a/python/src/lib.rs b/python/src/lib.rs index 70f72773..8c4cedb1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -5,7 +5,7 @@ use arrow::RecordBatchStream; use connection::{connect, Connection}; use env_logger::Env; use index::IndexConfig; -use permutation::PyAsyncPermutationBuilder; +use permutation::{PyAsyncPermutationBuilder, PyPermutationReader}; use pyo3::{ pymodule, types::{PyModule, PyModuleMethods}, @@ -52,6 +52,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?; m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; diff --git a/python/src/permutation.rs b/python/src/permutation.rs index a8d6b4ee..9a9d10e2 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -3,14 +3,23 @@ use std::sync::{Arc, Mutex}; -use crate::{error::PythonErrorExt, table::Table}; -use lancedb::dataloader::{ - permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, - permutation::split::{SplitSizes, SplitStrategy}, +use crate::{ + arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table, +}; +use arrow::pyarrow::ToPyArrow; +use lancedb::{ + dataloader::permutation::{ + builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, + reader::PermutationReader, + split::{SplitSizes, SplitStrategy}, + }, + query::Select, }; use pyo3::{ - exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut, - PyResult, + exceptions::PyRuntimeError, + pyclass, pymethods, + types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, + Bound, PyAny, PyRef, PyRefMut, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; @@ -56,13 +65,32 @@ impl PyAsyncPermutationBuilder { #[pymethods] impl PyAsyncPermutationBuilder { - #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))] + #[pyo3(signature = (database, table_name))] + pub fn persist( + slf: PyRefMut<'_, Self>, + database: Bound<'_, PyAny>, + table_name: String, + ) -> PyResult { + let conn = if database.hasattr("_conn")? { + database + .getattr("_conn")? + .getattr("_inner")? + .downcast_into::()? + } else { + database.getattr("_inner")?.downcast_into::()? + }; + let database = conn.borrow().database()?; + slf.modify(|builder| builder.persist(database, table_name)) + } + + #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))] pub fn split_random( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, seed: Option, + split_names: Option>, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] @@ -86,31 +114,38 @@ impl PyAsyncPermutationBuilder { unreachable!("One of the split arguments must be provided"); }; - slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes })) + slf.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names) + }) } - #[pyo3(signature = (columns, split_weights, *, discard_weight=0))] + #[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))] pub fn split_hash( slf: PyRefMut<'_, Self>, columns: Vec, split_weights: Vec, discard_weight: u64, + split_names: Option>, ) -> PyResult { slf.modify(|builder| { - builder.with_split_strategy(SplitStrategy::Hash { - columns, - split_weights, - discard_weight, - }) + builder.with_split_strategy( + SplitStrategy::Hash { + columns, + split_weights, + discard_weight, + }, + split_names, + ) }) } - #[pyo3(signature = (*, ratios=None, counts=None, fixed=None))] + #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))] pub fn split_sequential( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, + split_names: Option>, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] @@ -134,11 +169,19 @@ impl PyAsyncPermutationBuilder { unreachable!("One of the split arguments must be provided"); }; - slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes })) + slf.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names) + }) } - pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult { - slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation })) + pub fn split_calculated( + slf: PyRefMut<'_, Self>, + calculation: String, + split_names: Option>, + ) -> PyResult { + slf.modify(|builder| { + builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names) + }) } pub fn shuffle( @@ -168,3 +211,121 @@ impl PyAsyncPermutationBuilder { }) } } + +#[pyclass(name = "PermutationReader")] +pub struct PyPermutationReader { + reader: Arc, +} + +impl PyPermutationReader { + fn from_reader(reader: PermutationReader) -> Self { + Self { + reader: Arc::new(reader), + } + } + + fn parse_selection(selection: Option>) -> PyResult