feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)

This commit is contained in:
Weston Pace
2025-11-06 16:15:33 -08:00
committed by GitHub
parent 6ddd271627
commit aeac9c7644
24 changed files with 2071 additions and 126 deletions

View File

@@ -64,6 +64,36 @@ builder.filter("age > 18 AND status = 'active'");
***
### persist()
```ts
persist(connection, tableName): PermutationBuilder
```
Configure the permutation to be persisted.
#### Parameters
* **connection**: [`Connection`](Connection.md)
The connection to persist the permutation to
* **tableName**: `string`
The name of the table to create
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.persist(connection, "permutation_table");
```
***
### shuffle()
```ts
@@ -98,15 +128,15 @@ builder.shuffle({ seed: 42, clumpSize: 10 });
### splitCalculated()
```ts
splitCalculated(calculation): PermutationBuilder
splitCalculated(options): PermutationBuilder
```
Configure calculated splits for the permutation.
#### Parameters
* **calculation**: `string`
SQL expression for calculating splits
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
Configuration for calculated splitting
#### Returns

View File

@@ -77,6 +77,7 @@
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [ShuffleOptions](interfaces/ShuffleOptions.md)
- [SplitCalculatedOptions](interfaces/SplitCalculatedOptions.md)
- [SplitHashOptions](interfaces/SplitHashOptions.md)
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)

View File

@@ -0,0 +1,23 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitCalculatedOptions
# Interface: SplitCalculatedOptions
## Properties
### calculation
```ts
calculation: string;
```
***
### splitNames?
```ts
optional splitNames: string[];
```

View File

@@ -24,6 +24,14 @@ optional discardWeight: number;
***
### splitNames?
```ts
optional splitNames: string[];
```
***
### splitWeights
```ts

View File

@@ -37,3 +37,11 @@ optional ratios: number[];
```ts
optional seed: number;
```
***
### splitNames?
```ts
optional splitNames: string[];
```

View File

@@ -29,3 +29,11 @@ optional fixed: number;
```ts
optional ratios: number[];
```
***
### splitNames?
```ts
optional splitNames: string[];
```

View File

@@ -138,7 +138,9 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated("id % 2");
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -224,4 +226,146 @@ describe("PermutationBuilder", () => {
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
test("should accept custom split names with random splits", async () => {
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
splitNames: ["train", "test"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric (0, 1, etc.)
// The names are metadata that can be used by higher-level APIs
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with hash splits", async () => {
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
splitNames: ["set_a", "set_b"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with sequential splits", async () => {
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
splitNames: ["first", "second"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should accept custom split names with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
splitNames: ["even", "odd"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should persist permutation to a new table", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.splitRandom({
ratios: [0.7, 0.3],
seed: 42,
splitNames: ["train", "validation"],
})
.persist(db, "my_permutation");
// Execute the builder which will persist the table
const permutationTable = await builder.execute();
// Verify the persisted table exists and can be opened
const persistedTable = await db.openTable("my_permutation");
expect(persistedTable).toBeDefined();
// Verify the persisted table has the correct number of rows
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(10);
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
// Verify the table returned by execute is the same as the persisted one
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(10);
});
test("should persist permutation with multiple operations", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.filter("value > 30")
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
.shuffle({ seed: 456 })
.persist(db, "filtered_permutation");
// Execute the builder
const permutationTable = await builder.execute();
// Verify the persisted table
const persistedTable = await db.openTable("filtered_permutation");
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(7);
// Verify the executed table matches
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(7);
});
});

View File

@@ -43,6 +43,7 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitCalculatedOptions,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,

View File

@@ -1,10 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { Connection, LocalConnection } from "./connection.js";
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitCalculatedOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
@@ -29,6 +31,23 @@ export class PermutationBuilder {
this.inner = inner;
}
/**
* Configure the permutation to be persisted.
*
* @param connection - The connection to persist the permutation to
* @param tableName - The name of the table to create
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.persist(connection, "permutation_table");
* ```
*/
persist(connection: Connection, tableName: string): PermutationBuilder {
const localConnection = connection as LocalConnection;
const newInner = this.inner.persist(localConnection.inner, tableName);
return new PermutationBuilder(newInner);
}
/**
* Configure random splits for the permutation.
*
@@ -95,15 +114,15 @@ export class PermutationBuilder {
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @param options - Configuration for calculated splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
splitCalculated(options: SplitCalculatedOptions): PermutationBuilder {
const newInner = this.inner.splitCalculated(options);
return new PermutationBuilder(newInner);
}

View File

@@ -4,7 +4,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use lancedb::database::CreateTableMode;
use lancedb::database::{CreateTableMode, Database};
use napi::bindgen_prelude::*;
use napi_derive::*;
@@ -41,6 +41,10 @@ impl Connection {
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
}
}
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
Ok(self.get_inner()?.database().clone())
}
}
#[napi]

View File

@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
pub struct SplitCalculatedOptions {
pub calculation: String,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -76,6 +85,16 @@ impl PermutationBuilder {
#[napi]
impl PermutationBuilder {
#[napi]
pub fn persist(
&self,
connection: &crate::connection::Connection,
table_name: String,
) -> napi::Result<Self> {
let database = connection.database()?;
self.modify(|builder| builder.persist(database, table_name))
}
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
@@ -107,7 +126,12 @@ impl PermutationBuilder {
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
self.modify(|builder| {
builder.with_split_strategy(
SplitStrategy::Random { seed, sizes },
options.split_names.clone(),
)
})
}
/// Configure hash-based splits
@@ -120,12 +144,15 @@ impl PermutationBuilder {
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
},
options.split_names,
)
})
}
@@ -158,14 +185,21 @@ impl PermutationBuilder {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
self.modify(move |builder| {
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
})
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Calculated {
calculation: options.calculation,
},
options.split_names,
)
})
}

View File

@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .schema import vector
from .table import AsyncTable
from .table import AsyncTable, Table
from ._lancedb import Session
from .namespace import connect_namespace, LanceNamespaceDBConnection
@@ -233,6 +233,7 @@ __all__ = [
"LanceNamespaceDBConnection",
"RemoteDBConnection",
"Session",
"Table",
"__version__",
]

View File

@@ -340,3 +340,6 @@ def async_permutation_builder(
table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...
def fts_query_to_json(query: Any) -> str: ...
class PermutationReader:
def __init__(self, base_table: Table, permutation_table: Table): ...

View File

@@ -1,18 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from ._lancedb import async_permutation_builder
from deprecation import deprecated
from lancedb import AsyncConnection, DBConnection
import pyarrow as pa
import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from typing import Optional
from .util import batch_to_tensor
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
class PermutationBuilder:
"""
A utility for creating a "permutation table" which is a table that defines an
ordering on a base table.
The permutation table does not store the actual data. It only stores row
ids and split ids to define the ordering. The [Permutation] class can be used to
read the data from the base table in the order defined by the permutation table.
Permutations can split, shuffle, and filter the data in the base table.
A filter limits the rows that are included in the permutation.
Splits divide the data into subsets (for example, a test/train split, or K
different splits for cross-validation).
Shuffling randomizes the order of the rows in the permutation.
Splits can optionally be named. If names are provided it will enable them to
be referenced by name in the future. If names are not provided then they can only
be referenced by their ordinal index. There is no requirement to name every split.
By default, the permutation will be stored in memory and will be lost when the
program exits. To persist the permutation (for very large datasets or to share
the permutation across multiple workers) use the [persist](#persist) method to
create a permanent table.
"""
def __init__(self, table: LanceTable):
"""
Creates a new permutation builder for the given table.
By default, the permutation builder will create a single split that contains all
rows in the same order as the base table.
"""
self._async = async_permutation_builder(table)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
def persist(
self, database: Union[DBConnection, AsyncConnection], table_name: str
) -> "PermutationBuilder":
"""
Persist the permutation to the given database.
"""
self._async.persist(database, table_name)
return self
def split_random(
@@ -22,8 +67,38 @@ class PermutationBuilder:
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
split_names: Optional[list[str]] = None,
) -> "PermutationBuilder":
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
"""
Configure random splits for the permutation.
One of ratios, counts, or fixed must be provided.
If ratios are provided, they will be used to determine the relative size of each
split. For example, if ratios are [0.3, 0.7] then the first split will contain
30% of the rows and the second split will contain 70% of the rows.
If counts are provided, they will be used to determine the absolute number of
rows in each split. For example, if counts are [100, 200] then the first split
will contain 100 rows and the second split will contain 200 rows.
If fixed is provided, it will be used to determine the number of splits.
For example, if fixed is 3 then the permutation will be split evenly into 3
splits.
Rows will be randomly assigned to splits. The optional seed can be provided to
make the assignment deterministic.
The optional split_names can be provided to name the splits. If not provided,
the splits can only be referenced by their index.
"""
self._async.split_random(
ratios=ratios,
counts=counts,
fixed=fixed,
seed=seed,
split_names=split_names,
)
return self
def split_hash(
@@ -32,8 +107,33 @@ class PermutationBuilder:
split_weights: list[int],
*,
discard_weight: Optional[int] = None,
split_names: Optional[list[str]] = None,
) -> "PermutationBuilder":
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
"""
Configure hash-based splits for the permutation.
First, a hash will be calculated over the specified columns. The splits weights
are then used to determine how many rows to assign to each split. For example,
if split weights are [1, 2] then the first split will contain 1/3 of the rows
and the second split will contain 2/3 of the rows.
The optional discard weight can be provided to determine what percentage of rows
should be discarded. For example, if split weights are [1, 2] and discard
weight is 1 then 25% of the rows will be discarded.
Hash-based splits are useful if you want the split to be more or less random but
you don't want the split assignments to change if rows are added or removed
from the table.
The optional split_names can be provided to name the splits. If not provided,
the splits can only be referenced by their index.
"""
self._async.split_hash(
columns,
split_weights,
discard_weight=discard_weight,
split_names=split_names,
)
return self
def split_sequential(
@@ -42,25 +142,85 @@ class PermutationBuilder:
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
split_names: Optional[list[str]] = None,
) -> "PermutationBuilder":
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
"""
Configure sequential splits for the permutation.
One of ratios, counts, or fixed must be provided.
If ratios are provided, they will be used to determine the relative size of each
split. For example, if ratios are [0.3, 0.7] then the first split will contain
30% of the rows and the second split will contain 70% of the rows.
If counts are provided, they will be used to determine the absolute number of
rows in each split. For example, if counts are [100, 200] then the first split
will contain 100 rows and the second split will contain 200 rows.
If fixed is provided, it will be used to determine the number of splits.
For example, if fixed is 3 then the permutation will be split evenly into 3
splits.
Rows will be assigned to splits sequentially. The first N1 rows are assigned to
split 1, the next N2 rows are assigned to split 2, etc.
The optional split_names can be provided to name the splits. If not provided,
the splits can only be referenced by their index.
"""
self._async.split_sequential(
ratios=ratios, counts=counts, fixed=fixed, split_names=split_names
)
return self
def split_calculated(self, calculation: str) -> "PermutationBuilder":
self._async.split_calculated(calculation)
def split_calculated(
self, calculation: str, split_names: Optional[list[str]] = None
) -> "PermutationBuilder":
"""
Use pre-calculated splits for the permutation.
The calculation should be an SQL statement that returns an integer value between
0 and the number of splits - 1. For example, if you have 3 splits then the
calculation should return 0 for the first split, 1 for the second split, and 2
for the third split.
This can be used to implement any kind of user-defined split strategy.
The optional split_names can be provided to name the splits. If not provided,
the splits can only be referenced by their index.
"""
self._async.split_calculated(calculation, split_names=split_names)
return self
def shuffle(
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
) -> "PermutationBuilder":
"""
Randomly shuffle the rows in the permutation.
An optional seed can be provided to make the shuffle deterministic.
If a clump size is provided, then data will be shuffled as small "clumps"
of contiguous rows. This allows for a balance between randomization and
I/O performance. It can be useful when reading from cloud storage.
"""
self._async.shuffle(seed=seed, clump_size=clump_size)
return self
def filter(self, filter: str) -> "PermutationBuilder":
"""
Configure a filter for the permutation.
The filter should be an SQL statement that returns a boolean value for each row.
Only rows where the filter is true will be included in the permutation.
"""
self._async.filter(filter)
return self
def execute(self) -> LanceTable:
"""
Execute the configuration and create the permutation table.
"""
async def do_execute():
inner_tbl = await self._async.execute()
return LanceTable.from_inner(inner_tbl)
@@ -70,3 +230,592 @@ class PermutationBuilder:
def permutation_builder(table: LanceTable) -> PermutationBuilder:
return PermutationBuilder(table)
class Permutations:
"""
A collection of permutations indexed by name or ordinal index.
Splits are defined when the permutation is created. Splits can always be referenced
by their ordinal index. If names were provided when the permutation was created
then they can also be referenced by name.
Each permutation or "split" is a view of a portion of the base table. For more
details see [Permutation].
Attributes
----------
base_table: LanceTable
The base table that the permutations are based on.
permutation_table: LanceTable
The permutation table that defines the splits.
split_names: list[str]
The names of the splits.
split_dict: dict[str, int]
A dictionary mapping split names to their ordinal index.
Examples
--------
>>> # Initial data
>>> import lancedb
>>> db = lancedb.connect("memory:///")
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(1000)])
>>> # Create a permutation
>>> perm_tbl = (
... permutation_builder(tbl)
... .split_random(ratios=[0.95, 0.05], split_names=["train", "test"])
... .shuffle()
... .execute()
... )
>>> # Read the permutations
>>> permutations = Permutations(tbl, perm_tbl)
>>> permutations["train"]
<lancedb.permutation.Permutation ...>
>>> permutations[0]
<lancedb.permutation.Permutation ...>
>>> permutations.split_names
['train', 'test']
>>> permutations.split_dict
{'train': 0, 'test': 1}
"""
def __init__(self, base_table: LanceTable, permutation_table: LanceTable):
self.base_table = base_table
self.permutation_table = permutation_table
if permutation_table.schema.metadata is not None:
split_names = permutation_table.schema.metadata.get(
b"split_names", None
).decode("utf-8")
if split_names is not None:
self.split_names = json.loads(split_names)
self.split_dict = {
name: idx for idx, name in enumerate(self.split_names)
}
else:
# No split names are defined in the permutation table
self.split_names = []
self.split_dict = {}
else:
# No metadata is defined in the permutation table
self.split_names = []
self.split_dict = {}
def get_by_name(self, name: str) -> "Permutation":
"""
Get a permutation by name.
If no split named `name` is found then an error will be raised.
"""
idx = self.split_dict.get(name, None)
if idx is None:
raise ValueError(f"No split named `{name}` found")
return self.get_by_index(idx)
def get_by_index(self, index: int) -> "Permutation":
"""
Get a permutation by index.
"""
return Permutation.from_tables(self.base_table, self.permutation_table, index)
def __getitem__(self, name: Union[str, int]) -> "Permutation":
if isinstance(name, str):
return self.get_by_name(name)
elif isinstance(name, int):
return self.get_by_index(name)
else:
raise TypeError(f"Invalid split name or index: {name}")
class Transforms:
"""
Namespace for common transformation functions
"""
@staticmethod
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
return batch.to_pydict()
@staticmethod
def arrow2arrow(batch: pa.RecordBatch) -> pa.RecordBatch:
return batch
@staticmethod
def arrow2numpy(batch: pa.RecordBatch) -> "np.ndarray":
return batch.to_pandas().to_numpy()
@staticmethod
def arrow2pandas(batch: pa.RecordBatch) -> "pd.DataFrame":
return batch.to_pandas()
@staticmethod
def arrow2polars() -> "pl.DataFrame":
import polars as pl
def impl(batch: pa.RecordBatch) -> pl.DataFrame:
return pl.from_arrow(batch)
return impl
# HuggingFace uses 10 which is pretty small
DEFAULT_BATCH_SIZE = 100
class Permutation:
"""
A Permutation is a view of a dataset that can be used as input to model training
and evaluation.
A Permutation fulfills the pytorch Dataset contract and is loosely modeled after the
huggingface Dataset so it should be easy to use with existing code.
A permutation is not a "materialized view" or copy of the underlying data. It is
calculated on the fly from the base table. As a result, it is truly "lazy" and does
not require materializing the entire dataset in memory.
"""
def __init__(
self,
reader: PermutationReader,
selection: dict[str, str],
batch_size: int,
transform_fn: Callable[pa.RecordBatch, Any],
):
"""
Internal constructor. Use [from_tables](#from_tables) instead.
"""
assert reader is not None, "reader is required"
assert selection is not None, "selection is required"
self.reader = reader
self.selection = selection
self.transform_fn = transform_fn
self.batch_size = batch_size
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
"""
Creates a new permutation with the given selection
Does not validation of the selection and it replaces it entirely. This is not
intended for public use.
"""
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
def _with_reader(self, reader: PermutationReader) -> "Permutation":
"""
Creates a new permutation with the given reader
This is an internal method and should not be used directly.
"""
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
def with_batch_size(self, batch_size: int) -> "Permutation":
"""
Creates a new permutation with the given batch size
"""
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
@classmethod
def identity(cls, table: LanceTable) -> "Permutation":
"""
Creates an identity permutation for the given table.
"""
return Permutation.from_tables(table, None, None)
@classmethod
def from_tables(
cls,
base_table: LanceTable,
permutation_table: Optional[LanceTable] = None,
split: Optional[Union[str, int]] = None,
) -> "Permutation":
"""
Creates a permutation from the given base table and permutation table.
A permutation table identifies which rows, and in what order, the data should
be read from the base table. For more details see the [PermutationBuilder]
class.
If no permutation table is provided, then the identity permutation will be
created. An identity permutation is a permutation that reads all rows in the
base table in the order they are stored.
The split parameter identifies which split to use. If no split is provided
then the first split will be used.
"""
assert base_table is not None, "base_table is required"
if split is not None:
if permutation_table is None:
raise ValueError(
"Cannot create a permutation on split `{split}`"
" because no permutation table is provided"
)
if isinstance(split, str):
if permutation_table.schema.metadata is None:
raise ValueError(
f"Cannot create a permutation on split `{split}`"
" because no split names are defined in the permutation table"
)
split_names = permutation_table.schema.metadata.get(
b"split_names", None
).decode("utf-8")
if split_names is None:
raise ValueError(
f"Cannot create a permutation on split `{split}`"
" because no split names are defined in the permutation table"
)
split_names = json.loads(split_names)
try:
split = split_names.index(split)
except ValueError:
raise ValueError(
f"Cannot create a permutation on split `{split}`"
f" because split `{split}` is not defined in the "
"permutation table"
)
elif isinstance(split, int):
split = split
else:
raise TypeError(f"Invalid split: {split}")
else:
split = 0
async def do_from_tables():
reader = await PermutationReader.from_tables(
base_table, permutation_table, split
)
schema = await reader.output_schema(None)
initial_selection = {name: name for name in schema.names}
return cls(
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
)
return LOOP.run(do_from_tables())
@property
def schema(self) -> pa.Schema:
async def do_output_schema():
return await self.reader.output_schema(self.selection)
return LOOP.run(do_output_schema())
@property
def num_columns(self) -> int:
"""
The number of columns in the permutation
"""
return len(self.schema)
@property
def num_rows(self) -> int:
"""
The number of rows in the permutation
"""
return self.reader.count_rows()
@property
def column_names(self) -> list[str]:
"""
The names of the columns in the permutation
"""
return self.schema.names
@property
def shape(self) -> tuple[int, int]:
"""
The shape of the permutation
This will return self.num_rows, self.num_columns
"""
return self.num_rows, self.num_columns
def __len__(self) -> int:
"""
The number of rows in the permutation
This is an alias for [num_rows][lancedb.permutation.Permutation.num_rows]
"""
return self.num_rows
def unique(self, _column: str) -> list[Any]:
"""
Get the unique values in the given column
"""
raise Exception("unique is not yet implemented")
def flatten(self) -> "Permutation":
"""
Flatten the permutation
Each column with a struct type will be flattened into multiple columns.
This flattening operation happens at read time as a post-processing step
so this call is cheap and no data is copied or modified in the underlying
dataset.
"""
raise Exception("flatten is not yet implemented")
def remove_columns(self, columns: list[str]) -> "Permutation":
"""
Remove the given columns from the permutation
Note: this does not actually modify the underlying dataset. It only changes
which columns are visible from this permutation. Also, this does not introduce
a post-processing step. Instead, we simply do not read those columns in the
first place.
If any of the provided columns does not exist in the current permutation then it
will be ignored (no error is raised for missing columns)
Returns a new permutation with the given columns removed. This does not modify
self.
"""
assert columns is not None, "columns is required"
new_selection = {
name: value for name, value in self.selection.items() if name not in columns
}
if len(new_selection) == 0:
raise ValueError("Cannot remove all columns")
return self._with_selection(new_selection)
def rename_column(self, old_name: str, new_name: str) -> "Permutation":
"""
Rename a column in the permutation
If there is no column named old_name then an error will be raised
If there is already a column named new_name then an error will be raised
Note: this does not actually modify the underlying dataset. It only changes
the name of the column that is visible from this permutation. This is a
post-processing step but done at the batch level and so it is very cheap.
No data will be copied.
"""
assert old_name is not None, "old_name is required"
assert new_name is not None, "new_name is required"
if old_name not in self.selection:
raise ValueError(
f"Cannot rename column `{old_name}` because it does not exist"
)
if new_name in self.selection:
raise ValueError(
f"Cannot rename column `{old_name}` to `{new_name}` because a column "
"with that name already exists"
)
new_selection = self.selection.copy()
new_selection[new_name] = new_selection[old_name]
del new_selection[old_name]
return self._with_selection(new_selection)
def rename_columns(self, column_map: dict[str, str]) -> "Permutation":
"""
Rename the given columns in the permutation
If any of the columns do not exist then an error will be raised
If any of the new names already exist then an error will be raised
Note: this does not actually modify the underlying dataset. It only changes
the name of the column that is visible from this permutation. This is a
post-processing step but done at the batch level and so it is very cheap.
No data will be copied.
"""
assert column_map is not None, "column_map is required"
new_permutation = self
for old_name, new_name in column_map.items():
new_permutation = new_permutation.rename_column(old_name, new_name)
return new_permutation
def select_columns(self, columns: list[str]) -> "Permutation":
"""
Select the given columns from the permutation
This method refines the current selection, potentially removing columns. It
will not add back columns that were previously removed.
If any of the columns do not exist then an error will be raised
This does not introduce a post-processing step. It simply reduces the amount
of data we read.
"""
assert columns is not None, "columns is required"
if len(columns) == 0:
raise ValueError("Must select at least one column")
new_selection = {}
for name in columns:
value = self.selection.get(name, None)
if value is None:
raise ValueError(
f"Cannot select column `{name}` because it does not exist"
)
new_selection[name] = value
return self._with_selection(new_selection)
def __iter__(self) -> Iterator[dict[str, Any]]:
"""
Iterate over the permutation
"""
return self.iter(self.batch_size, skip_last_batch=True)
def iter(
self, batch_size: int, skip_last_batch: bool = False
) -> Iterator[dict[str, Any]]:
"""
Iterate over the permutation in batches
If skip_last_batch is True, the last batch will be skipped if it is not a
multiple of batch_size.
"""
async def get_iter():
return await self.reader.read(self.selection, batch_size=batch_size)
async_iter = LOOP.run(get_iter())
async def get_next():
return await async_iter.__anext__()
try:
while True:
batch = LOOP.run(get_next())
if batch.num_rows == batch_size or not skip_last_batch:
yield self.transform_fn(batch)
except StopAsyncIteration:
return
def with_format(
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
) -> "Permutation":
"""
Set the format for batches
If this method is not called, the "python" format will be used.
The format can be one of:
- "numpy" - the batch will be a dict of numpy arrays (one per column)
- "python" - the batch will be a dict of lists (one per column)
- "pandas" - the batch will be a pandas DataFrame
- "arrow" - the batch will be a pyarrow RecordBatch
- "torch" - the batch will be a two dimensional torch tensor
- "polars" - the batch will be a polars DataFrame
Conversion may or may not involve a data copy. Lance uses Arrow internally
and so it is able to zero-copy to the arrow and polars.
Conversion to torch will be zero-copy but will only support a subset of data
types (numeric types).
Conversion to numpy and/or pandas will typically be zero-copy for numeric
types. Conversion of strings, lists, and structs will require creating python
objects and this is not zero-copy.
For custom formatting, use [with_transform](#with_transform) which overrides
this method.
"""
assert format is not None, "format is required"
if format == "python":
return self.with_transform(Transforms.arrow2python)
elif format == "numpy":
return self.with_transform(Transforms.arrow2numpy)
elif format == "pandas":
return self.with_transform(Transforms.arrow2pandas)
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor)
elif format == "polars":
return self.with_transform(Transforms.arrow2polars())
else:
raise ValueError(f"Invalid format: {format}")
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
"""
Set a custom transform for the permutation
The transform is a callable that will be invoked with each record batch. The
return value will be used as the batch for iteration.
Note: transforms are not invoked in parallel. This method is not a good place
for expensive operations such as image decoding.
"""
assert transform is not None, "transform is required"
return Permutation(self.reader, self.selection, self.batch_size, transform)
def __getitem__(self, index: int) -> Any:
"""
Return a single row from the permutation
The output will always be a python dictionary regardless of the format.
This method is mostly useful for debugging and exploration. For actual
processing use [iter](#iter) or a torch data loader to perform batched
processing.
"""
pass
@deprecated(details="Use with_skip instead")
def skip(self, skip: int) -> "Permutation":
"""
Skip the first `skip` rows of the permutation
Note: this method returns a new permutation and does not modify `self`
It is provided for compatibility with the huggingface Dataset API.
Use [with_skip](#with_skip) instead to avoid confusion.
"""
return self.with_skip(skip)
def with_skip(self, skip: int) -> "Permutation":
"""
Skip the first `skip` rows of the permutation
"""
async def do_with_skip():
reader = await self.reader.with_offset(skip)
return self._with_reader(reader)
return LOOP.run(do_with_skip())
@deprecated(details="Use with_take instead")
def take(self, limit: int) -> "Permutation":
"""
Limit the permutation to `limit` rows (following any `skip`)
Note: this method returns a new permutation and does not modify `self`
It is provided for compatibility with the huggingface Dataset API.
Use [with_take](#with_take) instead to avoid confusion.
"""
return self.with_take(limit)
def with_take(self, limit: int) -> "Permutation":
"""
Limit the permutation to `limit` rows (following any `skip`)
"""
async def do_with_take():
reader = await self.reader.with_limit(limit)
return self._with_reader(reader)
return LOOP.run(do_with_take())
@deprecated(details="Use with_repeat instead")
def repeat(self, times: int) -> "Permutation":
"""
Repeat the permutation `times` times
Note: this method returns a new permutation and does not modify `self`
It is provided for compatibility with the huggingface Dataset API.
Use [with_repeat](#with_repeat) instead to avoid confusion.
"""
return self.with_repeat(times)
def with_repeat(self, times: int) -> "Permutation":
"""
Repeat the permutation `times` times
"""
raise Exception("with_repeat is not yet implemented")

View File

@@ -366,3 +366,56 @@ def add_note(base_exception: BaseException, note: str):
)
else:
raise ValueError("Cannot add note to exception")
def tbl_to_tensor(tbl: pa.Table):
"""
Convert a PyArrow Table to a PyTorch Tensor.
Each column is converted to a tensor (using zero-copy via DLPack)
and the columns are then stacked into a single tensor.
Fails if torch is not installed.
Fails if any column is more than one chunk.
Fails if a column's data type is not supported by PyTorch.
Parameters
----------
tbl : pa.Table or pa.RecordBatch
The table or record batch to convert to a tensor.
Returns
-------
torch.Tensor: The tensor containing the columns of the table.
"""
torch = attempt_import_or_raise("torch", "torch")
def to_tensor(col: pa.ChunkedArray):
if col.num_chunks > 1:
raise Exception("Single batch was too large to fit into a one-chunk table")
return torch.from_dlpack(col.chunk(0))
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
def batch_to_tensor(batch: pa.RecordBatch):
"""
Convert a PyArrow RecordBatch to a PyTorch Tensor.
Each column is converted to a tensor (using zero-copy via DLPack)
and the columns are then stacked into a single tensor.
Fails if torch is not installed.
Fails if a column's data type is not supported by PyTorch.
Parameters
----------
batch : pa.RecordBatch
The record batch to convert to a tensor.
Returns
-------
torch.Tensor: The tensor containing the columns of the record batch.
"""
torch = attempt_import_or_raise("torch", "torch")
return torch.stack([torch.from_dlpack(col) for col in batch.columns])

View File

@@ -2,9 +2,26 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import math
import pytest
from lancedb.permutation import permutation_builder
from lancedb import DBConnection, Table, connect
from lancedb.permutation import Permutation, Permutations, permutation_builder
def test_permutation_persistence(tmp_path):
db = connect(tmp_path)
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
permutation_tbl = (
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
)
assert permutation_tbl.count_rows() == 100
re_open = db.open_table("test_permutation")
assert re_open.count_rows() == 100
assert permutation_tbl.to_arrow() == re_open.to_arrow()
def test_split_random_ratios(mem_db):
@@ -195,21 +212,33 @@ def test_split_error_cases(mem_db):
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(Exception):
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
with pytest.raises(
ValueError,
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
):
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
@@ -460,3 +489,455 @@ def test_filter_empty_result(mem_db):
)
assert permutation_tbl.count_rows() == 0
@pytest.fixture
def mem_db() -> DBConnection:
return connect("memory:///")
@pytest.fixture
def some_table(mem_db: DBConnection) -> Table:
data = pa.table(
{
"id": range(1000),
"value": range(1000),
}
)
return mem_db.create_table("some_table", data)
def test_no_split_names(some_table: Table):
perm_tbl = (
permutation_builder(some_table).split_sequential(counts=[500, 500]).execute()
)
permutations = Permutations(some_table, perm_tbl)
assert permutations.split_names == []
assert permutations.split_dict == {}
assert permutations[0].num_rows == 500
assert permutations[1].num_rows == 500
@pytest.fixture
def some_perm_table(some_table: Table) -> Table:
return (
permutation_builder(some_table)
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
.shuffle(seed=42)
.execute()
)
def test_nonexistent_split(some_table: Table, some_perm_table: Table):
# Reference by name and name does not exist
with pytest.raises(ValueError, match="split `nonexistent` is not defined"):
Permutation.from_tables(some_table, some_perm_table, "nonexistent")
# Reference by ordinal and there are no rows
with pytest.raises(ValueError, match="No rows found"):
Permutation.from_tables(some_table, some_perm_table, 5)
def test_permutations(some_table: Table, some_perm_table: Table):
permutations = Permutations(some_table, some_perm_table)
assert permutations.split_names == ["train", "test"]
assert permutations.split_dict == {"train": 0, "test": 1}
assert permutations["train"].num_rows == 950
assert permutations[0].num_rows == 950
assert permutations["test"].num_rows == 50
assert permutations[1].num_rows == 50
with pytest.raises(ValueError, match="No split named `nonexistent` found"):
permutations["nonexistent"]
with pytest.raises(ValueError, match="No rows found"):
permutations[5]
@pytest.fixture
def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
return Permutation.from_tables(some_table, some_perm_table)
def test_num_rows(some_permutation: Permutation):
assert some_permutation.num_rows == 950
def test_num_columns(some_permutation: Permutation):
assert some_permutation.num_columns == 2
def test_column_names(some_permutation: Permutation):
assert some_permutation.column_names == ["id", "value"]
def test_shape(some_permutation: Permutation):
assert some_permutation.shape == (950, 2)
def test_schema(some_permutation: Permutation):
assert some_permutation.schema == pa.schema(
[("id", pa.int64()), ("value", pa.int64())]
)
def test_limit_offset(some_permutation: Permutation):
assert some_permutation.with_take(100).num_rows == 100
assert some_permutation.with_skip(100).num_rows == 850
assert some_permutation.with_take(100).with_skip(100).num_rows == 100
with pytest.raises(Exception):
some_permutation.with_take(1000000).num_rows
with pytest.raises(Exception):
some_permutation.with_skip(1000000).num_rows
with pytest.raises(Exception):
some_permutation.with_take(500).with_skip(500).num_rows
with pytest.raises(Exception):
some_permutation.with_skip(500).with_take(500).num_rows
def test_remove_columns(some_permutation: Permutation):
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
[("id", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot remove all columns
with pytest.raises(ValueError, match="Cannot remove all columns"):
some_permutation.remove_columns(["id", "value"])
def test_rename_column(some_permutation: Permutation):
assert some_permutation.rename_column("value", "new_value").schema == pa.schema(
[("id", pa.int64()), ("new_value", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot rename to an existing column
with pytest.raises(
ValueError,
match="a column with that name already exists",
):
some_permutation.rename_column("value", "id")
# Cannot rename a non-existent column
with pytest.raises(
ValueError,
match="does not exist",
):
some_permutation.rename_column("non_existent", "new_value")
def test_rename_columns(some_permutation: Permutation):
assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema(
[("id", pa.int64()), ("new_value", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot rename to an existing column
with pytest.raises(ValueError, match="a column with that name already exists"):
some_permutation.rename_columns({"value": "id"})
def test_select_columns(some_permutation: Permutation):
assert some_permutation.select_columns(["id"]).schema == pa.schema(
[("id", pa.int64())]
)
# Should not modify the original permutation
assert some_permutation.schema.names == ["id", "value"]
# Cannot select a non-existent column
with pytest.raises(ValueError, match="does not exist"):
some_permutation.select_columns(["non_existent"])
# Empty selection is not allowed
with pytest.raises(ValueError, match="select at least one column"):
some_permutation.select_columns([])
def test_iter_basic(some_permutation: Permutation):
"""Test basic iteration with custom batch size."""
batch_size = 100
batches = list(some_permutation.iter(batch_size, skip_last_batch=False))
# Check that we got the expected number of batches
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
assert len(batches) == expected_batches
# Check that all batches are dicts (default python format)
assert all(isinstance(batch, dict) for batch in batches)
# Check that batches have the correct structure
for batch in batches:
assert "id" in batch
assert "value" in batch
assert isinstance(batch["id"], list)
assert isinstance(batch["value"], list)
# Check that all batches except the last have the correct size
for batch in batches[:-1]:
assert len(batch["id"]) == batch_size
assert len(batch["value"]) == batch_size
# Last batch might be smaller
assert len(batches[-1]["id"]) <= batch_size
def test_iter_skip_last_batch(some_permutation: Permutation):
"""Test iteration with skip_last_batch=True."""
batch_size = 300
batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True))
batches_without_skip = list(
some_permutation.iter(batch_size, skip_last_batch=False)
)
# With skip_last_batch=True, we should have fewer batches if the last one is partial
num_full_batches = 950 // batch_size
assert len(batches_with_skip) == num_full_batches
# Without skip_last_batch, we should have one more batch if there's a remainder
if 950 % batch_size != 0:
assert len(batches_without_skip) == num_full_batches + 1
# Last batch should be smaller
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
# All batches with skip_last_batch should be full size
for batch in batches_with_skip:
assert len(batch["id"]) == batch_size
def test_iter_different_batch_sizes(some_permutation: Permutation):
"""Test iteration with different batch sizes."""
# Test with small batch size
small_batches = list(some_permutation.iter(100, skip_last_batch=False))
assert len(small_batches) == 10 # ceiling(950 / 100)
# Test with large batch size
large_batches = list(some_permutation.iter(400, skip_last_batch=False))
assert len(large_batches) == 3 # ceiling(950 / 400)
# Test with batch size equal to total rows
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
assert len(single_batch) == 1
assert len(single_batch[0]["id"]) == 950
# Test with batch size larger than total rows
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
assert len(oversized_batch) == 1
assert len(oversized_batch[0]["id"]) == 950
def test_dunder_iter(some_permutation: Permutation):
"""Test the __iter__ method."""
# __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True
batches = list(some_permutation)
# With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches
assert len(batches) == 9 # ceiling(950 / 100)
# All batches should be full size
for batch in batches:
assert len(batch["id"]) == 100
assert len(batch["value"]) == 100
some_permutation = some_permutation.with_batch_size(400)
batches = list(some_permutation)
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
for batch in batches:
assert len(batch["id"]) == 400
assert len(batch["value"]) == 400
def test_iter_with_different_formats(some_permutation: Permutation):
"""Test iteration with different output formats."""
batch_size = 100
# Test with arrow format
arrow_perm = some_permutation.with_format("arrow")
arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False))
assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches)
# Test with python format (default)
python_perm = some_permutation.with_format("python")
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
assert all(isinstance(batch, dict) for batch in python_batches)
# Test with pandas format
pandas_perm = some_permutation.with_format("pandas")
pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False))
# Import pandas to check the type
import pandas as pd
assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches)
def test_iter_with_column_selection(some_permutation: Permutation):
"""Test iteration after column selection."""
# Select only the id column
id_only = some_permutation.select_columns(["id"])
batches = list(id_only.iter(100, skip_last_batch=False))
# Check that batches only contain the id column
for batch in batches:
assert "id" in batch
assert "value" not in batch
def test_iter_with_column_rename(some_permutation: Permutation):
"""Test iteration after renaming columns."""
renamed = some_permutation.rename_column("value", "data")
batches = list(renamed.iter(100, skip_last_batch=False))
# Check that batches have the renamed column
for batch in batches:
assert "id" in batch
assert "data" in batch
assert "value" not in batch
def test_iter_with_limit_offset(some_permutation: Permutation):
"""Test iteration with limit and offset."""
# Test with offset
offset_perm = some_permutation.with_skip(100)
offset_batches = list(offset_perm.iter(100, skip_last_batch=False))
# Should have 850 rows (950 - 100)
expected_batches = math.ceil(850 / 100)
assert len(offset_batches) == expected_batches
# Test with limit
limit_perm = some_permutation.with_take(500)
limit_batches = list(limit_perm.iter(100, skip_last_batch=False))
# Should have 5 batches (500 / 100)
assert len(limit_batches) == 5
no_skip = some_permutation.iter(101, skip_last_batch=False)
row_100 = next(no_skip)["id"][100]
# Test with both limit and offset
limited_perm = some_permutation.with_skip(100).with_take(300)
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
# Should have 3 batches (300 / 100)
assert len(limited_batches) == 3
assert limited_batches[0]["id"][0] == row_100
def test_iter_empty_permutation(mem_db):
"""Test iteration over an empty permutation."""
# Create a table and filter it to be empty
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation_tbl = permutation_builder(tbl).filter("value > 100").execute()
with pytest.raises(ValueError, match="No rows found"):
Permutation.from_tables(tbl, permutation_tbl)
def test_iter_single_row(mem_db):
"""Test iteration over a permutation with a single row."""
tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]}))
permutation_tbl = permutation_builder(tbl).execute()
perm = Permutation.from_tables(tbl, permutation_tbl)
# With skip_last_batch=False, should get one batch
batches = list(perm.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 1
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
batches_skip = list(perm.iter(10, skip_last_batch=True))
assert len(batches_skip) == 0
def test_identity_permutation(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
assert permutation.num_rows == 10
assert permutation.num_columns == 2
batches = list(permutation.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 10
assert len(batches[0]["value"]) == 10
permutation = permutation.remove_columns(["value"])
assert permutation.num_columns == 1
assert permutation.schema == pa.schema([("id", pa.int64())])
assert permutation.column_names == ["id"]
assert permutation.shape == (10, 1)
def test_transform_fn(mem_db):
import numpy as np
import pandas as pd
import polars as pl
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[
0
]
assert np_result.shape == (10, 2)
assert np_result.dtype == np.int64
assert isinstance(np_result, np.ndarray)
pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[
0
]
assert pd_result.shape == (10, 2)
assert pd_result.dtypes.tolist() == [np.int64, np.int64]
assert isinstance(pd_result, pd.DataFrame)
pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[
0
]
assert pl_result.shape == (10, 2)
assert pl_result.dtypes == [pl.Int64, pl.Int64]
assert isinstance(pl_result, pl.DataFrame)
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
0
]
assert len(py_result) == 2
assert len(py_result["id"]) == 10
assert len(py_result["value"]) == 10
assert isinstance(py_result, dict)
try:
import torch
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert torch_result.shape == (2, 10)
assert torch_result.dtype == torch.int64
assert isinstance(torch_result, torch.Tensor)
except ImportError:
# Skip check if torch is not installed
pass
arrow_result = list(
permutation.with_format("arrow").iter(10, skip_last_batch=False)
)[0]
assert arrow_result.shape == (10, 2)
assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())])
assert isinstance(arrow_result, pa.RecordBatch)
def test_custom_transform(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
def transform(batch: pa.RecordBatch) -> pa.RecordBatch:
return batch.select(["id"])
transformed = permutation.with_transform(transform)
batches = list(transformed.iter(10, skip_last_batch=False))
assert len(batches) == 1
batch = batches[0]
assert batch == pa.record_batch([range(10)], ["id"])

View File

@@ -3,19 +3,11 @@
import pyarrow as pa
import pytest
from lancedb.util import tbl_to_tensor
torch = pytest.importorskip("torch")
def tbl_to_tensor(tbl):
def to_tensor(col: pa.ChunkedArray):
if col.num_chunks > 1:
raise Exception("Single batch was too large to fit into a one-chunk table")
return torch.from_dlpack(col.chunk(0))
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
def test_table_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
dataloader = torch.utils.data.DataLoader(

View File

@@ -6,7 +6,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
database::{CreateTableMode, Database, ReadConsistency},
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
@@ -42,6 +42,10 @@ impl Connection {
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
}
}
pub fn database(&self) -> PyResult<Arc<dyn Database>> {
Ok(self.get_inner()?.database().clone())
}
}
#[pymethods]

View File

@@ -5,7 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -52,6 +52,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_class::<PyPermutationReader>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;

View File

@@ -3,14 +3,23 @@
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
use crate::{
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
};
use arrow::pyarrow::ToPyArrow;
use lancedb::{
dataloader::permutation::{
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
reader::PermutationReader,
split::{SplitSizes, SplitStrategy},
},
query::Select,
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
exceptions::PyRuntimeError,
pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -56,13 +65,32 @@ impl PyAsyncPermutationBuilder {
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
#[pyo3(signature = (database, table_name))]
pub fn persist(
slf: PyRefMut<'_, Self>,
database: Bound<'_, PyAny>,
table_name: String,
) -> PyResult<Self> {
let conn = if database.hasattr("_conn")? {
database
.getattr("_conn")?
.getattr("_inner")?
.downcast_into::<Connection>()?
} else {
database.getattr("_inner")?.downcast_into::<Connection>()?
};
let database = conn.borrow().database()?;
slf.modify(|builder| builder.persist(database, table_name))
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
split_names: Option<Vec<String>>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
@@ -86,31 +114,38 @@ impl PyAsyncPermutationBuilder {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
})
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
split_names: Option<Vec<String>>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
builder.with_split_strategy(
SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
},
split_names,
)
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
split_names: Option<Vec<String>>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
@@ -134,11 +169,19 @@ impl PyAsyncPermutationBuilder {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
})
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
pub fn split_calculated(
slf: PyRefMut<'_, Self>,
calculation: String,
split_names: Option<Vec<String>>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
})
}
pub fn shuffle(
@@ -168,3 +211,121 @@ impl PyAsyncPermutationBuilder {
})
}
}
#[pyclass(name = "PermutationReader")]
pub struct PyPermutationReader {
reader: Arc<PermutationReader>,
}
impl PyPermutationReader {
fn from_reader(reader: PermutationReader) -> Self {
Self {
reader: Arc::new(reader),
}
}
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
let Some(selection) = selection else {
return Ok(Select::All);
};
let selection = selection.downcast_into::<PyDict>()?;
let selection = selection
.iter()
.map(|(key, value)| {
let key = key.extract::<String>()?;
let value = value.extract::<String>()?;
Ok((key, value))
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Select::dynamic(&selection))
}
}
#[pymethods]
impl PyPermutationReader {
#[classmethod]
pub fn from_tables<'py>(
cls: &Bound<'py, PyType>,
base_table: Bound<'py, PyAny>,
permutation_table: Option<Bound<'py, PyAny>>,
split: u64,
) -> PyResult<Bound<'py, PyAny>> {
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
let permutation_table = permutation_table
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
.transpose()?;
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
let permutation_table = permutation_table
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
.transpose()?;
future_into_py(cls.py(), async move {
let reader = if let Some(permutation_table) = permutation_table {
PermutationReader::try_from_tables(base_table, permutation_table, split)
.await
.infer_error()?
} else {
PermutationReader::identity(base_table).await
};
Ok(Self::from_reader(reader))
})
}
#[pyo3(signature = (selection=None))]
pub fn output_schema<'py>(
slf: PyRef<'py, Self>,
selection: Option<Bound<'py, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let selection = Self::parse_selection(selection)?;
let reader = slf.reader.clone();
future_into_py(slf.py(), async move {
let schema = reader.output_schema(selection).await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = ())]
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
slf.reader.count_rows()
}
#[pyo3(signature = (offset))]
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
let reader = slf.reader.as_ref().clone();
future_into_py(slf.py(), async move {
let reader = reader.with_offset(offset).await.infer_error()?;
Ok(Self::from_reader(reader))
})
}
#[pyo3(signature = (limit))]
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
let reader = slf.reader.as_ref().clone();
future_into_py(slf.py(), async move {
let reader = reader.with_limit(limit).await.infer_error()?;
Ok(Self::from_reader(reader))
})
}
#[pyo3(signature = (selection=None, *, batch_size=None))]
pub fn read<'py>(
slf: PyRef<'py, Self>,
selection: Option<Bound<'py, PyAny>>,
batch_size: Option<u32>,
) -> PyResult<Bound<'py, PyAny>> {
let selection = Self::parse_selection(selection)?;
let reader = slf.reader.clone();
let batch_size = batch_size.unwrap_or(1024);
future_into_py(slf.py(), async move {
use lancedb::query::QueryExecutionOptions;
let mut execution_options = QueryExecutionOptions::default();
execution_options.max_batch_length = batch_size;
let stream = reader
.read(selection, execution_options)
.await
.infer_error()?;
Ok(RecordBatchStream::new(stream))
})
}
}

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
@@ -25,6 +25,8 @@ use crate::{
pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
@@ -40,6 +42,8 @@ enum PermutationDestination {
pub struct PermutationConfig {
/// Splitting configuration
split_strategy: SplitStrategy,
/// Optional names for the splits
split_names: Option<Vec<String>>,
/// Shuffle strategy
shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
@@ -112,8 +116,16 @@ impl PermutationBuilder {
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
///
/// An optional list of names can be provided for the splits. This is for convenience and the names
/// will be stored in the permutation table's config metadata.
pub fn with_split_strategy(
mut self,
split_strategy: SplitStrategy,
split_names: Option<Vec<String>>,
) -> Self {
self.config.split_strategy = split_strategy;
self.config.split_names = split_names;
self
}
@@ -193,6 +205,30 @@ impl PermutationBuilder {
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
fn add_split_names(
data: SendableRecordBatchStream,
split_names: &[String],
) -> Result<SendableRecordBatchStream> {
let schema = data
.schema()
.as_ref()
.clone()
.with_metadata(HashMap::from([(
SPLIT_NAMES_CONFIG_KEY.to_string(),
serde_json::to_string(split_names).map_err(|e| Error::Other {
message: format!("Failed to serialize split names: {}", e),
source: Some(e.into()),
})?,
)]));
let schema = Arc::new(schema);
let schema_clone = schema.clone();
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
Ok(Box::pin(SimpleRecordBatchStream {
schema: schema_clone,
stream,
}))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
@@ -249,6 +285,12 @@ impl PermutationBuilder {
// Rename _rowid to row_id
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
let streaming_data = if let Some(split_names) = &self.config.split_names {
Self::add_split_names(renamed, split_names)?
} else {
renamed
};
let (name, database) = match &self.config.destination {
PermutationDestination::Permanent(database, table_name) => {
(table_name.as_str(), database.clone())
@@ -259,10 +301,13 @@ impl PermutationBuilder {
}
};
let create_table_request =
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
let create_table_request = CreateTableRequest::new(
name.to_string(),
CreateTableData::StreamingData(streaming_data),
);
let table = database.create_table(create_table_request).await?;
Ok(Table::new(table, database))
}
}
@@ -296,10 +341,13 @@ mod tests {
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
.with_split_strategy(
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
},
None,
)
.build()
.await
.unwrap();

View File

@@ -11,14 +11,19 @@ use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
use crate::error::Error;
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
use crate::table::{AnyQuery, BaseTable};
use crate::Result;
use crate::query::{
ExecutableQuery, QueryBase, QueryExecutionOptions, QueryFilter, QueryRequest, Select,
};
use crate::table::{AnyQuery, BaseTable, Filter};
use crate::{Result, Table};
use arrow::array::AsArray;
use arrow::compute::concat_batches;
use arrow::datatypes::UInt64Type;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use futures::{StreamExt, TryStreamExt};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::io::RecordBatchStream;
use lance_arrow::RecordBatchExt;
use lance_core::error::LanceOptionExt;
use lance_core::ROW_ID;
@@ -26,43 +31,140 @@ use std::collections::HashMap;
use std::sync::Arc;
/// Reads a permutation of a source table based on row IDs stored in a separate table
#[derive(Clone)]
pub struct PermutationReader {
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
permutation_table: Option<Arc<dyn BaseTable>>,
offset: Option<u64>,
limit: Option<u64>,
available_rows: u64,
split: u64,
}
impl std::fmt::Debug for PermutationReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PermutationReader(base={}, permutation={})",
"PermutationReader(base={}, permutation={}, split={}, offset={:?}, limit={:?})",
self.base_table.name(),
self.permutation_table.name(),
self.permutation_table
.as_ref()
.map(|t| t.name())
.unwrap_or("--"),
self.split,
self.offset,
self.limit,
)
}
}
impl PermutationReader {
/// Create a new PermutationReader
pub async fn try_new(
pub async fn inner_new(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
permutation_table: Option<Arc<dyn BaseTable>>,
split: u64,
) -> Result<Self> {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
Ok(Self {
let mut slf = Self {
base_table,
permutation_table,
offset: None,
limit: None,
available_rows: 0,
split,
};
slf.validate().await?;
// Calculate the number of available rows
slf.available_rows = slf.verify_limit_offset(None, None).await?;
if slf.available_rows == 0 {
return Err(Error::InvalidInput {
message: "No rows found in the permutation table for the given split".to_string(),
});
}
Ok(slf)
}
pub async fn try_from_tables(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
split: u64,
) -> Result<Self> {
Self::inner_new(base_table, Some(permutation_table), split).await
}
pub async fn identity(base_table: Arc<dyn BaseTable>) -> Self {
Self::inner_new(base_table, None, 0).await.unwrap()
}
/// Validates the limit and offset and returns the number of rows that will be read
fn validate_limit_offset(
limit: Option<u64>,
offset: Option<u64>,
available_rows: u64,
) -> Result<u64> {
match (limit, offset) {
(Some(limit), Some(offset)) => {
if offset + limit > available_rows {
Err(Error::InvalidInput {
message: "Offset + limit is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(limit)
}
}
(None, Some(offset)) => {
if offset > available_rows {
Err(Error::InvalidInput {
message:
"Offset is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(available_rows - offset)
}
}
(Some(limit), None) => {
if limit > available_rows {
Err(Error::InvalidInput {
message:
"Limit is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(limit)
}
}
(None, None) => Ok(available_rows),
}
}
async fn verify_limit_offset(&self, limit: Option<u64>, offset: Option<u64>) -> Result<u64> {
let available_rows = if let Some(permutation_table) = &self.permutation_table {
permutation_table
.count_rows(Some(Filter::Sql(format!(
"{} = {}",
SPLIT_ID_COLUMN, self.split
))))
.await? as u64
} else {
self.base_table.count_rows(None).await? as u64
};
Self::validate_limit_offset(limit, offset, available_rows)
}
pub async fn with_offset(mut self, offset: u64) -> Result<Self> {
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
self.offset = Some(offset);
self.available_rows = available_rows;
Ok(self)
}
pub async fn with_limit(mut self, limit: u64) -> Result<Self> {
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
self.available_rows = available_rows;
self.limit = Some(limit);
Ok(self)
}
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
@@ -103,7 +205,7 @@ impl PermutationReader {
..Default::default()
};
let mut data = base_table
let data = base_table
.query(
&AnyQuery::Query(base_query),
QueryExecutionOptions {
@@ -112,25 +214,29 @@ impl PermutationReader {
},
)
.await?;
let schema = data.schema();
let Some(batch) = data.try_next().await? else {
let batches = data.try_collect::<Vec<_>>().await?;
if batches.is_empty() {
return Err(Error::InvalidInput {
message: "Base table returned no batches".to_string(),
});
};
if data.try_next().await?.is_some() {
return Err(Error::InvalidInput {
message: "Base table returned more than one batch".to_string(),
});
}
if batch.num_rows() != num_rows {
if batches.iter().map(|b| b.num_rows()).sum::<usize>() != num_rows {
return Err(Error::InvalidInput {
message: "Base table returned different number of rows than the number of row IDs"
.to_string(),
});
}
let batch = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
concat_batches(&schema, &batches)?
};
// There is no guarantee the result order will match the order provided
// so may need to restore order
let actual_row_ids = batch
@@ -230,26 +336,75 @@ impl PermutationReader {
}
}
pub async fn read_split(
async fn validate(&self) -> Result<()> {
if let Some(permutation_table) = &self.permutation_table {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
}
let avail_rows = if let Some(permutation_table) = &self.permutation_table {
permutation_table.count_rows(None).await? as u64
} else {
self.base_table.count_rows(None).await? as u64
};
Self::validate_limit_offset(self.limit, self.offset, avail_rows)?;
Ok(())
}
pub async fn read(
&self,
split: u64,
selection: Select,
execution_options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let row_ids = self
.permutation_table
// Note: this relies on the row ids query here being returned in consistent order
let row_ids = if let Some(permutation_table) = &self.permutation_table {
permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
filter: Some(QueryFilter::Sql(format!(
"{} = {}",
SPLIT_ID_COLUMN, self.split
))),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
execution_options,
)
.await?;
.await?
} else {
self.base_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![ROW_ID.to_string()]),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
execution_options,
)
.await?
};
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
let table = Table::from(self.base_table.clone());
table.query().select(selection).output_schema().await
}
pub fn count_rows(&self) -> u64 {
self.available_rows
}
}
#[cfg(test)]
@@ -321,17 +476,17 @@ mod tests {
.unwrap();
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
let reader = PermutationReader::try_new(
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// Read split 0
let mut stream = reader
.read_split(
0,
.read(
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
@@ -366,9 +521,16 @@ mod tests {
assert!(stream.try_next().await.unwrap().is_none());
// Read split 1
let mut stream = reader
.read_split(
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
1,
)
.await
.unwrap();
let mut stream = reader
.read(
Select::All,
QueryExecutionOptions {
max_batch_length: 3,

View File

@@ -34,7 +34,7 @@ pub(crate) const DEFAULT_TOP_K: usize = 10;
/// Which columns should be retrieved from the database
#[derive(Debug, Clone)]
pub enum Select {
/// Select all columns
/// Select all non-system columns
///
/// Warning: This will always be slower than selecting only the columns you need.
All,

View File

@@ -620,7 +620,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
#[derive(Clone, Debug)]
pub struct Table {
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
database: Option<Arc<dyn Database>>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -644,7 +644,7 @@ mod test_utils {
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
database: Some(database),
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -666,7 +666,7 @@ mod test_utils {
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
database: Some(database),
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -680,11 +680,21 @@ impl std::fmt::Display for Table {
}
}
impl From<Arc<dyn BaseTable>> for Table {
fn from(inner: Arc<dyn BaseTable>) -> Self {
Self {
inner,
database: None,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
}
impl Table {
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
Self {
inner,
database,
database: Some(database),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -694,7 +704,7 @@ impl Table {
}
pub fn database(&self) -> &Arc<dyn Database> {
&self.database
self.database.as_ref().unwrap()
}
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
@@ -708,7 +718,7 @@ impl Table {
) -> Self {
Self {
inner,
database,
database: Some(database),
embedding_registry,
}
}