mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -64,6 +64,36 @@ builder.filter("age > 18 AND status = 'active'");
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### persist()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
persist(connection, tableName): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure the permutation to be persisted.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **connection**: [`Connection`](Connection.md)
|
||||||
|
The connection to persist the permutation to
|
||||||
|
|
||||||
|
* **tableName**: `string`
|
||||||
|
The name of the table to create
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
builder.persist(connection, "permutation_table");
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### shuffle()
|
### shuffle()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
@@ -98,15 +128,15 @@ builder.shuffle({ seed: 42, clumpSize: 10 });
|
|||||||
### splitCalculated()
|
### splitCalculated()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
splitCalculated(calculation): PermutationBuilder
|
splitCalculated(options): PermutationBuilder
|
||||||
```
|
```
|
||||||
|
|
||||||
Configure calculated splits for the permutation.
|
Configure calculated splits for the permutation.
|
||||||
|
|
||||||
#### Parameters
|
#### Parameters
|
||||||
|
|
||||||
* **calculation**: `string`
|
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
|
||||||
SQL expression for calculating splits
|
Configuration for calculated splitting
|
||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,7 @@
|
|||||||
- [RemovalStats](interfaces/RemovalStats.md)
|
- [RemovalStats](interfaces/RemovalStats.md)
|
||||||
- [RetryConfig](interfaces/RetryConfig.md)
|
- [RetryConfig](interfaces/RetryConfig.md)
|
||||||
- [ShuffleOptions](interfaces/ShuffleOptions.md)
|
- [ShuffleOptions](interfaces/ShuffleOptions.md)
|
||||||
|
- [SplitCalculatedOptions](interfaces/SplitCalculatedOptions.md)
|
||||||
- [SplitHashOptions](interfaces/SplitHashOptions.md)
|
- [SplitHashOptions](interfaces/SplitHashOptions.md)
|
||||||
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
|
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
|
||||||
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
|
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
|
||||||
|
|||||||
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / SplitCalculatedOptions
|
||||||
|
|
||||||
|
# Interface: SplitCalculatedOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### calculation
|
||||||
|
|
||||||
|
```ts
|
||||||
|
calculation: string;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
@@ -24,6 +24,14 @@ optional discardWeight: number;
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### splitWeights
|
### splitWeights
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -37,3 +37,11 @@ optional ratios: number[];
|
|||||||
```ts
|
```ts
|
||||||
optional seed: number;
|
optional seed: number;
|
||||||
```
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
|
|||||||
@@ -29,3 +29,11 @@ optional fixed: number;
|
|||||||
```ts
|
```ts
|
||||||
optional ratios: number[];
|
optional ratios: number[];
|
||||||
```
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
|
|||||||
@@ -138,7 +138,9 @@ describe("PermutationBuilder", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("should create permutation with calculated splits", async () => {
|
test("should create permutation with calculated splits", async () => {
|
||||||
const builder = permutationBuilder(table).splitCalculated("id % 2");
|
const builder = permutationBuilder(table).splitCalculated({
|
||||||
|
calculation: "id % 2",
|
||||||
|
});
|
||||||
|
|
||||||
const permutationTable = await builder.execute();
|
const permutationTable = await builder.execute();
|
||||||
const rowCount = await permutationTable.countRows();
|
const rowCount = await permutationTable.countRows();
|
||||||
@@ -224,4 +226,146 @@ describe("PermutationBuilder", () => {
|
|||||||
// Should throw error on second execution
|
// Should throw error on second execution
|
||||||
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with random splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitRandom({
|
||||||
|
ratios: [0.3, 0.7],
|
||||||
|
seed: 42,
|
||||||
|
splitNames: ["train", "test"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric (0, 1, etc.)
|
||||||
|
// The names are metadata that can be used by higher-level APIs
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with hash splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitHash({
|
||||||
|
columns: ["id"],
|
||||||
|
splitWeights: [50, 50],
|
||||||
|
discardWeight: 0,
|
||||||
|
splitNames: ["set_a", "set_b"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with sequential splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitSequential({
|
||||||
|
ratios: [0.5, 0.5],
|
||||||
|
splitNames: ["first", "second"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBe(5);
|
||||||
|
expect(split1Count).toBe(5);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with calculated splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitCalculated({
|
||||||
|
calculation: "id % 2",
|
||||||
|
splitNames: ["even", "odd"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should persist permutation to a new table", async () => {
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const builder = permutationBuilder(table)
|
||||||
|
.splitRandom({
|
||||||
|
ratios: [0.7, 0.3],
|
||||||
|
seed: 42,
|
||||||
|
splitNames: ["train", "validation"],
|
||||||
|
})
|
||||||
|
.persist(db, "my_permutation");
|
||||||
|
|
||||||
|
// Execute the builder which will persist the table
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
|
||||||
|
// Verify the persisted table exists and can be opened
|
||||||
|
const persistedTable = await db.openTable("my_permutation");
|
||||||
|
expect(persistedTable).toBeDefined();
|
||||||
|
|
||||||
|
// Verify the persisted table has the correct number of rows
|
||||||
|
const rowCount = await persistedTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Verify splits exist (numeric split_id values)
|
||||||
|
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
|
||||||
|
// Verify the table returned by execute is the same as the persisted one
|
||||||
|
const executedRowCount = await permutationTable.countRows();
|
||||||
|
expect(executedRowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should persist permutation with multiple operations", async () => {
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const builder = permutationBuilder(table)
|
||||||
|
.filter("value > 30")
|
||||||
|
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
|
||||||
|
.shuffle({ seed: 456 })
|
||||||
|
.persist(db, "filtered_permutation");
|
||||||
|
|
||||||
|
// Execute the builder
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
|
||||||
|
// Verify the persisted table
|
||||||
|
const persistedTable = await db.openTable("filtered_permutation");
|
||||||
|
const rowCount = await persistedTable.countRows();
|
||||||
|
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
|
||||||
|
|
||||||
|
// Verify splits exist (numeric split_id values)
|
||||||
|
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(7);
|
||||||
|
|
||||||
|
// Verify the executed table matches
|
||||||
|
const executedRowCount = await permutationTable.countRows();
|
||||||
|
expect(executedRowCount).toBe(7);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ export {
|
|||||||
DeleteResult,
|
DeleteResult,
|
||||||
DropColumnsResult,
|
DropColumnsResult,
|
||||||
UpdateResult,
|
UpdateResult,
|
||||||
|
SplitCalculatedOptions,
|
||||||
SplitRandomOptions,
|
SplitRandomOptions,
|
||||||
SplitHashOptions,
|
SplitHashOptions,
|
||||||
SplitSequentialOptions,
|
SplitSequentialOptions,
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { Connection, LocalConnection } from "./connection.js";
|
||||||
import {
|
import {
|
||||||
PermutationBuilder as NativePermutationBuilder,
|
PermutationBuilder as NativePermutationBuilder,
|
||||||
Table as NativeTable,
|
Table as NativeTable,
|
||||||
ShuffleOptions,
|
ShuffleOptions,
|
||||||
|
SplitCalculatedOptions,
|
||||||
SplitHashOptions,
|
SplitHashOptions,
|
||||||
SplitRandomOptions,
|
SplitRandomOptions,
|
||||||
SplitSequentialOptions,
|
SplitSequentialOptions,
|
||||||
@@ -29,6 +31,23 @@ export class PermutationBuilder {
|
|||||||
this.inner = inner;
|
this.inner = inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the permutation to be persisted.
|
||||||
|
*
|
||||||
|
* @param connection - The connection to persist the permutation to
|
||||||
|
* @param tableName - The name of the table to create
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* builder.persist(connection, "permutation_table");
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
persist(connection: Connection, tableName: string): PermutationBuilder {
|
||||||
|
const localConnection = connection as LocalConnection;
|
||||||
|
const newInner = this.inner.persist(localConnection.inner, tableName);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configure random splits for the permutation.
|
* Configure random splits for the permutation.
|
||||||
*
|
*
|
||||||
@@ -95,15 +114,15 @@ export class PermutationBuilder {
|
|||||||
/**
|
/**
|
||||||
* Configure calculated splits for the permutation.
|
* Configure calculated splits for the permutation.
|
||||||
*
|
*
|
||||||
* @param calculation - SQL expression for calculating splits
|
* @param options - Configuration for calculated splitting
|
||||||
* @returns A new PermutationBuilder instance
|
* @returns A new PermutationBuilder instance
|
||||||
* @example
|
* @example
|
||||||
* ```ts
|
* ```ts
|
||||||
* builder.splitCalculated("user_id % 3");
|
* builder.splitCalculated("user_id % 3");
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
splitCalculated(calculation: string): PermutationBuilder {
|
splitCalculated(options: SplitCalculatedOptions): PermutationBuilder {
|
||||||
const newInner = this.inner.splitCalculated(calculation);
|
const newInner = this.inner.splitCalculated(options);
|
||||||
return new PermutationBuilder(newInner);
|
return new PermutationBuilder(newInner);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use lancedb::database::CreateTableMode;
|
use lancedb::database::{CreateTableMode, Database};
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
@@ -41,6 +41,10 @@ impl Connection {
|
|||||||
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
|
||||||
|
Ok(self.get_inner()?.database().clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
|
|||||||
pub counts: Option<Vec<i64>>,
|
pub counts: Option<Vec<i64>>,
|
||||||
pub fixed: Option<i64>,
|
pub fixed: Option<i64>,
|
||||||
pub seed: Option<i64>,
|
pub seed: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi(object)]
|
#[napi(object)]
|
||||||
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
|
|||||||
pub columns: Vec<String>,
|
pub columns: Vec<String>,
|
||||||
pub split_weights: Vec<i64>,
|
pub split_weights: Vec<i64>,
|
||||||
pub discard_weight: Option<i64>,
|
pub discard_weight: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi(object)]
|
#[napi(object)]
|
||||||
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
|
|||||||
pub ratios: Option<Vec<f64>>,
|
pub ratios: Option<Vec<f64>>,
|
||||||
pub counts: Option<Vec<i64>>,
|
pub counts: Option<Vec<i64>>,
|
||||||
pub fixed: Option<i64>,
|
pub fixed: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct SplitCalculatedOptions {
|
||||||
|
pub calculation: String,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi(object)]
|
#[napi(object)]
|
||||||
@@ -76,6 +85,16 @@ impl PermutationBuilder {
|
|||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl PermutationBuilder {
|
impl PermutationBuilder {
|
||||||
|
#[napi]
|
||||||
|
pub fn persist(
|
||||||
|
&self,
|
||||||
|
connection: &crate::connection::Connection,
|
||||||
|
table_name: String,
|
||||||
|
) -> napi::Result<Self> {
|
||||||
|
let database = connection.database()?;
|
||||||
|
self.modify(|builder| builder.persist(database, table_name))
|
||||||
|
}
|
||||||
|
|
||||||
/// Configure random splits
|
/// Configure random splits
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
||||||
@@ -107,7 +126,12 @@ impl PermutationBuilder {
|
|||||||
|
|
||||||
let seed = options.seed.map(|s| s as u64);
|
let seed = options.seed.map(|s| s as u64);
|
||||||
|
|
||||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
self.modify(|builder| {
|
||||||
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Random { seed, sizes },
|
||||||
|
options.split_names.clone(),
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure hash-based splits
|
/// Configure hash-based splits
|
||||||
@@ -120,12 +144,15 @@ impl PermutationBuilder {
|
|||||||
.collect();
|
.collect();
|
||||||
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
||||||
|
|
||||||
self.modify(|builder| {
|
self.modify(move |builder| {
|
||||||
builder.with_split_strategy(SplitStrategy::Hash {
|
builder.with_split_strategy(
|
||||||
columns: options.columns,
|
SplitStrategy::Hash {
|
||||||
split_weights,
|
columns: options.columns,
|
||||||
discard_weight,
|
split_weights,
|
||||||
})
|
discard_weight,
|
||||||
|
},
|
||||||
|
options.split_names,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,14 +185,21 @@ impl PermutationBuilder {
|
|||||||
unreachable!("One of the split arguments must be provided");
|
unreachable!("One of the split arguments must be provided");
|
||||||
};
|
};
|
||||||
|
|
||||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
self.modify(move |builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure calculated splits
|
/// Configure calculated splits
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
|
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
|
||||||
self.modify(|builder| {
|
self.modify(move |builder| {
|
||||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Calculated {
|
||||||
|
calculation: options.calculation,
|
||||||
|
},
|
||||||
|
options.split_names,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
|||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
from .table import AsyncTable
|
from .table import AsyncTable, Table
|
||||||
from ._lancedb import Session
|
from ._lancedb import Session
|
||||||
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
||||||
|
|
||||||
@@ -233,6 +233,7 @@ __all__ = [
|
|||||||
"LanceNamespaceDBConnection",
|
"LanceNamespaceDBConnection",
|
||||||
"RemoteDBConnection",
|
"RemoteDBConnection",
|
||||||
"Session",
|
"Session",
|
||||||
|
"Table",
|
||||||
"__version__",
|
"__version__",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -340,3 +340,6 @@ def async_permutation_builder(
|
|||||||
table: Table, dest_table_name: str
|
table: Table, dest_table_name: str
|
||||||
) -> AsyncPermutationBuilder: ...
|
) -> AsyncPermutationBuilder: ...
|
||||||
def fts_query_to_json(query: Any) -> str: ...
|
def fts_query_to_json(query: Any) -> str: ...
|
||||||
|
|
||||||
|
class PermutationReader:
|
||||||
|
def __init__(self, base_table: Table, permutation_table: Table): ...
|
||||||
|
|||||||
@@ -1,18 +1,63 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
from ._lancedb import async_permutation_builder
|
from deprecation import deprecated
|
||||||
|
from lancedb import AsyncConnection, DBConnection
|
||||||
|
import pyarrow as pa
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ._lancedb import async_permutation_builder, PermutationReader
|
||||||
from .table import LanceTable
|
from .table import LanceTable
|
||||||
from .background_loop import LOOP
|
from .background_loop import LOOP
|
||||||
from typing import Optional
|
from .util import batch_to_tensor
|
||||||
|
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
|
||||||
|
|
||||||
|
|
||||||
class PermutationBuilder:
|
class PermutationBuilder:
|
||||||
|
"""
|
||||||
|
A utility for creating a "permutation table" which is a table that defines an
|
||||||
|
ordering on a base table.
|
||||||
|
|
||||||
|
The permutation table does not store the actual data. It only stores row
|
||||||
|
ids and split ids to define the ordering. The [Permutation] class can be used to
|
||||||
|
read the data from the base table in the order defined by the permutation table.
|
||||||
|
|
||||||
|
Permutations can split, shuffle, and filter the data in the base table.
|
||||||
|
|
||||||
|
A filter limits the rows that are included in the permutation.
|
||||||
|
Splits divide the data into subsets (for example, a test/train split, or K
|
||||||
|
different splits for cross-validation).
|
||||||
|
Shuffling randomizes the order of the rows in the permutation.
|
||||||
|
|
||||||
|
Splits can optionally be named. If names are provided it will enable them to
|
||||||
|
be referenced by name in the future. If names are not provided then they can only
|
||||||
|
be referenced by their ordinal index. There is no requirement to name every split.
|
||||||
|
|
||||||
|
By default, the permutation will be stored in memory and will be lost when the
|
||||||
|
program exits. To persist the permutation (for very large datasets or to share
|
||||||
|
the permutation across multiple workers) use the [persist](#persist) method to
|
||||||
|
create a permanent table.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, table: LanceTable):
|
def __init__(self, table: LanceTable):
|
||||||
|
"""
|
||||||
|
Creates a new permutation builder for the given table.
|
||||||
|
|
||||||
|
By default, the permutation builder will create a single split that contains all
|
||||||
|
rows in the same order as the base table.
|
||||||
|
"""
|
||||||
self._async = async_permutation_builder(table)
|
self._async = async_permutation_builder(table)
|
||||||
|
|
||||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
def persist(
|
||||||
self._async.select(projections)
|
self, database: Union[DBConnection, AsyncConnection], table_name: str
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Persist the permutation to the given database.
|
||||||
|
"""
|
||||||
|
self._async.persist(database, table_name)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def split_random(
|
def split_random(
|
||||||
@@ -22,8 +67,38 @@ class PermutationBuilder:
|
|||||||
counts: Optional[list[int]] = None,
|
counts: Optional[list[int]] = None,
|
||||||
fixed: Optional[int] = None,
|
fixed: Optional[int] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
) -> "PermutationBuilder":
|
) -> "PermutationBuilder":
|
||||||
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
|
"""
|
||||||
|
Configure random splits for the permutation.
|
||||||
|
|
||||||
|
One of ratios, counts, or fixed must be provided.
|
||||||
|
|
||||||
|
If ratios are provided, they will be used to determine the relative size of each
|
||||||
|
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||||
|
30% of the rows and the second split will contain 70% of the rows.
|
||||||
|
|
||||||
|
If counts are provided, they will be used to determine the absolute number of
|
||||||
|
rows in each split. For example, if counts are [100, 200] then the first split
|
||||||
|
will contain 100 rows and the second split will contain 200 rows.
|
||||||
|
|
||||||
|
If fixed is provided, it will be used to determine the number of splits.
|
||||||
|
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||||
|
splits.
|
||||||
|
|
||||||
|
Rows will be randomly assigned to splits. The optional seed can be provided to
|
||||||
|
make the assignment deterministic.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_random(
|
||||||
|
ratios=ratios,
|
||||||
|
counts=counts,
|
||||||
|
fixed=fixed,
|
||||||
|
seed=seed,
|
||||||
|
split_names=split_names,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def split_hash(
|
def split_hash(
|
||||||
@@ -32,8 +107,33 @@ class PermutationBuilder:
|
|||||||
split_weights: list[int],
|
split_weights: list[int],
|
||||||
*,
|
*,
|
||||||
discard_weight: Optional[int] = None,
|
discard_weight: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
) -> "PermutationBuilder":
|
) -> "PermutationBuilder":
|
||||||
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
|
"""
|
||||||
|
Configure hash-based splits for the permutation.
|
||||||
|
|
||||||
|
First, a hash will be calculated over the specified columns. The splits weights
|
||||||
|
are then used to determine how many rows to assign to each split. For example,
|
||||||
|
if split weights are [1, 2] then the first split will contain 1/3 of the rows
|
||||||
|
and the second split will contain 2/3 of the rows.
|
||||||
|
|
||||||
|
The optional discard weight can be provided to determine what percentage of rows
|
||||||
|
should be discarded. For example, if split weights are [1, 2] and discard
|
||||||
|
weight is 1 then 25% of the rows will be discarded.
|
||||||
|
|
||||||
|
Hash-based splits are useful if you want the split to be more or less random but
|
||||||
|
you don't want the split assignments to change if rows are added or removed
|
||||||
|
from the table.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_hash(
|
||||||
|
columns,
|
||||||
|
split_weights,
|
||||||
|
discard_weight=discard_weight,
|
||||||
|
split_names=split_names,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def split_sequential(
|
def split_sequential(
|
||||||
@@ -42,25 +142,85 @@ class PermutationBuilder:
|
|||||||
ratios: Optional[list[float]] = None,
|
ratios: Optional[list[float]] = None,
|
||||||
counts: Optional[list[int]] = None,
|
counts: Optional[list[int]] = None,
|
||||||
fixed: Optional[int] = None,
|
fixed: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
) -> "PermutationBuilder":
|
) -> "PermutationBuilder":
|
||||||
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
|
"""
|
||||||
|
Configure sequential splits for the permutation.
|
||||||
|
|
||||||
|
One of ratios, counts, or fixed must be provided.
|
||||||
|
|
||||||
|
If ratios are provided, they will be used to determine the relative size of each
|
||||||
|
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||||
|
30% of the rows and the second split will contain 70% of the rows.
|
||||||
|
|
||||||
|
If counts are provided, they will be used to determine the absolute number of
|
||||||
|
rows in each split. For example, if counts are [100, 200] then the first split
|
||||||
|
will contain 100 rows and the second split will contain 200 rows.
|
||||||
|
|
||||||
|
If fixed is provided, it will be used to determine the number of splits.
|
||||||
|
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||||
|
splits.
|
||||||
|
|
||||||
|
Rows will be assigned to splits sequentially. The first N1 rows are assigned to
|
||||||
|
split 1, the next N2 rows are assigned to split 2, etc.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_sequential(
|
||||||
|
ratios=ratios, counts=counts, fixed=fixed, split_names=split_names
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def split_calculated(self, calculation: str) -> "PermutationBuilder":
|
def split_calculated(
|
||||||
self._async.split_calculated(calculation)
|
self, calculation: str, split_names: Optional[list[str]] = None
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Use pre-calculated splits for the permutation.
|
||||||
|
|
||||||
|
The calculation should be an SQL statement that returns an integer value between
|
||||||
|
0 and the number of splits - 1. For example, if you have 3 splits then the
|
||||||
|
calculation should return 0 for the first split, 1 for the second split, and 2
|
||||||
|
for the third split.
|
||||||
|
|
||||||
|
This can be used to implement any kind of user-defined split strategy.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_calculated(calculation, split_names=split_names)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def shuffle(
|
def shuffle(
|
||||||
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
||||||
) -> "PermutationBuilder":
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Randomly shuffle the rows in the permutation.
|
||||||
|
|
||||||
|
An optional seed can be provided to make the shuffle deterministic.
|
||||||
|
|
||||||
|
If a clump size is provided, then data will be shuffled as small "clumps"
|
||||||
|
of contiguous rows. This allows for a balance between randomization and
|
||||||
|
I/O performance. It can be useful when reading from cloud storage.
|
||||||
|
"""
|
||||||
self._async.shuffle(seed=seed, clump_size=clump_size)
|
self._async.shuffle(seed=seed, clump_size=clump_size)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def filter(self, filter: str) -> "PermutationBuilder":
|
def filter(self, filter: str) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Configure a filter for the permutation.
|
||||||
|
|
||||||
|
The filter should be an SQL statement that returns a boolean value for each row.
|
||||||
|
Only rows where the filter is true will be included in the permutation.
|
||||||
|
"""
|
||||||
self._async.filter(filter)
|
self._async.filter(filter)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def execute(self) -> LanceTable:
|
def execute(self) -> LanceTable:
|
||||||
|
"""
|
||||||
|
Execute the configuration and create the permutation table.
|
||||||
|
"""
|
||||||
|
|
||||||
async def do_execute():
|
async def do_execute():
|
||||||
inner_tbl = await self._async.execute()
|
inner_tbl = await self._async.execute()
|
||||||
return LanceTable.from_inner(inner_tbl)
|
return LanceTable.from_inner(inner_tbl)
|
||||||
@@ -70,3 +230,592 @@ class PermutationBuilder:
|
|||||||
|
|
||||||
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||||
return PermutationBuilder(table)
|
return PermutationBuilder(table)
|
||||||
|
|
||||||
|
|
||||||
|
class Permutations:
|
||||||
|
"""
|
||||||
|
A collection of permutations indexed by name or ordinal index.
|
||||||
|
|
||||||
|
Splits are defined when the permutation is created. Splits can always be referenced
|
||||||
|
by their ordinal index. If names were provided when the permutation was created
|
||||||
|
then they can also be referenced by name.
|
||||||
|
|
||||||
|
Each permutation or "split" is a view of a portion of the base table. For more
|
||||||
|
details see [Permutation].
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
base_table: LanceTable
|
||||||
|
The base table that the permutations are based on.
|
||||||
|
permutation_table: LanceTable
|
||||||
|
The permutation table that defines the splits.
|
||||||
|
split_names: list[str]
|
||||||
|
The names of the splits.
|
||||||
|
split_dict: dict[str, int]
|
||||||
|
A dictionary mapping split names to their ordinal index.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> # Initial data
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("memory:///")
|
||||||
|
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(1000)])
|
||||||
|
>>> # Create a permutation
|
||||||
|
>>> perm_tbl = (
|
||||||
|
... permutation_builder(tbl)
|
||||||
|
... .split_random(ratios=[0.95, 0.05], split_names=["train", "test"])
|
||||||
|
... .shuffle()
|
||||||
|
... .execute()
|
||||||
|
... )
|
||||||
|
>>> # Read the permutations
|
||||||
|
>>> permutations = Permutations(tbl, perm_tbl)
|
||||||
|
>>> permutations["train"]
|
||||||
|
<lancedb.permutation.Permutation ...>
|
||||||
|
>>> permutations[0]
|
||||||
|
<lancedb.permutation.Permutation ...>
|
||||||
|
>>> permutations.split_names
|
||||||
|
['train', 'test']
|
||||||
|
>>> permutations.split_dict
|
||||||
|
{'train': 0, 'test': 1}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_table: LanceTable, permutation_table: LanceTable):
|
||||||
|
self.base_table = base_table
|
||||||
|
self.permutation_table = permutation_table
|
||||||
|
|
||||||
|
if permutation_table.schema.metadata is not None:
|
||||||
|
split_names = permutation_table.schema.metadata.get(
|
||||||
|
b"split_names", None
|
||||||
|
).decode("utf-8")
|
||||||
|
if split_names is not None:
|
||||||
|
self.split_names = json.loads(split_names)
|
||||||
|
self.split_dict = {
|
||||||
|
name: idx for idx, name in enumerate(self.split_names)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# No split names are defined in the permutation table
|
||||||
|
self.split_names = []
|
||||||
|
self.split_dict = {}
|
||||||
|
else:
|
||||||
|
# No metadata is defined in the permutation table
|
||||||
|
self.split_names = []
|
||||||
|
self.split_dict = {}
|
||||||
|
|
||||||
|
def get_by_name(self, name: str) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Get a permutation by name.
|
||||||
|
|
||||||
|
If no split named `name` is found then an error will be raised.
|
||||||
|
"""
|
||||||
|
idx = self.split_dict.get(name, None)
|
||||||
|
if idx is None:
|
||||||
|
raise ValueError(f"No split named `{name}` found")
|
||||||
|
return self.get_by_index(idx)
|
||||||
|
|
||||||
|
def get_by_index(self, index: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Get a permutation by index.
|
||||||
|
"""
|
||||||
|
return Permutation.from_tables(self.base_table, self.permutation_table, index)
|
||||||
|
|
||||||
|
def __getitem__(self, name: Union[str, int]) -> "Permutation":
|
||||||
|
if isinstance(name, str):
|
||||||
|
return self.get_by_name(name)
|
||||||
|
elif isinstance(name, int):
|
||||||
|
return self.get_by_index(name)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid split name or index: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class Transforms:
|
||||||
|
"""
|
||||||
|
Namespace for common transformation functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
|
||||||
|
return batch.to_pydict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2arrow(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2numpy(batch: pa.RecordBatch) -> "np.ndarray":
|
||||||
|
return batch.to_pandas().to_numpy()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2pandas(batch: pa.RecordBatch) -> "pd.DataFrame":
|
||||||
|
return batch.to_pandas()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2polars() -> "pl.DataFrame":
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
def impl(batch: pa.RecordBatch) -> pl.DataFrame:
|
||||||
|
return pl.from_arrow(batch)
|
||||||
|
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
# HuggingFace uses 10 which is pretty small
|
||||||
|
DEFAULT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class Permutation:
|
||||||
|
"""
|
||||||
|
A Permutation is a view of a dataset that can be used as input to model training
|
||||||
|
and evaluation.
|
||||||
|
|
||||||
|
A Permutation fulfills the pytorch Dataset contract and is loosely modeled after the
|
||||||
|
huggingface Dataset so it should be easy to use with existing code.
|
||||||
|
|
||||||
|
A permutation is not a "materialized view" or copy of the underlying data. It is
|
||||||
|
calculated on the fly from the base table. As a result, it is truly "lazy" and does
|
||||||
|
not require materializing the entire dataset in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reader: PermutationReader,
|
||||||
|
selection: dict[str, str],
|
||||||
|
batch_size: int,
|
||||||
|
transform_fn: Callable[pa.RecordBatch, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||||
|
"""
|
||||||
|
assert reader is not None, "reader is required"
|
||||||
|
assert selection is not None, "selection is required"
|
||||||
|
self.reader = reader
|
||||||
|
self.selection = selection
|
||||||
|
self.transform_fn = transform_fn
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given selection
|
||||||
|
|
||||||
|
Does not validation of the selection and it replaces it entirely. This is not
|
||||||
|
intended for public use.
|
||||||
|
"""
|
||||||
|
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given reader
|
||||||
|
|
||||||
|
This is an internal method and should not be used directly.
|
||||||
|
"""
|
||||||
|
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given batch size
|
||||||
|
"""
|
||||||
|
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def identity(cls, table: LanceTable) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates an identity permutation for the given table.
|
||||||
|
"""
|
||||||
|
return Permutation.from_tables(table, None, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tables(
|
||||||
|
cls,
|
||||||
|
base_table: LanceTable,
|
||||||
|
permutation_table: Optional[LanceTable] = None,
|
||||||
|
split: Optional[Union[str, int]] = None,
|
||||||
|
) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a permutation from the given base table and permutation table.
|
||||||
|
|
||||||
|
A permutation table identifies which rows, and in what order, the data should
|
||||||
|
be read from the base table. For more details see the [PermutationBuilder]
|
||||||
|
class.
|
||||||
|
|
||||||
|
If no permutation table is provided, then the identity permutation will be
|
||||||
|
created. An identity permutation is a permutation that reads all rows in the
|
||||||
|
base table in the order they are stored.
|
||||||
|
|
||||||
|
The split parameter identifies which split to use. If no split is provided
|
||||||
|
then the first split will be used.
|
||||||
|
"""
|
||||||
|
assert base_table is not None, "base_table is required"
|
||||||
|
if split is not None:
|
||||||
|
if permutation_table is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no permutation table is provided"
|
||||||
|
)
|
||||||
|
if isinstance(split, str):
|
||||||
|
if permutation_table.schema.metadata is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no split names are defined in the permutation table"
|
||||||
|
)
|
||||||
|
split_names = permutation_table.schema.metadata.get(
|
||||||
|
b"split_names", None
|
||||||
|
).decode("utf-8")
|
||||||
|
if split_names is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no split names are defined in the permutation table"
|
||||||
|
)
|
||||||
|
split_names = json.loads(split_names)
|
||||||
|
try:
|
||||||
|
split = split_names.index(split)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
f" because split `{split}` is not defined in the "
|
||||||
|
"permutation table"
|
||||||
|
)
|
||||||
|
elif isinstance(split, int):
|
||||||
|
split = split
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid split: {split}")
|
||||||
|
else:
|
||||||
|
split = 0
|
||||||
|
|
||||||
|
async def do_from_tables():
|
||||||
|
reader = await PermutationReader.from_tables(
|
||||||
|
base_table, permutation_table, split
|
||||||
|
)
|
||||||
|
schema = await reader.output_schema(None)
|
||||||
|
initial_selection = {name: name for name in schema.names}
|
||||||
|
return cls(
|
||||||
|
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||||
|
)
|
||||||
|
|
||||||
|
return LOOP.run(do_from_tables())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schema(self) -> pa.Schema:
|
||||||
|
async def do_output_schema():
|
||||||
|
return await self.reader.output_schema(self.selection)
|
||||||
|
|
||||||
|
return LOOP.run(do_output_schema())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_columns(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of columns in the permutation
|
||||||
|
"""
|
||||||
|
return len(self.schema)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_rows(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of rows in the permutation
|
||||||
|
"""
|
||||||
|
return self.reader.count_rows()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def column_names(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
The names of the columns in the permutation
|
||||||
|
"""
|
||||||
|
return self.schema.names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
The shape of the permutation
|
||||||
|
|
||||||
|
This will return self.num_rows, self.num_columns
|
||||||
|
"""
|
||||||
|
return self.num_rows, self.num_columns
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of rows in the permutation
|
||||||
|
|
||||||
|
This is an alias for [num_rows][lancedb.permutation.Permutation.num_rows]
|
||||||
|
"""
|
||||||
|
return self.num_rows
|
||||||
|
|
||||||
|
def unique(self, _column: str) -> list[Any]:
|
||||||
|
"""
|
||||||
|
Get the unique values in the given column
|
||||||
|
"""
|
||||||
|
raise Exception("unique is not yet implemented")
|
||||||
|
|
||||||
|
def flatten(self) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Flatten the permutation
|
||||||
|
|
||||||
|
Each column with a struct type will be flattened into multiple columns.
|
||||||
|
|
||||||
|
This flattening operation happens at read time as a post-processing step
|
||||||
|
so this call is cheap and no data is copied or modified in the underlying
|
||||||
|
dataset.
|
||||||
|
"""
|
||||||
|
raise Exception("flatten is not yet implemented")
|
||||||
|
|
||||||
|
def remove_columns(self, columns: list[str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Remove the given columns from the permutation
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
which columns are visible from this permutation. Also, this does not introduce
|
||||||
|
a post-processing step. Instead, we simply do not read those columns in the
|
||||||
|
first place.
|
||||||
|
|
||||||
|
If any of the provided columns does not exist in the current permutation then it
|
||||||
|
will be ignored (no error is raised for missing columns)
|
||||||
|
|
||||||
|
Returns a new permutation with the given columns removed. This does not modify
|
||||||
|
self.
|
||||||
|
"""
|
||||||
|
assert columns is not None, "columns is required"
|
||||||
|
|
||||||
|
new_selection = {
|
||||||
|
name: value for name, value in self.selection.items() if name not in columns
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(new_selection) == 0:
|
||||||
|
raise ValueError("Cannot remove all columns")
|
||||||
|
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def rename_column(self, old_name: str, new_name: str) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Rename a column in the permutation
|
||||||
|
|
||||||
|
If there is no column named old_name then an error will be raised
|
||||||
|
If there is already a column named new_name then an error will be raised
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
the name of the column that is visible from this permutation. This is a
|
||||||
|
post-processing step but done at the batch level and so it is very cheap.
|
||||||
|
No data will be copied.
|
||||||
|
"""
|
||||||
|
assert old_name is not None, "old_name is required"
|
||||||
|
assert new_name is not None, "new_name is required"
|
||||||
|
if old_name not in self.selection:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot rename column `{old_name}` because it does not exist"
|
||||||
|
)
|
||||||
|
if new_name in self.selection:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot rename column `{old_name}` to `{new_name}` because a column "
|
||||||
|
"with that name already exists"
|
||||||
|
)
|
||||||
|
new_selection = self.selection.copy()
|
||||||
|
new_selection[new_name] = new_selection[old_name]
|
||||||
|
del new_selection[old_name]
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def rename_columns(self, column_map: dict[str, str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Rename the given columns in the permutation
|
||||||
|
|
||||||
|
If any of the columns do not exist then an error will be raised
|
||||||
|
If any of the new names already exist then an error will be raised
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
the name of the column that is visible from this permutation. This is a
|
||||||
|
post-processing step but done at the batch level and so it is very cheap.
|
||||||
|
No data will be copied.
|
||||||
|
"""
|
||||||
|
assert column_map is not None, "column_map is required"
|
||||||
|
|
||||||
|
new_permutation = self
|
||||||
|
for old_name, new_name in column_map.items():
|
||||||
|
new_permutation = new_permutation.rename_column(old_name, new_name)
|
||||||
|
return new_permutation
|
||||||
|
|
||||||
|
def select_columns(self, columns: list[str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Select the given columns from the permutation
|
||||||
|
|
||||||
|
This method refines the current selection, potentially removing columns. It
|
||||||
|
will not add back columns that were previously removed.
|
||||||
|
|
||||||
|
If any of the columns do not exist then an error will be raised
|
||||||
|
|
||||||
|
This does not introduce a post-processing step. It simply reduces the amount
|
||||||
|
of data we read.
|
||||||
|
"""
|
||||||
|
assert columns is not None, "columns is required"
|
||||||
|
if len(columns) == 0:
|
||||||
|
raise ValueError("Must select at least one column")
|
||||||
|
|
||||||
|
new_selection = {}
|
||||||
|
for name in columns:
|
||||||
|
value = self.selection.get(name, None)
|
||||||
|
if value is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot select column `{name}` because it does not exist"
|
||||||
|
)
|
||||||
|
new_selection[name] = value
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Iterate over the permutation
|
||||||
|
"""
|
||||||
|
return self.iter(self.batch_size, skip_last_batch=True)
|
||||||
|
|
||||||
|
def iter(
|
||||||
|
self, batch_size: int, skip_last_batch: bool = False
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Iterate over the permutation in batches
|
||||||
|
|
||||||
|
If skip_last_batch is True, the last batch will be skipped if it is not a
|
||||||
|
multiple of batch_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_iter():
|
||||||
|
return await self.reader.read(self.selection, batch_size=batch_size)
|
||||||
|
|
||||||
|
async_iter = LOOP.run(get_iter())
|
||||||
|
|
||||||
|
async def get_next():
|
||||||
|
return await async_iter.__anext__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
batch = LOOP.run(get_next())
|
||||||
|
if batch.num_rows == batch_size or not skip_last_batch:
|
||||||
|
yield self.transform_fn(batch)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
def with_format(
|
||||||
|
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
|
||||||
|
) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Set the format for batches
|
||||||
|
|
||||||
|
If this method is not called, the "python" format will be used.
|
||||||
|
|
||||||
|
The format can be one of:
|
||||||
|
- "numpy" - the batch will be a dict of numpy arrays (one per column)
|
||||||
|
- "python" - the batch will be a dict of lists (one per column)
|
||||||
|
- "pandas" - the batch will be a pandas DataFrame
|
||||||
|
- "arrow" - the batch will be a pyarrow RecordBatch
|
||||||
|
- "torch" - the batch will be a two dimensional torch tensor
|
||||||
|
- "polars" - the batch will be a polars DataFrame
|
||||||
|
|
||||||
|
Conversion may or may not involve a data copy. Lance uses Arrow internally
|
||||||
|
and so it is able to zero-copy to the arrow and polars.
|
||||||
|
|
||||||
|
Conversion to torch will be zero-copy but will only support a subset of data
|
||||||
|
types (numeric types).
|
||||||
|
|
||||||
|
Conversion to numpy and/or pandas will typically be zero-copy for numeric
|
||||||
|
types. Conversion of strings, lists, and structs will require creating python
|
||||||
|
objects and this is not zero-copy.
|
||||||
|
|
||||||
|
For custom formatting, use [with_transform](#with_transform) which overrides
|
||||||
|
this method.
|
||||||
|
"""
|
||||||
|
assert format is not None, "format is required"
|
||||||
|
if format == "python":
|
||||||
|
return self.with_transform(Transforms.arrow2python)
|
||||||
|
elif format == "numpy":
|
||||||
|
return self.with_transform(Transforms.arrow2numpy)
|
||||||
|
elif format == "pandas":
|
||||||
|
return self.with_transform(Transforms.arrow2pandas)
|
||||||
|
elif format == "arrow":
|
||||||
|
return self.with_transform(Transforms.arrow2arrow)
|
||||||
|
elif format == "torch":
|
||||||
|
return self.with_transform(batch_to_tensor)
|
||||||
|
elif format == "polars":
|
||||||
|
return self.with_transform(Transforms.arrow2polars())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid format: {format}")
|
||||||
|
|
||||||
|
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Set a custom transform for the permutation
|
||||||
|
|
||||||
|
The transform is a callable that will be invoked with each record batch. The
|
||||||
|
return value will be used as the batch for iteration.
|
||||||
|
|
||||||
|
Note: transforms are not invoked in parallel. This method is not a good place
|
||||||
|
for expensive operations such as image decoding.
|
||||||
|
"""
|
||||||
|
assert transform is not None, "transform is required"
|
||||||
|
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Any:
|
||||||
|
"""
|
||||||
|
Return a single row from the permutation
|
||||||
|
|
||||||
|
The output will always be a python dictionary regardless of the format.
|
||||||
|
|
||||||
|
This method is mostly useful for debugging and exploration. For actual
|
||||||
|
processing use [iter](#iter) or a torch data loader to perform batched
|
||||||
|
processing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecated(details="Use with_skip instead")
|
||||||
|
def skip(self, skip: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Skip the first `skip` rows of the permutation
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_skip](#with_skip) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_skip(skip)
|
||||||
|
|
||||||
|
def with_skip(self, skip: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Skip the first `skip` rows of the permutation
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def do_with_skip():
|
||||||
|
reader = await self.reader.with_offset(skip)
|
||||||
|
return self._with_reader(reader)
|
||||||
|
|
||||||
|
return LOOP.run(do_with_skip())
|
||||||
|
|
||||||
|
@deprecated(details="Use with_take instead")
|
||||||
|
def take(self, limit: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Limit the permutation to `limit` rows (following any `skip`)
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_take](#with_take) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_take(limit)
|
||||||
|
|
||||||
|
def with_take(self, limit: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Limit the permutation to `limit` rows (following any `skip`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def do_with_take():
|
||||||
|
reader = await self.reader.with_limit(limit)
|
||||||
|
return self._with_reader(reader)
|
||||||
|
|
||||||
|
return LOOP.run(do_with_take())
|
||||||
|
|
||||||
|
@deprecated(details="Use with_repeat instead")
|
||||||
|
def repeat(self, times: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Repeat the permutation `times` times
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_repeat](#with_repeat) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_repeat(times)
|
||||||
|
|
||||||
|
def with_repeat(self, times: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Repeat the permutation `times` times
|
||||||
|
"""
|
||||||
|
raise Exception("with_repeat is not yet implemented")
|
||||||
|
|||||||
@@ -366,3 +366,56 @@ def add_note(base_exception: BaseException, note: str):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot add note to exception")
|
raise ValueError("Cannot add note to exception")
|
||||||
|
|
||||||
|
|
||||||
|
def tbl_to_tensor(tbl: pa.Table):
|
||||||
|
"""
|
||||||
|
Convert a PyArrow Table to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||||
|
and the columns are then stacked into a single tensor.
|
||||||
|
|
||||||
|
Fails if torch is not installed.
|
||||||
|
Fails if any column is more than one chunk.
|
||||||
|
Fails if a column's data type is not supported by PyTorch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tbl : pa.Table or pa.RecordBatch
|
||||||
|
The table or record batch to convert to a tensor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor: The tensor containing the columns of the table.
|
||||||
|
"""
|
||||||
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
|
||||||
|
def to_tensor(col: pa.ChunkedArray):
|
||||||
|
if col.num_chunks > 1:
|
||||||
|
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||||
|
return torch.from_dlpack(col.chunk(0))
|
||||||
|
|
||||||
|
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||||
|
|
||||||
|
|
||||||
|
def batch_to_tensor(batch: pa.RecordBatch):
|
||||||
|
"""
|
||||||
|
Convert a PyArrow RecordBatch to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||||
|
and the columns are then stacked into a single tensor.
|
||||||
|
|
||||||
|
Fails if torch is not installed.
|
||||||
|
Fails if a column's data type is not supported by PyTorch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch : pa.RecordBatch
|
||||||
|
The record batch to convert to a tensor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor: The tensor containing the columns of the record batch.
|
||||||
|
"""
|
||||||
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
|
||||||
|
|||||||
@@ -2,9 +2,26 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
import math
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lancedb.permutation import permutation_builder
|
from lancedb import DBConnection, Table, connect
|
||||||
|
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||||
|
|
||||||
|
|
||||||
|
def test_permutation_persistence(tmp_path):
|
||||||
|
db = connect(tmp_path)
|
||||||
|
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
|
||||||
|
)
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
re_open = db.open_table("test_permutation")
|
||||||
|
assert re_open.count_rows() == 100
|
||||||
|
|
||||||
|
assert permutation_tbl.to_arrow() == re_open.to_arrow()
|
||||||
|
|
||||||
|
|
||||||
def test_split_random_ratios(mem_db):
|
def test_split_random_ratios(mem_db):
|
||||||
@@ -195,21 +212,33 @@ def test_split_error_cases(mem_db):
|
|||||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
||||||
|
|
||||||
# Test split_random with no parameters
|
# Test split_random with no parameters
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
permutation_builder(tbl).split_random().execute()
|
permutation_builder(tbl).split_random().execute()
|
||||||
|
|
||||||
# Test split_random with multiple parameters
|
# Test split_random with multiple parameters
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
permutation_builder(tbl).split_random(
|
permutation_builder(tbl).split_random(
|
||||||
ratios=[0.5, 0.5], counts=[5, 5]
|
ratios=[0.5, 0.5], counts=[5, 5]
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
# Test split_sequential with no parameters
|
# Test split_sequential with no parameters
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
permutation_builder(tbl).split_sequential().execute()
|
permutation_builder(tbl).split_sequential().execute()
|
||||||
|
|
||||||
# Test split_sequential with multiple parameters
|
# Test split_sequential with multiple parameters
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||||
|
|
||||||
|
|
||||||
@@ -460,3 +489,455 @@ def test_filter_empty_result(mem_db):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert permutation_tbl.count_rows() == 0
|
assert permutation_tbl.count_rows() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mem_db() -> DBConnection:
|
||||||
|
return connect("memory:///")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_table(mem_db: DBConnection) -> Table:
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"id": range(1000),
|
||||||
|
"value": range(1000),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return mem_db.create_table("some_table", data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_split_names(some_table: Table):
|
||||||
|
perm_tbl = (
|
||||||
|
permutation_builder(some_table).split_sequential(counts=[500, 500]).execute()
|
||||||
|
)
|
||||||
|
permutations = Permutations(some_table, perm_tbl)
|
||||||
|
assert permutations.split_names == []
|
||||||
|
assert permutations.split_dict == {}
|
||||||
|
assert permutations[0].num_rows == 500
|
||||||
|
assert permutations[1].num_rows == 500
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_perm_table(some_table: Table) -> Table:
|
||||||
|
return (
|
||||||
|
permutation_builder(some_table)
|
||||||
|
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
|
||||||
|
.shuffle(seed=42)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nonexistent_split(some_table: Table, some_perm_table: Table):
|
||||||
|
# Reference by name and name does not exist
|
||||||
|
with pytest.raises(ValueError, match="split `nonexistent` is not defined"):
|
||||||
|
Permutation.from_tables(some_table, some_perm_table, "nonexistent")
|
||||||
|
|
||||||
|
# Reference by ordinal and there are no rows
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
Permutation.from_tables(some_table, some_perm_table, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_permutations(some_table: Table, some_perm_table: Table):
|
||||||
|
permutations = Permutations(some_table, some_perm_table)
|
||||||
|
assert permutations.split_names == ["train", "test"]
|
||||||
|
assert permutations.split_dict == {"train": 0, "test": 1}
|
||||||
|
assert permutations["train"].num_rows == 950
|
||||||
|
assert permutations[0].num_rows == 950
|
||||||
|
assert permutations["test"].num_rows == 50
|
||||||
|
assert permutations[1].num_rows == 50
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No split named `nonexistent` found"):
|
||||||
|
permutations["nonexistent"]
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
permutations[5]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
|
||||||
|
return Permutation.from_tables(some_table, some_perm_table)
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_rows(some_permutation: Permutation):
|
||||||
|
assert some_permutation.num_rows == 950
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.num_columns == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_column_names(some_permutation: Permutation):
|
||||||
|
assert some_permutation.column_names == ["id", "value"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shape(some_permutation: Permutation):
|
||||||
|
assert some_permutation.shape == (950, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema(some_permutation: Permutation):
|
||||||
|
assert some_permutation.schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("value", pa.int64())]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_limit_offset(some_permutation: Permutation):
|
||||||
|
assert some_permutation.with_take(100).num_rows == 100
|
||||||
|
assert some_permutation.with_skip(100).num_rows == 850
|
||||||
|
assert some_permutation.with_take(100).with_skip(100).num_rows == 100
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_take(1000000).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_skip(1000000).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_take(500).with_skip(500).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_skip(500).with_take(500).num_rows
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||||
|
[("id", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot remove all columns
|
||||||
|
with pytest.raises(ValueError, match="Cannot remove all columns"):
|
||||||
|
some_permutation.remove_columns(["id", "value"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_column(some_permutation: Permutation):
|
||||||
|
assert some_permutation.rename_column("value", "new_value").schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot rename to an existing column
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="a column with that name already exists",
|
||||||
|
):
|
||||||
|
some_permutation.rename_column("value", "id")
|
||||||
|
# Cannot rename a non-existent column
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="does not exist",
|
||||||
|
):
|
||||||
|
some_permutation.rename_column("non_existent", "new_value")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot rename to an existing column
|
||||||
|
with pytest.raises(ValueError, match="a column with that name already exists"):
|
||||||
|
some_permutation.rename_columns({"value": "id"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.select_columns(["id"]).schema == pa.schema(
|
||||||
|
[("id", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot select a non-existent column
|
||||||
|
with pytest.raises(ValueError, match="does not exist"):
|
||||||
|
some_permutation.select_columns(["non_existent"])
|
||||||
|
# Empty selection is not allowed
|
||||||
|
with pytest.raises(ValueError, match="select at least one column"):
|
||||||
|
some_permutation.select_columns([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_basic(some_permutation: Permutation):
|
||||||
|
"""Test basic iteration with custom batch size."""
|
||||||
|
batch_size = 100
|
||||||
|
batches = list(some_permutation.iter(batch_size, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that we got the expected number of batches
|
||||||
|
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
|
||||||
|
assert len(batches) == expected_batches
|
||||||
|
|
||||||
|
# Check that all batches are dicts (default python format)
|
||||||
|
assert all(isinstance(batch, dict) for batch in batches)
|
||||||
|
|
||||||
|
# Check that batches have the correct structure
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "value" in batch
|
||||||
|
assert isinstance(batch["id"], list)
|
||||||
|
assert isinstance(batch["value"], list)
|
||||||
|
|
||||||
|
# Check that all batches except the last have the correct size
|
||||||
|
for batch in batches[:-1]:
|
||||||
|
assert len(batch["id"]) == batch_size
|
||||||
|
assert len(batch["value"]) == batch_size
|
||||||
|
|
||||||
|
# Last batch might be smaller
|
||||||
|
assert len(batches[-1]["id"]) <= batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_skip_last_batch(some_permutation: Permutation):
|
||||||
|
"""Test iteration with skip_last_batch=True."""
|
||||||
|
batch_size = 300
|
||||||
|
batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True))
|
||||||
|
batches_without_skip = list(
|
||||||
|
some_permutation.iter(batch_size, skip_last_batch=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# With skip_last_batch=True, we should have fewer batches if the last one is partial
|
||||||
|
num_full_batches = 950 // batch_size
|
||||||
|
assert len(batches_with_skip) == num_full_batches
|
||||||
|
|
||||||
|
# Without skip_last_batch, we should have one more batch if there's a remainder
|
||||||
|
if 950 % batch_size != 0:
|
||||||
|
assert len(batches_without_skip) == num_full_batches + 1
|
||||||
|
# Last batch should be smaller
|
||||||
|
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
|
||||||
|
|
||||||
|
# All batches with skip_last_batch should be full size
|
||||||
|
for batch in batches_with_skip:
|
||||||
|
assert len(batch["id"]) == batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_different_batch_sizes(some_permutation: Permutation):
|
||||||
|
"""Test iteration with different batch sizes."""
|
||||||
|
|
||||||
|
# Test with small batch size
|
||||||
|
small_batches = list(some_permutation.iter(100, skip_last_batch=False))
|
||||||
|
assert len(small_batches) == 10 # ceiling(950 / 100)
|
||||||
|
|
||||||
|
# Test with large batch size
|
||||||
|
large_batches = list(some_permutation.iter(400, skip_last_batch=False))
|
||||||
|
assert len(large_batches) == 3 # ceiling(950 / 400)
|
||||||
|
|
||||||
|
# Test with batch size equal to total rows
|
||||||
|
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
|
||||||
|
assert len(single_batch) == 1
|
||||||
|
assert len(single_batch[0]["id"]) == 950
|
||||||
|
|
||||||
|
# Test with batch size larger than total rows
|
||||||
|
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
|
||||||
|
assert len(oversized_batch) == 1
|
||||||
|
assert len(oversized_batch[0]["id"]) == 950
|
||||||
|
|
||||||
|
|
||||||
|
def test_dunder_iter(some_permutation: Permutation):
|
||||||
|
"""Test the __iter__ method."""
|
||||||
|
# __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True
|
||||||
|
batches = list(some_permutation)
|
||||||
|
|
||||||
|
# With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches
|
||||||
|
assert len(batches) == 9 # ceiling(950 / 100)
|
||||||
|
|
||||||
|
# All batches should be full size
|
||||||
|
for batch in batches:
|
||||||
|
assert len(batch["id"]) == 100
|
||||||
|
assert len(batch["value"]) == 100
|
||||||
|
|
||||||
|
some_permutation = some_permutation.with_batch_size(400)
|
||||||
|
batches = list(some_permutation)
|
||||||
|
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
|
||||||
|
for batch in batches:
|
||||||
|
assert len(batch["id"]) == 400
|
||||||
|
assert len(batch["value"]) == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_different_formats(some_permutation: Permutation):
|
||||||
|
"""Test iteration with different output formats."""
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
# Test with arrow format
|
||||||
|
arrow_perm = some_permutation.with_format("arrow")
|
||||||
|
arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches)
|
||||||
|
|
||||||
|
# Test with python format (default)
|
||||||
|
python_perm = some_permutation.with_format("python")
|
||||||
|
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
assert all(isinstance(batch, dict) for batch in python_batches)
|
||||||
|
|
||||||
|
# Test with pandas format
|
||||||
|
pandas_perm = some_permutation.with_format("pandas")
|
||||||
|
pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
# Import pandas to check the type
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_column_selection(some_permutation: Permutation):
|
||||||
|
"""Test iteration after column selection."""
|
||||||
|
# Select only the id column
|
||||||
|
id_only = some_permutation.select_columns(["id"])
|
||||||
|
batches = list(id_only.iter(100, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that batches only contain the id column
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "value" not in batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_column_rename(some_permutation: Permutation):
|
||||||
|
"""Test iteration after renaming columns."""
|
||||||
|
renamed = some_permutation.rename_column("value", "data")
|
||||||
|
batches = list(renamed.iter(100, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that batches have the renamed column
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "data" in batch
|
||||||
|
assert "value" not in batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_limit_offset(some_permutation: Permutation):
|
||||||
|
"""Test iteration with limit and offset."""
|
||||||
|
# Test with offset
|
||||||
|
offset_perm = some_permutation.with_skip(100)
|
||||||
|
offset_batches = list(offset_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 850 rows (950 - 100)
|
||||||
|
expected_batches = math.ceil(850 / 100)
|
||||||
|
assert len(offset_batches) == expected_batches
|
||||||
|
|
||||||
|
# Test with limit
|
||||||
|
limit_perm = some_permutation.with_take(500)
|
||||||
|
limit_batches = list(limit_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 5 batches (500 / 100)
|
||||||
|
assert len(limit_batches) == 5
|
||||||
|
|
||||||
|
no_skip = some_permutation.iter(101, skip_last_batch=False)
|
||||||
|
row_100 = next(no_skip)["id"][100]
|
||||||
|
|
||||||
|
# Test with both limit and offset
|
||||||
|
limited_perm = some_permutation.with_skip(100).with_take(300)
|
||||||
|
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 3 batches (300 / 100)
|
||||||
|
assert len(limited_batches) == 3
|
||||||
|
assert limited_batches[0]["id"][0] == row_100
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_empty_permutation(mem_db):
|
||||||
|
"""Test iteration over an empty permutation."""
|
||||||
|
# Create a table and filter it to be empty
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation_tbl = permutation_builder(tbl).filter("value > 100").execute()
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
Permutation.from_tables(tbl, permutation_tbl)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_single_row(mem_db):
|
||||||
|
"""Test iteration over a permutation with a single row."""
|
||||||
|
tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]}))
|
||||||
|
permutation_tbl = permutation_builder(tbl).execute()
|
||||||
|
perm = Permutation.from_tables(tbl, permutation_tbl)
|
||||||
|
|
||||||
|
# With skip_last_batch=False, should get one batch
|
||||||
|
batches = list(perm.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
assert len(batches[0]["id"]) == 1
|
||||||
|
|
||||||
|
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
|
||||||
|
batches_skip = list(perm.iter(10, skip_last_batch=True))
|
||||||
|
assert len(batches_skip) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_permutation(mem_db):
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
assert permutation.num_rows == 10
|
||||||
|
assert permutation.num_columns == 2
|
||||||
|
|
||||||
|
batches = list(permutation.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
assert len(batches[0]["id"]) == 10
|
||||||
|
assert len(batches[0]["value"]) == 10
|
||||||
|
|
||||||
|
permutation = permutation.remove_columns(["value"])
|
||||||
|
assert permutation.num_columns == 1
|
||||||
|
assert permutation.schema == pa.schema([("id", pa.int64())])
|
||||||
|
assert permutation.column_names == ["id"]
|
||||||
|
assert permutation.shape == (10, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_fn(mem_db):
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert np_result.shape == (10, 2)
|
||||||
|
assert np_result.dtype == np.int64
|
||||||
|
assert isinstance(np_result, np.ndarray)
|
||||||
|
|
||||||
|
pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert pd_result.shape == (10, 2)
|
||||||
|
assert pd_result.dtypes.tolist() == [np.int64, np.int64]
|
||||||
|
assert isinstance(pd_result, pd.DataFrame)
|
||||||
|
|
||||||
|
pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert pl_result.shape == (10, 2)
|
||||||
|
assert pl_result.dtypes == [pl.Int64, pl.Int64]
|
||||||
|
assert isinstance(pl_result, pl.DataFrame)
|
||||||
|
|
||||||
|
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert len(py_result) == 2
|
||||||
|
assert len(py_result["id"]) == 10
|
||||||
|
assert len(py_result["value"]) == 10
|
||||||
|
assert isinstance(py_result, dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_result = list(
|
||||||
|
permutation.with_format("torch").iter(10, skip_last_batch=False)
|
||||||
|
)[0]
|
||||||
|
assert torch_result.shape == (2, 10)
|
||||||
|
assert torch_result.dtype == torch.int64
|
||||||
|
assert isinstance(torch_result, torch.Tensor)
|
||||||
|
except ImportError:
|
||||||
|
# Skip check if torch is not installed
|
||||||
|
pass
|
||||||
|
|
||||||
|
arrow_result = list(
|
||||||
|
permutation.with_format("arrow").iter(10, skip_last_batch=False)
|
||||||
|
)[0]
|
||||||
|
assert arrow_result.shape == (10, 2)
|
||||||
|
assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())])
|
||||||
|
assert isinstance(arrow_result, pa.RecordBatch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_transform(mem_db):
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
def transform(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||||
|
return batch.select(["id"])
|
||||||
|
|
||||||
|
transformed = permutation.with_transform(transform)
|
||||||
|
batches = list(transformed.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
batch = batches[0]
|
||||||
|
|
||||||
|
assert batch == pa.record_batch([range(10)], ["id"])
|
||||||
|
|||||||
@@ -3,19 +3,11 @@
|
|||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
from lancedb.util import tbl_to_tensor
|
||||||
|
|
||||||
torch = pytest.importorskip("torch")
|
torch = pytest.importorskip("torch")
|
||||||
|
|
||||||
|
|
||||||
def tbl_to_tensor(tbl):
|
|
||||||
def to_tensor(col: pa.ChunkedArray):
|
|
||||||
if col.num_chunks > 1:
|
|
||||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
|
||||||
return torch.from_dlpack(col.chunk(0))
|
|
||||||
|
|
||||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
|
||||||
|
|
||||||
|
|
||||||
def test_table_dataloader(mem_db):
|
def test_table_dataloader(mem_db):
|
||||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
|
|||||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||||
use lancedb::{
|
use lancedb::{
|
||||||
connection::Connection as LanceConnection,
|
connection::Connection as LanceConnection,
|
||||||
database::{CreateTableMode, ReadConsistency},
|
database::{CreateTableMode, Database, ReadConsistency},
|
||||||
};
|
};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
@@ -42,6 +42,10 @@ impl Connection {
|
|||||||
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn database(&self) -> PyResult<Arc<dyn Database>> {
|
||||||
|
Ok(self.get_inner()?.database().clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use arrow::RecordBatchStream;
|
|||||||
use connection::{connect, Connection};
|
use connection::{connect, Connection};
|
||||||
use env_logger::Env;
|
use env_logger::Env;
|
||||||
use index::IndexConfig;
|
use index::IndexConfig;
|
||||||
use permutation::PyAsyncPermutationBuilder;
|
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
pymodule,
|
pymodule,
|
||||||
types::{PyModule, PyModuleMethods},
|
types::{PyModule, PyModuleMethods},
|
||||||
@@ -52,6 +52,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
m.add_class::<DropColumnsResult>()?;
|
m.add_class::<DropColumnsResult>()?;
|
||||||
m.add_class::<UpdateResult>()?;
|
m.add_class::<UpdateResult>()?;
|
||||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||||
|
m.add_class::<PyPermutationReader>()?;
|
||||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||||
|
|||||||
@@ -3,14 +3,23 @@
|
|||||||
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use crate::{error::PythonErrorExt, table::Table};
|
use crate::{
|
||||||
use lancedb::dataloader::{
|
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
};
|
||||||
permutation::split::{SplitSizes, SplitStrategy},
|
use arrow::pyarrow::ToPyArrow;
|
||||||
|
use lancedb::{
|
||||||
|
dataloader::permutation::{
|
||||||
|
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||||
|
reader::PermutationReader,
|
||||||
|
split::{SplitSizes, SplitStrategy},
|
||||||
|
},
|
||||||
|
query::Select,
|
||||||
};
|
};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
exceptions::PyRuntimeError,
|
||||||
PyResult,
|
pyclass, pymethods,
|
||||||
|
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||||
|
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||||
};
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
@@ -56,13 +65,32 @@ impl PyAsyncPermutationBuilder {
|
|||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyAsyncPermutationBuilder {
|
impl PyAsyncPermutationBuilder {
|
||||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
#[pyo3(signature = (database, table_name))]
|
||||||
|
pub fn persist(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
database: Bound<'_, PyAny>,
|
||||||
|
table_name: String,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let conn = if database.hasattr("_conn")? {
|
||||||
|
database
|
||||||
|
.getattr("_conn")?
|
||||||
|
.getattr("_inner")?
|
||||||
|
.downcast_into::<Connection>()?
|
||||||
|
} else {
|
||||||
|
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||||
|
};
|
||||||
|
let database = conn.borrow().database()?;
|
||||||
|
slf.modify(|builder| builder.persist(database, table_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||||
pub fn split_random(
|
pub fn split_random(
|
||||||
slf: PyRefMut<'_, Self>,
|
slf: PyRefMut<'_, Self>,
|
||||||
ratios: Option<Vec<f64>>,
|
ratios: Option<Vec<f64>>,
|
||||||
counts: Option<Vec<u64>>,
|
counts: Option<Vec<u64>>,
|
||||||
fixed: Option<u64>,
|
fixed: Option<u64>,
|
||||||
seed: Option<u64>,
|
seed: Option<u64>,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
// Check that exactly one split type is provided
|
// Check that exactly one split type is provided
|
||||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||||
@@ -86,31 +114,38 @@ impl PyAsyncPermutationBuilder {
|
|||||||
unreachable!("One of the split arguments must be provided");
|
unreachable!("One of the split arguments must be provided");
|
||||||
};
|
};
|
||||||
|
|
||||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
||||||
pub fn split_hash(
|
pub fn split_hash(
|
||||||
slf: PyRefMut<'_, Self>,
|
slf: PyRefMut<'_, Self>,
|
||||||
columns: Vec<String>,
|
columns: Vec<String>,
|
||||||
split_weights: Vec<u64>,
|
split_weights: Vec<u64>,
|
||||||
discard_weight: u64,
|
discard_weight: u64,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
slf.modify(|builder| {
|
slf.modify(|builder| {
|
||||||
builder.with_split_strategy(SplitStrategy::Hash {
|
builder.with_split_strategy(
|
||||||
columns,
|
SplitStrategy::Hash {
|
||||||
split_weights,
|
columns,
|
||||||
discard_weight,
|
split_weights,
|
||||||
})
|
discard_weight,
|
||||||
|
},
|
||||||
|
split_names,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
||||||
pub fn split_sequential(
|
pub fn split_sequential(
|
||||||
slf: PyRefMut<'_, Self>,
|
slf: PyRefMut<'_, Self>,
|
||||||
ratios: Option<Vec<f64>>,
|
ratios: Option<Vec<f64>>,
|
||||||
counts: Option<Vec<u64>>,
|
counts: Option<Vec<u64>>,
|
||||||
fixed: Option<u64>,
|
fixed: Option<u64>,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
// Check that exactly one split type is provided
|
// Check that exactly one split type is provided
|
||||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||||
@@ -134,11 +169,19 @@ impl PyAsyncPermutationBuilder {
|
|||||||
unreachable!("One of the split arguments must be provided");
|
unreachable!("One of the split arguments must be provided");
|
||||||
};
|
};
|
||||||
|
|
||||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
pub fn split_calculated(
|
||||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
slf: PyRefMut<'_, Self>,
|
||||||
|
calculation: String,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shuffle(
|
pub fn shuffle(
|
||||||
@@ -168,3 +211,121 @@ impl PyAsyncPermutationBuilder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass(name = "PermutationReader")]
|
||||||
|
pub struct PyPermutationReader {
|
||||||
|
reader: Arc<PermutationReader>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyPermutationReader {
|
||||||
|
fn from_reader(reader: PermutationReader) -> Self {
|
||||||
|
Self {
|
||||||
|
reader: Arc::new(reader),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
||||||
|
let Some(selection) = selection else {
|
||||||
|
return Ok(Select::All);
|
||||||
|
};
|
||||||
|
let selection = selection.downcast_into::<PyDict>()?;
|
||||||
|
let selection = selection
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| {
|
||||||
|
let key = key.extract::<String>()?;
|
||||||
|
let value = value.extract::<String>()?;
|
||||||
|
Ok((key, value))
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
Ok(Select::dynamic(&selection))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyPermutationReader {
|
||||||
|
#[classmethod]
|
||||||
|
pub fn from_tables<'py>(
|
||||||
|
cls: &Bound<'py, PyType>,
|
||||||
|
base_table: Bound<'py, PyAny>,
|
||||||
|
permutation_table: Option<Bound<'py, PyAny>>,
|
||||||
|
split: u64,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||||
|
let permutation_table = permutation_table
|
||||||
|
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
||||||
|
let permutation_table = permutation_table
|
||||||
|
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
future_into_py(cls.py(), async move {
|
||||||
|
let reader = if let Some(permutation_table) = permutation_table {
|
||||||
|
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
||||||
|
.await
|
||||||
|
.infer_error()?
|
||||||
|
} else {
|
||||||
|
PermutationReader::identity(base_table).await
|
||||||
|
};
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (selection=None))]
|
||||||
|
pub fn output_schema<'py>(
|
||||||
|
slf: PyRef<'py, Self>,
|
||||||
|
selection: Option<Bound<'py, PyAny>>,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let selection = Self::parse_selection(selection)?;
|
||||||
|
let reader = slf.reader.clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
||||||
|
slf.reader.count_rows()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (offset))]
|
||||||
|
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let reader = slf.reader.as_ref().clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let reader = reader.with_offset(offset).await.infer_error()?;
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (limit))]
|
||||||
|
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let reader = slf.reader.as_ref().clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let reader = reader.with_limit(limit).await.infer_error()?;
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
||||||
|
pub fn read<'py>(
|
||||||
|
slf: PyRef<'py, Self>,
|
||||||
|
selection: Option<Bound<'py, PyAny>>,
|
||||||
|
batch_size: Option<u32>,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let selection = Self::parse_selection(selection)?;
|
||||||
|
let reader = slf.reader.clone();
|
||||||
|
let batch_size = batch_size.unwrap_or(1024);
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
use lancedb::query::QueryExecutionOptions;
|
||||||
|
let mut execution_options = QueryExecutionOptions::default();
|
||||||
|
execution_options.max_batch_length = batch_size;
|
||||||
|
let stream = reader
|
||||||
|
.read(selection, execution_options)
|
||||||
|
.await
|
||||||
|
.infer_error()?;
|
||||||
|
Ok(RecordBatchStream::new(stream))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||||
@@ -25,6 +25,8 @@ use crate::{
|
|||||||
|
|
||||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||||
|
|
||||||
|
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
||||||
|
|
||||||
/// Where to store the permutation table
|
/// Where to store the permutation table
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
enum PermutationDestination {
|
enum PermutationDestination {
|
||||||
@@ -40,6 +42,8 @@ enum PermutationDestination {
|
|||||||
pub struct PermutationConfig {
|
pub struct PermutationConfig {
|
||||||
/// Splitting configuration
|
/// Splitting configuration
|
||||||
split_strategy: SplitStrategy,
|
split_strategy: SplitStrategy,
|
||||||
|
/// Optional names for the splits
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
/// Shuffle strategy
|
/// Shuffle strategy
|
||||||
shuffle_strategy: ShuffleStrategy,
|
shuffle_strategy: ShuffleStrategy,
|
||||||
/// Optional filter to apply to the base table
|
/// Optional filter to apply to the base table
|
||||||
@@ -112,8 +116,16 @@ impl PermutationBuilder {
|
|||||||
/// multiple processes and multiple nodes.
|
/// multiple processes and multiple nodes.
|
||||||
///
|
///
|
||||||
/// The default is a single split that contains all rows.
|
/// The default is a single split that contains all rows.
|
||||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
///
|
||||||
|
/// An optional list of names can be provided for the splits. This is for convenience and the names
|
||||||
|
/// will be stored in the permutation table's config metadata.
|
||||||
|
pub fn with_split_strategy(
|
||||||
|
mut self,
|
||||||
|
split_strategy: SplitStrategy,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> Self {
|
||||||
self.config.split_strategy = split_strategy;
|
self.config.split_strategy = split_strategy;
|
||||||
|
self.config.split_names = split_names;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +205,30 @@ impl PermutationBuilder {
|
|||||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_split_names(
|
||||||
|
data: SendableRecordBatchStream,
|
||||||
|
split_names: &[String],
|
||||||
|
) -> Result<SendableRecordBatchStream> {
|
||||||
|
let schema = data
|
||||||
|
.schema()
|
||||||
|
.as_ref()
|
||||||
|
.clone()
|
||||||
|
.with_metadata(HashMap::from([(
|
||||||
|
SPLIT_NAMES_CONFIG_KEY.to_string(),
|
||||||
|
serde_json::to_string(split_names).map_err(|e| Error::Other {
|
||||||
|
message: format!("Failed to serialize split names: {}", e),
|
||||||
|
source: Some(e.into()),
|
||||||
|
})?,
|
||||||
|
)]));
|
||||||
|
let schema = Arc::new(schema);
|
||||||
|
let schema_clone = schema.clone();
|
||||||
|
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||||
|
Ok(Box::pin(SimpleRecordBatchStream {
|
||||||
|
schema: schema_clone,
|
||||||
|
stream,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Builds the permutation table and stores it in the given database.
|
/// Builds the permutation table and stores it in the given database.
|
||||||
pub async fn build(self) -> Result<Table> {
|
pub async fn build(self) -> Result<Table> {
|
||||||
// First pass, apply filter and load row ids
|
// First pass, apply filter and load row ids
|
||||||
@@ -249,6 +285,12 @@ impl PermutationBuilder {
|
|||||||
// Rename _rowid to row_id
|
// Rename _rowid to row_id
|
||||||
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
||||||
|
|
||||||
|
let streaming_data = if let Some(split_names) = &self.config.split_names {
|
||||||
|
Self::add_split_names(renamed, split_names)?
|
||||||
|
} else {
|
||||||
|
renamed
|
||||||
|
};
|
||||||
|
|
||||||
let (name, database) = match &self.config.destination {
|
let (name, database) = match &self.config.destination {
|
||||||
PermutationDestination::Permanent(database, table_name) => {
|
PermutationDestination::Permanent(database, table_name) => {
|
||||||
(table_name.as_str(), database.clone())
|
(table_name.as_str(), database.clone())
|
||||||
@@ -259,10 +301,13 @@ impl PermutationBuilder {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let create_table_request =
|
let create_table_request = CreateTableRequest::new(
|
||||||
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
|
name.to_string(),
|
||||||
|
CreateTableData::StreamingData(streaming_data),
|
||||||
|
);
|
||||||
|
|
||||||
let table = database.create_table(create_table_request).await?;
|
let table = database.create_table(create_table_request).await?;
|
||||||
|
|
||||||
Ok(Table::new(table, database))
|
Ok(Table::new(table, database))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -296,10 +341,13 @@ mod tests {
|
|||||||
|
|
||||||
let permutation_table = PermutationBuilder::new(data_table.clone())
|
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||||
.with_filter("some_value > 57".to_string())
|
.with_filter("some_value > 57".to_string())
|
||||||
.with_split_strategy(SplitStrategy::Random {
|
.with_split_strategy(
|
||||||
seed: Some(42),
|
SplitStrategy::Random {
|
||||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
seed: Some(42),
|
||||||
})
|
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
.build()
|
.build()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -11,14 +11,19 @@ use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
|
|||||||
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
||||||
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
|
use crate::query::{
|
||||||
use crate::table::{AnyQuery, BaseTable};
|
ExecutableQuery, QueryBase, QueryExecutionOptions, QueryFilter, QueryRequest, Select,
|
||||||
use crate::Result;
|
};
|
||||||
|
use crate::table::{AnyQuery, BaseTable, Filter};
|
||||||
|
use crate::{Result, Table};
|
||||||
use arrow::array::AsArray;
|
use arrow::array::AsArray;
|
||||||
|
use arrow::compute::concat_batches;
|
||||||
use arrow::datatypes::UInt64Type;
|
use arrow::datatypes::UInt64Type;
|
||||||
use arrow_array::{RecordBatch, UInt64Array};
|
use arrow_array::{RecordBatch, UInt64Array};
|
||||||
|
use arrow_schema::SchemaRef;
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||||
|
use lance::io::RecordBatchStream;
|
||||||
use lance_arrow::RecordBatchExt;
|
use lance_arrow::RecordBatchExt;
|
||||||
use lance_core::error::LanceOptionExt;
|
use lance_core::error::LanceOptionExt;
|
||||||
use lance_core::ROW_ID;
|
use lance_core::ROW_ID;
|
||||||
@@ -26,43 +31,140 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct PermutationReader {
|
pub struct PermutationReader {
|
||||||
base_table: Arc<dyn BaseTable>,
|
base_table: Arc<dyn BaseTable>,
|
||||||
permutation_table: Arc<dyn BaseTable>,
|
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||||
|
offset: Option<u64>,
|
||||||
|
limit: Option<u64>,
|
||||||
|
available_rows: u64,
|
||||||
|
split: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for PermutationReader {
|
impl std::fmt::Debug for PermutationReader {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"PermutationReader(base={}, permutation={})",
|
"PermutationReader(base={}, permutation={}, split={}, offset={:?}, limit={:?})",
|
||||||
self.base_table.name(),
|
self.base_table.name(),
|
||||||
self.permutation_table.name(),
|
self.permutation_table
|
||||||
|
.as_ref()
|
||||||
|
.map(|t| t.name())
|
||||||
|
.unwrap_or("--"),
|
||||||
|
self.split,
|
||||||
|
self.offset,
|
||||||
|
self.limit,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PermutationReader {
|
impl PermutationReader {
|
||||||
/// Create a new PermutationReader
|
/// Create a new PermutationReader
|
||||||
pub async fn try_new(
|
pub async fn inner_new(
|
||||||
base_table: Arc<dyn BaseTable>,
|
base_table: Arc<dyn BaseTable>,
|
||||||
permutation_table: Arc<dyn BaseTable>,
|
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||||
|
split: u64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let schema = permutation_table.schema().await?;
|
let mut slf = Self {
|
||||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
|
||||||
return Err(Error::InvalidInput {
|
|
||||||
message: "Permutation table must contain a column named row_id".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
|
||||||
return Err(Error::InvalidInput {
|
|
||||||
message: "Permutation table must contain a column named split_id".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
base_table,
|
base_table,
|
||||||
permutation_table,
|
permutation_table,
|
||||||
})
|
offset: None,
|
||||||
|
limit: None,
|
||||||
|
available_rows: 0,
|
||||||
|
split,
|
||||||
|
};
|
||||||
|
slf.validate().await?;
|
||||||
|
// Calculate the number of available rows
|
||||||
|
slf.available_rows = slf.verify_limit_offset(None, None).await?;
|
||||||
|
if slf.available_rows == 0 {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: "No rows found in the permutation table for the given split".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(slf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn try_from_tables(
|
||||||
|
base_table: Arc<dyn BaseTable>,
|
||||||
|
permutation_table: Arc<dyn BaseTable>,
|
||||||
|
split: u64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::inner_new(base_table, Some(permutation_table), split).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn identity(base_table: Arc<dyn BaseTable>) -> Self {
|
||||||
|
Self::inner_new(base_table, None, 0).await.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates the limit and offset and returns the number of rows that will be read
|
||||||
|
fn validate_limit_offset(
|
||||||
|
limit: Option<u64>,
|
||||||
|
offset: Option<u64>,
|
||||||
|
available_rows: u64,
|
||||||
|
) -> Result<u64> {
|
||||||
|
match (limit, offset) {
|
||||||
|
(Some(limit), Some(offset)) => {
|
||||||
|
if offset + limit > available_rows {
|
||||||
|
Err(Error::InvalidInput {
|
||||||
|
message: "Offset + limit is greater than the number of rows in the permutation table"
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some(offset)) => {
|
||||||
|
if offset > available_rows {
|
||||||
|
Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"Offset is greater than the number of rows in the permutation table"
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(available_rows - offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Some(limit), None) => {
|
||||||
|
if limit > available_rows {
|
||||||
|
Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"Limit is greater than the number of rows in the permutation table"
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, None) => Ok(available_rows),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn verify_limit_offset(&self, limit: Option<u64>, offset: Option<u64>) -> Result<u64> {
|
||||||
|
let available_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||||
|
permutation_table
|
||||||
|
.count_rows(Some(Filter::Sql(format!(
|
||||||
|
"{} = {}",
|
||||||
|
SPLIT_ID_COLUMN, self.split
|
||||||
|
))))
|
||||||
|
.await? as u64
|
||||||
|
} else {
|
||||||
|
self.base_table.count_rows(None).await? as u64
|
||||||
|
};
|
||||||
|
Self::validate_limit_offset(limit, offset, available_rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn with_offset(mut self, offset: u64) -> Result<Self> {
|
||||||
|
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
|
||||||
|
self.offset = Some(offset);
|
||||||
|
self.available_rows = available_rows;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn with_limit(mut self, limit: u64) -> Result<Self> {
|
||||||
|
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
|
||||||
|
self.available_rows = available_rows;
|
||||||
|
self.limit = Some(limit);
|
||||||
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
||||||
@@ -103,7 +205,7 @@ impl PermutationReader {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut data = base_table
|
let data = base_table
|
||||||
.query(
|
.query(
|
||||||
&AnyQuery::Query(base_query),
|
&AnyQuery::Query(base_query),
|
||||||
QueryExecutionOptions {
|
QueryExecutionOptions {
|
||||||
@@ -112,25 +214,29 @@ impl PermutationReader {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
let schema = data.schema();
|
||||||
|
|
||||||
let Some(batch) = data.try_next().await? else {
|
let batches = data.try_collect::<Vec<_>>().await?;
|
||||||
|
|
||||||
|
if batches.is_empty() {
|
||||||
return Err(Error::InvalidInput {
|
return Err(Error::InvalidInput {
|
||||||
message: "Base table returned no batches".to_string(),
|
message: "Base table returned no batches".to_string(),
|
||||||
});
|
});
|
||||||
};
|
|
||||||
if data.try_next().await?.is_some() {
|
|
||||||
return Err(Error::InvalidInput {
|
|
||||||
message: "Base table returned more than one batch".to_string(),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if batch.num_rows() != num_rows {
|
if batches.iter().map(|b| b.num_rows()).sum::<usize>() != num_rows {
|
||||||
return Err(Error::InvalidInput {
|
return Err(Error::InvalidInput {
|
||||||
message: "Base table returned different number of rows than the number of row IDs"
|
message: "Base table returned different number of rows than the number of row IDs"
|
||||||
.to_string(),
|
.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let batch = if batches.len() == 1 {
|
||||||
|
batches.into_iter().next().unwrap()
|
||||||
|
} else {
|
||||||
|
concat_batches(&schema, &batches)?
|
||||||
|
};
|
||||||
|
|
||||||
// There is no guarantee the result order will match the order provided
|
// There is no guarantee the result order will match the order provided
|
||||||
// so may need to restore order
|
// so may need to restore order
|
||||||
let actual_row_ids = batch
|
let actual_row_ids = batch
|
||||||
@@ -230,26 +336,75 @@ impl PermutationReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn read_split(
|
async fn validate(&self) -> Result<()> {
|
||||||
|
if let Some(permutation_table) = &self.permutation_table {
|
||||||
|
let schema = permutation_table.schema().await?;
|
||||||
|
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: "Permutation table must contain a column named row_id".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: "Permutation table must contain a column named split_id".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let avail_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||||
|
permutation_table.count_rows(None).await? as u64
|
||||||
|
} else {
|
||||||
|
self.base_table.count_rows(None).await? as u64
|
||||||
|
};
|
||||||
|
Self::validate_limit_offset(self.limit, self.offset, avail_rows)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn read(
|
||||||
&self,
|
&self,
|
||||||
split: u64,
|
|
||||||
selection: Select,
|
selection: Select,
|
||||||
execution_options: QueryExecutionOptions,
|
execution_options: QueryExecutionOptions,
|
||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
let row_ids = self
|
// Note: this relies on the row ids query here being returned in consistent order
|
||||||
.permutation_table
|
let row_ids = if let Some(permutation_table) = &self.permutation_table {
|
||||||
.query(
|
permutation_table
|
||||||
&AnyQuery::Query(QueryRequest {
|
.query(
|
||||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
&AnyQuery::Query(QueryRequest {
|
||||||
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
|
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||||
..Default::default()
|
filter: Some(QueryFilter::Sql(format!(
|
||||||
}),
|
"{} = {}",
|
||||||
execution_options,
|
SPLIT_ID_COLUMN, self.split
|
||||||
)
|
))),
|
||||||
.await?;
|
offset: self.offset.map(|o| o as usize),
|
||||||
|
limit: self.limit.map(|l| l as usize),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
execution_options,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
self.base_table
|
||||||
|
.query(
|
||||||
|
&AnyQuery::Query(QueryRequest {
|
||||||
|
select: Select::Columns(vec![ROW_ID.to_string()]),
|
||||||
|
offset: self.offset.map(|o| o as usize),
|
||||||
|
limit: self.limit.map(|l| l as usize),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
execution_options,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
|
||||||
|
let table = Table::from(self.base_table.clone());
|
||||||
|
table.query().select(selection).output_schema().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_rows(&self) -> u64 {
|
||||||
|
self.available_rows
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -321,17 +476,17 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||||
|
|
||||||
let reader = PermutationReader::try_new(
|
let reader = PermutationReader::try_from_tables(
|
||||||
base_table.base_table().clone(),
|
base_table.base_table().clone(),
|
||||||
row_ids_table.base_table().clone(),
|
row_ids_table.base_table().clone(),
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Read split 0
|
// Read split 0
|
||||||
let mut stream = reader
|
let mut stream = reader
|
||||||
.read_split(
|
.read(
|
||||||
0,
|
|
||||||
Select::All,
|
Select::All,
|
||||||
QueryExecutionOptions {
|
QueryExecutionOptions {
|
||||||
max_batch_length: 3,
|
max_batch_length: 3,
|
||||||
@@ -366,9 +521,16 @@ mod tests {
|
|||||||
assert!(stream.try_next().await.unwrap().is_none());
|
assert!(stream.try_next().await.unwrap().is_none());
|
||||||
|
|
||||||
// Read split 1
|
// Read split 1
|
||||||
|
let reader = PermutationReader::try_from_tables(
|
||||||
|
base_table.base_table().clone(),
|
||||||
|
row_ids_table.base_table().clone(),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut stream = reader
|
let mut stream = reader
|
||||||
.read_split(
|
.read(
|
||||||
1,
|
|
||||||
Select::All,
|
Select::All,
|
||||||
QueryExecutionOptions {
|
QueryExecutionOptions {
|
||||||
max_batch_length: 3,
|
max_batch_length: 3,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ pub(crate) const DEFAULT_TOP_K: usize = 10;
|
|||||||
/// Which columns should be retrieved from the database
|
/// Which columns should be retrieved from the database
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Select {
|
pub enum Select {
|
||||||
/// Select all columns
|
/// Select all non-system columns
|
||||||
///
|
///
|
||||||
/// Warning: This will always be slower than selecting only the columns you need.
|
/// Warning: This will always be slower than selecting only the columns you need.
|
||||||
All,
|
All,
|
||||||
|
|||||||
@@ -620,7 +620,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Table {
|
pub struct Table {
|
||||||
inner: Arc<dyn BaseTable>,
|
inner: Arc<dyn BaseTable>,
|
||||||
database: Arc<dyn Database>,
|
database: Option<Arc<dyn Database>>,
|
||||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -644,7 +644,7 @@ mod test_utils {
|
|||||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
database,
|
database: Some(database),
|
||||||
// Registry is unused.
|
// Registry is unused.
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
@@ -666,7 +666,7 @@ mod test_utils {
|
|||||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
database,
|
database: Some(database),
|
||||||
// Registry is unused.
|
// Registry is unused.
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
@@ -680,11 +680,21 @@ impl std::fmt::Display for Table {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Arc<dyn BaseTable>> for Table {
|
||||||
|
fn from(inner: Arc<dyn BaseTable>) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
database: None,
|
||||||
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Table {
|
impl Table {
|
||||||
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
|
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
database,
|
database: Some(database),
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -694,7 +704,7 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn database(&self) -> &Arc<dyn Database> {
|
pub fn database(&self) -> &Arc<dyn Database> {
|
||||||
&self.database
|
self.database.as_ref().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
|
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
|
||||||
@@ -708,7 +718,7 @@ impl Table {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inner,
|
inner,
|
||||||
database,
|
database: Some(database),
|
||||||
embedding_registry,
|
embedding_registry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user